From de2f5e3e6d66ddccb44c4c41ab04260acce6fb2f Mon Sep 17 00:00:00 2001 From: marcoyang Date: Wed, 2 Nov 2022 16:15:56 +0800 Subject: [PATCH 001/120] support RNNLM shallow fusion for LSTM transducer --- .../ASR/lstm_transducer_stateless2/decode.py | 143 ++++- .../beam_search.py | 514 +++++++++--------- icefall/rnn_lm/model.py | 126 ++++- 3 files changed, 503 insertions(+), 280 deletions(-) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index c7b53ebc0..1d46c0177 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -115,7 +115,8 @@ from beam_search import ( greedy_search, greedy_search_batch, modified_beam_search, - modified_beam_search_ngram_rescoring, + modified_beam_search_rnnlm_shallow_fusion, + ) from librispeech import LibriSpeech from train import add_model_arguments, get_params, get_transducer_model @@ -128,6 +129,7 @@ from icefall.checkpoint import ( load_checkpoint, ) from icefall.lexicon import Lexicon +from icefall.rnn_lm.model import RnnLmModel from icefall.utils import ( AttributeDict, setup_logger, @@ -216,7 +218,7 @@ def get_parser(): - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG - - modified_beam_search_ngram_rescoring + - modified-beam-search_rnnlm_shallow_fusion # for rnn lm shallow fusion If you use fast_beam_search_nbest_LG, you have to specify `--lang-dir`, which should contain `LG.pt`. """, @@ -307,21 +309,74 @@ def get_parser(): ) parser.add_argument( - "--tokens-ngram", - type=int, - default=3, - help="""Token Ngram used for rescoring. - Used only when the decoding method is modified_beam_search_ngram_rescoring""", + "--rnn-lm-scale", + type=float, + default=0.0, + help="""Used only when --method is modified_beam_search3. + It specifies the path to RNN LM exp dir. + """, ) parser.add_argument( - "--backoff-id", - type=int, - default=500, - help="""ID of the backoff symbol. - Used only when the decoding method is modified_beam_search_ngram_rescoring""", + "--rnn-lm-exp-dir", + type=str, + default="rnn_lm/exp", + help="""Used only when --method is rnn-lm. + It specifies the path to RNN LM exp dir. + """, ) + parser.add_argument( + "--rnn-lm-epoch", + type=int, + default=7, + help="""Used only when --method is rnn-lm. + It specifies the checkpoint to use. + """, + ) + + parser.add_argument( + "--rnn-lm-avg", + type=int, + default=2, + help="""Used only when --method is rnn-lm. + It specifies the number of checkpoints to average. + """, + ) + + parser.add_argument( + "--rnn-lm-embedding-dim", + type=int, + default=2048, + help="Embedding dim of the model", + ) + + parser.add_argument( + "--rnn-lm-hidden-dim", + type=int, + default=2048, + help="Hidden dim of the model", + ) + + parser.add_argument( + "--rnn-lm-num-layers", + type=int, + default=4, + help="Number of RNN layers the model", + ) + parser.add_argument( + "--rnn-lm-tie-weights", + type=str2bool, + default=False, + help="""True to share the weights between the input embedding layer and the + last output linear layer + """, + ) + parser.add_argument( + "--ilm-scale", + type=float, + default=-0.1 + ) add_model_arguments(parser) return parser @@ -336,6 +391,8 @@ def decode_one_batch( decoding_graph: Optional[k2.Fsa] = None, ngram_lm: Optional[NgramLm] = None, ngram_lm_scale: float = 1.0, + rnnlm: Optional[RnnLmModel] = None, + rnnlm_scale: float = 1.0, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -469,14 +526,14 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_ngram_rescoring": - hyp_tokens = modified_beam_search_ngram_rescoring( + elif params.decoding_method == "modified_beam_search_sf_rnnlm": + hyp_tokens = modified_beam_search_sf_rnnlm_batched( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, - ngram_lm=ngram_lm, - ngram_lm_scale=ngram_lm_scale, - beam=params.beam_size, + sp=sp, + rnnlm=rnnlm, + rnnlm_scale=rnnlm_scale, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -531,7 +588,9 @@ def decode_dataset( decoding_graph: Optional[k2.Fsa] = None, ngram_lm: Optional[NgramLm] = None, ngram_lm_scale: float = 1.0, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + rnnlm: Optional[NgramLm] = None, + rnnlm_scale: float = 1.0, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. Args: @@ -572,6 +631,9 @@ def decode_dataset( for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + total_duration = sum([cut.duration for cut in batch["supervisions"]["cut"]]) + + logging.info(f"Decoding {batch_idx}-th batch, batch size is {len(cut_ids)}, total duration is {total_duration}") hyps_dict = decode_one_batch( params=params, @@ -582,6 +644,8 @@ def decode_dataset( batch=batch, ngram_lm=ngram_lm, ngram_lm_scale=ngram_lm_scale, + rnnlm=rnnlm, + rnnlm_scale=rnnlm_scale, ) for name, hyps in hyps_dict.items(): @@ -607,7 +671,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], ): test_set_wers = dict() for key, results in results_dict.items(): @@ -667,7 +731,7 @@ def main(): "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", - "modified_beam_search_ngram_rescoring", + "modified_beam_search_sf_rnnlm", ) params.res_dir = params.exp_dir / params.decoding_method @@ -692,7 +756,12 @@ def main(): else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + + if "rnnlm" in params.decoding_method: + params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}" + + if "ILME" in params.decoding_method: + params.suffix += f"-ILME-scale={params.ilm_scale}" if params.use_averaged_model: params.suffix += "-use-averaged-model" @@ -806,14 +875,28 @@ def main(): model.to(device) model.eval() - lm_filename = f"{params.tokens_ngram}gram.fst.txt" - logging.info(f"lm filename: {lm_filename}") - ngram_lm = NgramLm( - str(params.lang_dir / lm_filename), - backoff_id=params.backoff_id, - is_binary=False, - ) - logging.info(f"num states: {ngram_lm.lm.num_states}") + # only load rnnlm if used + if "rnnlm" in params.decoding_method: + rnn_lm_scale = params.rnn_lm_scale + + rnn_lm_model = RnnLmModel( + vocab_size=params.vocab_size, + embedding_dim=params.rnn_lm_embedding_dim, + hidden_dim=params.rnn_lm_hidden_dim, + num_layers=params.rnn_lm_num_layers, + tie_weights=params.rnn_lm_tie_weights, + ) + assert params.rnn_lm_avg == 1 + + load_checkpoint( + f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt", + rnn_lm_model, + ) + rnn_lm_model.to(device) + rnn_lm_model.eval() + + else: + rnn_lm_model = None if "fast_beam_search" in params.decoding_method: if params.decoding_method == "fast_beam_search_nbest_LG": @@ -861,6 +944,8 @@ def main(): decoding_graph=decoding_graph, ngram_lm=ngram_lm, ngram_lm_scale=params.ngram_lm_scale, + rnnlm=rnn_lm_model, + rnnlm_scale=rnn_lm_scale, ) save_results( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 0004a24eb..01cc566e8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -16,7 +16,7 @@ import warnings from dataclasses import dataclass -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional import k2 import sentencepiece as spm @@ -25,13 +25,8 @@ from model import Transducer from icefall import NgramLm, NgramLmStateCost from icefall.decode import Nbest, one_best_decoding -from icefall.utils import ( - DecodingResults, - add_eos, - add_sos, - get_texts, - get_texts_with_timestamp, -) +from icefall.rnn_lm.model import RnnLmModel +from icefall.utils import add_eos, add_sos, get_texts def fast_beam_search_one_best( @@ -43,8 +38,7 @@ def fast_beam_search_one_best( max_states: int, max_contexts: int, temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: +) -> List[List[int]]: """It limits the maximum number of symbols per frame to 1. A lattice is first obtained using fast beam search, and then @@ -68,12 +62,8 @@ def fast_beam_search_one_best( Max contexts pre stream per frame. temperature: Softmax temperature. - return_timestamps: - Whether to return timestamps. Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. + Return the decoded result. """ lattice = fast_beam_search( model=model, @@ -87,11 +77,8 @@ def fast_beam_search_one_best( ) best_path = one_best_decoding(lattice) - - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) + hyps = get_texts(best_path) + return hyps def fast_beam_search_nbest_LG( @@ -106,8 +93,7 @@ def fast_beam_search_nbest_LG( nbest_scale: float = 0.5, use_double_scores: bool = True, temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: +) -> List[List[int]]: """It limits the maximum number of symbols per frame to 1. The process to get the results is: @@ -144,12 +130,8 @@ def fast_beam_search_nbest_LG( single precision. temperature: Softmax temperature. - return_timestamps: - Whether to return timestamps. Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. + Return the decoded result. """ lattice = fast_beam_search( model=model, @@ -214,10 +196,9 @@ def fast_beam_search_nbest_LG( best_hyp_indexes = ragged_tot_scores.argmax() best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes) - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) + hyps = get_texts(best_path) + + return hyps def fast_beam_search_nbest( @@ -232,8 +213,7 @@ def fast_beam_search_nbest( nbest_scale: float = 0.5, use_double_scores: bool = True, temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: +) -> List[List[int]]: """It limits the maximum number of symbols per frame to 1. The process to get the results is: @@ -270,12 +250,8 @@ def fast_beam_search_nbest( single precision. temperature: Softmax temperature. - return_timestamps: - Whether to return timestamps. Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. + Return the decoded result. """ lattice = fast_beam_search( model=model, @@ -304,10 +280,9 @@ def fast_beam_search_nbest( best_path = k2.index_fsa(nbest.fsa, max_indexes) - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) + hyps = get_texts(best_path) + + return hyps def fast_beam_search_nbest_oracle( @@ -323,8 +298,7 @@ def fast_beam_search_nbest_oracle( use_double_scores: bool = True, nbest_scale: float = 0.5, temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: +) -> List[List[int]]: """It limits the maximum number of symbols per frame to 1. A lattice is first obtained using fast beam search, and then @@ -365,12 +339,8 @@ def fast_beam_search_nbest_oracle( yields more unique paths. temperature: Softmax temperature. - return_timestamps: - Whether to return timestamps. Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. + Return the decoded result. """ lattice = fast_beam_search( model=model, @@ -409,10 +379,8 @@ def fast_beam_search_nbest_oracle( best_path = k2.index_fsa(nbest.fsa, max_indexes) - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) + hyps = get_texts(best_path) + return hyps def fast_beam_search( @@ -502,11 +470,8 @@ def fast_beam_search( def greedy_search( - model: Transducer, - encoder_out: torch.Tensor, - max_sym_per_frame: int, - return_timestamps: bool = False, -) -> Union[List[int], DecodingResults]: + model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int +) -> List[int]: """Greedy search for a single utterance. Args: model: @@ -516,12 +481,8 @@ def greedy_search( max_sym_per_frame: Maximum number of symbols per frame. If it is set to 0, the WER would be 100%. - return_timestamps: - Whether to return timestamps. Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. + Return the decoded result. """ assert encoder_out.ndim == 3 @@ -547,10 +508,6 @@ def greedy_search( t = 0 hyp = [blank_id] * context_size - # timestamp[i] is the frame index after subsampling - # on which hyp[i] is decoded - timestamp = [] - # Maximum symbols per utterance. max_sym_per_utt = 1000 @@ -577,7 +534,6 @@ def greedy_search( y = logits.argmax().item() if y not in (blank_id, unk_id): hyp.append(y) - timestamp.append(t) decoder_input = torch.tensor( [hyp[-context_size:]], device=device ).reshape(1, context_size) @@ -592,21 +548,14 @@ def greedy_search( t += 1 hyp = hyp[context_size:] # remove blanks - if not return_timestamps: - return hyp - else: - return DecodingResults( - tokens=[hyp], - timestamps=[timestamp], - ) + return hyp def greedy_search_batch( model: Transducer, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: +) -> List[List[int]]: """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. Args: model: @@ -616,12 +565,9 @@ def greedy_search_batch( encoder_out_lens: A 1-D tensor of shape (N,), containing number of valid frames in encoder_out before padding. - return_timestamps: - Whether to return timestamps. Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. + Return a list-of-list of token IDs containing the decoded results. + len(ans) equals to encoder_out.size(0). """ assert encoder_out.ndim == 3 assert encoder_out.size(0) >= 1, encoder_out.size(0) @@ -646,10 +592,6 @@ def greedy_search_batch( hyps = [[blank_id] * context_size for _ in range(N)] - # timestamp[n][i] is the frame index after subsampling - # on which hyp[n][i] is decoded - timestamps = [[] for _ in range(N)] - decoder_input = torch.tensor( hyps, device=device, @@ -663,7 +605,7 @@ def greedy_search_batch( encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) offset = 0 - for (t, batch_size) in enumerate(batch_size_list): + for batch_size in batch_size_list: start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] @@ -685,7 +627,6 @@ def greedy_search_batch( for i, v in enumerate(y): if v not in (blank_id, unk_id): hyps[i].append(v) - timestamps[i].append(t) emitted = True if emitted: # update decoder output @@ -700,19 +641,11 @@ def greedy_search_batch( sorted_ans = [h[context_size:] for h in hyps] ans = [] - ans_timestamps = [] unsorted_indices = packed_encoder_out.unsorted_indices.tolist() for i in range(N): ans.append(sorted_ans[unsorted_indices[i]]) - ans_timestamps.append(timestamps[unsorted_indices[i]]) - if not return_timestamps: - return ans - else: - return DecodingResults( - tokens=ans, - timestamps=ans_timestamps, - ) + return ans @dataclass @@ -725,11 +658,9 @@ class Hypothesis: # It contains only one entry. log_prob: torch.Tensor - # timestamp[i] is the frame index after subsampling - # on which ys[i] is decoded - timestamp: List[int] - state_cost: Optional[NgramLmStateCost] = None + state: Optional = None + lm_score: Optional=None @property def key(self) -> str: @@ -878,8 +809,7 @@ def modified_beam_search( encoder_out_lens: torch.Tensor, beam: int = 4, temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: +) -> List[List[int]]: """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. Args: @@ -894,12 +824,9 @@ def modified_beam_search( Number of active paths during the beam search. temperature: Softmax temperature. - return_timestamps: - Whether to return timestamps. Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. + Return a list-of-list of token IDs. ans[i] is the decoding results + for the i-th utterance. """ assert encoder_out.ndim == 3, encoder_out.shape assert encoder_out.size(0) >= 1, encoder_out.size(0) @@ -917,7 +844,7 @@ def modified_beam_search( device = next(model.parameters()).device batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) + N = encoder_out.size(0) assert torch.all(encoder_out_lens > 0), encoder_out_lens assert N == batch_size_list[0], (N, batch_size_list) @@ -927,7 +854,6 @@ def modified_beam_search( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), - timestamp=[], ) ) @@ -935,7 +861,7 @@ def modified_beam_search( offset = 0 finalized_B = [] - for (t, batch_size) in enumerate(batch_size_list): + for batch_size in batch_size_list: start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] @@ -1013,44 +939,30 @@ def modified_beam_search( new_ys = hyp.ys[:] new_token = topk_token_indexes[k] - new_timestamp = hyp.timestamp[:] if new_token not in (blank_id, unk_id): new_ys.append(new_token) - new_timestamp.append(t) new_log_prob = topk_log_probs[k] - new_hyp = Hypothesis( - ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp - ) + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) B[i].add(new_hyp) B = B + finalized_B best_hyps = [b.get_most_probable(length_norm=True) for b in B] sorted_ans = [h.ys[context_size:] for h in best_hyps] - sorted_timestamps = [h.timestamp for h in best_hyps] ans = [] - ans_timestamps = [] unsorted_indices = packed_encoder_out.unsorted_indices.tolist() for i in range(N): ans.append(sorted_ans[unsorted_indices[i]]) - ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) - if not return_timestamps: - return ans - else: - return DecodingResults( - tokens=ans, - timestamps=ans_timestamps, - ) + return ans def _deprecated_modified_beam_search( model: Transducer, encoder_out: torch.Tensor, beam: int = 4, - return_timestamps: bool = False, -) -> Union[List[int], DecodingResults]: +) -> List[int]: """It limits the maximum number of symbols per frame to 1. It decodes only one utterance at a time. We keep it only for reference. @@ -1065,13 +977,8 @@ def _deprecated_modified_beam_search( A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. beam: Beam size. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. + Return the decoded result. """ assert encoder_out.ndim == 3 @@ -1091,7 +998,6 @@ def _deprecated_modified_beam_search( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), - timestamp=[], ) ) encoder_out = model.joiner.encoder_proj(encoder_out) @@ -1150,24 +1056,17 @@ def _deprecated_modified_beam_search( for i in range(len(topk_hyp_indexes)): hyp = A[topk_hyp_indexes[i]] new_ys = hyp.ys[:] - new_timestamp = hyp.timestamp[:] new_token = topk_token_indexes[i] if new_token not in (blank_id, unk_id): new_ys.append(new_token) - new_timestamp.append(t) new_log_prob = topk_log_probs[i] - new_hyp = Hypothesis( - ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp - ) + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) B.add(new_hyp) best_hyp = B.get_most_probable(length_norm=True) ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - if not return_timestamps: - return ys - else: - return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp]) + return ys def beam_search( @@ -1175,8 +1074,7 @@ def beam_search( encoder_out: torch.Tensor, beam: int = 4, temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[int], DecodingResults]: +) -> List[int]: """ It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf @@ -1191,13 +1089,8 @@ def beam_search( Beam size. temperature: Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. + Return the decoded result. """ assert encoder_out.ndim == 3 @@ -1224,7 +1117,7 @@ def beam_search( t = 0 B = HypothesisList() - B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0, timestamp=[])) + B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0)) max_sym_per_utt = 20000 @@ -1285,13 +1178,7 @@ def beam_search( new_y_star_log_prob = y_star.log_prob + skip_log_prob # ys[:] returns a copy of ys - B.add( - Hypothesis( - ys=y_star.ys[:], - log_prob=new_y_star_log_prob, - timestamp=y_star.timestamp[:], - ) - ) + B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) # Second, process other non-blank labels values, indices = log_prob.topk(beam + 1) @@ -1300,14 +1187,7 @@ def beam_search( continue new_ys = y_star.ys + [i] new_log_prob = y_star.log_prob + v - new_timestamp = y_star.timestamp + [t] - A.add( - Hypothesis( - ys=new_ys, - log_prob=new_log_prob, - timestamp=new_timestamp, - ) - ) + A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob)) # Check whether B contains more than "beam" elements more probable # than the most probable in A @@ -1323,11 +1203,7 @@ def beam_search( best_hyp = B.get_most_probable(length_norm=True) ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - - if not return_timestamps: - return ys - else: - return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp]) + return ys def fast_beam_search_with_nbest_rescoring( @@ -1347,8 +1223,7 @@ def fast_beam_search_with_nbest_rescoring( use_double_scores: bool = True, nbest_scale: float = 0.5, temperature: float = 1.0, - return_timestamps: bool = False, -) -> Dict[str, Union[List[List[int]], DecodingResults]]: +) -> Dict[str, List[List[int]]]: """It limits the maximum number of symbols per frame to 1. A lattice is first obtained using fast beam search, num_path are selected and rescored using a given language model. The shortest path within the @@ -1390,13 +1265,10 @@ def fast_beam_search_with_nbest_rescoring( yields more unique paths. temperature: Softmax temperature. - return_timestamps: - Whether to return timestamps. Returns: Return the decoded result in a dict, where the key has the form - 'ngram_lm_scale_xx' and the value is the decoded results - optionally with timestamps. `xx` is the ngram LM scale value - used during decoding, i.e., 0.1. + 'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the + ngram LM scale value used during decoding, i.e., 0.1. """ lattice = fast_beam_search( model=model, @@ -1474,18 +1346,16 @@ def fast_beam_search_with_nbest_rescoring( log_semiring=False, ) - ans: Dict[str, Union[List[List[int]], DecodingResults]] = {} + ans: Dict[str, List[List[int]]] = {} for s in ngram_lm_scale_list: key = f"ngram_lm_scale_{s}" tot_scores = am_scores.values + s * ngram_lm_scores ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) max_indexes = ragged_tot_scores.argmax() best_path = k2.index_fsa(nbest.fsa, max_indexes) + hyps = get_texts(best_path) - if not return_timestamps: - ans[key] = get_texts(best_path) - else: - ans[key] = get_texts_with_timestamp(best_path) + ans[key] = hyps return ans @@ -1509,8 +1379,7 @@ def fast_beam_search_with_nbest_rnn_rescoring( use_double_scores: bool = True, nbest_scale: float = 0.5, temperature: float = 1.0, - return_timestamps: bool = False, -) -> Dict[str, Union[List[List[int]], DecodingResults]]: +) -> Dict[str, List[List[int]]]: """It limits the maximum number of symbols per frame to 1. A lattice is first obtained using fast beam search, num_path are selected and rescored using a given language model and a rnn-lm. @@ -1556,13 +1425,10 @@ def fast_beam_search_with_nbest_rnn_rescoring( yields more unique paths. temperature: Softmax temperature. - return_timestamps: - Whether to return timestamps. Returns: Return the decoded result in a dict, where the key has the form - 'ngram_lm_scale_xx' and the value is the decoded results - optionally with timestamps. `xx` is the ngram LM scale value - used during decoding, i.e., 0.1. + 'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the + ngram LM scale value used during decoding, i.e., 0.1. """ lattice = fast_beam_search( model=model, @@ -1674,45 +1540,151 @@ def fast_beam_search_with_nbest_rnn_rescoring( ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) max_indexes = ragged_tot_scores.argmax() best_path = k2.index_fsa(nbest.fsa, max_indexes) + hyps = get_texts(best_path) - if not return_timestamps: - ans[key] = get_texts(best_path) - else: - ans[key] = get_texts_with_timestamp(best_path) + ans[key] = hyps return ans +def modified_beam_search_sf_rnnlm( + model: Transducer, + encoder_out: torch.Tensor, + sp, + rnnlm: RnnLmModel, + rnnlm_scale: float, + beam: int = 4, +): + encoder_out = model.joiner.encoder_proj(encoder_out) + lm_scale = rnnlm_scale -def modified_beam_search_ngram_rescoring( + assert rnnlm is not None + assert encoder_out.ndim == 2, encoder_out.shape + rnnlm.clean_cache() + blank_id = model.decoder.blank_id + sos_id = sp.piece_to_id("") + eos_id = sp.piece_to_id("") + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + B = HypothesisList() + B.add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + + T = encoder_out.shape[0] + for t in range(T): + current_encoder_out = encoder_out[t : t + 1] + A = list(B) + B = HypothesisList() + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyp in A] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyp in A], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + decoder_out = model.decoder(decoder_input, need_pad=False).squeeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + + # decoder_out is of shape (num_hyps, joiner_dim) + current_encoder_out = current_encoder_out.repeat(len(A), 1) + # current_encoder_out is of shape (num_hyps, encoder_out_dim) + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, vocab_size) + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + log_probs = log_probs.reshape(-1) + topk_log_probs, topk_indexes = log_probs.topk( + beam + ) # get topk tokens and scores + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[hyp_idx] # get hyp + new_ys = hyp.ys[:] + state = "ys=" + "+".join(list(map(str, new_ys))) + tokens = k2.RaggedTensor([new_ys[context_size:]]) + + lm_score = rnnlm.predict( + tokens, state, sos_id, eos_id, blank_id + ) # get rnnlm score + + hyp_log_prob = topk_log_probs[k] # get score of current hyp + new_token = topk_token_indexes[k] # get token + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + # state_cost = hyp.state_cost.forward_one_step(new_token) + hyp_log_prob += ( + lm_score[new_token] * lm_scale + ) # add the lm score + else: + new_ys = new_ys + new_log_prob = hyp_log_prob + + new_hyp = Hypothesis( + ys=new_ys, + log_prob=new_log_prob, + ) + B.add(new_hyp) + + best_hyp = B.get_most_probable(length_norm=True) + return best_hyp.ys[context_size:] + +def modified_beam_search_rnnlm_shallow_fusion( model: Transducer, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, - ngram_lm: NgramLm, - ngram_lm_scale: float, + sp: spm.SentencePieceProcessor, + rnnlm: RnnLmModel, + rnnlm_scale: float, beam: int = 4, - temperature: float = 1.0, ) -> List[List[int]]: - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + """Modified_beam_search + RNNLM shallow fusion Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C). - encoder_out_lens: - A 1-D tensor of shape (N,), containing number of valid frames in - encoder_out before padding. - beam: - Number of active paths during the beam search. - temperature: - Softmax temperature. + model (Transducer): + The transducer model + encoder_out (torch.Tensor): + Encoder output in (N,T,C) + encoder_out_lens (torch.Tensor): + A 1-D tensor of shape (N,), containing the number of + valid frames in encoder_out before padding. + sp: + Sentence piece generator. + rnnlm (RnnLmModel): + RNNLM + rnnlm_scale (float): + scale of RNNLM in shallow fusion + beam (int, optional): + Beam size. Defaults to 4. + Returns: Return a list-of-list of token IDs. ans[i] is the decoding results for the i-th utterance. """ assert encoder_out.ndim == 3, encoder_out.shape assert encoder_out.size(0) >= 1, encoder_out.size(0) - + assert rnnlm is not None + lm_scale = rnnlm_scale + vocab_size = rnnlm.vocab_size + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( input=encoder_out, lengths=encoder_out_lens.cpu(), @@ -1721,34 +1693,41 @@ def modified_beam_search_ngram_rescoring( ) blank_id = model.decoder.blank_id + sos_id = sp.piece_to_id("") + eos_id = sp.piece_to_id("") unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size device = next(model.parameters()).device - lm_scale = ngram_lm_scale - + batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) + N = encoder_out.size(0) assert torch.all(encoder_out_lens > 0), encoder_out_lens assert N == batch_size_list[0], (N, batch_size_list) + # get initial lm score and lm state by scoring the "sos" token + sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) + init_score, init_states = rnnlm.score_token(sos_token) + B = [HypothesisList() for _ in range(N)] for i in range(N): B[i].add( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), - state_cost=NgramLmStateCost(ngram_lm), + state=init_states, + lm_score=init_score.reshape(-1) ) ) + rnnlm.clean_cache() encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - + offset = 0 finalized_B = [] for batch_size in batch_size_list: start = offset end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] + current_encoder_out = encoder_out.data[start:end] # get batch current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) offset = end @@ -1760,49 +1739,44 @@ def modified_beam_search_ngram_rescoring( A = [list(b) for b in B] B = [HypothesisList() for _ in range(batch_size)] - + ys_log_probs = torch.cat( - [ - hyp.log_prob.reshape(1, 1) + hyp.state_cost.lm_score * lm_scale - for hyps in A - for hyp in hyps - ] - ) # (num_hyps, 1) - + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) + decoder_input = torch.tensor( [hyp.ys[-context_size:] for hyps in A for hyp in hyps], device=device, dtype=torch.int64, ) # (num_hyps, context_size) - + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( current_encoder_out, dim=0, index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, 1, 1, encoder_out_dim) - + logits = model.joiner( current_encoder_out, decoder_out, project_input=False, ) # (num_hyps, 1, 1, vocab_size) - + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - log_probs = (logits / temperature).log_softmax( + log_probs = logits.log_softmax( dim=-1 ) # (num_hyps, vocab_size) log_probs.add_(ys_log_probs) + vocab_size = log_probs.size(-1) - log_probs = log_probs.reshape(-1) + log_probs = log_probs.reshape(-1) + row_splits = hyps_shape.row_splits(1) * vocab_size log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() @@ -1810,7 +1784,12 @@ def modified_beam_search_ngram_rescoring( ragged_log_probs = k2.RaggedTensor( shape=log_probs_shape, value=log_probs ) + + # for all hyps with a non-blank new token, score it + token_list = [] + hs = [] + cs = [] for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1818,28 +1797,63 @@ def modified_beam_search_ngram_rescoring( warnings.simplefilter("ignore") topk_hyp_indexes = (topk_indexes // vocab_size).tolist() topk_token_indexes = (topk_indexes % vocab_size).tolist() + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_token = topk_token_indexes[k] + if new_token not in (blank_id, unk_id): + assert new_token != 0, new_token + token_list.append([new_token]) + hs.append(hyp.state[0]) + cs.append(hyp.state[1]) + # forward RNNLM to get new states and scores + if len(token_list) != 0: + tokens_to_score = torch.tensor(token_list).to(torch.int64).to(device).reshape(-1,1) + + hs = torch.cat(hs, dim=1).to(device) + cs = torch.cat(cs, dim=1).to(device) + scores, lm_states = rnnlm.score_token(tokens_to_score, (hs,cs)) + + count = 0 # index, used to locate score and lm states + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + for k in range(len(topk_hyp_indexes)): hyp_idx = topk_hyp_indexes[k] hyp = A[i][hyp_idx] - new_ys = hyp.ys[:] + ys = hyp.ys[:] + + lm_score = hyp.lm_score + state = hyp.state + + hyp_log_prob = topk_log_probs[k] # get score of current hyp new_token = topk_token_indexes[k] if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - state_cost = hyp.state_cost.forward_one_step(new_token) - else: - state_cost = hyp.state_cost - - # We only keep AM scores in new_hyp.log_prob - new_log_prob = ( - topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale - ) - + + ys.append(new_token) + hyp_log_prob += ( + lm_score[new_token] * lm_scale + ) # add the lm score + + lm_score = scores[count] + state = (lm_states[0][:, count, :].unsqueeze(1), lm_states[1][:, count, :].unsqueeze(1)) + count += 1 + new_hyp = Hypothesis( - ys=new_ys, log_prob=new_log_prob, state_cost=state_cost + ys=ys, + log_prob=hyp_log_prob, + state=state, + lm_score=lm_score ) - B[i].add(new_hyp) + B[i].add(new_hyp) B = B + finalized_B best_hyps = [b.get_most_probable(length_norm=True) for b in B] @@ -1850,4 +1864,4 @@ def modified_beam_search_ngram_rescoring( for i in range(N): ans.append(sorted_ans[unsorted_indices[i]]) - return ans + return ans \ No newline at end of file diff --git a/icefall/rnn_lm/model.py b/icefall/rnn_lm/model.py index 88b2cc41f..2552f65a6 100644 --- a/icefall/rnn_lm/model.py +++ b/icefall/rnn_lm/model.py @@ -18,8 +18,9 @@ import logging import torch import torch.nn.functional as F +import k2 -from icefall.utils import make_pad_mask +from icefall.utils import add_eos, add_sos, make_pad_mask class RnnLmModel(torch.nn.Module): @@ -72,6 +73,8 @@ class RnnLmModel(torch.nn.Module): else: logging.info("Not tying weights") + self.cache = {} + def forward( self, x: torch.Tensor, y: torch.Tensor, lengths: torch.Tensor ) -> torch.Tensor: @@ -118,3 +121,124 @@ class RnnLmModel(torch.nn.Module): nll_loss = nll_loss.reshape(batch_size, -1) return nll_loss + + def get_init_states(self, sos): + p = next(self.parameters()) + + def predict_batch(self, tokens, token_lens, sos_id, eos_id, blank_id): + device = next(self.parameters()).device + batch_size = len(token_lens) + + sos_tokens = add_sos(tokens, sos_id) + tokens_eos = add_eos(tokens, eos_id) + sos_tokens_row_splits = sos_tokens.shape.row_splits(1) + + sentence_lengths = ( + sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1] + ) + + x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id) + y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id) + + x_tokens = x_tokens.to(torch.int64).to(device) + y_tokens = y_tokens.to(torch.int64).to(device) + sentence_lengths = sentence_lengths.to(torch.int64).to(device) + + embedding = self.input_embedding(x_tokens) + + # Note: We use batch_first==True + rnn_out, states = self.rnn(embedding) + logits = self.output_linear(rnn_out) + mask = torch.zeros(logits.shape).bool().to(device) + for i in range(batch_size): + mask[i, token_lens[i], :] = True + logits = logits[mask].reshape(batch_size, -1) + + return logits[:,:].log_softmax(-1), states + + def clean_cache(self): + self.cache = {} + + def score_token(self, tokens: torch.Tensor, state=None): + device = next(self.parameters()).device + batch_size = tokens.size(0) + if state: + h,c = state + else: + h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(device) + c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(device) + + embedding = self.input_embedding(tokens) + rnn_out, states = self.rnn(embedding, (h,c)) + logits = self.output_linear(rnn_out) + + return logits[:,0].log_softmax(-1), states + + def forward_with_state(self, tokens, token_lens, sos_id, eos_id, blank_id, state=None): + batch_size = len(token_lens) + if state: + h,c = state + else: + h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size) + c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size) + + device = next(self.parameters()).device + + sos_tokens = add_sos(tokens, sos_id) + tokens_eos = add_eos(tokens, eos_id) + sos_tokens_row_splits = sos_tokens.shape.row_splits(1) + + sentence_lengths = ( + sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1] + ) + + x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id) + y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id) + + x_tokens = x_tokens.to(torch.int64).to(device) + y_tokens = y_tokens.to(torch.int64).to(device) + sentence_lengths = sentence_lengths.to(torch.int64).to(device) + + embedding = self.input_embedding(x_tokens) + + # Note: We use batch_first==True + rnn_out, states = self.rnn(embedding, (h,c)) + logits = self.output_linear(rnn_out) + + return logits, states + +if __name__=="__main__": + LM = RnnLmModel(500, 2048, 2048, 3, True) + h0 = torch.zeros(3, 1, 2048) + c0 = torch.zeros(3, 1, 2048) + seq = [[0,1,2,3]] + seq_lens = [len(s) for s in seq] + tokens = k2.RaggedTensor(seq) + output1, state = LM.forward_with_state( + tokens, + seq_lens, + 1, + 1, + 0, + state=(h0,c0) + ) + seq = [[0,1,2,3,4]] + seq_lens = [len(s) for s in seq] + tokens = k2.RaggedTensor(seq) + output2, _ = LM.forward_with_state( + tokens, + seq_lens, + 1, + 1, + 0, + state=(h0,c0) + ) + + seq = [[4]] + seq_lens = [len(s) for s in seq] + output3 = LM.score_token(seq, seq_lens, state) + + print("Finished") + + + From 63d0a52dbd703a0c1692b7ac9f4557fcb0e85df8 Mon Sep 17 00:00:00 2001 From: marcoyang Date: Wed, 2 Nov 2022 16:37:29 +0800 Subject: [PATCH 002/120] support RNNLM shallow fusion in stateless5 --- .../beam_search.py | 3 - .../pruned_transducer_stateless5/decode.py | 182 ++++++++++++------ 2 files changed, 124 insertions(+), 61 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 01cc566e8..d569b0752 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -23,7 +23,6 @@ import sentencepiece as spm import torch from model import Transducer -from icefall import NgramLm, NgramLmStateCost from icefall.decode import Nbest, one_best_decoding from icefall.rnn_lm.model import RnnLmModel from icefall.utils import add_eos, add_sos, get_texts @@ -658,8 +657,6 @@ class Hypothesis: # It contains only one entry. log_prob: torch.Tensor - state_cost: Optional[NgramLmStateCost] = None - state: Optional = None lm_score: Optional=None @property diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 632932214..59c646717 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -19,36 +19,36 @@ """ Usage: (1) greedy search -./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ +./lstm_transducer_stateless2/decode.py \ + --epoch 35 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ + --exp-dir ./lstm_transducer_stateless2/exp \ --max-duration 600 \ --decoding-method greedy_search (2) beam search (not recommended) -./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ +./lstm_transducer_stateless2/decode.py \ + --epoch 35 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ + --exp-dir ./lstm_transducer_stateless2/exp \ --max-duration 600 \ --decoding-method beam_search \ --beam-size 4 (3) modified beam search -./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ +./lstm_transducer_stateless2/decode.py \ + --epoch 35 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ + --exp-dir ./lstm_transducer_stateless2/exp \ --max-duration 600 \ --decoding-method modified_beam_search \ --beam-size 4 (4) fast beam search (one best) -./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ +./lstm_transducer_stateless2/decode.py \ + --epoch 35 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ + --exp-dir ./lstm_transducer_stateless2/exp \ --max-duration 600 \ --decoding-method fast_beam_search \ --beam 20.0 \ @@ -56,10 +56,10 @@ Usage: --max-states 64 (5) fast beam search (nbest) -./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ +./lstm_transducer_stateless2/decode.py \ + --epoch 30 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ + --exp-dir ./pruned_transducer_stateless3/exp \ --max-duration 600 \ --decoding-method fast_beam_search_nbest \ --beam 20.0 \ @@ -69,10 +69,10 @@ Usage: --nbest-scale 0.5 (6) fast beam search (nbest oracle WER) -./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ +./lstm_transducer_stateless2/decode.py \ + --epoch 35 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ + --exp-dir ./lstm_transducer_stateless2/exp \ --max-duration 600 \ --decoding-method fast_beam_search_nbest_oracle \ --beam 20.0 \ @@ -82,10 +82,10 @@ Usage: --nbest-scale 0.5 (7) fast beam search (with LG) -./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ +./lstm_transducer_stateless2/decode.py \ + --epoch 35 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ + --exp-dir ./lstm_transducer_stateless2/exp \ --max-duration 600 \ --decoding-method fast_beam_search_nbest_LG \ --beam 20.0 \ @@ -115,6 +115,7 @@ from beam_search import ( greedy_search, greedy_search_batch, modified_beam_search, + modified_beam_search_rnnlm_shallow_fusion, ) from train import add_model_arguments, get_params, get_transducer_model @@ -125,6 +126,7 @@ from icefall.checkpoint import ( load_checkpoint, ) from icefall.lexicon import Lexicon +from icefall.rnn_lm.model import RnnLmModel from icefall.utils import ( AttributeDict, setup_logger, @@ -183,7 +185,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless5/exp", + default="lstm_transducer_stateless2/exp", help="The experiment dir", ) @@ -213,6 +215,7 @@ def get_parser(): - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG + - modified-beam-search3 # for rnn lm shallow fusion If you use fast_beam_search_nbest_LG, you have to specify `--lang-dir`, which should contain `LG.pt`. """, @@ -240,16 +243,6 @@ def get_parser(): """, ) - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. - """, - ) - parser.add_argument( "--max-contexts", type=int, @@ -275,6 +268,7 @@ def get_parser(): help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) + parser.add_argument( "--max-sym-per-frame", type=int, @@ -302,28 +296,69 @@ def get_parser(): ) parser.add_argument( - "--simulate-streaming", - type=str2bool, - default=False, - help="""Whether to simulate streaming in decoding, this is a good way to - test a streaming model. + "--rnn-lm-scale", + type=float, + default=0.0, + help="""Used only when --method is modified_beam_search3. + It specifies the path to RNN LM exp dir. """, ) parser.add_argument( - "--decode-chunk-size", - type=int, - default=16, - help="The chunk size for decoding (in frames after subsampling)", + "--rnn-lm-exp-dir", + type=str, + default="rnn_lm/exp", + help="""Used only when --method is rnn-lm. + It specifies the path to RNN LM exp dir. + """, ) parser.add_argument( - "--left-context", + "--rnn-lm-epoch", type=int, - default=64, - help="left context can be seen during decoding (in frames after subsampling)", + default=7, + help="""Used only when --method is rnn-lm. + It specifies the checkpoint to use. + """, ) + parser.add_argument( + "--rnn-lm-avg", + type=int, + default=2, + help="""Used only when --method is rnn-lm. + It specifies the number of checkpoints to average. + """, + ) + + parser.add_argument( + "--rnn-lm-embedding-dim", + type=int, + default=2048, + help="Embedding dim of the model", + ) + + parser.add_argument( + "--rnn-lm-hidden-dim", + type=int, + default=2048, + help="Hidden dim of the model", + ) + + parser.add_argument( + "--rnn-lm-num-layers", + type=int, + default=4, + help="Number of RNN layers the model", + ) + parser.add_argument( + "--rnn-lm-tie-weights", + type=str2bool, + default=False, + help="""True to share the weights between the input embedding layer and the + last output linear layer + """, + ) add_model_arguments(parser) return parser @@ -336,6 +371,8 @@ def decode_one_batch( batch: dict, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + rnnlm: Optional[RnnLmModel] = None, + rnnlm_scale: float = 1.0, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -361,7 +398,7 @@ def decode_one_batch( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: @@ -474,12 +511,21 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion": + hyp_tokens = modified_beam_search_rnnlm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) else: batch_size = encoder_out.size(0) for i in range(batch_size): # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] # fmt: on if params.decoding_method == "greedy_search": hyp = greedy_search( @@ -523,7 +569,9 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + rnnlm: Optional[RnnLmModel] = None, + rnnlm_scale: float = 1.0, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. Args: @@ -538,7 +586,7 @@ def decode_dataset( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: @@ -564,6 +612,7 @@ def decode_dataset( for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + logging.info(f"Decoding {batch_idx}-th batch") hyps_dict = decode_one_batch( params=params, @@ -572,6 +621,8 @@ def decode_dataset( decoding_graph=decoding_graph, word_table=word_table, batch=batch, + rnnlm=rnnlm, + rnnlm_scale=rnnlm_scale, ) for name, hyps in hyps_dict.items(): @@ -597,7 +648,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], ): test_set_wers = dict() for key, results in results_dict.items(): @@ -657,6 +708,7 @@ def main(): "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", + "modified_beam_search_sf_rnnlm", ) params.res_dir = params.exp_dir / params.decoding_method @@ -665,10 +717,6 @@ def main(): else: params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if params.simulate_streaming: - params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" - params.suffix += f"-left-context-{params.left_context}" - if "fast_beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" @@ -686,6 +734,8 @@ def main(): params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}" + if params.use_averaged_model: params.suffix += "-use-averaged-model" @@ -706,11 +756,6 @@ def main(): params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() - if params.simulate_streaming: - assert ( - params.causal_convolution - ), "Decoding in streaming requires causal convolution" - logging.info(params) logging.info("About to create model") @@ -796,6 +841,25 @@ def main(): model.to(device) model.eval() + rnn_lm_model = None + rnn_lm_scale = params.rnn_lm_scale + if params.decoding_method == "modified_beam_search3": + rnn_lm_model = RnnLmModel( + vocab_size=params.vocab_size, + embedding_dim=params.rnn_lm_embedding_dim, + hidden_dim=params.rnn_lm_hidden_dim, + num_layers=params.rnn_lm_num_layers, + tie_weights=params.rnn_lm_tie_weights, + ) + assert params.rnn_lm_avg == 1 + + load_checkpoint( + f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt", + rnn_lm_model, + ) + rnn_lm_model.to(device) + rnn_lm_model.eval() + if "fast_beam_search" in params.decoding_method: if params.decoding_method == "fast_beam_search_nbest_LG": lexicon = Lexicon(params.lang_dir) @@ -839,6 +903,8 @@ def main(): sp=sp, word_table=word_table, decoding_graph=decoding_graph, + rnnlm=rnn_lm_model, + rnnlm_scale=rnn_lm_scale, ) save_results( From 86662f0b97622c1367fff5e9f974f96ac3874ccf Mon Sep 17 00:00:00 2001 From: marcoyang Date: Wed, 2 Nov 2022 17:24:53 +0800 Subject: [PATCH 003/120] update results --- egs/librispeech/ASR/RESULTS.md | 53 ++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 92323a556..57dd9f230 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -101,6 +101,7 @@ The WERs are: |-------------------------------------|------------|------------|-------------------------| | greedy search (max sym per frame 1) | 2.78 | 7.36 | --iter 468000 --avg 16 | | modified_beam_search | 2.73 | 7.15 | --iter 468000 --avg 16 | +| modified_beam_search + RNNLM shallow fusion | 2.42 | 6.46 | --iter 468000 --avg 16 | | fast_beam_search | 2.76 | 7.31 | --iter 468000 --avg 16 | | greedy search (max sym per frame 1) | 2.77 | 7.35 | --iter 472000 --avg 18 | | modified_beam_search | 2.75 | 7.08 | --iter 472000 --avg 18 | @@ -155,6 +156,27 @@ for m in greedy_search fast_beam_search modified_beam_search; do done ``` +To decode with RNNLM shallow fusion, use the following decoding command. A well-trained RNNLM +can be found here: + +for iter in 472000; do + for avg in 8 10 12 14 16 18; do + ./lstm_transducer_stateless2/decode.py \ + --iter $iter \ + --avg $avg \ + --exp-dir ./lstm_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search_rnnlm_shallow_fusion \ + --beam 4 \ + --rnn-lm-scale 0.3 \ + --rnn-lm-exp-dir /path/to/RNNLM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 + done +done + Pretrained models, training logs, decoding logs, and decoding results are available at @@ -1311,6 +1333,7 @@ layers (24 v.s 12) but a narrower model (1536 feedforward dim and 384 encoder di |-------------------------------------|------------|------------|-----------------------------------------| | greedy search (max sym per frame 1) | 2.54 | 5.72 | --epoch 30 --avg 10 --max-duration 600 | | modified beam search | 2.47 | 5.71 | --epoch 30 --avg 10 --max-duration 600 | +| modified beam search + RNNLM shallow fusion | 2.27 | 5.24 | --epoch 30 --avg 10 --max-duration 600 | | fast beam search | 2.5 | 5.72 | --epoch 30 --avg 10 --max-duration 600 | ```bash @@ -1356,6 +1379,36 @@ for method in greedy_search modified_beam_search fast_beam_search; do done ``` +To decode with RNNLM shallow fusion, use the following decoding command. A well-trained RNNLM +can be found here: + +```bash +for method in greedy_search modified_beam_search fast_beam_search; do + ./pruned_transducer_stateless5/decode.py \ + --epoch 30 \ + --avg 10 \ + --exp-dir ./pruned_transducer_stateless5/exp-B \ + --max-duration 600 \ + --decoding-method modified_beam_search_rnnlm_shallow_fusion \ + --max-sym-per-frame 1 \ + --num-encoder-layers 24 \ + --dim-feedforward 1536 \ + --nhead 8 \ + --encoder-dim 384 \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --use-averaged-model True + --beam 4 \ + --max-contexts 4 \ + --rnn-lm-scale 0.4 \ + --rnn-lm-exp-dir /path/to/RNNLM/exp \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 +done +``` + You can find a pretrained model, training logs, decoding logs, and decoding results at: From 0a46a39e24a687487eeab3396d35fd395b156a0c Mon Sep 17 00:00:00 2001 From: marcoyang Date: Wed, 2 Nov 2022 17:25:31 +0800 Subject: [PATCH 004/120] update decoding commands --- .../ASR/lstm_transducer_stateless2/decode.py | 33 +++--- .../beam_search.py | 105 +----------------- .../pruned_transducer_stateless5/decode.py | 95 +++++++++++----- 3 files changed, 88 insertions(+), 145 deletions(-) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index 1d46c0177..c43328e08 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -91,6 +91,21 @@ Usage: --beam 20.0 \ --max-contexts 8 \ --max-states 64 + +(8) modified beam search (with RNNLM shallow fusion) +./lstm_transducer_stateless2/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./lstm_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search_rnnlm_shallow_fusion \ + --beam 4 \ + --rnn-lm-scale 0.3 \ + --rnn-lm-exp-dir /path/to/RNNLM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 """ @@ -121,7 +136,6 @@ from beam_search import ( from librispeech import LibriSpeech from train import add_model_arguments, get_params, get_transducer_model -from icefall import NgramLm from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -389,8 +403,6 @@ def decode_one_batch( batch: dict, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, - ngram_lm: Optional[NgramLm] = None, - ngram_lm_scale: float = 1.0, rnnlm: Optional[RnnLmModel] = None, rnnlm_scale: float = 1.0, ) -> Dict[str, List[List[str]]]: @@ -526,11 +538,12 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_sf_rnnlm": - hyp_tokens = modified_beam_search_sf_rnnlm_batched( + elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion": + hyp_tokens = modified_beam_search_rnnlm_shallow_fusion( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, + beam=params.beam_size, sp=sp, rnnlm=rnnlm, rnnlm_scale=rnnlm_scale, @@ -586,9 +599,7 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, - ngram_lm: Optional[NgramLm] = None, - ngram_lm_scale: float = 1.0, - rnnlm: Optional[NgramLm] = None, + rnnlm: Optional[RnnLmModel] = None, rnnlm_scale: float = 1.0, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -642,8 +653,6 @@ def decode_dataset( decoding_graph=decoding_graph, word_table=word_table, batch=batch, - ngram_lm=ngram_lm, - ngram_lm_scale=ngram_lm_scale, rnnlm=rnnlm, rnnlm_scale=rnnlm_scale, ) @@ -731,7 +740,7 @@ def main(): "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", - "modified_beam_search_sf_rnnlm", + "modified_beam_search_rnnlm_shallow_fusion", ) params.res_dir = params.exp_dir / params.decoding_method @@ -942,8 +951,6 @@ def main(): sp=sp, word_table=word_table, decoding_graph=decoding_graph, - ngram_lm=ngram_lm, - ngram_lm_scale=params.ngram_lm_scale, rnnlm=rnn_lm_model, rnnlm_scale=rnn_lm_scale, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index d569b0752..e454bc1a6 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -1,4 +1,5 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang +# Xiaoyu Yang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -656,6 +657,7 @@ class Hypothesis: # The log prob of ys. # It contains only one entry. log_prob: torch.Tensor + state: Optional=None lm_score: Optional=None @@ -1542,107 +1544,6 @@ def fast_beam_search_with_nbest_rnn_rescoring( ans[key] = hyps return ans - -def modified_beam_search_sf_rnnlm( - model: Transducer, - encoder_out: torch.Tensor, - sp, - rnnlm: RnnLmModel, - rnnlm_scale: float, - beam: int = 4, -): - encoder_out = model.joiner.encoder_proj(encoder_out) - lm_scale = rnnlm_scale - - assert rnnlm is not None - assert encoder_out.ndim == 2, encoder_out.shape - rnnlm.clean_cache() - blank_id = model.decoder.blank_id - sos_id = sp.piece_to_id("") - eos_id = sp.piece_to_id("") - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - - B = HypothesisList() - B.add( - Hypothesis( - ys=[blank_id] * context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - ) - ) - - T = encoder_out.shape[0] - for t in range(T): - current_encoder_out = encoder_out[t : t + 1] - A = list(B) - B = HypothesisList() - - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyp in A] - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyp in A], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - decoder_out = model.decoder(decoder_input, need_pad=False).squeeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - - # decoder_out is of shape (num_hyps, joiner_dim) - current_encoder_out = current_encoder_out.repeat(len(A), 1) - # current_encoder_out is of shape (num_hyps, encoder_out_dim) - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, vocab_size) - log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - log_probs = log_probs.reshape(-1) - topk_log_probs, topk_indexes = log_probs.topk( - beam - ) # get topk tokens and scores - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[hyp_idx] # get hyp - new_ys = hyp.ys[:] - state = "ys=" + "+".join(list(map(str, new_ys))) - tokens = k2.RaggedTensor([new_ys[context_size:]]) - - lm_score = rnnlm.predict( - tokens, state, sos_id, eos_id, blank_id - ) # get rnnlm score - - hyp_log_prob = topk_log_probs[k] # get score of current hyp - new_token = topk_token_indexes[k] # get token - if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - # state_cost = hyp.state_cost.forward_one_step(new_token) - hyp_log_prob += ( - lm_score[new_token] * lm_scale - ) # add the lm score - else: - new_ys = new_ys - new_log_prob = hyp_log_prob - - new_hyp = Hypothesis( - ys=new_ys, - log_prob=new_log_prob, - ) - B.add(new_hyp) - - best_hyp = B.get_most_probable(length_norm=True) - return best_hyp.ys[context_size:] def modified_beam_search_rnnlm_shallow_fusion( model: Transducer, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 59c646717..8c69cfd6e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 # # Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) +# Zengwei Yao, +# Xiaoyu Yang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -19,47 +20,43 @@ """ Usage: (1) greedy search -./lstm_transducer_stateless2/decode.py \ - --epoch 35 \ +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ --avg 15 \ - --exp-dir ./lstm_transducer_stateless2/exp \ + --exp-dir ./pruned_transducer_stateless5/exp \ --max-duration 600 \ --decoding-method greedy_search - (2) beam search (not recommended) -./lstm_transducer_stateless2/decode.py \ - --epoch 35 \ +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ --avg 15 \ - --exp-dir ./lstm_transducer_stateless2/exp \ + --exp-dir ./pruned_transducer_stateless5/exp \ --max-duration 600 \ --decoding-method beam_search \ --beam-size 4 - (3) modified beam search -./lstm_transducer_stateless2/decode.py \ - --epoch 35 \ +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ --avg 15 \ - --exp-dir ./lstm_transducer_stateless2/exp \ + --exp-dir ./pruned_transducer_stateless5/exp \ --max-duration 600 \ --decoding-method modified_beam_search \ --beam-size 4 - (4) fast beam search (one best) -./lstm_transducer_stateless2/decode.py \ - --epoch 35 \ +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ --avg 15 \ - --exp-dir ./lstm_transducer_stateless2/exp \ + --exp-dir ./pruned_transducer_stateless5/exp \ --max-duration 600 \ --decoding-method fast_beam_search \ --beam 20.0 \ --max-contexts 8 \ --max-states 64 - (5) fast beam search (nbest) -./lstm_transducer_stateless2/decode.py \ - --epoch 30 \ +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless3/exp \ + --exp-dir ./pruned_transducer_stateless5/exp \ --max-duration 600 \ --decoding-method fast_beam_search_nbest \ --beam 20.0 \ @@ -67,12 +64,11 @@ Usage: --max-states 64 \ --num-paths 200 \ --nbest-scale 0.5 - (6) fast beam search (nbest oracle WER) -./lstm_transducer_stateless2/decode.py \ - --epoch 35 \ +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ --avg 15 \ - --exp-dir ./lstm_transducer_stateless2/exp \ + --exp-dir ./pruned_transducer_stateless5/exp \ --max-duration 600 \ --decoding-method fast_beam_search_nbest_oracle \ --beam 20.0 \ @@ -80,17 +76,34 @@ Usage: --max-states 64 \ --num-paths 200 \ --nbest-scale 0.5 - (7) fast beam search (with LG) -./lstm_transducer_stateless2/decode.py \ - --epoch 35 \ +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ --avg 15 \ - --exp-dir ./lstm_transducer_stateless2/exp \ + --exp-dir ./pruned_transducer_stateless5/exp \ --max-duration 600 \ --decoding-method fast_beam_search_nbest_LG \ --beam 20.0 \ --max-contexts 8 \ --max-states 64 + +(8) modified beam search with RNNLM shallow fusion (with LG) +./pruned_transducer_stateless5/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 4 \ + --max-contexts 4 \ + --rnn-lm-scale 0.4 \ + --rnn-lm-exp-dir /path/to/RNNLM/exp \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 + + """ @@ -243,6 +256,16 @@ def get_parser(): """, ) + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + parser.add_argument( "--max-contexts", type=int, @@ -294,6 +317,15 @@ def get_parser(): Used only when the decoding method is fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) + + parser.add_argument( + "--simulate-streaming", + type=str2bool, + default=False, + help="""Whether to simulate streaming in decoding, this is a good way to + test a streaming model. + """, + ) parser.add_argument( "--rnn-lm-scale", @@ -517,6 +549,9 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, + sp=sp, + rnnlm=rnnlm, + rnnlm_scale=rnnlm_scale, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -708,7 +743,7 @@ def main(): "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", - "modified_beam_search_sf_rnnlm", + "modified_beam_search_rnnlm_shallow_fusion", ) params.res_dir = params.exp_dir / params.decoding_method @@ -843,7 +878,7 @@ def main(): rnn_lm_model = None rnn_lm_scale = params.rnn_lm_scale - if params.decoding_method == "modified_beam_search3": + if params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion": rnn_lm_model = RnnLmModel( vocab_size=params.vocab_size, embedding_dim=params.rnn_lm_embedding_dim, From babcfd4b68a0f6729161eb1aa0c10e2c2aea2764 Mon Sep 17 00:00:00 2001 From: marcoyang Date: Wed, 2 Nov 2022 17:27:31 +0800 Subject: [PATCH 005/120] update author info --- .../ASR/lstm_transducer_stateless2/decode.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index c43328e08..fc077f062 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 # # Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) +# Zengwei Yao, +# Xiaoyu Yang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -91,7 +92,7 @@ Usage: --beam 20.0 \ --max-contexts 8 \ --max-states 64 - + (8) modified beam search (with RNNLM shallow fusion) ./lstm_transducer_stateless2/decode.py \ --epoch 35 \ @@ -105,7 +106,7 @@ Usage: --rnn-lm-epoch 99 \ --rnn-lm-avg 1 \ --rnn-lm-num-layers 3 \ - --rnn-lm-tie-weights 1 + --rnn-lm-tie-weights 1 """ @@ -131,7 +132,6 @@ from beam_search import ( greedy_search_batch, modified_beam_search, modified_beam_search_rnnlm_shallow_fusion, - ) from librispeech import LibriSpeech from train import add_model_arguments, get_params, get_transducer_model @@ -386,11 +386,7 @@ def get_parser(): last output linear layer """, ) - parser.add_argument( - "--ilm-scale", - type=float, - default=-0.1 - ) + parser.add_argument("--ilm-scale", type=float, default=-0.1) add_model_arguments(parser) return parser @@ -642,9 +638,13 @@ def decode_dataset( for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - total_duration = sum([cut.duration for cut in batch["supervisions"]["cut"]]) - - logging.info(f"Decoding {batch_idx}-th batch, batch size is {len(cut_ids)}, total duration is {total_duration}") + total_duration = sum( + [cut.duration for cut in batch["supervisions"]["cut"]] + ) + + logging.info( + f"Decoding {batch_idx}-th batch, batch size is {len(cut_ids)}, total duration is {total_duration}" + ) hyps_dict = decode_one_batch( params=params, @@ -765,10 +765,10 @@ def main(): else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - + if "rnnlm" in params.decoding_method: params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}" - + if "ILME" in params.decoding_method: params.suffix += f"-ILME-scale={params.ilm_scale}" @@ -903,7 +903,7 @@ def main(): ) rnn_lm_model.to(device) rnn_lm_model.eval() - + else: rnn_lm_model = None From 6c8d1f9ef5feb448565b68a533689794eb83548c Mon Sep 17 00:00:00 2001 From: marcoyang Date: Wed, 2 Nov 2022 17:48:58 +0800 Subject: [PATCH 006/120] update --- .../beam_search.py | 523 ++++++++++++++---- 1 file changed, 417 insertions(+), 106 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index e454bc1a6..7c5a5ace4 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -17,16 +17,23 @@ import warnings from dataclasses import dataclass -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union import k2 import sentencepiece as spm import torch from model import Transducer +from icefall import NgramLm, NgramLmStateCost from icefall.decode import Nbest, one_best_decoding from icefall.rnn_lm.model import RnnLmModel -from icefall.utils import add_eos, add_sos, get_texts +from icefall.utils import ( + DecodingResults, + add_eos, + add_sos, + get_texts, + get_texts_with_timestamp, +) def fast_beam_search_one_best( @@ -38,7 +45,8 @@ def fast_beam_search_one_best( max_states: int, max_contexts: int, temperature: float = 1.0, -) -> List[List[int]]: + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: """It limits the maximum number of symbols per frame to 1. A lattice is first obtained using fast beam search, and then @@ -62,8 +70,12 @@ def fast_beam_search_one_best( Max contexts pre stream per frame. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. Returns: - Return the decoded result. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ lattice = fast_beam_search( model=model, @@ -77,8 +89,11 @@ def fast_beam_search_one_best( ) best_path = one_best_decoding(lattice) - hyps = get_texts(best_path) - return hyps + + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) def fast_beam_search_nbest_LG( @@ -93,7 +108,8 @@ def fast_beam_search_nbest_LG( nbest_scale: float = 0.5, use_double_scores: bool = True, temperature: float = 1.0, -) -> List[List[int]]: + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: """It limits the maximum number of symbols per frame to 1. The process to get the results is: @@ -130,8 +146,12 @@ def fast_beam_search_nbest_LG( single precision. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. Returns: - Return the decoded result. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ lattice = fast_beam_search( model=model, @@ -196,9 +216,10 @@ def fast_beam_search_nbest_LG( best_hyp_indexes = ragged_tot_scores.argmax() best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes) - hyps = get_texts(best_path) - - return hyps + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) def fast_beam_search_nbest( @@ -213,7 +234,8 @@ def fast_beam_search_nbest( nbest_scale: float = 0.5, use_double_scores: bool = True, temperature: float = 1.0, -) -> List[List[int]]: + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: """It limits the maximum number of symbols per frame to 1. The process to get the results is: @@ -250,8 +272,12 @@ def fast_beam_search_nbest( single precision. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. Returns: - Return the decoded result. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ lattice = fast_beam_search( model=model, @@ -280,9 +306,10 @@ def fast_beam_search_nbest( best_path = k2.index_fsa(nbest.fsa, max_indexes) - hyps = get_texts(best_path) - - return hyps + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) def fast_beam_search_nbest_oracle( @@ -298,7 +325,8 @@ def fast_beam_search_nbest_oracle( use_double_scores: bool = True, nbest_scale: float = 0.5, temperature: float = 1.0, -) -> List[List[int]]: + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: """It limits the maximum number of symbols per frame to 1. A lattice is first obtained using fast beam search, and then @@ -339,8 +367,12 @@ def fast_beam_search_nbest_oracle( yields more unique paths. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. Returns: - Return the decoded result. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ lattice = fast_beam_search( model=model, @@ -379,8 +411,10 @@ def fast_beam_search_nbest_oracle( best_path = k2.index_fsa(nbest.fsa, max_indexes) - hyps = get_texts(best_path) - return hyps + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) def fast_beam_search( @@ -470,8 +504,11 @@ def fast_beam_search( def greedy_search( - model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int -) -> List[int]: + model: Transducer, + encoder_out: torch.Tensor, + max_sym_per_frame: int, + return_timestamps: bool = False, +) -> Union[List[int], DecodingResults]: """Greedy search for a single utterance. Args: model: @@ -481,8 +518,12 @@ def greedy_search( max_sym_per_frame: Maximum number of symbols per frame. If it is set to 0, the WER would be 100%. + return_timestamps: + Whether to return timestamps. Returns: - Return the decoded result. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ assert encoder_out.ndim == 3 @@ -508,6 +549,10 @@ def greedy_search( t = 0 hyp = [blank_id] * context_size + # timestamp[i] is the frame index after subsampling + # on which hyp[i] is decoded + timestamp = [] + # Maximum symbols per utterance. max_sym_per_utt = 1000 @@ -534,6 +579,7 @@ def greedy_search( y = logits.argmax().item() if y not in (blank_id, unk_id): hyp.append(y) + timestamp.append(t) decoder_input = torch.tensor( [hyp[-context_size:]], device=device ).reshape(1, context_size) @@ -548,14 +594,21 @@ def greedy_search( t += 1 hyp = hyp[context_size:] # remove blanks - return hyp + if not return_timestamps: + return hyp + else: + return DecodingResults( + tokens=[hyp], + timestamps=[timestamp], + ) def greedy_search_batch( model: Transducer, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, -) -> List[List[int]]: + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. Args: model: @@ -565,9 +618,12 @@ def greedy_search_batch( encoder_out_lens: A 1-D tensor of shape (N,), containing number of valid frames in encoder_out before padding. + return_timestamps: + Whether to return timestamps. Returns: - Return a list-of-list of token IDs containing the decoded results. - len(ans) equals to encoder_out.size(0). + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ assert encoder_out.ndim == 3 assert encoder_out.size(0) >= 1, encoder_out.size(0) @@ -592,6 +648,10 @@ def greedy_search_batch( hyps = [[blank_id] * context_size for _ in range(N)] + # timestamp[n][i] is the frame index after subsampling + # on which hyp[n][i] is decoded + timestamps = [[] for _ in range(N)] + decoder_input = torch.tensor( hyps, device=device, @@ -605,7 +665,7 @@ def greedy_search_batch( encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) offset = 0 - for batch_size in batch_size_list: + for (t, batch_size) in enumerate(batch_size_list): start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] @@ -627,6 +687,7 @@ def greedy_search_batch( for i, v in enumerate(y): if v not in (blank_id, unk_id): hyps[i].append(v) + timestamps[i].append(t) emitted = True if emitted: # update decoder output @@ -641,11 +702,19 @@ def greedy_search_batch( sorted_ans = [h[context_size:] for h in hyps] ans = [] + ans_timestamps = [] unsorted_indices = packed_encoder_out.unsorted_indices.tolist() for i in range(N): ans.append(sorted_ans[unsorted_indices[i]]) + ans_timestamps.append(timestamps[unsorted_indices[i]]) - return ans + if not return_timestamps: + return ans + else: + return DecodingResults( + tokens=ans, + timestamps=ans_timestamps, + ) @dataclass @@ -657,9 +726,12 @@ class Hypothesis: # The log prob of ys. # It contains only one entry. log_prob: torch.Tensor - state: Optional=None - lm_score: Optional=None + # timestamp[i] is the frame index after subsampling + # on which ys[i] is decoded + timestamp: List[int] + + state_cost: Optional[NgramLmStateCost] = None @property def key(self) -> str: @@ -808,7 +880,8 @@ def modified_beam_search( encoder_out_lens: torch.Tensor, beam: int = 4, temperature: float = 1.0, -) -> List[List[int]]: + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. Args: @@ -823,9 +896,12 @@ def modified_beam_search( Number of active paths during the beam search. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. Returns: - Return a list-of-list of token IDs. ans[i] is the decoding results - for the i-th utterance. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ assert encoder_out.ndim == 3, encoder_out.shape assert encoder_out.size(0) >= 1, encoder_out.size(0) @@ -843,7 +919,7 @@ def modified_beam_search( device = next(model.parameters()).device batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) + N = encoder_out.size(0) assert torch.all(encoder_out_lens > 0), encoder_out_lens assert N == batch_size_list[0], (N, batch_size_list) @@ -853,6 +929,7 @@ def modified_beam_search( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), + timestamp=[], ) ) @@ -860,7 +937,7 @@ def modified_beam_search( offset = 0 finalized_B = [] - for batch_size in batch_size_list: + for (t, batch_size) in enumerate(batch_size_list): start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] @@ -938,30 +1015,44 @@ def modified_beam_search( new_ys = hyp.ys[:] new_token = topk_token_indexes[k] + new_timestamp = hyp.timestamp[:] if new_token not in (blank_id, unk_id): new_ys.append(new_token) + new_timestamp.append(t) new_log_prob = topk_log_probs[k] - new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + new_hyp = Hypothesis( + ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp + ) B[i].add(new_hyp) B = B + finalized_B best_hyps = [b.get_most_probable(length_norm=True) for b in B] sorted_ans = [h.ys[context_size:] for h in best_hyps] + sorted_timestamps = [h.timestamp for h in best_hyps] ans = [] + ans_timestamps = [] unsorted_indices = packed_encoder_out.unsorted_indices.tolist() for i in range(N): ans.append(sorted_ans[unsorted_indices[i]]) + ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) - return ans + if not return_timestamps: + return ans + else: + return DecodingResults( + tokens=ans, + timestamps=ans_timestamps, + ) def _deprecated_modified_beam_search( model: Transducer, encoder_out: torch.Tensor, beam: int = 4, -) -> List[int]: + return_timestamps: bool = False, +) -> Union[List[int], DecodingResults]: """It limits the maximum number of symbols per frame to 1. It decodes only one utterance at a time. We keep it only for reference. @@ -976,8 +1067,13 @@ def _deprecated_modified_beam_search( A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. beam: Beam size. + return_timestamps: + Whether to return timestamps. + Returns: - Return the decoded result. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ assert encoder_out.ndim == 3 @@ -997,6 +1093,7 @@ def _deprecated_modified_beam_search( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), + timestamp=[], ) ) encoder_out = model.joiner.encoder_proj(encoder_out) @@ -1055,17 +1152,24 @@ def _deprecated_modified_beam_search( for i in range(len(topk_hyp_indexes)): hyp = A[topk_hyp_indexes[i]] new_ys = hyp.ys[:] + new_timestamp = hyp.timestamp[:] new_token = topk_token_indexes[i] if new_token not in (blank_id, unk_id): new_ys.append(new_token) + new_timestamp.append(t) new_log_prob = topk_log_probs[i] - new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + new_hyp = Hypothesis( + ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp + ) B.add(new_hyp) best_hyp = B.get_most_probable(length_norm=True) ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - return ys + if not return_timestamps: + return ys + else: + return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp]) def beam_search( @@ -1073,7 +1177,8 @@ def beam_search( encoder_out: torch.Tensor, beam: int = 4, temperature: float = 1.0, -) -> List[int]: + return_timestamps: bool = False, +) -> Union[List[int], DecodingResults]: """ It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf @@ -1088,8 +1193,13 @@ def beam_search( Beam size. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: - Return the decoded result. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ assert encoder_out.ndim == 3 @@ -1116,7 +1226,7 @@ def beam_search( t = 0 B = HypothesisList() - B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0)) + B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0, timestamp=[])) max_sym_per_utt = 20000 @@ -1177,7 +1287,13 @@ def beam_search( new_y_star_log_prob = y_star.log_prob + skip_log_prob # ys[:] returns a copy of ys - B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) + B.add( + Hypothesis( + ys=y_star.ys[:], + log_prob=new_y_star_log_prob, + timestamp=y_star.timestamp[:], + ) + ) # Second, process other non-blank labels values, indices = log_prob.topk(beam + 1) @@ -1186,7 +1302,14 @@ def beam_search( continue new_ys = y_star.ys + [i] new_log_prob = y_star.log_prob + v - A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob)) + new_timestamp = y_star.timestamp + [t] + A.add( + Hypothesis( + ys=new_ys, + log_prob=new_log_prob, + timestamp=new_timestamp, + ) + ) # Check whether B contains more than "beam" elements more probable # than the most probable in A @@ -1202,7 +1325,11 @@ def beam_search( best_hyp = B.get_most_probable(length_norm=True) ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - return ys + + if not return_timestamps: + return ys + else: + return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp]) def fast_beam_search_with_nbest_rescoring( @@ -1222,7 +1349,8 @@ def fast_beam_search_with_nbest_rescoring( use_double_scores: bool = True, nbest_scale: float = 0.5, temperature: float = 1.0, -) -> Dict[str, List[List[int]]]: + return_timestamps: bool = False, +) -> Dict[str, Union[List[List[int]], DecodingResults]]: """It limits the maximum number of symbols per frame to 1. A lattice is first obtained using fast beam search, num_path are selected and rescored using a given language model. The shortest path within the @@ -1264,10 +1392,13 @@ def fast_beam_search_with_nbest_rescoring( yields more unique paths. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. Returns: Return the decoded result in a dict, where the key has the form - 'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the - ngram LM scale value used during decoding, i.e., 0.1. + 'ngram_lm_scale_xx' and the value is the decoded results + optionally with timestamps. `xx` is the ngram LM scale value + used during decoding, i.e., 0.1. """ lattice = fast_beam_search( model=model, @@ -1345,16 +1476,18 @@ def fast_beam_search_with_nbest_rescoring( log_semiring=False, ) - ans: Dict[str, List[List[int]]] = {} + ans: Dict[str, Union[List[List[int]], DecodingResults]] = {} for s in ngram_lm_scale_list: key = f"ngram_lm_scale_{s}" tot_scores = am_scores.values + s * ngram_lm_scores ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) max_indexes = ragged_tot_scores.argmax() best_path = k2.index_fsa(nbest.fsa, max_indexes) - hyps = get_texts(best_path) - ans[key] = hyps + if not return_timestamps: + ans[key] = get_texts(best_path) + else: + ans[key] = get_texts_with_timestamp(best_path) return ans @@ -1378,7 +1511,8 @@ def fast_beam_search_with_nbest_rnn_rescoring( use_double_scores: bool = True, nbest_scale: float = 0.5, temperature: float = 1.0, -) -> Dict[str, List[List[int]]]: + return_timestamps: bool = False, +) -> Dict[str, Union[List[List[int]], DecodingResults]]: """It limits the maximum number of symbols per frame to 1. A lattice is first obtained using fast beam search, num_path are selected and rescored using a given language model and a rnn-lm. @@ -1424,10 +1558,13 @@ def fast_beam_search_with_nbest_rnn_rescoring( yields more unique paths. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. Returns: Return the decoded result in a dict, where the key has the form - 'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the - ngram LM scale value used during decoding, i.e., 0.1. + 'ngram_lm_scale_xx' and the value is the decoded results + optionally with timestamps. `xx` is the ngram LM scale value + used during decoding, i.e., 0.1. """ lattice = fast_beam_search( model=model, @@ -1539,12 +1676,185 @@ def fast_beam_search_with_nbest_rnn_rescoring( ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) max_indexes = ragged_tot_scores.argmax() best_path = k2.index_fsa(nbest.fsa, max_indexes) - hyps = get_texts(best_path) - ans[key] = hyps + if not return_timestamps: + ans[key] = get_texts(best_path) + else: + ans[key] = get_texts_with_timestamp(best_path) return ans - + + +def modified_beam_search_ngram_rescoring( + model: Transducer, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ngram_lm: NgramLm, + ngram_lm_scale: float, + beam: int = 4, + temperature: float = 1.0, +) -> List[List[int]]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + beam: + Number of active paths during the beam search. + temperature: + Softmax temperature. + Returns: + Return a list-of-list of token IDs. ans[i] is the decoding results + for the i-th utterance. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + lm_scale = ngram_lm_scale + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + state_cost=NgramLmStateCost(ngram_lm), + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [ + hyp.log_prob.reshape(1, 1) + hyp.state_cost.lm_score * lm_scale + for hyps in A + for hyp in hyps + ] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = (logits / temperature).log_softmax( + dim=-1 + ) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + vocab_size = log_probs.size(-1) + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + state_cost = hyp.state_cost.forward_one_step(new_token) + else: + state_cost = hyp.state_cost + + # We only keep AM scores in new_hyp.log_prob + new_log_prob = ( + topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale + ) + + new_hyp = Hypothesis( + ys=new_ys, log_prob=new_log_prob, state_cost=state_cost + ) + B[i].add(new_hyp) + + B = B + finalized_B + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + + sorted_ans = [h.ys[context_size:] for h in best_hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + def modified_beam_search_rnnlm_shallow_fusion( model: Transducer, encoder_out: torch.Tensor, @@ -1559,18 +1869,18 @@ def modified_beam_search_rnnlm_shallow_fusion( Args: model (Transducer): The transducer model - encoder_out (torch.Tensor): + encoder_out (torch.Tensor): Encoder output in (N,T,C) - encoder_out_lens (torch.Tensor): - A 1-D tensor of shape (N,), containing the number of + encoder_out_lens (torch.Tensor): + A 1-D tensor of shape (N,), containing the number of valid frames in encoder_out before padding. - sp: + sp: Sentence piece generator. - rnnlm (RnnLmModel): + rnnlm (RnnLmModel): RNNLM - rnnlm_scale (float): + rnnlm_scale (float): scale of RNNLM in shallow fusion - beam (int, optional): + beam (int, optional): Beam size. Defaults to 4. Returns: @@ -1582,7 +1892,7 @@ def modified_beam_search_rnnlm_shallow_fusion( assert rnnlm is not None lm_scale = rnnlm_scale vocab_size = rnnlm.vocab_size - + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( input=encoder_out, lengths=encoder_out_lens.cpu(), @@ -1592,20 +1902,19 @@ def modified_beam_search_rnnlm_shallow_fusion( blank_id = model.decoder.blank_id sos_id = sp.piece_to_id("") - eos_id = sp.piece_to_id("") unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size device = next(model.parameters()).device - + batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) + N = encoder_out.size(0) assert torch.all(encoder_out_lens > 0), encoder_out_lens assert N == batch_size_list[0], (N, batch_size_list) # get initial lm score and lm state by scoring the "sos" token sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) init_score, init_states = rnnlm.score_token(sos_token) - + B = [HypothesisList() for _ in range(N)] for i in range(N): B[i].add( @@ -1613,19 +1922,19 @@ def modified_beam_search_rnnlm_shallow_fusion( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), state=init_states, - lm_score=init_score.reshape(-1) + lm_score=init_score.reshape(-1), ) ) rnnlm.clean_cache() encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - + offset = 0 finalized_B = [] for batch_size in batch_size_list: start = offset end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] # get batch + current_encoder_out = encoder_out.data[start:end] # get batch current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) offset = end @@ -1637,44 +1946,42 @@ def modified_beam_search_rnnlm_shallow_fusion( A = [list(b) for b in B] B = [HypothesisList() for _ in range(batch_size)] - + ys_log_probs = torch.cat( [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] ) - + decoder_input = torch.tensor( [hyp.ys[-context_size:] for hyps in A for hyp in hyps], device=device, dtype=torch.int64, ) # (num_hyps, context_size) - + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) decoder_out = model.joiner.decoder_proj(decoder_out) - + current_encoder_out = torch.index_select( current_encoder_out, dim=0, index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, 1, 1, encoder_out_dim) - + logits = model.joiner( current_encoder_out, decoder_out, project_input=False, ) # (num_hyps, 1, 1, vocab_size) - + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - log_probs = logits.log_softmax( - dim=-1 - ) # (num_hyps, vocab_size) + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) log_probs.add_(ys_log_probs) - + vocab_size = log_probs.size(-1) log_probs = log_probs.reshape(-1) - + row_splits = hyps_shape.row_splits(1) * vocab_size log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() @@ -1682,7 +1989,6 @@ def modified_beam_search_rnnlm_shallow_fusion( ragged_log_probs = k2.RaggedTensor( shape=log_probs_shape, value=log_probs ) - # for all hyps with a non-blank new token, score it token_list = [] @@ -1698,7 +2004,7 @@ def modified_beam_search_rnnlm_shallow_fusion( for k in range(len(topk_hyp_indexes)): hyp_idx = topk_hyp_indexes[k] hyp = A[i][hyp_idx] - + new_token = topk_token_indexes[k] if new_token not in (blank_id, unk_id): @@ -1708,13 +2014,18 @@ def modified_beam_search_rnnlm_shallow_fusion( cs.append(hyp.state[1]) # forward RNNLM to get new states and scores if len(token_list) != 0: - tokens_to_score = torch.tensor(token_list).to(torch.int64).to(device).reshape(-1,1) - + tokens_to_score = ( + torch.tensor(token_list) + .to(torch.int64) + .to(device) + .reshape(-1, 1) + ) + hs = torch.cat(hs, dim=1).to(device) cs = torch.cat(cs, dim=1).to(device) - scores, lm_states = rnnlm.score_token(tokens_to_score, (hs,cs)) - - count = 0 # index, used to locate score and lm states + scores, lm_states = rnnlm.score_token(tokens_to_score, (hs, cs)) + + count = 0 # index, used to locate score and lm states for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1722,36 +2033,36 @@ def modified_beam_search_rnnlm_shallow_fusion( warnings.simplefilter("ignore") topk_hyp_indexes = (topk_indexes // vocab_size).tolist() topk_token_indexes = (topk_indexes % vocab_size).tolist() - + for k in range(len(topk_hyp_indexes)): hyp_idx = topk_hyp_indexes[k] hyp = A[i][hyp_idx] ys = hyp.ys[:] - + lm_score = hyp.lm_score state = hyp.state - + hyp_log_prob = topk_log_probs[k] # get score of current hyp new_token = topk_token_indexes[k] if new_token not in (blank_id, unk_id): - + ys.append(new_token) hyp_log_prob += ( lm_score[new_token] * lm_scale ) # add the lm score - + lm_score = scores[count] - state = (lm_states[0][:, count, :].unsqueeze(1), lm_states[1][:, count, :].unsqueeze(1)) + state = ( + lm_states[0][:, count, :].unsqueeze(1), + lm_states[1][:, count, :].unsqueeze(1), + ) count += 1 - + new_hyp = Hypothesis( - ys=ys, - log_prob=hyp_log_prob, - state=state, - lm_score=lm_score + ys=ys, log_prob=hyp_log_prob, state=state, lm_score=lm_score ) - B[i].add(new_hyp) + B[i].add(new_hyp) B = B + finalized_B best_hyps = [b.get_most_probable(length_norm=True) for b in B] @@ -1762,4 +2073,4 @@ def modified_beam_search_rnnlm_shallow_fusion( for i in range(N): ans.append(sorted_ans[unsorted_indices[i]]) - return ans \ No newline at end of file + return ans From 9a01b9098deb56c9c4b048c000b4eead756c98f5 Mon Sep 17 00:00:00 2001 From: marcoyang Date: Wed, 2 Nov 2022 18:03:56 +0800 Subject: [PATCH 007/120] include previous added decoding method --- .../ASR/lstm_transducer_stateless2/decode.py | 65 +++++++++++++++---- 1 file changed, 51 insertions(+), 14 deletions(-) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index fc077f062..20a5ebd8b 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -131,11 +131,13 @@ from beam_search import ( greedy_search, greedy_search_batch, modified_beam_search, + modified_beam_search_ngram_rescoring, modified_beam_search_rnnlm_shallow_fusion, ) from librispeech import LibriSpeech from train import add_model_arguments, get_params, get_transducer_model +from icefall import NgramLm from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -232,6 +234,7 @@ def get_parser(): - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG + - modified_beam_search_ngram_rescoring - modified-beam-search_rnnlm_shallow_fusion # for rnn lm shallow fusion If you use fast_beam_search_nbest_LG, you have to specify `--lang-dir`, which should contain `LG.pt`. @@ -386,7 +389,23 @@ def get_parser(): last output linear layer """, ) - parser.add_argument("--ilm-scale", type=float, default=-0.1) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=3, + help="""Token Ngram used for rescoring. + Used only when the decoding method is modified_beam_search_ngram_rescoring""", + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="""ID of the backoff symbol. + Used only when the decoding method is modified_beam_search_ngram_rescoring""", + ) + add_model_arguments(parser) return parser @@ -399,6 +418,8 @@ def decode_one_batch( batch: dict, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, rnnlm: Optional[RnnLmModel] = None, rnnlm_scale: float = 1.0, ) -> Dict[str, List[List[str]]]: @@ -534,6 +555,17 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_ngram_rescoring": + hyp_tokens = modified_beam_search_ngram_rescoring( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion": hyp_tokens = modified_beam_search_rnnlm_shallow_fusion( model=model, @@ -595,9 +627,11 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, rnnlm: Optional[RnnLmModel] = None, rnnlm_scale: float = 1.0, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -638,13 +672,6 @@ def decode_dataset( for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - total_duration = sum( - [cut.duration for cut in batch["supervisions"]["cut"]] - ) - - logging.info( - f"Decoding {batch_idx}-th batch, batch size is {len(cut_ids)}, total duration is {total_duration}" - ) hyps_dict = decode_one_batch( params=params, @@ -653,6 +680,8 @@ def decode_dataset( decoding_graph=decoding_graph, word_table=word_table, batch=batch, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, rnnlm=rnnlm, rnnlm_scale=rnnlm_scale, ) @@ -680,7 +709,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): @@ -740,6 +769,7 @@ def main(): "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", + "modified_beam_search_ngram_rescoring", "modified_beam_search_rnnlm_shallow_fusion", ) params.res_dir = params.exp_dir / params.decoding_method @@ -765,13 +795,10 @@ def main(): else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" if "rnnlm" in params.decoding_method: params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}" - if "ILME" in params.decoding_method: - params.suffix += f"-ILME-scale={params.ilm_scale}" - if params.use_averaged_model: params.suffix += "-use-averaged-model" @@ -884,6 +911,14 @@ def main(): model.to(device) model.eval() + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"lm filename: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") # only load rnnlm if used if "rnnlm" in params.decoding_method: rnn_lm_scale = params.rnn_lm_scale @@ -951,6 +986,8 @@ def main(): sp=sp, word_table=word_table, decoding_graph=decoding_graph, + ngram_lm=ngram_lm, + ngram_lm_scale=params.ngram_lm_scale, rnnlm=rnn_lm_model, rnnlm_scale=rnn_lm_scale, ) From fb45b95c901de33562d76c277232464fb42bb2bd Mon Sep 17 00:00:00 2001 From: marcoyang Date: Wed, 2 Nov 2022 18:11:39 +0800 Subject: [PATCH 008/120] minor fixes --- .../pruned_transducer_stateless5/decode.py | 45 ++++++++++++++----- 1 file changed, 33 insertions(+), 12 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 8c69cfd6e..8ba36e582 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -86,7 +86,7 @@ Usage: --beam 20.0 \ --max-contexts 8 \ --max-states 64 - + (8) modified beam search with RNNLM shallow fusion (with LG) ./pruned_transducer_stateless5/decode.py \ --epoch 35 \ @@ -103,7 +103,7 @@ Usage: --rnn-lm-num-layers 3 \ --rnn-lm-tie-weights 1 - + """ @@ -198,7 +198,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="lstm_transducer_stateless2/exp", + default="pruned_transducer_stateless5/exp", help="The experiment dir", ) @@ -228,7 +228,7 @@ def get_parser(): - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG - - modified-beam-search3 # for rnn lm shallow fusion + - modified-beam-search_rnnlm_shallow_fusion # for rnn lm shallow fusion If you use fast_beam_search_nbest_LG, you have to specify `--lang-dir`, which should contain `LG.pt`. """, @@ -265,7 +265,21 @@ def get_parser(): It specifies the scale for n-gram LM scores. """, ) - + + parser.add_argument( + "--decode-chunk-size", + type=int, + default=16, + help="The chunk size for decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--left-context", + type=int, + default=64, + help="left context can be seen during decoding (in frames after subsampling)", + ) + parser.add_argument( "--max-contexts", type=int, @@ -317,7 +331,7 @@ def get_parser(): Used only when the decoding method is fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) - + parser.add_argument( "--simulate-streaming", type=str2bool, @@ -331,7 +345,7 @@ def get_parser(): "--rnn-lm-scale", type=float, default=0.0, - help="""Used only when --method is modified_beam_search3. + help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. It specifies the path to RNN LM exp dir. """, ) @@ -430,7 +444,7 @@ def decode_one_batch( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: @@ -560,7 +574,7 @@ def decode_one_batch( for i in range(batch_size): # fmt: off - encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] # fmt: on if params.decoding_method == "greedy_search": hyp = greedy_search( @@ -606,7 +620,7 @@ def decode_dataset( decoding_graph: Optional[k2.Fsa] = None, rnnlm: Optional[RnnLmModel] = None, rnnlm_scale: float = 1.0, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -683,7 +697,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): @@ -751,7 +765,9 @@ def main(): params.suffix = f"iter-{params.iter}-avg-{params.avg}" else: params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - + if params.simulate_streaming: + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" + params.suffix += f"-left-context-{params.left_context}" if "fast_beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" @@ -791,6 +807,11 @@ def main(): params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() + if params.simulate_streaming: + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" + logging.info(params) logging.info("About to create model") From b62fd917ae54fb0305a3f4fac931d850bfe231c1 Mon Sep 17 00:00:00 2001 From: marcoyang Date: Wed, 2 Nov 2022 18:17:05 +0800 Subject: [PATCH 009/120] remove redundant test lines --- icefall/rnn_lm/model.py | 88 ++++++++++++++--------------------------- 1 file changed, 29 insertions(+), 59 deletions(-) diff --git a/icefall/rnn_lm/model.py b/icefall/rnn_lm/model.py index 2552f65a6..a6144727a 100644 --- a/icefall/rnn_lm/model.py +++ b/icefall/rnn_lm/model.py @@ -18,7 +18,6 @@ import logging import torch import torch.nn.functional as F -import k2 from icefall.utils import add_eos, add_sos, make_pad_mask @@ -121,9 +120,6 @@ class RnnLmModel(torch.nn.Module): nll_loss = nll_loss.reshape(batch_size, -1) return nll_loss - - def get_init_states(self, sos): - p = next(self.parameters()) def predict_batch(self, tokens, token_lens, sos_id, eos_id, blank_id): device = next(self.parameters()).device @@ -153,35 +149,45 @@ class RnnLmModel(torch.nn.Module): for i in range(batch_size): mask[i, token_lens[i], :] = True logits = logits[mask].reshape(batch_size, -1) - - return logits[:,:].log_softmax(-1), states - + + return logits[:, :].log_softmax(-1), states + def clean_cache(self): self.cache = {} - + def score_token(self, tokens: torch.Tensor, state=None): device = next(self.parameters()).device batch_size = tokens.size(0) if state: - h,c = state + h, c = state else: - h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(device) - c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(device) - - embedding = self.input_embedding(tokens) - rnn_out, states = self.rnn(embedding, (h,c)) - logits = self.output_linear(rnn_out) - - return logits[:,0].log_softmax(-1), states + h = torch.zeros( + self.rnn.num_layers, batch_size, self.rnn.input_size + ).to(device) + c = torch.zeros( + self.rnn.num_layers, batch_size, self.rnn.input_size + ).to(device) - def forward_with_state(self, tokens, token_lens, sos_id, eos_id, blank_id, state=None): + embedding = self.input_embedding(tokens) + rnn_out, states = self.rnn(embedding, (h, c)) + logits = self.output_linear(rnn_out) + + return logits[:, 0].log_softmax(-1), states + + def forward_with_state( + self, tokens, token_lens, sos_id, eos_id, blank_id, state=None + ): batch_size = len(token_lens) if state: - h,c = state + h, c = state else: - h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size) - c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size) - + h = torch.zeros( + self.rnn.num_layers, batch_size, self.rnn.input_size + ) + c = torch.zeros( + self.rnn.num_layers, batch_size, self.rnn.input_size + ) + device = next(self.parameters()).device sos_tokens = add_sos(tokens, sos_id) @@ -202,43 +208,7 @@ class RnnLmModel(torch.nn.Module): embedding = self.input_embedding(x_tokens) # Note: We use batch_first==True - rnn_out, states = self.rnn(embedding, (h,c)) + rnn_out, states = self.rnn(embedding, (h, c)) logits = self.output_linear(rnn_out) return logits, states - -if __name__=="__main__": - LM = RnnLmModel(500, 2048, 2048, 3, True) - h0 = torch.zeros(3, 1, 2048) - c0 = torch.zeros(3, 1, 2048) - seq = [[0,1,2,3]] - seq_lens = [len(s) for s in seq] - tokens = k2.RaggedTensor(seq) - output1, state = LM.forward_with_state( - tokens, - seq_lens, - 1, - 1, - 0, - state=(h0,c0) - ) - seq = [[0,1,2,3,4]] - seq_lens = [len(s) for s in seq] - tokens = k2.RaggedTensor(seq) - output2, _ = LM.forward_with_state( - tokens, - seq_lens, - 1, - 1, - 0, - state=(h0,c0) - ) - - seq = [[4]] - seq_lens = [len(s) for s in seq] - output3 = LM.score_token(seq, seq_lens, state) - - print("Finished") - - - From e3f218b62b13408e4688129efd12acb182077bf6 Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Wed, 2 Nov 2022 22:10:23 +0800 Subject: [PATCH 010/120] Update egs/librispeech/ASR/lstm_transducer_stateless2/decode.py Co-authored-by: Fangjun Kuang --- egs/librispeech/ASR/lstm_transducer_stateless2/decode.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index 20a5ebd8b..ac17da207 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -329,7 +329,7 @@ def get_parser(): "--rnn-lm-scale", type=float, default=0.0, - help="""Used only when --method is modified_beam_search3. + help="""Used only when --method is modified-beam-search_rnnlm_shallow_fusion. It specifies the path to RNN LM exp dir. """, ) From 8f79f6de007883be07fe2d9521c5c4efc57c0ab9 Mon Sep 17 00:00:00 2001 From: zr_jin <60612200+JinZr@users.noreply.github.com> Date: Wed, 2 Nov 2022 23:36:07 +0800 Subject: [PATCH 011/120] Update tdnn_lstm_ctc.rst (#647) --- docs/source/recipes/aishell/tdnn_lstm_ctc.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/recipes/aishell/tdnn_lstm_ctc.rst b/docs/source/recipes/aishell/tdnn_lstm_ctc.rst index 275931698..9eb3b11f7 100644 --- a/docs/source/recipes/aishell/tdnn_lstm_ctc.rst +++ b/docs/source/recipes/aishell/tdnn_lstm_ctc.rst @@ -498,7 +498,7 @@ We do provide a colab notebook for this recipe showing how to use a pre-trained |aishell asr conformer ctc colab notebook| .. |aishell asr conformer ctc colab notebook| image:: https://colab.research.google.com/assets/colab-badge.svg - :target: https://colab.research.google.com/drive/1qULaGvXq7PCu_P61oubfz9b53JzY4H3z + :target: https://colab.research.google.com/drive/1jbyzYq3ytm6j2nlEt-diQm-6QVWyDDEa?usp=sharing **Congratulations!** You have finished the aishell ASR recipe with TDNN-LSTM CTC models in ``icefall``. From 04671b44f85dbd1cb23a93015ccea3fbcd1156f6 Mon Sep 17 00:00:00 2001 From: zr_jin <60612200+JinZr@users.noreply.github.com> Date: Wed, 2 Nov 2022 23:36:40 +0800 Subject: [PATCH 012/120] Update README.md (#649) --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7213d8460..83ce0ac16 100644 --- a/README.md +++ b/README.md @@ -82,7 +82,7 @@ The WER for this model is: |-----|------------|------------| | WER | 6.59 | 17.69 | -We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1kNmDXNMwREi0rZGAOIAOJo93REBuOTcd?usp=sharing) +We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-iSfQMp2So-We_Uu49N4AAcMInB72u9z?usp=sharing) #### Transducer: Conformer encoder + LSTM decoder @@ -162,7 +162,7 @@ The CER for this model is: |-----|-------| | CER | 10.16 | -We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1qULaGvXq7PCu_P61oubfz9b53JzY4H3z?usp=sharing) +We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1jbyzYq3ytm6j2nlEt-diQm-6QVWyDDEa?usp=sharing) ### TIMIT From 5d285625cf14f6f1d7108d650ad58856ac0f7cd6 Mon Sep 17 00:00:00 2001 From: zr_jin <60612200+JinZr@users.noreply.github.com> Date: Wed, 2 Nov 2022 23:37:01 +0800 Subject: [PATCH 013/120] Update tdnn_lstm_ctc.rst (#648) --- docs/source/recipes/librispeech/tdnn_lstm_ctc.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst b/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst index ca477fbaa..aa380396a 100644 --- a/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst +++ b/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst @@ -398,7 +398,7 @@ We provide a colab notebook for decoding with pre-trained model. |librispeech tdnn_lstm_ctc colab notebook| .. |librispeech tdnn_lstm_ctc colab notebook| image:: https://colab.research.google.com/assets/colab-badge.svg - :target: https://colab.research.google.com/drive/1kNmDXNMwREi0rZGAOIAOJo93REBuOTcd + :target: https://colab.research.google.com/drive/1-iSfQMp2So-We_Uu49N4AAcMInB72u9z?usp=sharing **Congratulations!** You have finished the TDNN-LSTM-CTC recipe on librispeech in ``icefall``. From d2a1c65c5cc058e848d949cd99283e1981d499cc Mon Sep 17 00:00:00 2001 From: Teo Wen Shen <36886809+teowenshen@users.noreply.github.com> Date: Thu, 3 Nov 2022 11:27:18 +0900 Subject: [PATCH 014/120] fix torchaudio version in dockerfile (#653) * fix torchaudio version in dockerfile * remove kaldiio --- docker/README.md | 4 ++-- docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile | 5 +++-- docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile | 4 ++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/docker/README.md b/docker/README.md index 0a39b7a49..6f2314e96 100644 --- a/docker/README.md +++ b/docker/README.md @@ -72,14 +72,14 @@ docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all icefall ``` ### Tips: -1. Since your data and models most probably won't be in the docker, you must use the -v flag to access the host machine. Do this by specifying `-v {/path/in/docker}:{/path/in/host/machine}`. +1. Since your data and models most probably won't be in the docker, you must use the -v flag to access the host machine. Do this by specifying `-v {/path/in/host/machine}:{/path/in/docker}`. 2. Also, if your environment requires a proxy, this would be a good time to add it in too: `-e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080`. Overall, your docker run command should look like this. ```bash -docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all -v {/path/in/docker}:{/path/in/host/machine} -e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080 icefall/pytorch1.12.1 +docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all -v {/path/in/host/machine}:{/path/in/docker} -e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080 icefall/pytorch1.12.1 ``` You can explore more docker run options [here](https://docs.docker.com/engine/reference/commandline/run/) to suit your environment. diff --git a/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile b/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile index db4dda864..524303fb8 100644 --- a/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile +++ b/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile @@ -51,8 +51,9 @@ RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz && find /opt/flac-1.3.2 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \ cd - -RUN pip install kaldiio graphviz && \ - conda install -y -c pytorch torchaudio +RUN conda install -y -c pytorch torchaudio=0.12 && \ + pip install graphviz + #install k2 from source RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \ diff --git a/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile b/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile index 7a14a00ad..17a8215f9 100644 --- a/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile +++ b/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile @@ -69,8 +69,8 @@ RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz && find /opt/flac-1.3.2 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \ cd - -RUN pip install kaldiio graphviz && \ - conda install -y -c pytorch torchaudio=0.7.1 +RUN conda install -y -c pytorch torchaudio=0.7.1 && \ + pip install graphviz #install k2 from source RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \ From 2a52b8c125019feb305275b4e356ea5969a35046 Mon Sep 17 00:00:00 2001 From: marcoyang Date: Thu, 3 Nov 2022 11:10:21 +0800 Subject: [PATCH 015/120] update docs --- .../ASR/lstm_transducer_stateless2/decode.py | 35 +++++++++++-------- .../beam_search.py | 25 ++++++++++--- .../pruned_transducer_stateless5/decode.py | 8 ++--- 3 files changed, 45 insertions(+), 23 deletions(-) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index 20a5ebd8b..40a0d5bf7 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -235,7 +235,7 @@ def get_parser(): - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG - modified_beam_search_ngram_rescoring - - modified-beam-search_rnnlm_shallow_fusion # for rnn lm shallow fusion + - modified_beam_search_rnnlm_shallow_fusion # for rnn lm shallow fusion If you use fast_beam_search_nbest_LG, you have to specify `--lang-dir`, which should contain `LG.pt`. """, @@ -329,7 +329,7 @@ def get_parser(): "--rnn-lm-scale", type=float, default=0.0, - help="""Used only when --method is modified_beam_search3. + help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. It specifies the path to RNN LM exp dir. """, ) @@ -338,7 +338,7 @@ def get_parser(): "--rnn-lm-exp-dir", type=str, default="rnn_lm/exp", - help="""Used only when --method is rnn-lm. + help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. It specifies the path to RNN LM exp dir. """, ) @@ -347,7 +347,7 @@ def get_parser(): "--rnn-lm-epoch", type=int, default=7, - help="""Used only when --method is rnn-lm. + help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. It specifies the checkpoint to use. """, ) @@ -356,7 +356,7 @@ def get_parser(): "--rnn-lm-avg", type=int, default=2, - help="""Used only when --method is rnn-lm. + help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. It specifies the number of checkpoints to average. """, ) @@ -911,14 +911,20 @@ def main(): model.to(device) model.eval() - lm_filename = f"{params.tokens_ngram}gram.fst.txt" - logging.info(f"lm filename: {lm_filename}") - ngram_lm = NgramLm( - str(params.lang_dir / lm_filename), - backoff_id=params.backoff_id, - is_binary=False, - ) - logging.info(f"num states: {ngram_lm.lm.num_states}") + # only load N-gram LM when needed + if "ngram" in params.decoding_method: + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"lm filename: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + else: + ngram_lm = None + ngram_lm_scale = None + # only load rnnlm if used if "rnnlm" in params.decoding_method: rnn_lm_scale = params.rnn_lm_scale @@ -941,6 +947,7 @@ def main(): else: rnn_lm_model = None + rnn_lm_scale = 0.0 if "fast_beam_search" in params.decoding_method: if params.decoding_method == "fast_beam_search_nbest_LG": @@ -987,7 +994,7 @@ def main(): word_table=word_table, decoding_graph=decoding_graph, ngram_lm=ngram_lm, - ngram_lm_scale=params.ngram_lm_scale, + ngram_lm_scale=ngram_lm_scale, rnnlm=rnn_lm_model, rnnlm_scale=rnn_lm_scale, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 7c5a5ace4..480146a59 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -17,7 +17,7 @@ import warnings from dataclasses import dataclass -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union import k2 import sentencepiece as spm @@ -729,8 +729,15 @@ class Hypothesis: # timestamp[i] is the frame index after subsampling # on which ys[i] is decoded - timestamp: List[int] + timestamp: List[int] = None + # the lm score for next token given the current ys + lm_score: Optional[torch.Tensor] = None + + # the RNNLM states (h and c in LSTM) + state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + + # N-gram LM state state_cost: Optional[NgramLmStateCost] = None @property @@ -1989,8 +1996,15 @@ def modified_beam_search_rnnlm_shallow_fusion( ragged_log_probs = k2.RaggedTensor( shape=log_probs_shape, value=log_probs ) - - # for all hyps with a non-blank new token, score it + """ + for all hyps with a non-blank new token, score this token. + It is a little confusing here because this for-loop + looks very similar to the one below. Here, we go through all + top-k tokens and only add the non-blanks ones to the token_list. + The RNNLM will score those tokens given the LM states. Note that + the variable `scores` is the LM score after seeing the new + non-blank token. + """ token_list = [] hs = [] cs = [] @@ -2007,11 +2021,12 @@ def modified_beam_search_rnnlm_shallow_fusion( new_token = topk_token_indexes[k] if new_token not in (blank_id, unk_id): - assert new_token != 0, new_token token_list.append([new_token]) + # store the LSTM states hs.append(hyp.state[0]) cs.append(hyp.state[1]) + # forward RNNLM to get new states and scores if len(token_list) != 0: tokens_to_score = ( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 8ba36e582..2711c4cc9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -228,7 +228,7 @@ def get_parser(): - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG - - modified-beam-search_rnnlm_shallow_fusion # for rnn lm shallow fusion + - modified_beam_search_rnnlm_shallow_fusion # for rnn lm shallow fusion If you use fast_beam_search_nbest_LG, you have to specify `--lang-dir`, which should contain `LG.pt`. """, @@ -354,7 +354,7 @@ def get_parser(): "--rnn-lm-exp-dir", type=str, default="rnn_lm/exp", - help="""Used only when --method is rnn-lm. + help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. It specifies the path to RNN LM exp dir. """, ) @@ -363,7 +363,7 @@ def get_parser(): "--rnn-lm-epoch", type=int, default=7, - help="""Used only when --method is rnn-lm. + help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. It specifies the checkpoint to use. """, ) @@ -372,7 +372,7 @@ def get_parser(): "--rnn-lm-avg", type=int, default=2, - help="""Used only when --method is rnn-lm. + help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. It specifies the number of checkpoints to average. """, ) From 163d929601d0130f51bad266f107f9dfbaf6fde4 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Thu, 3 Nov 2022 16:29:30 +0800 Subject: [PATCH 016/120] Add fast_beam_search_LG (#622) * Add fast_beam_search_LG * add fast_beam_search_LG to commonly used recipes * fix ci * fix ci * Fix error --- ...pruned-transducer-stateless2-2022-04-29.sh | 1 + ...pruned-transducer-stateless3-2022-04-29.sh | 1 + .../ASR/pruned_transducer_stateless/decode.py | 33 ++++++++++------ .../beam_search.py | 4 +- .../pruned_transducer_stateless2/decode.py | 39 ++++++++++++------- .../pruned_transducer_stateless3/decode.py | 33 ++++++++++------ .../pruned_transducer_stateless4/decode.py | 25 +++++++----- .../pruned_transducer_stateless5/decode.py | 33 ++++++++++------ icefall/utils.py | 7 +++- 9 files changed, 113 insertions(+), 63 deletions(-) diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh index ae2bb6822..c3d07dc0e 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh @@ -83,4 +83,5 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == done rm pruned_transducer_stateless2/exp/*.pt + rm -r data/lang_bpe_500 fi diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh index 00580ca1f..22de3b45d 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh @@ -82,4 +82,5 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == done rm pruned_transducer_stateless3/exp/*.pt + rm -r data/lang_bpe_500 fi diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 3977f8443..ab23a5a83 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -206,6 +206,7 @@ def get_parser(): - beam_search - modified_beam_search - fast_beam_search + - fast_beam_search_LG - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG @@ -230,7 +231,7 @@ def get_parser(): help="""A floating point value to calculate the cutoff score during beam search (i.e., `cutoff = max-score - beam`), which is the same as the `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, + Used only when --decoding-method is fast_beam_search, fast_beam_search_LG fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle """, @@ -241,7 +242,7 @@ def get_parser(): type=float, default=0.01, help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. + Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG. It specifies the scale for n-gram LM scores. """, ) @@ -250,7 +251,7 @@ def get_parser(): "--max-contexts", type=int, default=8, - help="""Used only when --decoding-method is + help="""Used only when --decoding-method is fast_beam_search_LG fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) @@ -259,7 +260,7 @@ def get_parser(): "--max-states", type=int, default=8, - help="""Used only when --decoding-method is + help="""Used only when --decoding-method is fast_beam_search_LG fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) @@ -355,8 +356,8 @@ def decode_one_batch( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return the decoding result. See above description for the format of @@ -387,7 +388,10 @@ def decode_one_batch( ) hyps = [] - if params.decoding_method == "fast_beam_search": + if ( + params.decoding_method == "fast_beam_search" + or params.decoding_method == "fast_beam_search_LG" + ): hyp_tokens = fast_beam_search_one_best( model=model, decoding_graph=decoding_graph, @@ -397,8 +401,12 @@ def decode_one_batch( max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + if params.decoding_method == "fast_beam_search": + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) elif params.decoding_method == "fast_beam_search_nbest_LG": hyp_tokens = fast_beam_search_nbest_LG( model=model, @@ -526,8 +534,8 @@ def decode_dataset( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return a dict, whose key may be "greedy_search" if greedy search @@ -643,6 +651,7 @@ def main(): "greedy_search", "beam_search", "fast_beam_search", + "fast_beam_search_LG", "fast_beam_search_nbest", "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", @@ -737,7 +746,7 @@ def main(): model.device = device if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": + if "LG" in params.decoding_method: lexicon = Lexicon(params.lang_dir) word_table = lexicon.word_table lg_filename = params.lang_dir / "LG.pt" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 0004a24eb..4f5016e94 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -15,7 +15,7 @@ # limitations under the License. import warnings -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Dict, List, Optional, Union import k2 @@ -727,7 +727,7 @@ class Hypothesis: # timestamp[i] is the frame index after subsampling # on which ys[i] is decoded - timestamp: List[int] + timestamp: List[int] = field(default_factory=list) state_cost: Optional[NgramLmStateCost] = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 3b834b919..99d4b5702 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -212,6 +212,7 @@ def get_parser(): - beam_search - modified_beam_search - fast_beam_search + - fast_beam_search_LG - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG @@ -247,8 +248,8 @@ def get_parser(): type=float, default=0.01, help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. + Used only when --decoding_method is fast_beam_search_LG and + fast_beam_search_nbest_LG. It specifies the scale for n-gram LM scores. """, ) @@ -256,7 +257,7 @@ def get_parser(): "--max-contexts", type=int, default=8, - help="""Used only when --decoding-method is + help="""Used only when --decoding-method is fast_beam_search_LG fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) @@ -265,7 +266,7 @@ def get_parser(): "--max-states", type=int, default=64, - help="""Used only when --decoding-method is + help="""Used only when --decoding-method is fast_beam_search_LG fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) @@ -363,9 +364,10 @@ def decode_one_batch( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_LG, + fast_beam_search_nbest, fast_beam_search_nbest_oracle, and + fast_beam_search_nbest_LG. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -401,7 +403,10 @@ def decode_one_batch( hyps = [] - if params.decoding_method == "fast_beam_search": + if ( + params.decoding_method == "fast_beam_search" + or params.decoding_method == "fast_beam_search_LG" + ): hyp_tokens = fast_beam_search_one_best( model=model, decoding_graph=decoding_graph, @@ -411,8 +416,12 @@ def decode_one_batch( max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + if params.decoding_method == "fast_beam_search": + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) elif params.decoding_method == "fast_beam_search_nbest_LG": hyp_tokens = fast_beam_search_nbest_LG( model=model, @@ -548,9 +557,10 @@ def decode_dataset( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_LG, + fast_beam_search_nbest, fast_beam_search_nbest_oracle, and + fast_beam_search_nbest_LG. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -663,6 +673,7 @@ def main(): "greedy_search", "beam_search", "fast_beam_search", + "fast_beam_search_LG", "fast_beam_search_nbest", "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", @@ -757,7 +768,7 @@ def main(): model.device = device if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": + if "LG" in params.decoding_method: lexicon = Lexicon(params.lang_dir) word_table = lexicon.word_table lg_filename = params.lang_dir / "LG.pt" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index 0f30792e3..f34cf1e1f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -202,6 +202,7 @@ def get_parser(): - beam_search - modified_beam_search - fast_beam_search + - fast_beam_search_LG - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG @@ -226,7 +227,7 @@ def get_parser(): help="""A floating point value to calculate the cutoff score during beam search (i.e., `cutoff = max-score - beam`), which is the same as the `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, + Used only when --decoding-method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle """, @@ -237,7 +238,7 @@ def get_parser(): type=float, default=0.01, help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. + Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG. It specifies the scale for n-gram LM scores. """, ) @@ -246,7 +247,7 @@ def get_parser(): "--max-contexts", type=int, default=8, - help="""Used only when --decoding-method is + help="""Used only when --decoding-method is fast_beam_search_LG, fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) @@ -255,7 +256,7 @@ def get_parser(): "--max-states", type=int, default=64, - help="""Used only when --decoding-method is + help="""Used only when --decoding-method is, fast_beam_search_LG, fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) @@ -440,8 +441,8 @@ def decode_one_batch( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. G: Optional. Used only when decoding method is fast_beam_search, @@ -483,7 +484,10 @@ def decode_one_batch( hyps = [] - if params.decoding_method == "fast_beam_search": + if ( + params.decoding_method == "fast_beam_search" + or params.decoding_method == "fast_beam_search_LG" + ): hyp_tokens = fast_beam_search_one_best( model=model, decoding_graph=decoding_graph, @@ -494,8 +498,12 @@ def decode_one_batch( max_states=params.max_states, temperature=params.temperature, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + if params.decoding_method == "fast_beam_search": + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) elif params.decoding_method == "fast_beam_search_nbest_LG": hyp_tokens = fast_beam_search_nbest_LG( model=model, @@ -714,8 +722,8 @@ def decode_dataset( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. G: Optional. Used only when decoding method is fast_beam_search, @@ -901,6 +909,7 @@ def main(): "greedy_search", "beam_search", "fast_beam_search", + "fast_beam_search_LG", "fast_beam_search_nbest", "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", @@ -1002,7 +1011,7 @@ def main(): G = None if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": + if "LG" in params.decoding_method: lexicon = Lexicon(params.lang_dir) word_table = lexicon.word_table lg_filename = params.lang_dir / "LG.pt" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 85097a01a..6afc21ce7 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -243,6 +243,7 @@ def get_parser(): - beam_search - modified_beam_search - fast_beam_search + - fast_beam_search_LG - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG @@ -267,7 +268,7 @@ def get_parser(): help="""A floating point value to calculate the cutoff score during beam search (i.e., `cutoff = max-score - beam`), which is the same as the `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, + Used only when --decoding-method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle """, @@ -278,7 +279,7 @@ def get_parser(): type=float, default=0.01, help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. + Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG. It specifies the scale for n-gram LM scores. """, ) @@ -287,7 +288,7 @@ def get_parser(): "--max-contexts", type=int, default=8, - help="""Used only when --decoding-method is + help="""Used only when --decoding-method is fast_beam_search_LG, fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) @@ -296,7 +297,7 @@ def get_parser(): "--max-states", type=int, default=64, - help="""Used only when --decoding-method is + help="""Used only when --decoding-method is fast_beam_search_LG, fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) @@ -394,8 +395,8 @@ def decode_one_batch( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return the decoding result and timestamps. See above description for the @@ -430,7 +431,10 @@ def decode_one_batch( x=feature, x_lens=feature_lens ) - if params.decoding_method == "fast_beam_search": + if ( + params.decoding_method == "fast_beam_search" + or params.decoding_method == "fast_beam_search_LG" + ): res = fast_beam_search_one_best( model=model, decoding_graph=decoding_graph, @@ -579,8 +583,8 @@ def decode_dataset( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return a dict, whose key may be "greedy_search" if greedy search @@ -742,6 +746,7 @@ def main(): "greedy_search", "beam_search", "fast_beam_search", + "fast_beam_search_LG", "fast_beam_search_nbest", "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", @@ -886,7 +891,7 @@ def main(): model.eval() if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": + if "LG" in params.decoding_method: lexicon = Lexicon(params.lang_dir) word_table = lexicon.word_table lg_filename = params.lang_dir / "LG.pt" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 632932214..c27d78e34 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -210,6 +210,7 @@ def get_parser(): - beam_search - modified_beam_search - fast_beam_search + - fast_beam_search_LG - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG @@ -234,7 +235,7 @@ def get_parser(): help="""A floating point value to calculate the cutoff score during beam search (i.e., `cutoff = max-score - beam`), which is the same as the `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, + Used only when --decoding-method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle """, @@ -245,7 +246,7 @@ def get_parser(): type=float, default=0.01, help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. + Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG. It specifies the scale for n-gram LM scores. """, ) @@ -254,7 +255,7 @@ def get_parser(): "--max-contexts", type=int, default=8, - help="""Used only when --decoding-method is + help="""Used only when --decoding-method is fast_beam_search_LG, fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) @@ -263,7 +264,7 @@ def get_parser(): "--max-states", type=int, default=64, - help="""Used only when --decoding-method is + help="""Used only when --decoding-method is fast_beam_search_LG, fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) @@ -361,8 +362,8 @@ def decode_one_batch( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return the decoding result. See above description for the format of @@ -399,7 +400,10 @@ def decode_one_batch( hyps = [] - if params.decoding_method == "fast_beam_search": + if ( + params.decoding_method == "fast_beam_search" + or params.decoding_method == "fast_beam_search_LG" + ): hyp_tokens = fast_beam_search_one_best( model=model, decoding_graph=decoding_graph, @@ -409,8 +413,12 @@ def decode_one_batch( max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + if params.decoding_method == "fast_beam_search": + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) elif params.decoding_method == "fast_beam_search_nbest_LG": hyp_tokens = fast_beam_search_nbest_LG( model=model, @@ -538,8 +546,8 @@ def decode_dataset( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return a dict, whose key may be "greedy_search" if greedy search @@ -653,6 +661,7 @@ def main(): "greedy_search", "beam_search", "fast_beam_search", + "fast_beam_search_LG", "fast_beam_search_nbest", "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", @@ -797,7 +806,7 @@ def main(): model.eval() if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": + if "LG" in params.decoding_method: lexicon = Lexicon(params.lang_dir) word_table = lexicon.word_table lg_filename = params.lang_dir / "LG.pt" diff --git a/icefall/utils.py b/icefall/utils.py index 45a49fb5c..93dd0b967 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -1369,6 +1369,7 @@ def parse_hyp_and_timestamp( - beam_search - modified_beam_search - fast_beam_search + - fast_beam_search_LG - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG @@ -1388,6 +1389,7 @@ def parse_hyp_and_timestamp( "greedy_search", "beam_search", "fast_beam_search", + "fast_beam_search_LG", "fast_beam_search_nbest", "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", @@ -1400,7 +1402,10 @@ def parse_hyp_and_timestamp( N = len(res.tokens) assert len(res.timestamps) == N use_word_table = False - if decoding_method == "fast_beam_search_nbest_LG": + if ( + decoding_method == "fast_beam_search_nbest_LG" + and decoding_method == "fast_beam_search_LG" + ): assert word_table is not None use_word_table = True From 64aed2cdeb8aae9ae98af8f6211f696fec2ee9d8 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Thu, 3 Nov 2022 23:12:35 +0800 Subject: [PATCH 017/120] Fix LG log file name (#657) --- egs/librispeech/ASR/pruned_transducer_stateless/decode.py | 8 ++++---- .../ASR/pruned_transducer_stateless2/decode.py | 8 ++++---- .../ASR/pruned_transducer_stateless3/decode.py | 8 ++++---- .../ASR/pruned_transducer_stateless4/decode.py | 8 ++++---- .../ASR/pruned_transducer_stateless5/decode.py | 8 ++++---- 5 files changed, 20 insertions(+), 20 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index ab23a5a83..7b6338948 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -504,8 +504,8 @@ def decode_one_batch( if "nbest" in params.decoding_method: key += f"_num_paths_{params.num_paths}_" key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" return {key: hyps} else: @@ -675,8 +675,8 @@ def main(): if "nbest" in params.decoding_method: params.suffix += f"-nbest-scale-{params.nbest_scale}" params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: params.suffix += ( f"-{params.decoding_method}-beam-size-{params.beam_size}" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 99d4b5702..979a0e02e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -527,8 +527,8 @@ def decode_one_batch( if "nbest" in params.decoding_method: key += f"_num_paths_{params.num_paths}_" key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" return {key: hyps} else: @@ -697,8 +697,8 @@ def main(): if "nbest" in params.decoding_method: params.suffix += f"-nbest-scale-{params.nbest_scale}" params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: params.suffix += ( f"-{params.decoding_method}-beam-size-{params.beam_size}" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index f34cf1e1f..8025d6be1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -686,8 +686,8 @@ def decode_one_batch( if "nbest" in params.decoding_method: key += f"_num_paths_{params.num_paths}_" key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" return {key: hyps} else: return { @@ -936,8 +936,8 @@ def main(): if "nbest" in params.decoding_method: params.suffix += f"-nbest-scale-{params.nbest_scale}" params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: params.suffix += ( f"-{params.decoding_method}-beam-size-{params.beam_size}" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 6afc21ce7..7003e4764 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -551,8 +551,8 @@ def decode_one_batch( if "nbest" in params.decoding_method: key += f"_num_paths_{params.num_paths}_" key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" return {key: (hyps, timestamps)} else: @@ -770,8 +770,8 @@ def main(): if "nbest" in params.decoding_method: params.suffix += f"-nbest-scale-{params.nbest_scale}" params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: params.suffix += ( f"-{params.decoding_method}-beam-size-{params.beam_size}" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index c27d78e34..d251246b9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -516,8 +516,8 @@ def decode_one_batch( if "nbest" in params.decoding_method: key += f"_num_paths_{params.num_paths}_" key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" return {key: hyps} else: @@ -685,8 +685,8 @@ def main(): if "nbest" in params.decoding_method: params.suffix += f"-nbest-scale-{params.nbest_scale}" params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: params.suffix += ( f"-{params.decoding_method}-beam-size-{params.beam_size}" From 0df597291f71bd9c22f28b7482a9ff636dfab351 Mon Sep 17 00:00:00 2001 From: marcoyang Date: Fri, 4 Nov 2022 11:17:56 +0800 Subject: [PATCH 018/120] resolve conflict with timestamp feature --- egs/librispeech/ASR/beam_search.py | 1821 ++++++++++++++++++++++++++++ 1 file changed, 1821 insertions(+) create mode 100644 egs/librispeech/ASR/beam_search.py diff --git a/egs/librispeech/ASR/beam_search.py b/egs/librispeech/ASR/beam_search.py new file mode 100644 index 000000000..cc5c1c09d --- /dev/null +++ b/egs/librispeech/ASR/beam_search.py @@ -0,0 +1,1821 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from dataclasses import dataclass +from typing import Dict, List, Optional, Union + +import k2 +import sentencepiece as spm +import torch +from model import Transducer + +from icefall import NgramLm, NgramLmStateCost +from icefall.decode import Nbest, one_best_decoding +from icefall.rnn_lm.model import RnnLmModel +from icefall.utils import ( + DecodingResults, + add_eos, + add_sos, + get_texts, + get_texts_with_timestamp, +) + + +def fast_beam_search_one_best( + model: Transducer, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + temperature: float = 1.0, + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """It limits the maximum number of symbols per frame to 1. + + A lattice is first obtained using fast beam search, and then + the shortest path within the lattice is used as the final output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + temperature=temperature, + ) + + best_path = one_best_decoding(lattice) + + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) + + +def fast_beam_search_nbest_LG( + model: Transducer, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + num_paths: int, + nbest_scale: float = 0.5, + use_double_scores: bool = True, + temperature: float = 1.0, + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """It limits the maximum number of symbols per frame to 1. + + The process to get the results is: + - (1) Use fast beam search to get a lattice + - (2) Select `num_paths` paths from the lattice using k2.random_paths() + - (3) Unique the selected paths + - (4) Intersect the selected paths with the lattice and compute the + shortest path from the intersection result + - (5) The path with the largest score is used as the decoding output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + num_paths: + Number of paths to extract from the decoded lattice. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + use_double_scores: + True to use double precision for computation. False to use + single precision. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + temperature=temperature, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + + # The following code is modified from nbest.intersect() + word_fsa = k2.invert(nbest.fsa) + if hasattr(lattice, "aux_labels"): + # delete token IDs as it is not needed + del word_fsa.aux_labels + word_fsa.scores.zero_() + word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa) + path_to_utt_map = nbest.shape.row_ids(1) + + if hasattr(lattice, "aux_labels"): + # lattice has token IDs as labels and word IDs as aux_labels. + # inv_lattice has word IDs as labels and token IDs as aux_labels + inv_lattice = k2.invert(lattice) + inv_lattice = k2.arc_sort(inv_lattice) + else: + inv_lattice = k2.arc_sort(lattice) + + if inv_lattice.shape[0] == 1: + path_lattice = k2.intersect_device( + inv_lattice, + word_fsa_with_epsilon_loops, + b_to_a_map=torch.zeros_like(path_to_utt_map), + sorted_match_a=True, + ) + else: + path_lattice = k2.intersect_device( + inv_lattice, + word_fsa_with_epsilon_loops, + b_to_a_map=path_to_utt_map, + sorted_match_a=True, + ) + + # path_lattice has word IDs as labels and token IDs as aux_labels + path_lattice = k2.top_sort(k2.connect(path_lattice)) + tot_scores = path_lattice.get_tot_scores( + use_double_scores=use_double_scores, + log_semiring=True, # Note: we always use True + ) + # See https://github.com/k2-fsa/icefall/pull/420 for why + # we always use log_semiring=True + + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + best_hyp_indexes = ragged_tot_scores.argmax() + best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes) + + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) + + +def fast_beam_search_nbest( + model: Transducer, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + num_paths: int, + nbest_scale: float = 0.5, + use_double_scores: bool = True, + temperature: float = 1.0, + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """It limits the maximum number of symbols per frame to 1. + + The process to get the results is: + - (1) Use fast beam search to get a lattice + - (2) Select `num_paths` paths from the lattice using k2.random_paths() + - (3) Unique the selected paths + - (4) Intersect the selected paths with the lattice and compute the + shortest path from the intersection result + - (5) The path with the largest score is used as the decoding output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + num_paths: + Number of paths to extract from the decoded lattice. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + use_double_scores: + True to use double precision for computation. False to use + single precision. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + temperature=temperature, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + + # at this point, nbest.fsa.scores are all zeros. + + nbest = nbest.intersect(lattice) + # Now nbest.fsa.scores contains acoustic scores + + max_indexes = nbest.tot_scores().argmax() + + best_path = k2.index_fsa(nbest.fsa, max_indexes) + + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) + + +def fast_beam_search_nbest_oracle( + model: Transducer, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + num_paths: int, + ref_texts: List[List[int]], + use_double_scores: bool = True, + nbest_scale: float = 0.5, + temperature: float = 1.0, + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """It limits the maximum number of symbols per frame to 1. + + A lattice is first obtained using fast beam search, and then + we select `num_paths` linear paths from the lattice. The path + that has the minimum edit distance with the given reference transcript + is used as the output. + + This is the best result we can achieve for any nbest based rescoring + methods. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + num_paths: + Number of paths to extract from the decoded lattice. + ref_texts: + A list-of-list of integers containing the reference transcripts. + If the decoding_graph is a trivial_graph, the integer ID is the + BPE token ID. + use_double_scores: + True to use double precision for computation. False to use + single precision. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + temperature=temperature, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + + hyps = nbest.build_levenshtein_graphs() + refs = k2.levenshtein_graph(ref_texts, device=hyps.device) + + levenshtein_alignment = k2.levenshtein_alignment( + refs=refs, + hyps=hyps, + hyp_to_ref_map=nbest.shape.row_ids(1), + sorted_match_ref=True, + ) + + tot_scores = levenshtein_alignment.get_tot_scores( + use_double_scores=False, log_semiring=False + ) + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + + max_indexes = ragged_tot_scores.argmax() + + best_path = k2.index_fsa(nbest.fsa, max_indexes) + + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) + + +def fast_beam_search( + model: Transducer, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + temperature: float = 1.0, +) -> k2.Fsa: + """It limits the maximum number of symbols per frame to 1. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + temperature: + Softmax temperature. + Returns: + Return an FsaVec with axes [utt][state][arc] containing the decoded + lattice. Note: When the input graph is a TrivialGraph, the returned + lattice is actually an acceptor. + """ + assert encoder_out.ndim == 3 + + context_size = model.decoder.context_size + vocab_size = model.decoder.vocab_size + + B, T, C = encoder_out.shape + + config = k2.RnntDecodingConfig( + vocab_size=vocab_size, + decoder_history_len=context_size, + beam=beam, + max_contexts=max_contexts, + max_states=max_states, + ) + individual_streams = [] + for i in range(B): + individual_streams.append(k2.RnntDecodingStream(decoding_graph)) + decoding_streams = k2.RnntDecodingStreams(individual_streams, config) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + for t in range(T): + # shape is a RaggedShape of shape (B, context) + # contexts is a Tensor of shape (shape.NumElements(), context_size) + shape, contexts = decoding_streams.get_contexts() + # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 + contexts = contexts.to(torch.int64) + # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) + decoder_out = model.decoder(contexts, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + # current_encoder_out is of shape + # (shape.NumElements(), 1, joiner_dim) + # fmt: off + current_encoder_out = torch.index_select( + encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) + ) + # fmt: on + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + logits = logits.squeeze(1).squeeze(1) + log_probs = (logits / temperature).log_softmax(dim=-1) + decoding_streams.advance(log_probs) + decoding_streams.terminate_and_flush_to_streams() + lattice = decoding_streams.format_output(encoder_out_lens.tolist()) + + return lattice + + +def greedy_search( + model: Transducer, + encoder_out: torch.Tensor, + max_sym_per_frame: int, + return_timestamps: bool = False, +) -> Union[List[int], DecodingResults]: + """Greedy search for a single utterance. + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + max_sym_per_frame: + Maximum number of symbols per frame. If it is set to 0, the WER + would be 100%. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + unk_id = getattr(model, "unk_id", blank_id) + + device = next(model.parameters()).device + + decoder_input = torch.tensor( + [blank_id] * context_size, device=device, dtype=torch.int64 + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + T = encoder_out.size(1) + t = 0 + hyp = [blank_id] * context_size + + # timestamp[i] is the frame index after subsampling + # on which hyp[i] is decoded + timestamp = [] + + # Maximum symbols per utterance. + max_sym_per_utt = 1000 + + # symbols per frame + sym_per_frame = 0 + + # symbols per utterance decoded so far + sym_per_utt = 0 + + while t < T and sym_per_utt < max_sym_per_utt: + if sym_per_frame >= max_sym_per_frame: + sym_per_frame = 0 + t += 1 + continue + + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) + # fmt: on + logits = model.joiner( + current_encoder_out, decoder_out.unsqueeze(1), project_input=False + ) + # logits is (1, 1, 1, vocab_size) + + y = logits.argmax().item() + if y not in (blank_id, unk_id): + hyp.append(y) + timestamp.append(t) + decoder_input = torch.tensor( + [hyp[-context_size:]], device=device + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + sym_per_utt += 1 + sym_per_frame += 1 + else: + sym_per_frame = 0 + t += 1 + hyp = hyp[context_size:] # remove blanks + + if not return_timestamps: + return hyp + else: + return DecodingResults( + tokens=[hyp], + timestamps=[timestamp], + ) + + +def greedy_search_batch( + model: Transducer, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C), where N >= 1. + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + Returns: + Return a list-of-list of token IDs containing the decoded results. + len(ans) equals to encoder_out.size(0). + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + device = next(model.parameters()).device + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out: (N, 1, decoder_out_dim) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = model.joiner( + current_encoder_out, decoder_out.unsqueeze(1), project_input=False + ) + # logits'shape (batch_size, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v not in (blank_id, unk_id): + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@dataclass +class Hypothesis: + # The predicted tokens so far. + # Newly predicted tokens are appended to `ys`. + ys: List[int] + + # The log prob of ys. + # It contains only one entry. + log_prob: torch.Tensor + state: Optional = None + + lm_score: Optional = None + + @property + def key(self) -> str: + """Return a string representation of self.ys""" + return "_".join(map(str, self.ys)) + + +class HypothesisList(object): + def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None: + """ + Args: + data: + A dict of Hypotheses. Its key is its `value.key`. + """ + if data is None: + self._data = {} + else: + self._data = data + + @property + def data(self) -> Dict[str, Hypothesis]: + return self._data + + def add(self, hyp: Hypothesis) -> None: + """Add a Hypothesis to `self`. + + If `hyp` already exists in `self`, its probability is updated using + `log-sum-exp` with the existed one. + + Args: + hyp: + The hypothesis to be added. + """ + key = hyp.key + if key in self: + old_hyp = self._data[key] # shallow copy + torch.logaddexp( + old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob + ) + else: + self._data[key] = hyp + + def get_most_probable(self, length_norm: bool = False) -> Hypothesis: + """Get the most probable hypothesis, i.e., the one with + the largest `log_prob`. + + Args: + length_norm: + If True, the `log_prob` of a hypothesis is normalized by the + number of tokens in it. + Returns: + Return the hypothesis that has the largest `log_prob`. + """ + if length_norm: + return max( + self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) + ) + else: + return max(self._data.values(), key=lambda hyp: hyp.log_prob) + + def remove(self, hyp: Hypothesis) -> None: + """Remove a given hypothesis. + + Caution: + `self` is modified **in-place**. + + Args: + hyp: + The hypothesis to be removed from `self`. + Note: It must be contained in `self`. Otherwise, + an exception is raised. + """ + key = hyp.key + assert key in self, f"{key} does not exist" + del self._data[key] + + def filter(self, threshold: torch.Tensor) -> "HypothesisList": + """Remove all Hypotheses whose log_prob is less than threshold. + + Caution: + `self` is not modified. Instead, a new HypothesisList is returned. + + Returns: + Return a new HypothesisList containing all hypotheses from `self` + with `log_prob` being greater than the given `threshold`. + """ + ans = HypothesisList() + for _, hyp in self._data.items(): + if hyp.log_prob > threshold: + ans.add(hyp) # shallow copy + return ans + + def topk(self, k: int) -> "HypothesisList": + """Return the top-k hypothesis.""" + hyps = list(self._data.items()) + + hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] + + ans = HypothesisList(dict(hyps)) + return ans + + def __contains__(self, key: str): + return key in self._data + + def __iter__(self): + return iter(self._data.values()) + + def __len__(self) -> int: + return len(self._data) + + def __str__(self) -> str: + s = [] + for key in self: + s.append(key) + return ", ".join(s) + + +def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: + """Return a ragged shape with axes [utt][num_hyps]. + + Args: + hyps: + len(hyps) == batch_size. It contains the current hypothesis for + each utterance in the batch. + Returns: + Return a ragged shape with 2 axes [utt][num_hyps]. Note that + the shape is on CPU. + """ + num_hyps = [len(h) for h in hyps] + + # torch.cumsum() is inclusive sum, so we put a 0 at the beginning + # to get exclusive sum later. + num_hyps.insert(0, 0) + + num_hyps = torch.tensor(num_hyps) + row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32) + ans = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=row_splits[-1].item() + ) + return ans + + +def modified_beam_search( + model: Transducer, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: int = 4, + temperature: float = 1.0, +) -> List[List[int]]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + beam: + Number of active paths during the beam search. + temperature: + Softmax temperature. + Returns: + Return a list-of-list of token IDs. ans[i] is the decoding results + for the i-th utterance. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = (logits / temperature).log_softmax( + dim=-1 + ) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B[i].add(new_hyp) + + B = B + finalized_B + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + + sorted_ans = [h.ys[context_size:] for h in best_hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +def _deprecated_modified_beam_search( + model: Transducer, + encoder_out: torch.Tensor, + beam: int = 4, +) -> List[int]: + """It limits the maximum number of symbols per frame to 1. + + It decodes only one utterance at a time. We keep it only for reference. + The function :func:`modified_beam_search` should be preferred as it + supports batch decoding. + + + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + beam: + Beam size. + Returns: + Return the decoded result. + """ + + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + + device = next(model.parameters()).device + + T = encoder_out.size(1) + + B = HypothesisList() + B.add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + encoder_out = model.joiner.encoder_proj(encoder_out) + + for t in range(T): + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) + # current_encoder_out is of shape (1, 1, 1, encoder_out_dim) + # fmt: on + A = list(B) + B = HypothesisList() + + ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A]) + # ys_log_probs is of shape (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyp in A], + device=device, + dtype=torch.int64, + ) + # decoder_input is of shape (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_output is of shape (num_hyps, 1, 1, joiner_dim) + + current_encoder_out = current_encoder_out.expand( + decoder_out.size(0), 1, 1, -1 + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) + # logits is of shape (num_hyps, 1, 1, vocab_size) + logits = logits.squeeze(1).squeeze(1) + + # now logits is of shape (num_hyps, vocab_size) + log_probs = logits.log_softmax(dim=-1) + + log_probs.add_(ys_log_probs) + + log_probs = log_probs.reshape(-1) + topk_log_probs, topk_indexes = log_probs.topk(beam) + + # topk_hyp_indexes are indexes into `A` + topk_hyp_indexes = topk_indexes // logits.size(-1) + topk_token_indexes = topk_indexes % logits.size(-1) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = topk_hyp_indexes.tolist() + topk_token_indexes = topk_token_indexes.tolist() + + for i in range(len(topk_hyp_indexes)): + hyp = A[topk_hyp_indexes[i]] + new_ys = hyp.ys[:] + new_token = topk_token_indexes[i] + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + new_log_prob = topk_log_probs[i] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B.add(new_hyp) + + best_hyp = B.get_most_probable(length_norm=True) + ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks + + return ys + + +def beam_search( + model: Transducer, + encoder_out: torch.Tensor, + beam: int = 4, + temperature: float = 1.0, +) -> List[int]: + """ + It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf + + espnet/nets/beam_search_transducer.py#L247 is used as a reference. + + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + beam: + Beam size. + temperature: + Softmax temperature. + Returns: + Return the decoded result. + """ + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + + device = next(model.parameters()).device + + decoder_input = torch.tensor( + [blank_id] * context_size, + device=device, + dtype=torch.int64, + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + T = encoder_out.size(1) + t = 0 + + B = HypothesisList() + B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0)) + + max_sym_per_utt = 20000 + + sym_per_utt = 0 + + decoder_cache: Dict[str, torch.Tensor] = {} + + while t < T and sym_per_utt < max_sym_per_utt: + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) + # fmt: on + A = B + B = HypothesisList() + + joint_cache: Dict[str, torch.Tensor] = {} + + # TODO(fangjun): Implement prefix search to update the `log_prob` + # of hypotheses in A + + while True: + y_star = A.get_most_probable() + A.remove(y_star) + + cached_key = y_star.key + + if cached_key not in decoder_cache: + decoder_input = torch.tensor( + [y_star.ys[-context_size:]], + device=device, + dtype=torch.int64, + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + decoder_cache[cached_key] = decoder_out + else: + decoder_out = decoder_cache[cached_key] + + cached_key += f"-t-{t}" + if cached_key not in joint_cache: + logits = model.joiner( + current_encoder_out, + decoder_out.unsqueeze(1), + project_input=False, + ) + + # TODO(fangjun): Scale the blank posterior + log_prob = (logits / temperature).log_softmax(dim=-1) + # log_prob is (1, 1, 1, vocab_size) + log_prob = log_prob.squeeze() + # Now log_prob is (vocab_size,) + joint_cache[cached_key] = log_prob + else: + log_prob = joint_cache[cached_key] + + # First, process the blank symbol + skip_log_prob = log_prob[blank_id] + new_y_star_log_prob = y_star.log_prob + skip_log_prob + + # ys[:] returns a copy of ys + B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) + + # Second, process other non-blank labels + values, indices = log_prob.topk(beam + 1) + for i, v in zip(indices.tolist(), values.tolist()): + if i in (blank_id, unk_id): + continue + new_ys = y_star.ys + [i] + new_log_prob = y_star.log_prob + v + A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob)) + + # Check whether B contains more than "beam" elements more probable + # than the most probable in A + A_most_probable = A.get_most_probable() + + kept_B = B.filter(A_most_probable.log_prob) + + if len(kept_B) >= beam: + B = kept_B.topk(beam) + break + + t += 1 + + best_hyp = B.get_most_probable(length_norm=True) + ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks + return ys + + +def fast_beam_search_with_nbest_rescoring( + model: Transducer, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + ngram_lm_scale_list: List[float], + num_paths: int, + G: k2.Fsa, + sp: spm.SentencePieceProcessor, + word_table: k2.SymbolTable, + oov_word: str = "", + use_double_scores: bool = True, + nbest_scale: float = 0.5, + temperature: float = 1.0, +) -> Dict[str, List[List[int]]]: + """It limits the maximum number of symbols per frame to 1. + A lattice is first obtained using fast beam search, num_path are selected + and rescored using a given language model. The shortest path within the + lattice is used as the final output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + ngram_lm_scale_list: + A list of floats representing LM score scales. + num_paths: + Number of paths to extract from the decoded lattice. + G: + An FsaVec containing only a single FSA. It is an n-gram LM. + sp: + The BPE model. + word_table: + The word symbol table. + oov_word: + OOV words are replaced with this word. + use_double_scores: + True to use double precision for computation. False to use + single precision. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + temperature: + Softmax temperature. + Returns: + Return the decoded result in a dict, where the key has the form + 'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the + ngram LM scale value used during decoding, i.e., 0.1. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + temperature=temperature, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + # at this point, nbest.fsa.scores are all zeros. + + nbest = nbest.intersect(lattice) + # Now nbest.fsa.scores contains acoustic scores + + am_scores = nbest.tot_scores() + + # Now we need to compute the LM scores of each path. + # (1) Get the token IDs of each Path. We assume the decoding_graph + # is an acceptor, i.e., lattice is also an acceptor + tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc] + + tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous()) + tokens = tokens.remove_values_leq(0) # remove -1 and 0 + + token_list: List[List[int]] = tokens.tolist() + word_list: List[List[str]] = sp.decode(token_list) + + assert isinstance(oov_word, str), oov_word + assert oov_word in word_table, oov_word + oov_word_id = word_table[oov_word] + + word_ids_list: List[List[int]] = [] + + for words in word_list: + this_word_ids = [] + for w in words.split(): + if w in word_table: + this_word_ids.append(word_table[w]) + else: + this_word_ids.append(oov_word_id) + word_ids_list.append(this_word_ids) + + word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device) + word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas) + + num_unique_paths = len(word_ids_list) + + b_to_a_map = torch.zeros( + num_unique_paths, + dtype=torch.int32, + device=lattice.device, + ) + + rescored_word_fsas = k2.intersect_device( + a_fsas=G, + b_fsas=word_fsas_with_self_loops, + b_to_a_map=b_to_a_map, + sorted_match_a=True, + ret_arc_maps=False, + ) + + rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas) + rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas)) + ngram_lm_scores = rescored_word_fsas.get_tot_scores( + use_double_scores=True, + log_semiring=False, + ) + + ans: Dict[str, List[List[int]]] = {} + for s in ngram_lm_scale_list: + key = f"ngram_lm_scale_{s}" + tot_scores = am_scores.values + s * ngram_lm_scores + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + max_indexes = ragged_tot_scores.argmax() + best_path = k2.index_fsa(nbest.fsa, max_indexes) + hyps = get_texts(best_path) + + ans[key] = hyps + + return ans + + +def fast_beam_search_with_nbest_rnn_rescoring( + model: Transducer, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + ngram_lm_scale_list: List[float], + num_paths: int, + G: k2.Fsa, + sp: spm.SentencePieceProcessor, + word_table: k2.SymbolTable, + rnn_lm_model: torch.nn.Module, + rnn_lm_scale_list: List[float], + oov_word: str = "", + use_double_scores: bool = True, + nbest_scale: float = 0.5, + temperature: float = 1.0, +) -> Dict[str, List[List[int]]]: + """It limits the maximum number of symbols per frame to 1. + A lattice is first obtained using fast beam search, num_path are selected + and rescored using a given language model and a rnn-lm. + The shortest path within the lattice is used as the final output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + ngram_lm_scale_list: + A list of floats representing LM score scales. + num_paths: + Number of paths to extract from the decoded lattice. + G: + An FsaVec containing only a single FSA. It is an n-gram LM. + sp: + The BPE model. + word_table: + The word symbol table. + rnn_lm_model: + A rnn-lm model used for LM rescoring + rnn_lm_scale_list: + A list of floats representing RNN score scales. + oov_word: + OOV words are replaced with this word. + use_double_scores: + True to use double precision for computation. False to use + single precision. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + temperature: + Softmax temperature. + Returns: + Return the decoded result in a dict, where the key has the form + 'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the + ngram LM scale value used during decoding, i.e., 0.1. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + temperature=temperature, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + # at this point, nbest.fsa.scores are all zeros. + + nbest = nbest.intersect(lattice) + # Now nbest.fsa.scores contains acoustic scores + + am_scores = nbest.tot_scores() + + # Now we need to compute the LM scores of each path. + # (1) Get the token IDs of each Path. We assume the decoding_graph + # is an acceptor, i.e., lattice is also an acceptor + tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc] + + tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous()) + tokens = tokens.remove_values_leq(0) # remove -1 and 0 + + token_list: List[List[int]] = tokens.tolist() + word_list: List[List[str]] = sp.decode(token_list) + + assert isinstance(oov_word, str), oov_word + assert oov_word in word_table, oov_word + oov_word_id = word_table[oov_word] + + word_ids_list: List[List[int]] = [] + + for words in word_list: + this_word_ids = [] + for w in words.split(): + if w in word_table: + this_word_ids.append(word_table[w]) + else: + this_word_ids.append(oov_word_id) + word_ids_list.append(this_word_ids) + + word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device) + word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas) + + num_unique_paths = len(word_ids_list) + + b_to_a_map = torch.zeros( + num_unique_paths, + dtype=torch.int32, + device=lattice.device, + ) + + rescored_word_fsas = k2.intersect_device( + a_fsas=G, + b_fsas=word_fsas_with_self_loops, + b_to_a_map=b_to_a_map, + sorted_match_a=True, + ret_arc_maps=False, + ) + + rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas) + rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas)) + ngram_lm_scores = rescored_word_fsas.get_tot_scores( + use_double_scores=True, + log_semiring=False, + ) + + # Now RNN-LM + blank_id = model.decoder.blank_id + sos_id = sp.piece_to_id("sos_id") + eos_id = sp.piece_to_id("eos_id") + + sos_tokens = add_sos(tokens, sos_id) + tokens_eos = add_eos(tokens, eos_id) + sos_tokens_row_splits = sos_tokens.shape.row_splits(1) + sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1] + + x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id) + y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id) + + x_tokens = x_tokens.to(torch.int64) + y_tokens = y_tokens.to(torch.int64) + sentence_lengths = sentence_lengths.to(torch.int64) + + rnn_lm_nll = rnn_lm_model(x=x_tokens, y=y_tokens, lengths=sentence_lengths) + assert rnn_lm_nll.ndim == 2 + assert rnn_lm_nll.shape[0] == len(token_list) + rnn_lm_scores = -1 * rnn_lm_nll.sum(dim=1) + + ans: Dict[str, List[List[int]]] = {} + for n_scale in ngram_lm_scale_list: + for rnn_scale in rnn_lm_scale_list: + key = f"ngram_lm_scale_{n_scale}_rnn_lm_scale_{rnn_scale}" + tot_scores = ( + am_scores.values + + n_scale * ngram_lm_scores + + rnn_scale * rnn_lm_scores + ) + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + max_indexes = ragged_tot_scores.argmax() + best_path = k2.index_fsa(nbest.fsa, max_indexes) + hyps = get_texts(best_path) + + ans[key] = hyps + + return ans + + +def modified_beam_search_rnnlm_shallow_fusion( + model: Transducer, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + sp: spm.SentencePieceProcessor, + rnnlm: RnnLmModel, + rnnlm_scale: float, + beam: int = 4, +) -> List[List[int]]: + """Modified_beam_search + RNNLM shallow fusion + + Args: + model (Transducer): + The transducer model + encoder_out (torch.Tensor): + Encoder output in (N,T,C) + encoder_out_lens (torch.Tensor): + A 1-D tensor of shape (N,), containing the number of + valid frames in encoder_out before padding. + sp: + Sentence piece generator. + rnnlm (RnnLmModel): + RNNLM + rnnlm_scale (float): + scale of RNNLM in shallow fusion + beam (int, optional): + Beam size. Defaults to 4. + + Returns: + Return a list-of-list of token IDs. ans[i] is the decoding results + for the i-th utterance. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + assert rnnlm is not None + lm_scale = rnnlm_scale + vocab_size = rnnlm.vocab_size + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + sos_id = sp.piece_to_id("") + eos_id = sp.piece_to_id("") + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + # get initial lm score and lm state by scoring the "sos" token + sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) + init_score, init_states = rnnlm.score_token(sos_token) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + state=init_states, + lm_score=init_score.reshape(-1), + ) + ) + + rnnlm.clean_cache() + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] # get batch + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) + + # for all hyps with a non-blank new token, score it + token_list = [] + hs = [] + cs = [] + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_token = topk_token_indexes[k] + if new_token not in (blank_id, unk_id): + + assert new_token != 0, new_token + token_list.append([new_token]) + hs.append(hyp.state[0]) + cs.append(hyp.state[1]) + # forward RNNLM to get new states and scores + if len(token_list) != 0: + tokens_to_score = ( + torch.tensor(token_list) + .to(torch.int64) + .to(device) + .reshape(-1, 1) + ) + + hs = torch.cat(hs, dim=1).to(device) + cs = torch.cat(cs, dim=1).to(device) + scores, lm_states = rnnlm.score_token(tokens_to_score, (hs, cs)) + + count = 0 # index, used to locate score and lm states + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + ys = hyp.ys[:] + + lm_score = hyp.lm_score + state = hyp.state + + hyp_log_prob = topk_log_probs[k] # get score of current hyp + new_token = topk_token_indexes[k] + if new_token not in (blank_id, unk_id): + + ys.append(new_token) + hyp_log_prob += ( + lm_score[new_token] * lm_scale + ) # add the lm score + + lm_score = scores[count] + state = ( + lm_states[0][:, count, :].unsqueeze(1), + lm_states[1][:, count, :].unsqueeze(1), + ) + count += 1 + + new_hyp = Hypothesis( + ys=ys, log_prob=hyp_log_prob, state=state, lm_score=lm_score + ) + B[i].add(new_hyp) + + B = B + finalized_B + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + + sorted_ans = [h.ys[context_size:] for h in best_hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans From bdaeaae1ae33021fadd1f8e9b6bbe45f1bf5ae00 Mon Sep 17 00:00:00 2001 From: marcoyang Date: Fri, 4 Nov 2022 11:25:10 +0800 Subject: [PATCH 019/120] resolve conflicts --- .../beam_search.py | 27 +++++++++++++++---- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 480146a59..b1fd75204 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -16,7 +16,7 @@ # limitations under the License. import warnings -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Dict, List, Optional, Tuple, Union import k2 @@ -729,7 +729,7 @@ class Hypothesis: # timestamp[i] is the frame index after subsampling # on which ys[i] is decoded - timestamp: List[int] = None + timestamp: List[int] = field(default_factory=list) # the lm score for next token given the current ys lm_score: Optional[torch.Tensor] = None @@ -1870,6 +1870,7 @@ def modified_beam_search_rnnlm_shallow_fusion( rnnlm: RnnLmModel, rnnlm_scale: float, beam: int = 4, + return_timestamps: bool = False, ) -> List[List[int]]: """Modified_beam_search + RNNLM shallow fusion @@ -1930,6 +1931,7 @@ def modified_beam_search_rnnlm_shallow_fusion( log_prob=torch.zeros(1, dtype=torch.float32, device=device), state=init_states, lm_score=init_score.reshape(-1), + timestamp=[], ) ) @@ -1938,7 +1940,7 @@ def modified_beam_search_rnnlm_shallow_fusion( offset = 0 finalized_B = [] - for batch_size in batch_size_list: + for (t, batch_size) in enumerate(batch_size_list): start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] # get batch @@ -2060,9 +2062,11 @@ def modified_beam_search_rnnlm_shallow_fusion( hyp_log_prob = topk_log_probs[k] # get score of current hyp new_token = topk_token_indexes[k] + new_timestamp = hyp.timestamp[:] if new_token not in (blank_id, unk_id): ys.append(new_token) + new_timestamp.append(t) hyp_log_prob += ( lm_score[new_token] * lm_scale ) # add the lm score @@ -2075,7 +2079,11 @@ def modified_beam_search_rnnlm_shallow_fusion( count += 1 new_hyp = Hypothesis( - ys=ys, log_prob=hyp_log_prob, state=state, lm_score=lm_score + ys=ys, + log_prob=hyp_log_prob, + state=state, + lm_score=lm_score, + timestampe=new_timestamp, ) B[i].add(new_hyp) @@ -2083,9 +2091,18 @@ def modified_beam_search_rnnlm_shallow_fusion( best_hyps = [b.get_most_probable(length_norm=True) for b in B] sorted_ans = [h.ys[context_size:] for h in best_hyps] + sorted_timestamps = [h.timestamp for h in best_hyps] ans = [] + ans_timestamps = [] unsorted_indices = packed_encoder_out.unsorted_indices.tolist() for i in range(N): ans.append(sorted_ans[unsorted_indices[i]]) + ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) - return ans + if not return_timestamps: + return ans + else: + return DecodingResults( + tokens=ans, + timestamps=ans_timestamps, + ) From b3c61b85e3340f5d5f68c3f09659e5a05d052665 Mon Sep 17 00:00:00 2001 From: marcoyang Date: Fri, 4 Nov 2022 11:32:09 +0800 Subject: [PATCH 020/120] minor fixes --- egs/librispeech/ASR/pruned_transducer_stateless5/decode.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 2711c4cc9..96aa66c29 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -636,8 +636,8 @@ def decode_dataset( The word symbol table. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or LG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + only when --decoding_method is fast_beam_search, fast_beam_search_LG, + fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. From 2271c3d396139e515f94f1f681f5086b27442937 Mon Sep 17 00:00:00 2001 From: marcoyang Date: Fri, 4 Nov 2022 12:26:38 +0800 Subject: [PATCH 021/120] remove testing file --- egs/librispeech/ASR/beam_search.py | 1821 ---------------------------- 1 file changed, 1821 deletions(-) delete mode 100644 egs/librispeech/ASR/beam_search.py diff --git a/egs/librispeech/ASR/beam_search.py b/egs/librispeech/ASR/beam_search.py deleted file mode 100644 index cc5c1c09d..000000000 --- a/egs/librispeech/ASR/beam_search.py +++ /dev/null @@ -1,1821 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang -# Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings -from dataclasses import dataclass -from typing import Dict, List, Optional, Union - -import k2 -import sentencepiece as spm -import torch -from model import Transducer - -from icefall import NgramLm, NgramLmStateCost -from icefall.decode import Nbest, one_best_decoding -from icefall.rnn_lm.model import RnnLmModel -from icefall.utils import ( - DecodingResults, - add_eos, - add_sos, - get_texts, - get_texts_with_timestamp, -) - - -def fast_beam_search_one_best( - model: Transducer, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """It limits the maximum number of symbols per frame to 1. - - A lattice is first obtained using fast beam search, and then - the shortest path within the lattice is used as the final output. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - temperature=temperature, - ) - - best_path = one_best_decoding(lattice) - - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) - - -def fast_beam_search_nbest_LG( - model: Transducer, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - num_paths: int, - nbest_scale: float = 0.5, - use_double_scores: bool = True, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """It limits the maximum number of symbols per frame to 1. - - The process to get the results is: - - (1) Use fast beam search to get a lattice - - (2) Select `num_paths` paths from the lattice using k2.random_paths() - - (3) Unique the selected paths - - (4) Intersect the selected paths with the lattice and compute the - shortest path from the intersection result - - (5) The path with the largest score is used as the decoding output. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - num_paths: - Number of paths to extract from the decoded lattice. - nbest_scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - use_double_scores: - True to use double precision for computation. False to use - single precision. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - temperature=temperature, - ) - - nbest = Nbest.from_lattice( - lattice=lattice, - num_paths=num_paths, - use_double_scores=use_double_scores, - nbest_scale=nbest_scale, - ) - - # The following code is modified from nbest.intersect() - word_fsa = k2.invert(nbest.fsa) - if hasattr(lattice, "aux_labels"): - # delete token IDs as it is not needed - del word_fsa.aux_labels - word_fsa.scores.zero_() - word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa) - path_to_utt_map = nbest.shape.row_ids(1) - - if hasattr(lattice, "aux_labels"): - # lattice has token IDs as labels and word IDs as aux_labels. - # inv_lattice has word IDs as labels and token IDs as aux_labels - inv_lattice = k2.invert(lattice) - inv_lattice = k2.arc_sort(inv_lattice) - else: - inv_lattice = k2.arc_sort(lattice) - - if inv_lattice.shape[0] == 1: - path_lattice = k2.intersect_device( - inv_lattice, - word_fsa_with_epsilon_loops, - b_to_a_map=torch.zeros_like(path_to_utt_map), - sorted_match_a=True, - ) - else: - path_lattice = k2.intersect_device( - inv_lattice, - word_fsa_with_epsilon_loops, - b_to_a_map=path_to_utt_map, - sorted_match_a=True, - ) - - # path_lattice has word IDs as labels and token IDs as aux_labels - path_lattice = k2.top_sort(k2.connect(path_lattice)) - tot_scores = path_lattice.get_tot_scores( - use_double_scores=use_double_scores, - log_semiring=True, # Note: we always use True - ) - # See https://github.com/k2-fsa/icefall/pull/420 for why - # we always use log_semiring=True - - ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) - best_hyp_indexes = ragged_tot_scores.argmax() - best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes) - - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) - - -def fast_beam_search_nbest( - model: Transducer, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - num_paths: int, - nbest_scale: float = 0.5, - use_double_scores: bool = True, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """It limits the maximum number of symbols per frame to 1. - - The process to get the results is: - - (1) Use fast beam search to get a lattice - - (2) Select `num_paths` paths from the lattice using k2.random_paths() - - (3) Unique the selected paths - - (4) Intersect the selected paths with the lattice and compute the - shortest path from the intersection result - - (5) The path with the largest score is used as the decoding output. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - num_paths: - Number of paths to extract from the decoded lattice. - nbest_scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - use_double_scores: - True to use double precision for computation. False to use - single precision. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - temperature=temperature, - ) - - nbest = Nbest.from_lattice( - lattice=lattice, - num_paths=num_paths, - use_double_scores=use_double_scores, - nbest_scale=nbest_scale, - ) - - # at this point, nbest.fsa.scores are all zeros. - - nbest = nbest.intersect(lattice) - # Now nbest.fsa.scores contains acoustic scores - - max_indexes = nbest.tot_scores().argmax() - - best_path = k2.index_fsa(nbest.fsa, max_indexes) - - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) - - -def fast_beam_search_nbest_oracle( - model: Transducer, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - num_paths: int, - ref_texts: List[List[int]], - use_double_scores: bool = True, - nbest_scale: float = 0.5, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """It limits the maximum number of symbols per frame to 1. - - A lattice is first obtained using fast beam search, and then - we select `num_paths` linear paths from the lattice. The path - that has the minimum edit distance with the given reference transcript - is used as the output. - - This is the best result we can achieve for any nbest based rescoring - methods. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - num_paths: - Number of paths to extract from the decoded lattice. - ref_texts: - A list-of-list of integers containing the reference transcripts. - If the decoding_graph is a trivial_graph, the integer ID is the - BPE token ID. - use_double_scores: - True to use double precision for computation. False to use - single precision. - nbest_scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - temperature=temperature, - ) - - nbest = Nbest.from_lattice( - lattice=lattice, - num_paths=num_paths, - use_double_scores=use_double_scores, - nbest_scale=nbest_scale, - ) - - hyps = nbest.build_levenshtein_graphs() - refs = k2.levenshtein_graph(ref_texts, device=hyps.device) - - levenshtein_alignment = k2.levenshtein_alignment( - refs=refs, - hyps=hyps, - hyp_to_ref_map=nbest.shape.row_ids(1), - sorted_match_ref=True, - ) - - tot_scores = levenshtein_alignment.get_tot_scores( - use_double_scores=False, log_semiring=False - ) - ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) - - max_indexes = ragged_tot_scores.argmax() - - best_path = k2.index_fsa(nbest.fsa, max_indexes) - - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) - - -def fast_beam_search( - model: Transducer, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - temperature: float = 1.0, -) -> k2.Fsa: - """It limits the maximum number of symbols per frame to 1. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - temperature: - Softmax temperature. - Returns: - Return an FsaVec with axes [utt][state][arc] containing the decoded - lattice. Note: When the input graph is a TrivialGraph, the returned - lattice is actually an acceptor. - """ - assert encoder_out.ndim == 3 - - context_size = model.decoder.context_size - vocab_size = model.decoder.vocab_size - - B, T, C = encoder_out.shape - - config = k2.RnntDecodingConfig( - vocab_size=vocab_size, - decoder_history_len=context_size, - beam=beam, - max_contexts=max_contexts, - max_states=max_states, - ) - individual_streams = [] - for i in range(B): - individual_streams.append(k2.RnntDecodingStream(decoding_graph)) - decoding_streams = k2.RnntDecodingStreams(individual_streams, config) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - for t in range(T): - # shape is a RaggedShape of shape (B, context) - # contexts is a Tensor of shape (shape.NumElements(), context_size) - shape, contexts = decoding_streams.get_contexts() - # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 - contexts = contexts.to(torch.int64) - # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) - decoder_out = model.decoder(contexts, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # current_encoder_out is of shape - # (shape.NumElements(), 1, joiner_dim) - # fmt: off - current_encoder_out = torch.index_select( - encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) - ) - # fmt: on - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - logits = logits.squeeze(1).squeeze(1) - log_probs = (logits / temperature).log_softmax(dim=-1) - decoding_streams.advance(log_probs) - decoding_streams.terminate_and_flush_to_streams() - lattice = decoding_streams.format_output(encoder_out_lens.tolist()) - - return lattice - - -def greedy_search( - model: Transducer, - encoder_out: torch.Tensor, - max_sym_per_frame: int, - return_timestamps: bool = False, -) -> Union[List[int], DecodingResults]: - """Greedy search for a single utterance. - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - max_sym_per_frame: - Maximum number of symbols per frame. If it is set to 0, the WER - would be 100%. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - unk_id = getattr(model, "unk_id", blank_id) - - device = next(model.parameters()).device - - decoder_input = torch.tensor( - [blank_id] * context_size, device=device, dtype=torch.int64 - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - T = encoder_out.size(1) - t = 0 - hyp = [blank_id] * context_size - - # timestamp[i] is the frame index after subsampling - # on which hyp[i] is decoded - timestamp = [] - - # Maximum symbols per utterance. - max_sym_per_utt = 1000 - - # symbols per frame - sym_per_frame = 0 - - # symbols per utterance decoded so far - sym_per_utt = 0 - - while t < T and sym_per_utt < max_sym_per_utt: - if sym_per_frame >= max_sym_per_frame: - sym_per_frame = 0 - t += 1 - continue - - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) - # fmt: on - logits = model.joiner( - current_encoder_out, decoder_out.unsqueeze(1), project_input=False - ) - # logits is (1, 1, 1, vocab_size) - - y = logits.argmax().item() - if y not in (blank_id, unk_id): - hyp.append(y) - timestamp.append(t) - decoder_input = torch.tensor( - [hyp[-context_size:]], device=device - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - sym_per_utt += 1 - sym_per_frame += 1 - else: - sym_per_frame = 0 - t += 1 - hyp = hyp[context_size:] # remove blanks - - if not return_timestamps: - return hyp - else: - return DecodingResults( - tokens=[hyp], - timestamps=[timestamp], - ) - - -def greedy_search_batch( - model: Transducer, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C), where N >= 1. - encoder_out_lens: - A 1-D tensor of shape (N,), containing number of valid frames in - encoder_out before padding. - Returns: - Return a list-of-list of token IDs containing the decoded results. - len(ans) equals to encoder_out.size(0). - """ - assert encoder_out.ndim == 3 - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - device = next(model.parameters()).device - - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - hyps = [[blank_id] * context_size for _ in range(N)] - - decoder_input = torch.tensor( - hyps, - device=device, - dtype=torch.int64, - ) # (N, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out: (N, 1, decoder_out_dim) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - for batch_size in batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) - offset = end - - decoder_out = decoder_out[:batch_size] - - logits = model.joiner( - current_encoder_out, decoder_out.unsqueeze(1), project_input=False - ) - # logits'shape (batch_size, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v not in (blank_id, unk_id): - hyps[i].append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = [h[-context_size:] for h in hyps[:batch_size]] - decoder_input = torch.tensor( - decoder_input, - device=device, - dtype=torch.int64, - ) - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - sorted_ans = [h[context_size:] for h in hyps] - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans - - -@dataclass -class Hypothesis: - # The predicted tokens so far. - # Newly predicted tokens are appended to `ys`. - ys: List[int] - - # The log prob of ys. - # It contains only one entry. - log_prob: torch.Tensor - state: Optional = None - - lm_score: Optional = None - - @property - def key(self) -> str: - """Return a string representation of self.ys""" - return "_".join(map(str, self.ys)) - - -class HypothesisList(object): - def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None: - """ - Args: - data: - A dict of Hypotheses. Its key is its `value.key`. - """ - if data is None: - self._data = {} - else: - self._data = data - - @property - def data(self) -> Dict[str, Hypothesis]: - return self._data - - def add(self, hyp: Hypothesis) -> None: - """Add a Hypothesis to `self`. - - If `hyp` already exists in `self`, its probability is updated using - `log-sum-exp` with the existed one. - - Args: - hyp: - The hypothesis to be added. - """ - key = hyp.key - if key in self: - old_hyp = self._data[key] # shallow copy - torch.logaddexp( - old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob - ) - else: - self._data[key] = hyp - - def get_most_probable(self, length_norm: bool = False) -> Hypothesis: - """Get the most probable hypothesis, i.e., the one with - the largest `log_prob`. - - Args: - length_norm: - If True, the `log_prob` of a hypothesis is normalized by the - number of tokens in it. - Returns: - Return the hypothesis that has the largest `log_prob`. - """ - if length_norm: - return max( - self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) - ) - else: - return max(self._data.values(), key=lambda hyp: hyp.log_prob) - - def remove(self, hyp: Hypothesis) -> None: - """Remove a given hypothesis. - - Caution: - `self` is modified **in-place**. - - Args: - hyp: - The hypothesis to be removed from `self`. - Note: It must be contained in `self`. Otherwise, - an exception is raised. - """ - key = hyp.key - assert key in self, f"{key} does not exist" - del self._data[key] - - def filter(self, threshold: torch.Tensor) -> "HypothesisList": - """Remove all Hypotheses whose log_prob is less than threshold. - - Caution: - `self` is not modified. Instead, a new HypothesisList is returned. - - Returns: - Return a new HypothesisList containing all hypotheses from `self` - with `log_prob` being greater than the given `threshold`. - """ - ans = HypothesisList() - for _, hyp in self._data.items(): - if hyp.log_prob > threshold: - ans.add(hyp) # shallow copy - return ans - - def topk(self, k: int) -> "HypothesisList": - """Return the top-k hypothesis.""" - hyps = list(self._data.items()) - - hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] - - ans = HypothesisList(dict(hyps)) - return ans - - def __contains__(self, key: str): - return key in self._data - - def __iter__(self): - return iter(self._data.values()) - - def __len__(self) -> int: - return len(self._data) - - def __str__(self) -> str: - s = [] - for key in self: - s.append(key) - return ", ".join(s) - - -def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: - """Return a ragged shape with axes [utt][num_hyps]. - - Args: - hyps: - len(hyps) == batch_size. It contains the current hypothesis for - each utterance in the batch. - Returns: - Return a ragged shape with 2 axes [utt][num_hyps]. Note that - the shape is on CPU. - """ - num_hyps = [len(h) for h in hyps] - - # torch.cumsum() is inclusive sum, so we put a 0 at the beginning - # to get exclusive sum later. - num_hyps.insert(0, 0) - - num_hyps = torch.tensor(num_hyps) - row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32) - ans = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=row_splits[-1].item() - ) - return ans - - -def modified_beam_search( - model: Transducer, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: int = 4, - temperature: float = 1.0, -) -> List[List[int]]: - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. - - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C). - encoder_out_lens: - A 1-D tensor of shape (N,), containing number of valid frames in - encoder_out before padding. - beam: - Number of active paths during the beam search. - temperature: - Softmax temperature. - Returns: - Return a list-of-list of token IDs. ans[i] is the decoding results - for the i-th utterance. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[blank_id] * context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - ) - ) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - for batch_size in batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - log_probs = (logits / temperature).log_softmax( - dim=-1 - ) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) - - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_ys = hyp.ys[:] - new_token = topk_token_indexes[k] - if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - - new_log_prob = topk_log_probs[k] - new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) - B[i].add(new_hyp) - - B = B + finalized_B - best_hyps = [b.get_most_probable(length_norm=True) for b in B] - - sorted_ans = [h.ys[context_size:] for h in best_hyps] - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans - - -def _deprecated_modified_beam_search( - model: Transducer, - encoder_out: torch.Tensor, - beam: int = 4, -) -> List[int]: - """It limits the maximum number of symbols per frame to 1. - - It decodes only one utterance at a time. We keep it only for reference. - The function :func:`modified_beam_search` should be preferred as it - supports batch decoding. - - - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - beam: - Beam size. - Returns: - Return the decoded result. - """ - - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - - device = next(model.parameters()).device - - T = encoder_out.size(1) - - B = HypothesisList() - B.add( - Hypothesis( - ys=[blank_id] * context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - ) - ) - encoder_out = model.joiner.encoder_proj(encoder_out) - - for t in range(T): - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) - # current_encoder_out is of shape (1, 1, 1, encoder_out_dim) - # fmt: on - A = list(B) - B = HypothesisList() - - ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A]) - # ys_log_probs is of shape (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyp in A], - device=device, - dtype=torch.int64, - ) - # decoder_input is of shape (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_output is of shape (num_hyps, 1, 1, joiner_dim) - - current_encoder_out = current_encoder_out.expand( - decoder_out.size(0), 1, 1, -1 - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) - # logits is of shape (num_hyps, 1, 1, vocab_size) - logits = logits.squeeze(1).squeeze(1) - - # now logits is of shape (num_hyps, vocab_size) - log_probs = logits.log_softmax(dim=-1) - - log_probs.add_(ys_log_probs) - - log_probs = log_probs.reshape(-1) - topk_log_probs, topk_indexes = log_probs.topk(beam) - - # topk_hyp_indexes are indexes into `A` - topk_hyp_indexes = topk_indexes // logits.size(-1) - topk_token_indexes = topk_indexes % logits.size(-1) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = topk_hyp_indexes.tolist() - topk_token_indexes = topk_token_indexes.tolist() - - for i in range(len(topk_hyp_indexes)): - hyp = A[topk_hyp_indexes[i]] - new_ys = hyp.ys[:] - new_token = topk_token_indexes[i] - if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - new_log_prob = topk_log_probs[i] - new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) - B.add(new_hyp) - - best_hyp = B.get_most_probable(length_norm=True) - ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - - return ys - - -def beam_search( - model: Transducer, - encoder_out: torch.Tensor, - beam: int = 4, - temperature: float = 1.0, -) -> List[int]: - """ - It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf - - espnet/nets/beam_search_transducer.py#L247 is used as a reference. - - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - beam: - Beam size. - temperature: - Softmax temperature. - Returns: - Return the decoded result. - """ - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - - device = next(model.parameters()).device - - decoder_input = torch.tensor( - [blank_id] * context_size, - device=device, - dtype=torch.int64, - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - T = encoder_out.size(1) - t = 0 - - B = HypothesisList() - B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0)) - - max_sym_per_utt = 20000 - - sym_per_utt = 0 - - decoder_cache: Dict[str, torch.Tensor] = {} - - while t < T and sym_per_utt < max_sym_per_utt: - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) - # fmt: on - A = B - B = HypothesisList() - - joint_cache: Dict[str, torch.Tensor] = {} - - # TODO(fangjun): Implement prefix search to update the `log_prob` - # of hypotheses in A - - while True: - y_star = A.get_most_probable() - A.remove(y_star) - - cached_key = y_star.key - - if cached_key not in decoder_cache: - decoder_input = torch.tensor( - [y_star.ys[-context_size:]], - device=device, - dtype=torch.int64, - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - decoder_cache[cached_key] = decoder_out - else: - decoder_out = decoder_cache[cached_key] - - cached_key += f"-t-{t}" - if cached_key not in joint_cache: - logits = model.joiner( - current_encoder_out, - decoder_out.unsqueeze(1), - project_input=False, - ) - - # TODO(fangjun): Scale the blank posterior - log_prob = (logits / temperature).log_softmax(dim=-1) - # log_prob is (1, 1, 1, vocab_size) - log_prob = log_prob.squeeze() - # Now log_prob is (vocab_size,) - joint_cache[cached_key] = log_prob - else: - log_prob = joint_cache[cached_key] - - # First, process the blank symbol - skip_log_prob = log_prob[blank_id] - new_y_star_log_prob = y_star.log_prob + skip_log_prob - - # ys[:] returns a copy of ys - B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) - - # Second, process other non-blank labels - values, indices = log_prob.topk(beam + 1) - for i, v in zip(indices.tolist(), values.tolist()): - if i in (blank_id, unk_id): - continue - new_ys = y_star.ys + [i] - new_log_prob = y_star.log_prob + v - A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob)) - - # Check whether B contains more than "beam" elements more probable - # than the most probable in A - A_most_probable = A.get_most_probable() - - kept_B = B.filter(A_most_probable.log_prob) - - if len(kept_B) >= beam: - B = kept_B.topk(beam) - break - - t += 1 - - best_hyp = B.get_most_probable(length_norm=True) - ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - return ys - - -def fast_beam_search_with_nbest_rescoring( - model: Transducer, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - ngram_lm_scale_list: List[float], - num_paths: int, - G: k2.Fsa, - sp: spm.SentencePieceProcessor, - word_table: k2.SymbolTable, - oov_word: str = "", - use_double_scores: bool = True, - nbest_scale: float = 0.5, - temperature: float = 1.0, -) -> Dict[str, List[List[int]]]: - """It limits the maximum number of symbols per frame to 1. - A lattice is first obtained using fast beam search, num_path are selected - and rescored using a given language model. The shortest path within the - lattice is used as the final output. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - ngram_lm_scale_list: - A list of floats representing LM score scales. - num_paths: - Number of paths to extract from the decoded lattice. - G: - An FsaVec containing only a single FSA. It is an n-gram LM. - sp: - The BPE model. - word_table: - The word symbol table. - oov_word: - OOV words are replaced with this word. - use_double_scores: - True to use double precision for computation. False to use - single precision. - nbest_scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - temperature: - Softmax temperature. - Returns: - Return the decoded result in a dict, where the key has the form - 'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the - ngram LM scale value used during decoding, i.e., 0.1. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - temperature=temperature, - ) - - nbest = Nbest.from_lattice( - lattice=lattice, - num_paths=num_paths, - use_double_scores=use_double_scores, - nbest_scale=nbest_scale, - ) - # at this point, nbest.fsa.scores are all zeros. - - nbest = nbest.intersect(lattice) - # Now nbest.fsa.scores contains acoustic scores - - am_scores = nbest.tot_scores() - - # Now we need to compute the LM scores of each path. - # (1) Get the token IDs of each Path. We assume the decoding_graph - # is an acceptor, i.e., lattice is also an acceptor - tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc] - - tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous()) - tokens = tokens.remove_values_leq(0) # remove -1 and 0 - - token_list: List[List[int]] = tokens.tolist() - word_list: List[List[str]] = sp.decode(token_list) - - assert isinstance(oov_word, str), oov_word - assert oov_word in word_table, oov_word - oov_word_id = word_table[oov_word] - - word_ids_list: List[List[int]] = [] - - for words in word_list: - this_word_ids = [] - for w in words.split(): - if w in word_table: - this_word_ids.append(word_table[w]) - else: - this_word_ids.append(oov_word_id) - word_ids_list.append(this_word_ids) - - word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device) - word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas) - - num_unique_paths = len(word_ids_list) - - b_to_a_map = torch.zeros( - num_unique_paths, - dtype=torch.int32, - device=lattice.device, - ) - - rescored_word_fsas = k2.intersect_device( - a_fsas=G, - b_fsas=word_fsas_with_self_loops, - b_to_a_map=b_to_a_map, - sorted_match_a=True, - ret_arc_maps=False, - ) - - rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas) - rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas)) - ngram_lm_scores = rescored_word_fsas.get_tot_scores( - use_double_scores=True, - log_semiring=False, - ) - - ans: Dict[str, List[List[int]]] = {} - for s in ngram_lm_scale_list: - key = f"ngram_lm_scale_{s}" - tot_scores = am_scores.values + s * ngram_lm_scores - ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) - max_indexes = ragged_tot_scores.argmax() - best_path = k2.index_fsa(nbest.fsa, max_indexes) - hyps = get_texts(best_path) - - ans[key] = hyps - - return ans - - -def fast_beam_search_with_nbest_rnn_rescoring( - model: Transducer, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - ngram_lm_scale_list: List[float], - num_paths: int, - G: k2.Fsa, - sp: spm.SentencePieceProcessor, - word_table: k2.SymbolTable, - rnn_lm_model: torch.nn.Module, - rnn_lm_scale_list: List[float], - oov_word: str = "", - use_double_scores: bool = True, - nbest_scale: float = 0.5, - temperature: float = 1.0, -) -> Dict[str, List[List[int]]]: - """It limits the maximum number of symbols per frame to 1. - A lattice is first obtained using fast beam search, num_path are selected - and rescored using a given language model and a rnn-lm. - The shortest path within the lattice is used as the final output. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - ngram_lm_scale_list: - A list of floats representing LM score scales. - num_paths: - Number of paths to extract from the decoded lattice. - G: - An FsaVec containing only a single FSA. It is an n-gram LM. - sp: - The BPE model. - word_table: - The word symbol table. - rnn_lm_model: - A rnn-lm model used for LM rescoring - rnn_lm_scale_list: - A list of floats representing RNN score scales. - oov_word: - OOV words are replaced with this word. - use_double_scores: - True to use double precision for computation. False to use - single precision. - nbest_scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - temperature: - Softmax temperature. - Returns: - Return the decoded result in a dict, where the key has the form - 'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the - ngram LM scale value used during decoding, i.e., 0.1. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - temperature=temperature, - ) - - nbest = Nbest.from_lattice( - lattice=lattice, - num_paths=num_paths, - use_double_scores=use_double_scores, - nbest_scale=nbest_scale, - ) - # at this point, nbest.fsa.scores are all zeros. - - nbest = nbest.intersect(lattice) - # Now nbest.fsa.scores contains acoustic scores - - am_scores = nbest.tot_scores() - - # Now we need to compute the LM scores of each path. - # (1) Get the token IDs of each Path. We assume the decoding_graph - # is an acceptor, i.e., lattice is also an acceptor - tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc] - - tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous()) - tokens = tokens.remove_values_leq(0) # remove -1 and 0 - - token_list: List[List[int]] = tokens.tolist() - word_list: List[List[str]] = sp.decode(token_list) - - assert isinstance(oov_word, str), oov_word - assert oov_word in word_table, oov_word - oov_word_id = word_table[oov_word] - - word_ids_list: List[List[int]] = [] - - for words in word_list: - this_word_ids = [] - for w in words.split(): - if w in word_table: - this_word_ids.append(word_table[w]) - else: - this_word_ids.append(oov_word_id) - word_ids_list.append(this_word_ids) - - word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device) - word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas) - - num_unique_paths = len(word_ids_list) - - b_to_a_map = torch.zeros( - num_unique_paths, - dtype=torch.int32, - device=lattice.device, - ) - - rescored_word_fsas = k2.intersect_device( - a_fsas=G, - b_fsas=word_fsas_with_self_loops, - b_to_a_map=b_to_a_map, - sorted_match_a=True, - ret_arc_maps=False, - ) - - rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas) - rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas)) - ngram_lm_scores = rescored_word_fsas.get_tot_scores( - use_double_scores=True, - log_semiring=False, - ) - - # Now RNN-LM - blank_id = model.decoder.blank_id - sos_id = sp.piece_to_id("sos_id") - eos_id = sp.piece_to_id("eos_id") - - sos_tokens = add_sos(tokens, sos_id) - tokens_eos = add_eos(tokens, eos_id) - sos_tokens_row_splits = sos_tokens.shape.row_splits(1) - sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1] - - x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id) - y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id) - - x_tokens = x_tokens.to(torch.int64) - y_tokens = y_tokens.to(torch.int64) - sentence_lengths = sentence_lengths.to(torch.int64) - - rnn_lm_nll = rnn_lm_model(x=x_tokens, y=y_tokens, lengths=sentence_lengths) - assert rnn_lm_nll.ndim == 2 - assert rnn_lm_nll.shape[0] == len(token_list) - rnn_lm_scores = -1 * rnn_lm_nll.sum(dim=1) - - ans: Dict[str, List[List[int]]] = {} - for n_scale in ngram_lm_scale_list: - for rnn_scale in rnn_lm_scale_list: - key = f"ngram_lm_scale_{n_scale}_rnn_lm_scale_{rnn_scale}" - tot_scores = ( - am_scores.values - + n_scale * ngram_lm_scores - + rnn_scale * rnn_lm_scores - ) - ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) - max_indexes = ragged_tot_scores.argmax() - best_path = k2.index_fsa(nbest.fsa, max_indexes) - hyps = get_texts(best_path) - - ans[key] = hyps - - return ans - - -def modified_beam_search_rnnlm_shallow_fusion( - model: Transducer, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - sp: spm.SentencePieceProcessor, - rnnlm: RnnLmModel, - rnnlm_scale: float, - beam: int = 4, -) -> List[List[int]]: - """Modified_beam_search + RNNLM shallow fusion - - Args: - model (Transducer): - The transducer model - encoder_out (torch.Tensor): - Encoder output in (N,T,C) - encoder_out_lens (torch.Tensor): - A 1-D tensor of shape (N,), containing the number of - valid frames in encoder_out before padding. - sp: - Sentence piece generator. - rnnlm (RnnLmModel): - RNNLM - rnnlm_scale (float): - scale of RNNLM in shallow fusion - beam (int, optional): - Beam size. Defaults to 4. - - Returns: - Return a list-of-list of token IDs. ans[i] is the decoding results - for the i-th utterance. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - assert rnnlm is not None - lm_scale = rnnlm_scale - vocab_size = rnnlm.vocab_size - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = model.decoder.blank_id - sos_id = sp.piece_to_id("") - eos_id = sp.piece_to_id("") - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - # get initial lm score and lm state by scoring the "sos" token - sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) - init_score, init_states = rnnlm.score_token(sos_token) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[blank_id] * context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - state=init_states, - lm_score=init_score.reshape(-1), - ) - ) - - rnnlm.clean_cache() - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - for batch_size in batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] # get batch - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] - ) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) - - # for all hyps with a non-blank new token, score it - token_list = [] - hs = [] - cs = [] - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_token = topk_token_indexes[k] - if new_token not in (blank_id, unk_id): - - assert new_token != 0, new_token - token_list.append([new_token]) - hs.append(hyp.state[0]) - cs.append(hyp.state[1]) - # forward RNNLM to get new states and scores - if len(token_list) != 0: - tokens_to_score = ( - torch.tensor(token_list) - .to(torch.int64) - .to(device) - .reshape(-1, 1) - ) - - hs = torch.cat(hs, dim=1).to(device) - cs = torch.cat(cs, dim=1).to(device) - scores, lm_states = rnnlm.score_token(tokens_to_score, (hs, cs)) - - count = 0 # index, used to locate score and lm states - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - ys = hyp.ys[:] - - lm_score = hyp.lm_score - state = hyp.state - - hyp_log_prob = topk_log_probs[k] # get score of current hyp - new_token = topk_token_indexes[k] - if new_token not in (blank_id, unk_id): - - ys.append(new_token) - hyp_log_prob += ( - lm_score[new_token] * lm_scale - ) # add the lm score - - lm_score = scores[count] - state = ( - lm_states[0][:, count, :].unsqueeze(1), - lm_states[1][:, count, :].unsqueeze(1), - ) - count += 1 - - new_hyp = Hypothesis( - ys=ys, log_prob=hyp_log_prob, state=state, lm_score=lm_score - ) - B[i].add(new_hyp) - - B = B + finalized_B - best_hyps = [b.get_most_probable(length_norm=True) for b in B] - - sorted_ans = [h.ys[context_size:] for h in best_hyps] - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans From 3600ce1b5fc528683f7f5eed029bbcce04b05a4a Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Fri, 4 Nov 2022 16:10:09 +0800 Subject: [PATCH 022/120] Apply delay penalty on transducer (#654) * add delay penalty * fix CI * fix CI --- .github/workflows/test.yml | 3 +++ .../ASR/lstm_transducer_stateless/model.py | 8 ++++++++ .../ASR/lstm_transducer_stateless/train.py | 11 +++++++++++ .../ASR/lstm_transducer_stateless2/model.py | 8 ++++++++ .../ASR/lstm_transducer_stateless2/train.py | 11 +++++++++++ .../ASR/lstm_transducer_stateless3/train.py | 11 +++++++++++ .../ASR/pruned_transducer_stateless/test_model.py | 5 ----- .../ASR/pruned_transducer_stateless2/model.py | 9 +++++++++ .../ASR/pruned_transducer_stateless2/train.py | 11 +++++++++++ .../ASR/pruned_transducer_stateless3/model.py | 8 ++++++++ .../ASR/pruned_transducer_stateless3/train.py | 11 +++++++++++ .../ASR/pruned_transducer_stateless4/train.py | 11 +++++++++++ .../ASR/pruned_transducer_stateless5/train.py | 11 +++++++++++ 13 files changed, 113 insertions(+), 5 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1583926ec..04fc0265f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -79,6 +79,9 @@ jobs: pip uninstall -y protobuf pip install --no-binary protobuf protobuf + pip install kaldifst + pip install onnxruntime + pip install -r requirements.txt - name: Install graphviz diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/model.py b/egs/librispeech/ASR/lstm_transducer_stateless/model.py index efbc88a55..d71132b4a 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/model.py @@ -81,6 +81,7 @@ class Transducer(nn.Module): lm_scale: float = 0.0, warmup: float = 1.0, reduction: str = "sum", + delay_penalty: float = 0.0, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -108,6 +109,11 @@ class Transducer(nn.Module): "sum" to sum the losses over all utterances in the batch. "none" to return the loss in a 1-D tensor for each utterance in the batch. + delay_penalty: + A constant value used to penalize symbol delay, to encourage + streaming models to emit symbols earlier. + See https://github.com/k2-fsa/k2/issues/955 and + https://arxiv.org/pdf/2211.00490.pdf for more details. Returns: Return the transducer loss. @@ -164,6 +170,7 @@ class Transducer(nn.Module): am_only_scale=am_scale, boundary=boundary, reduction=reduction, + delay_penalty=delay_penalty, return_grad=True, ) @@ -196,6 +203,7 @@ class Transducer(nn.Module): ranges=ranges, termination_symbol=blank_id, boundary=boundary, + delay_penalty=delay_penalty, reduction=reduction, ) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/train.py b/egs/librispeech/ASR/lstm_transducer_stateless/train.py index a50686df9..fbb4e7224 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/train.py @@ -318,6 +318,16 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--delay-penalty", + type=float, + default=0.0, + help="""A constant value used to penalize symbol delay, + to encourage streaming models to emit symbols earlier. + See https://github.com/k2-fsa/k2/issues/955 and + https://arxiv.org/pdf/2211.00490.pdf for more details.""", + ) + add_model_arguments(parser) return parser @@ -611,6 +621,7 @@ def compute_loss( lm_scale=params.lm_scale, warmup=warmup, reduction="none", + delay_penalty=params.delay_penalty if warmup >= 2.0 else 0, ) simple_loss_is_finite = torch.isfinite(simple_loss) pruned_loss_is_finite = torch.isfinite(pruned_loss) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py index b0fb6ab89..fadeb4ac2 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py @@ -106,6 +106,7 @@ class Transducer(nn.Module): lm_scale: float = 0.0, warmup: float = 1.0, reduction: str = "sum", + delay_penalty: float = 0.0, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -136,6 +137,11 @@ class Transducer(nn.Module): "sum" to sum the losses over all utterances in the batch. "none" to return the loss in a 1-D tensor for each utterance in the batch. + delay_penalty: + A constant value used to penalize symbol delay, to encourage + streaming models to emit symbols earlier. + See https://github.com/k2-fsa/k2/issues/955 and + https://arxiv.org/pdf/2211.00490.pdf for more details. Returns: Return the transducer loss. @@ -203,6 +209,7 @@ class Transducer(nn.Module): am_only_scale=am_scale, boundary=boundary, reduction=reduction, + delay_penalty=delay_penalty, return_grad=True, ) @@ -235,6 +242,7 @@ class Transducer(nn.Module): ranges=ranges, termination_symbol=blank_id, boundary=boundary, + delay_penalty=delay_penalty, reduction=reduction, ) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py index 9eed2dfcb..ac6bf7e04 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py @@ -341,6 +341,16 @@ def get_parser(): help="The probability to select a batch from the GigaSpeech dataset", ) + parser.add_argument( + "--delay-penalty", + type=float, + default=0.0, + help="""A constant value used to penalize symbol delay, + to encourage streaming models to emit symbols earlier. + See https://github.com/k2-fsa/k2/issues/955 and + https://arxiv.org/pdf/2211.00490.pdf for more details.""", + ) + add_model_arguments(parser) return parser @@ -665,6 +675,7 @@ def compute_loss( lm_scale=params.lm_scale, warmup=warmup, reduction="none", + delay_penalty=params.delay_penalty if warmup >= 2.0 else 0, ) simple_loss_is_finite = torch.isfinite(simple_loss) pruned_loss_is_finite = torch.isfinite(pruned_loss) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py index fa50576d8..f2aa84625 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py @@ -328,6 +328,16 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--delay-penalty", + type=float, + default=0.0, + help="""A constant value used to penalize symbol delay, + to encourage streaming models to emit symbols earlier. + See https://github.com/k2-fsa/k2/issues/955 and + https://arxiv.org/pdf/2211.00490.pdf for more details.""", + ) + add_model_arguments(parser) return parser @@ -623,6 +633,7 @@ def compute_loss( lm_scale=params.lm_scale, warmup=warmup, reduction="none", + delay_penalty=params.delay_penalty if warmup >= 2.0 else 0, ) simple_loss_is_finite = torch.isfinite(simple_loss) pruned_loss_is_finite = torch.isfinite(pruned_loss) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless/test_model.py index 1858d6bf0..fc82d8c69 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/test_model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/test_model.py @@ -23,7 +23,6 @@ To run this file, do: python ./pruned_transducer_stateless/test_model.py """ -import torch from train import get_params, get_transducer_model @@ -43,8 +42,6 @@ def test_model(): num_param = sum([p.numel() for p in model.parameters()]) print(f"Number of model parameters: {num_param}") - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - torch.jit.script(model) def test_model_streaming(): @@ -63,8 +60,6 @@ def test_model_streaming(): num_param = sum([p.numel() for p in model.parameters()]) print(f"Number of model parameters: {num_param}") - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - torch.jit.script(model) def main(): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index ba7616c61..417c391d9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -81,6 +81,7 @@ class Transducer(nn.Module): lm_scale: float = 0.0, warmup: float = 1.0, reduction: str = "sum", + delay_penalty: float = 0.0, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -108,6 +109,12 @@ class Transducer(nn.Module): "sum" to sum the losses over all utterances in the batch. "none" to return the loss in a 1-D tensor for each utterance in the batch. + delay_penalty: + A constant value used to penalize symbol delay, to encourage + streaming models to emit symbols earlier. + See https://github.com/k2-fsa/k2/issues/955 and + https://arxiv.org/pdf/2211.00490.pdf for more details. + Returns: Returns: Return the transducer loss. @@ -164,6 +171,7 @@ class Transducer(nn.Module): am_only_scale=am_scale, boundary=boundary, reduction=reduction, + delay_penalty=delay_penalty, return_grad=True, ) @@ -196,6 +204,7 @@ class Transducer(nn.Module): ranges=ranges, termination_symbol=blank_id, boundary=boundary, + delay_penalty=delay_penalty, reduction=reduction, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 5c2f67534..7ce2ca779 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -317,6 +317,16 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--delay-penalty", + type=float, + default=0.0, + help="""A constant value used to penalize symbol delay, + to encourage streaming models to emit symbols earlier. + See https://github.com/k2-fsa/k2/issues/955 and + https://arxiv.org/pdf/2211.00490.pdf for more details.""", + ) + add_model_arguments(parser) return parser @@ -607,6 +617,7 @@ def compute_loss( lm_scale=params.lm_scale, warmup=warmup, reduction="none", + delay_penalty=params.delay_penalty if warmup >= 2.0 else 0, ) simple_loss_is_finite = torch.isfinite(simple_loss) pruned_loss_is_finite = torch.isfinite(pruned_loss) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py index 0d5f7cc6d..7852f84e9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py @@ -106,6 +106,7 @@ class Transducer(nn.Module): lm_scale: float = 0.0, warmup: float = 1.0, reduction: str = "sum", + delay_penalty: float = 0.0, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -136,6 +137,11 @@ class Transducer(nn.Module): "sum" to sum the losses over all utterances in the batch. "none" to return the loss in a 1-D tensor for each utterance in the batch. + delay_penalty: + A constant value used to penalize symbol delay, to encourage + streaming models to emit symbols earlier. + See https://github.com/k2-fsa/k2/issues/955 and + https://arxiv.org/pdf/2211.00490.pdf for more details. Returns: Return the transducer loss. @@ -203,6 +209,7 @@ class Transducer(nn.Module): am_only_scale=am_scale, boundary=boundary, reduction=reduction, + delay_penalty=delay_penalty, return_grad=True, ) @@ -235,6 +242,7 @@ class Transducer(nn.Module): ranges=ranges, termination_symbol=blank_id, boundary=boundary, + delay_penalty=delay_penalty, reduction=reduction, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index a74975caf..6cc34f18a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -328,6 +328,16 @@ def get_parser(): help="The probability to select a batch from the GigaSpeech dataset", ) + parser.add_argument( + "--delay-penalty", + type=float, + default=0.0, + help="""A constant value used to penalize symbol delay, + to encourage streaming models to emit symbols earlier. + See https://github.com/k2-fsa/k2/issues/955 and + https://arxiv.org/pdf/2211.00490.pdf for more details.""", + ) + add_model_arguments(parser) return parser @@ -645,6 +655,7 @@ def compute_loss( lm_scale=params.lm_scale, warmup=warmup, reduction="none", + delay_penalty=params.delay_penalty if warmup >= 2.0 else 0, ) simple_loss_is_finite = torch.isfinite(simple_loss) pruned_loss_is_finite = torch.isfinite(pruned_loss) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index 4c55fd609..57548270d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -335,6 +335,16 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--delay-penalty", + type=float, + default=0.0, + help="""A constant value used to penalize symbol delay, + to encourage streaming models to emit symbols earlier. + See https://github.com/k2-fsa/k2/issues/955 and + https://arxiv.org/pdf/2211.00490.pdf for more details.""", + ) + add_model_arguments(parser) return parser @@ -638,6 +648,7 @@ def compute_loss( lm_scale=params.lm_scale, warmup=warmup, reduction="none", + delay_penalty=params.delay_penalty if warmup >= 2.0 else 0, ) simple_loss_is_finite = torch.isfinite(simple_loss) pruned_loss_is_finite = torch.isfinite(pruned_loss) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index 1fa668293..b964cd05d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -368,6 +368,16 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--delay-penalty", + type=float, + default=0.0, + help="""A constant value used to penalize symbol delay, + to encourage streaming models to emit symbols earlier. + See https://github.com/k2-fsa/k2/issues/955 and + https://arxiv.org/pdf/2211.00490.pdf for more details.""", + ) + add_model_arguments(parser) return parser @@ -662,6 +672,7 @@ def compute_loss( lm_scale=params.lm_scale, warmup=warmup, reduction="none", + delay_penalty=params.delay_penalty if warmup >= 2.0 else 0, ) simple_loss_is_finite = torch.isfinite(simple_loss) pruned_loss_is_finite = torch.isfinite(pruned_loss) From 32de2766d591d2e1a77c06a40d2861fb1bbcd3ad Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Sat, 5 Nov 2022 22:36:06 +0800 Subject: [PATCH 023/120] Refactor getting timestamps in fsa-based decoding (#660) * refactor getting timestamps for fsa-based decoding * fix doc * fix bug --- .../ASR/lstm_transducer_stateless3/decode.py | 2 +- .../beam_search.py | 10 +-- .../pruned_transducer_stateless4/decode.py | 2 +- icefall/utils.py | 74 +++++++++---------- 4 files changed, 41 insertions(+), 47 deletions(-) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py index 052d027e3..9eee19379 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py @@ -487,7 +487,7 @@ def decode_one_batch( ) tokens.extend(res.tokens) timestamps.extend(res.timestamps) - res = DecodingResults(tokens=tokens, timestamps=timestamps) + res = DecodingResults(hyps=tokens, timestamps=timestamps) hyps, timestamps = parse_hyp_and_timestamp( decoding_method=params.decoding_method, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index b1fd75204..a3fa6cc7c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -598,7 +598,7 @@ def greedy_search( return hyp else: return DecodingResults( - tokens=[hyp], + hyps=[hyp], timestamps=[timestamp], ) @@ -712,7 +712,7 @@ def greedy_search_batch( return ans else: return DecodingResults( - tokens=ans, + hyps=ans, timestamps=ans_timestamps, ) @@ -1049,7 +1049,7 @@ def modified_beam_search( return ans else: return DecodingResults( - tokens=ans, + hyps=ans, timestamps=ans_timestamps, ) @@ -1176,7 +1176,7 @@ def _deprecated_modified_beam_search( if not return_timestamps: return ys else: - return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp]) + return DecodingResults(hyps=[ys], timestamps=[best_hyp.timestamp]) def beam_search( @@ -1336,7 +1336,7 @@ def beam_search( if not return_timestamps: return ys else: - return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp]) + return DecodingResults(hyps=[ys], timestamps=[best_hyp.timestamp]) def fast_beam_search_with_nbest_rescoring( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 7003e4764..4f043e5a6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -531,7 +531,7 @@ def decode_one_batch( ) tokens.extend(res.tokens) timestamps.extend(res.timestamps) - res = DecodingResults(tokens=tokens, timestamps=timestamps) + res = DecodingResults(hyps=tokens, timestamps=timestamps) hyps, timestamps = parse_hyp_and_timestamp( decoding_method=params.decoding_method, diff --git a/icefall/utils.py b/icefall/utils.py index 93dd0b967..e83fccdde 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -251,33 +251,20 @@ def get_texts( @dataclass class DecodingResults: - # Decoded token IDs for each utterance in the batch - tokens: List[List[int]] - # timestamps[i][k] contains the frame number on which tokens[i][k] # is decoded timestamps: List[List[int]] - # hyps[i] is the recognition results, i.e., word IDs + # hyps[i] is the recognition results, i.e., word IDs or token IDs # for the i-th utterance with fast_beam_search_nbest_LG. - hyps: Union[List[List[int]], k2.RaggedTensor] = None - - -def get_tokens_and_timestamps(labels: List[int]) -> Tuple[List[int], List[int]]: - tokens = [] - timestamps = [] - for i, v in enumerate(labels): - if v != 0: - tokens.append(v) - timestamps.append(i) - - return tokens, timestamps + hyps: Union[List[List[int]], k2.RaggedTensor] def get_texts_with_timestamp( best_paths: k2.Fsa, return_ragged: bool = False ) -> DecodingResults: - """Extract the texts (as word IDs) and timestamps from the best-path FSAs. + """Extract the texts (as word IDs) and timestamps (as frame indexes) + from the best-path FSAs. Args: best_paths: A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. @@ -292,11 +279,18 @@ def get_texts_with_timestamp( decoded. """ if isinstance(best_paths.aux_labels, k2.RaggedTensor): + all_aux_shape = ( + best_paths.arcs.shape() + .remove_axis(1) + .compose(best_paths.aux_labels.shape) + ) + all_aux_labels = k2.RaggedTensor( + all_aux_shape, best_paths.aux_labels.values + ) # remove 0's and -1's. aux_labels = best_paths.aux_labels.remove_values_leq(0) # TODO: change arcs.shape() to arcs.shape aux_shape = best_paths.arcs.shape().compose(aux_labels.shape) - # remove the states and arcs axes. aux_shape = aux_shape.remove_axis(1) aux_shape = aux_shape.remove_axis(1) @@ -304,26 +298,26 @@ def get_texts_with_timestamp( else: # remove axis corresponding to states. aux_shape = best_paths.arcs.shape().remove_axis(1) - aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels) + all_aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels) # remove 0's and -1's. - aux_labels = aux_labels.remove_values_leq(0) + aux_labels = all_aux_labels.remove_values_leq(0) assert aux_labels.num_axes == 2 - labels_shape = best_paths.arcs.shape().remove_axis(1) - labels_list = k2.RaggedTensor( - labels_shape, best_paths.labels.contiguous() - ).tolist() - - tokens = [] timestamps = [] - for labels in labels_list: - token, time = get_tokens_and_timestamps(labels[:-1]) - tokens.append(token) - timestamps.append(time) + if isinstance(best_paths.aux_labels, k2.RaggedTensor): + for p in range(all_aux_labels.dim0): + time = [] + for i, arc in enumerate(all_aux_labels[p].tolist()): + if len(arc) == 1 and arc[0] > 0: + time.append(i) + timestamps.append(time) + else: + for labels in all_aux_labels.tolist(): + time = [i for i, v in enumerate(labels) if v > 0] + timestamps.append(time) return DecodingResults( - tokens=tokens, timestamps=timestamps, hyps=aux_labels if return_ragged else aux_labels.tolist(), ) @@ -1399,8 +1393,8 @@ def parse_hyp_and_timestamp( hyps = [] timestamps = [] - N = len(res.tokens) - assert len(res.timestamps) == N + N = len(res.hyps) + assert len(res.timestamps) == N, (len(res.timestamps), N) use_word_table = False if ( decoding_method == "fast_beam_search_nbest_LG" @@ -1410,16 +1404,16 @@ def parse_hyp_and_timestamp( use_word_table = True for i in range(N): - tokens = sp.id_to_piece(res.tokens[i]) - if use_word_table: - words = [word_table[i] for i in res.hyps[i]] - else: - words = sp.decode_pieces(tokens).split() time = convert_timestamp( res.timestamps[i], subsampling_factor, frame_shift_ms ) - time = parse_timestamp(tokens, time) - assert len(time) == len(words), (tokens, words) + if use_word_table: + words = [word_table[i] for i in res.hyps[i]] + else: + tokens = sp.id_to_piece(res.hyps[i]) + words = sp.decode_pieces(tokens).split() + time = parse_timestamp(tokens, time) + assert len(time) == len(words), (len(time), len(words)) hyps.append(words) timestamps.append(time) From 2f43e4508b9eada64a9b89a4576935fb0b72694c Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Thu, 10 Nov 2022 22:28:04 +0800 Subject: [PATCH 024/120] fix mask errors when padding audios (#670) --- icefall/utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/icefall/utils.py b/icefall/utils.py index e83fccdde..c502cb4d8 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -1017,11 +1017,13 @@ def add_eos(ragged: k2.RaggedTensor, eos_id: int) -> k2.RaggedTensor: return concat(ragged, eos_id, direction="right") -def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor: +def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: """ Args: lengths: A 1-D tensor containing sentence lengths. + max_len: + The length of masks. Returns: Return a 2-D bool tensor, where masked positions are filled with `True` and non-masked positions are @@ -1035,8 +1037,7 @@ def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor: [False, False, False, False, False]]) """ assert lengths.ndim == 1, lengths.ndim - - max_len = lengths.max() + max_len = max(max_len, lengths.max()) n = lengths.size(0) expaned_lengths = torch.arange(max_len).expand(n, max_len).to(lengths) From e334e570d838cbe15188201f6bd47c009b9292be Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 12 Nov 2022 07:57:58 +0800 Subject: [PATCH 025/120] Filter utterances with number_tokens > number_feature_frames. (#604) --- .../ASR/local/compute_fbank_librispeech.py | 31 +++- egs/librispeech/ASR/local/filter_cuts.py | 161 ++++++++++++++++++ .../ASR/lstm_transducer_stateless/train.py | 29 +++- .../ASR/lstm_transducer_stateless2/train.py | 38 ++++- .../ASR/lstm_transducer_stateless3/train.py | 29 +++- .../pruned_stateless_emformer_rnnt2/train.py | 29 +++- .../ASR/pruned_transducer_stateless/train.py | 29 +++- .../ASR/pruned_transducer_stateless2/train.py | 29 +++- .../ASR/pruned_transducer_stateless3/train.py | 38 ++++- .../ASR/pruned_transducer_stateless4/train.py | 29 +++- .../ASR/pruned_transducer_stateless5/train.py | 29 +++- .../ASR/pruned_transducer_stateless6/train.py | 29 +++- 12 files changed, 481 insertions(+), 19 deletions(-) create mode 100644 egs/librispeech/ASR/local/filter_cuts.py diff --git a/egs/librispeech/ASR/local/compute_fbank_librispeech.py b/egs/librispeech/ASR/local/compute_fbank_librispeech.py index f3e15e039..ce7d087f0 100755 --- a/egs/librispeech/ASR/local/compute_fbank_librispeech.py +++ b/egs/librispeech/ASR/local/compute_fbank_librispeech.py @@ -23,11 +23,15 @@ It looks for manifests in the directory data/manifests. The generated fbank features are saved in data/fbank. """ +import argparse import logging import os from pathlib import Path +from typing import Optional +import sentencepiece as spm import torch +from filter_cuts import filter_cuts from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter from lhotse.recipes.utils import read_manifests_if_cached @@ -41,12 +45,29 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_librispeech(): +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to the bpe.model. If not None, we will remove short and + long utterances before extracting features""", + ) + return parser.parse_args() + + +def compute_fbank_librispeech(bpe_model: Optional[str] = None): src_dir = Path("data/manifests") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) num_mel_bins = 80 + if bpe_model: + logging.info(f"Loading {bpe_model}") + sp = spm.SentencePieceProcessor() + sp.load(bpe_model) + dataset_parts = ( "dev-clean", "dev-other", @@ -86,6 +107,9 @@ def compute_fbank_librispeech(): recordings=m["recordings"], supervisions=m["supervisions"], ) + if bpe_model: + cut_set = filter_cuts(cut_set, sp) + if "train" in partition: cut_set = ( cut_set @@ -109,5 +133,6 @@ if __name__ == "__main__": ) logging.basicConfig(format=formatter, level=logging.INFO) - - compute_fbank_librispeech() + args = get_args() + logging.info(vars(args)) + compute_fbank_librispeech(bpe_model=args.bpe_model) diff --git a/egs/librispeech/ASR/local/filter_cuts.py b/egs/librispeech/ASR/local/filter_cuts.py new file mode 100644 index 000000000..53dbb8211 --- /dev/null +++ b/egs/librispeech/ASR/local/filter_cuts.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script removes short and long utterances from a cutset. + +Caution: + You may need to tune the thresholds for your own dataset. + +Usage example: + + python3 ./local/filter_cuts.py \ + --bpe-model data/lang_bpe_500/bpe.model \ + --in-cuts data/fbank/librispeech_cuts_test-clean.jsonl.gz \ + --out-cuts data/fbank-filtered/librispeech_cuts_test-clean.jsonl.gz +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +from lhotse import CutSet, load_manifest_lazy +from lhotse.cut import Cut + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--bpe-model", + type=Path, + help="Path to the bpe.model", + ) + + parser.add_argument( + "--in-cuts", + type=Path, + help="Path to the input cutset", + ) + + parser.add_argument( + "--out-cuts", + type=Path, + help="Path to the output cutset", + ) + + return parser.parse_args() + + +def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor): + total = 0 # number of total utterances before removal + removed = 0 # number of removed utterances + + def remove_short_and_long_utterances(c: Cut): + """Return False to exclude the input cut""" + nonlocal removed, total + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ./display_manifest_statistics.py + # + # You should use ./display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + total += 1 + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" + ) + removed += 1 + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./pruned_transducer_stateless2/conformer.py, the + # conv module uses the following expression + # for subsampling + if c.num_frames is None: + num_frames = c.duration * 100 # approximate + else: + num_frames = c.num_frames + + T = ((num_frames - 1) // 2 - 1) // 2 + # Note: for ./lstm_transducer_stateless/lstm.py, the formula is + # T = ((num_frames - 3) // 2 - 1) // 2 + + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + removed += 1 + return False + + return True + + # We use to_eager() here so that we can print out the value of total + # and removed below. + ans = cut_set.filter(remove_short_and_long_utterances).to_eager() + ratio = removed / total * 100 + logging.info( + f"Removed {removed} cuts from {total} cuts. " + f"{ratio:.3f}% data is removed." + ) + return ans + + +def main(): + args = get_args() + logging.info(vars(args)) + + if args.out_cuts.is_file(): + logging.info(f"{args.out_cuts} already exists - skipping") + return + + assert args.in_cuts.is_file(), f"{args.in_cuts} does not exist" + assert args.bpe_model.is_file(), f"{args.bpe_model} does not exist" + + sp = spm.SentencePieceProcessor() + sp.load(str(args.bpe_model)) + + cut_set = load_manifest_lazy(args.in_cuts) + assert isinstance(cut_set, CutSet) + + cut_set = filter_cuts(cut_set, sp) + logging.info(f"Saving to {args.out_cuts}") + args.out_cuts.parent.mkdir(parents=True, exist_ok=True) + cut_set.to_file(args.out_cuts) + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/train.py b/egs/librispeech/ASR/lstm_transducer_stateless/train.py index fbb4e7224..d30fc260a 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/train.py @@ -987,7 +987,34 @@ def run(rank, world_size, args): # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - return 1.0 <= c.duration <= 20.0 + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./lstm.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 3) // 2 - 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True train_cuts = train_cuts.filter(remove_short_and_long_utt) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py index ac6bf7e04..5eaaf321f 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py @@ -991,7 +991,10 @@ def train_one_epoch( params.best_train_loss = params.train_loss -def filter_short_and_long_utterances(cuts: CutSet) -> CutSet: +def filter_short_and_long_utterances( + cuts: CutSet, + sp: spm.SentencePieceProcessor, +) -> CutSet: def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds # @@ -1001,7 +1004,34 @@ def filter_short_and_long_utterances(cuts: CutSet) -> CutSet: # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - return 1.0 <= c.duration <= 20.0 + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./lstm.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 3) // 2 - 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True cuts = cuts.filter(remove_short_and_long_utt) @@ -1104,7 +1134,7 @@ def run(rank, world_size, args): train_cuts += librispeech.train_clean_360_cuts() train_cuts += librispeech.train_other_500_cuts() - train_cuts = filter_short_and_long_utterances(train_cuts) + train_cuts = filter_short_and_long_utterances(train_cuts, sp) gigaspeech = GigaSpeech(manifest_dir=args.manifest_dir) # XL 10k hours @@ -1121,7 +1151,7 @@ def run(rank, world_size, args): logging.info("Using the S subset of GigaSpeech (250 hours)") train_giga_cuts = gigaspeech.train_S_cuts() - train_giga_cuts = filter_short_and_long_utterances(train_giga_cuts) + train_giga_cuts = filter_short_and_long_utterances(train_giga_cuts, sp) train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py index f2aa84625..60a5a2be7 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py @@ -1007,7 +1007,34 @@ def run(rank, world_size, args): # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - return 1.0 <= c.duration <= 20.0 + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./lstm.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 3) // 2 - 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True train_cuts = train_cuts.filter(remove_short_and_long_utt) diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py index dd23309b3..fed814f19 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py @@ -906,7 +906,34 @@ def run(rank, world_size, args): # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - return 1.0 <= c.duration <= 20.0 + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./emformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 1) // 2 - 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True train_cuts = train_cuts.filter(remove_short_and_long_utt) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index 193c5050c..399b11a29 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -895,7 +895,34 @@ def run(rank, world_size, args): # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - return 1.0 <= c.duration <= 20.0 + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./conformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 1) // 2 - 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True train_cuts = train_cuts.filter(remove_short_and_long_utt) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 7ce2ca779..1947834bf 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -961,7 +961,34 @@ def run(rank, world_size, args): # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - return 1.0 <= c.duration <= 20.0 + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./conformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 1) // 2 - 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True train_cuts = train_cuts.filter(remove_short_and_long_utt) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index 6cc34f18a..44e96644a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -952,7 +952,10 @@ def train_one_epoch( params.best_train_loss = params.train_loss -def filter_short_and_long_utterances(cuts: CutSet) -> CutSet: +def filter_short_and_long_utterances( + cuts: CutSet, + sp: spm.SentencePieceProcessor, +) -> CutSet: def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds # @@ -962,7 +965,34 @@ def filter_short_and_long_utterances(cuts: CutSet) -> CutSet: # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - return 1.0 <= c.duration <= 20.0 + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./conformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 1) // 2 - 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True cuts = cuts.filter(remove_short_and_long_utt) @@ -1058,7 +1088,7 @@ def run(rank, world_size, args): train_cuts += librispeech.train_clean_360_cuts() train_cuts += librispeech.train_other_500_cuts() - train_cuts = filter_short_and_long_utterances(train_cuts) + train_cuts = filter_short_and_long_utterances(train_cuts, sp) gigaspeech = GigaSpeech(manifest_dir=args.manifest_dir) # XL 10k hours @@ -1075,7 +1105,7 @@ def run(rank, world_size, args): logging.info("Using the S subset of GigaSpeech (250 hours)") train_giga_cuts = gigaspeech.train_S_cuts() - train_giga_cuts = filter_short_and_long_utterances(train_giga_cuts) + train_giga_cuts = filter_short_and_long_utterances(train_giga_cuts, sp) train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index 57548270d..cf32e565b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -1011,7 +1011,34 @@ def run(rank, world_size, args): # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - return 1.0 <= c.duration <= 20.0 + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./conformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 1) // 2 - 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True train_cuts = train_cuts.filter(remove_short_and_long_utt) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index b964cd05d..179d9372e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -1043,7 +1043,34 @@ def run(rank, world_size, args): # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - return 1.0 <= c.duration <= 20.0 + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./conformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 1) // 2 - 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True train_cuts = train_cuts.filter(remove_short_and_long_utt) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py index 25d1c4ca6..f717d85fb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py @@ -1005,7 +1005,34 @@ def run(rank, world_size, args): # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - return 1.0 <= c.duration <= 20.0 + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./conformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 1) // 2 - 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True train_cuts = train_cuts.filter(remove_short_and_long_utt) From 7e82f87126d3d380ed2bd0b280b5b18e6808d7d2 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 12 Nov 2022 18:11:19 +0800 Subject: [PATCH 026/120] Add Zipformer from Dan (#672) --- ...pruned-transducer-stateless7-2022-11-11.sh | 106 + .../run-librispeech-2022-11-11-stateless7.yml | 155 ++ egs/librispeech/ASR/README.md | 1 + egs/librispeech/ASR/RESULTS.md | 58 + .../ASR/pruned2_knowledge/__init__.py | 0 .../ASR/pruned2_knowledge/asr_datamodule.py | 428 ++++ .../ASR/pruned2_knowledge/beam_search.py | 766 +++++++ .../ASR/pruned2_knowledge/conformer.py | 1071 ++++++++++ .../ASR/pruned2_knowledge/decode.py | 547 +++++ .../ASR/pruned2_knowledge/decoder.py | 103 + .../ASR/pruned2_knowledge/decoder2.py | 238 +++ .../pruned2_knowledge/encoder_interface.py | 43 + .../ASR/pruned2_knowledge/export.py | 182 ++ .../ASR/pruned2_knowledge/joiner.py | 67 + .../ASR/pruned2_knowledge/model.py | 193 ++ .../ASR/pruned2_knowledge/optim.py | 331 +++ .../ASR/pruned2_knowledge/sampling.py | 332 +++ .../ASR/pruned2_knowledge/scaling.py | 707 +++++++ .../ASR/pruned2_knowledge/scaling_tmp.py | 628 ++++++ .../ASR/pruned2_knowledge/train.py | 997 +++++++++ .../beam_search.py | 4 +- .../pruned_transducer_stateless2/conformer.py | 2 + .../pruned_transducer_stateless7/__init__.py | 0 .../asr_datamodule.py | 1 + .../beam_search.py | 1 + .../pruned_transducer_stateless7/decode.py | 854 ++++++++ .../pruned_transducer_stateless7/decoder.py | 104 + .../encoder_interface.py | 1 + .../pruned_transducer_stateless7/export.py | 324 +++ .../jit_pretrained.py | 274 +++ .../pruned_transducer_stateless7/joiner.py | 67 + .../ASR/pruned_transducer_stateless7/model.py | 195 ++ .../ASR/pruned_transducer_stateless7/optim.py | 971 +++++++++ .../pretrained.py | 363 ++++ .../pruned_transducer_stateless7/scaling.py | 1161 ++++++++++ .../scaling_converter.py | 118 ++ .../test_model.py | 56 + .../ASR/pruned_transducer_stateless7/train.py | 1217 +++++++++++ .../pruned_transducer_stateless7/zipformer.py | 1858 +++++++++++++++++ icefall/checkpoint.py | 14 +- icefall/diagnostics.py | 92 +- icefall/hooks.py | 102 + 42 files changed, 14696 insertions(+), 36 deletions(-) create mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh create mode 100644 .github/workflows/run-librispeech-2022-11-11-stateless7.yml create mode 100644 egs/librispeech/ASR/pruned2_knowledge/__init__.py create mode 100644 egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py create mode 100644 egs/librispeech/ASR/pruned2_knowledge/beam_search.py create mode 100644 egs/librispeech/ASR/pruned2_knowledge/conformer.py create mode 100755 egs/librispeech/ASR/pruned2_knowledge/decode.py create mode 100644 egs/librispeech/ASR/pruned2_knowledge/decoder.py create mode 100644 egs/librispeech/ASR/pruned2_knowledge/decoder2.py create mode 100644 egs/librispeech/ASR/pruned2_knowledge/encoder_interface.py create mode 100755 egs/librispeech/ASR/pruned2_knowledge/export.py create mode 100644 egs/librispeech/ASR/pruned2_knowledge/joiner.py create mode 100644 egs/librispeech/ASR/pruned2_knowledge/model.py create mode 100644 egs/librispeech/ASR/pruned2_knowledge/optim.py create mode 100644 egs/librispeech/ASR/pruned2_knowledge/sampling.py create mode 100644 egs/librispeech/ASR/pruned2_knowledge/scaling.py create mode 100644 egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py create mode 100755 egs/librispeech/ASR/pruned2_knowledge/train.py create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless7/__init__.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7/asr_datamodule.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7/beam_search.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7/decode.py create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7/encoder_interface.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7/export.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless7/model.py create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless7/optim.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7/test_model.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7/train.py create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py create mode 100644 icefall/hooks.py diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh new file mode 100755 index 000000000..75861bbc7 --- /dev/null +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh @@ -0,0 +1,106 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 + +log "Downloading pre-trained model from $repo_url" +git lfs install +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +soxi $repo/test_wavs/*.wav +ls -lh $repo/test_wavs/*.wav + +pushd $repo/exp +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/cpu_jit.pt" +git lfs pull --include "exp/pretrained.pt" +ln -s pretrained.pt epoch-99.pt +ls -lh *.pt +popd + +log "Export to torchscript model" +./pruned_transducer_stateless7/export.py \ + --exp-dir $repo/exp \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --jit 1 + +ls -lh $repo/exp/*.pt + +log "Decode with models exported by torch.jit.script()" + +./pruned_transducer_stateless7/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --nn-model-filename $repo/exp/cpu_jit.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless7/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./pruned_transducer_stateless7/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" +echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" +if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then + mkdir -p pruned_transducer_stateless7/exp + ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless7/exp/epoch-999.pt + ln -s $PWD/$repo/data/lang_bpe_500 data/ + + ls -lh data + ls -lh pruned_transducer_stateless7/exp + + log "Decoding test-clean and test-other" + + # use a small value for decoding with CPU + max_duration=100 + + for method in greedy_search fast_beam_search modified_beam_search; do + log "Decoding with $method" + + ./pruned_transducer_stateless7/decode.py \ + --decoding-method $method \ + --epoch 999 \ + --avg 1 \ + --use-averaged-model 0 \ + --max-duration $max_duration \ + --exp-dir pruned_transducer_stateless7/exp + done + + rm pruned_transducer_stateless7/exp/*.pt +fi diff --git a/.github/workflows/run-librispeech-2022-11-11-stateless7.yml b/.github/workflows/run-librispeech-2022-11-11-stateless7.yml new file mode 100644 index 000000000..3b98b500e --- /dev/null +++ b/.github/workflows/run-librispeech-2022-11-11-stateless7.yml @@ -0,0 +1,155 @@ +# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: run-librispeech-2022-11-11-stateless7 +# zipformer + +on: + push: + branches: + - master + pull_request: + types: [labeled] + + schedule: + # minute (0-59) + # hour (0-23) + # day of the month (1-31) + # month (1-12) + # day of the week (0-6) + # nightly build at 15:50 UTC time every day + - cron: "50 15 * * *" + +jobs: + run_librispeech_2022_11_11_zipformer: + if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: [3.8] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }}-2022-09-25 + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: Cache LibriSpeech test-clean and test-other datasets + id: libri-test-clean-and-test-other-data + uses: actions/cache@v2 + with: + path: | + ~/tmp/download + key: cache-libri-test-clean-and-test-other + + - name: Download LibriSpeech test-clean and test-other + if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh + + - name: Prepare manifests for LibriSpeech test-clean and test-other + shell: bash + run: | + .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh + + - name: Cache LibriSpeech test-clean and test-other fbank features + id: libri-test-clean-and-test-other-fbank + uses: actions/cache@v2 + with: + path: | + ~/tmp/fbank-libri + key: cache-libri-fbank-test-clean-and-test-other-v2 + + - name: Compute fbank for LibriSpeech test-clean and test-other + if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh + + - name: Inference with pre-trained model + shell: bash + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} + run: | + mkdir -p egs/librispeech/ASR/data + ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank + ls -lh egs/librispeech/ASR/data/* + + sudo apt-get -qq install git-lfs tree sox + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + .github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh + + - name: Display decoding results for librispeech pruned_transducer_stateless7 + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + shell: bash + run: | + cd egs/librispeech/ASR/ + tree ./pruned_transducer_stateless7/exp + + cd pruned_transducer_stateless7 + echo "results for pruned_transducer_stateless7" + echo "===greedy search===" + find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===fast_beam_search===" + find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===modified beam search===" + find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + - name: Upload decoding results for librispeech pruned_transducer_stateless7 + uses: actions/upload-artifact@v2 + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + with: + name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless7-2022-11-11 + path: egs/librispeech/ASR/pruned_transducer_stateless7/exp/ diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md index 570d1ba1f..c366650bb 100644 --- a/egs/librispeech/ASR/README.md +++ b/egs/librispeech/ASR/README.md @@ -22,6 +22,7 @@ The following table lists the differences among them. | `pruned_transducer_stateless4` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless2 + save averaged models periodically during training | | `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless4 + more layers + random combiner| | `pruned_transducer_stateless6` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless4 + distillation with hubert| +| `pruned_transducer_stateless7` | Zipformer | Embedding + Conv1d | First experiment with Zipformer from Dan| | `pruned_stateless_emformer_rnnt2` | Emformer(from torchaudio) | Embedding + Conv1d | Using Emformer from torchaudio for streaming ASR| | `conv_emformer_transducer_stateless` | ConvEmformer | Embedding + Conv1d | Using ConvEmformer for streaming ASR + mechanisms in reworked model | | `conv_emformer_transducer_stateless2` | ConvEmformer | Embedding + Conv1d | Using ConvEmformer with simplified memory for streaming ASR + mechanisms in reworked model | diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 57dd9f230..43cd67c85 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -1,5 +1,63 @@ ## Results +### pruned_transducer_stateless7 (zipformer) + +See for more details. + +[pruned_transducer_stateless7](./pruned_transducer_stateless7) + +The tensorboard log can be found at + + +You can find a pretrained model, training logs, decoding logs, and decoding +results at: + + +You can use to deploy it. + +Number of model parameters: 70369391, i.e., 70.37 M + +| | test-clean | test-other | comment | +|----------------------|------------|-------------|----------------------------------------| +| greedy search | 2.17 | 5.23 | --epoch 39 --avg 6 --max-duration 600 | +| modified beam search | 2.15 | 5.20 | --epoch 39 --avg 6 --max-duration 600 | +| fast beam search | 2.15 | 5.22 | --epoch 39 --avg 6 --max-duration 600 | + +The training commands are: +```bash +export CUDA_VISIBLE_DEVICES="0,3,6,7" + +./pruned_transducer_stateless7/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --full-libri 1 \ + --use-fp16 1 \ + --max-duration 750 \ + --exp-dir pruned_transducer_stateless7/exp \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --master-port 12535 +``` + +The decoding commands are: +```bash +for m in greedy_search fast_beam_search modified_beam_search ; do + for epoch in 30; do + for avg in 9; do + ./pruned_transducer_stateless7/decode.py \ + --epoch $epoch \ + --avg $avg \ + --use-averaged-model 1 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --max-duration 600 \ + --decoding-method $m + done + done +done +``` + + + ### LibriSpeech BPE training results (Pruned Stateless LSTM RNN-T + gradient filter) #### [lstm_transducer_stateless3](./lstm_transducer_stateless3) diff --git a/egs/librispeech/ASR/pruned2_knowledge/__init__.py b/egs/librispeech/ASR/pruned2_knowledge/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py new file mode 100644 index 000000000..8dd1459ca --- /dev/null +++ b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py @@ -0,0 +1,428 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest +from lhotse.dataset import ( + BucketingSampler, + CutConcatenate, + CutMix, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SingleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import OnTheFlyFeatures +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class LibriSpeechAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="When enabled, use 960h LibriSpeech. " + "Otherwise, use 100h subset.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the BucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest( + self.args.manifest_dir / "cuts_musan.json.gz" + ) + transforms.append( + CutMix( + cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True + ) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info( + f"Time warp factor: {self.args.spec_aug_time_warp_factor}" + ) + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using BucketingSampler.") + train_sampler = BucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + bucket_method="equal_duration", + drop_last=True, + ) + else: + logging.info("Using SingleCutSampler.") + train_sampler = SingleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = BucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else PrecomputedFeatures(), + return_cuts=self.args.return_cuts, + ) + sampler = BucketingSampler( + cuts, max_duration=self.args.max_duration, shuffle=False + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_clean_100_cuts(self) -> CutSet: + logging.info("About to get train-clean-100 cuts") + return load_manifest( + self.args.manifest_dir / "cuts_train-clean-100.json.gz" + ) + + @lru_cache() + def train_clean_360_cuts(self) -> CutSet: + logging.info("About to get train-clean-360 cuts") + return load_manifest( + self.args.manifest_dir / "cuts_train-clean-360.json.gz" + ) + + @lru_cache() + def train_other_500_cuts(self) -> CutSet: + logging.info("About to get train-other-500 cuts") + return load_manifest( + self.args.manifest_dir / "cuts_train-other-500.json.gz" + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest(self.args.manifest_dir / "cuts_dev-clean.json.gz") + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + return load_manifest(self.args.manifest_dir / "cuts_dev-other.json.gz") + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest(self.args.manifest_dir / "cuts_test-clean.json.gz") + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest(self.args.manifest_dir / "cuts_test-other.json.gz") diff --git a/egs/librispeech/ASR/pruned2_knowledge/beam_search.py b/egs/librispeech/ASR/pruned2_knowledge/beam_search.py new file mode 100644 index 000000000..2e9bf3e0b --- /dev/null +++ b/egs/librispeech/ASR/pruned2_knowledge/beam_search.py @@ -0,0 +1,766 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Dict, List, Optional + +import k2 +import torch +from model import Transducer + +from icefall.decode import one_best_decoding +from icefall.utils import get_texts + + +def fast_beam_search( + model: Transducer, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, +) -> List[List[int]]: + """It limits the maximum number of symbols per frame to 1. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a HLG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + Returns: + Return the decoded result. + """ + assert encoder_out.ndim == 3 + + context_size = model.decoder.context_size + vocab_size = model.decoder.vocab_size + + B, T, C = encoder_out.shape + + config = k2.RnntDecodingConfig( + vocab_size=vocab_size, + decoder_history_len=context_size, + beam=beam, + max_contexts=max_contexts, + max_states=max_states, + ) + individual_streams = [] + for i in range(B): + individual_streams.append(k2.RnntDecodingStream(decoding_graph)) + decoding_streams = k2.RnntDecodingStreams(individual_streams, config) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + for t in range(T): + # shape is a RaggedShape of shape (B, context) + # contexts is a Tensor of shape (shape.NumElements(), context_size) + shape, contexts = decoding_streams.get_contexts() + # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 + contexts = contexts.to(torch.int64) + # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) + decoder_out = model.decoder(contexts, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + # current_encoder_out is of shape + # (shape.NumElements(), 1, joiner_dim) + # fmt: off + current_encoder_out = torch.index_select( + encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) + ) + # fmt: on + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + logits = logits.squeeze(1).squeeze(1) + log_probs = logits.log_softmax(dim=-1) + decoding_streams.advance(log_probs) + decoding_streams.terminate_and_flush_to_streams() + lattice = decoding_streams.format_output(encoder_out_lens.tolist()) + + best_path = one_best_decoding(lattice) + hyps = get_texts(best_path) + return hyps + + +def greedy_search( + model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int +) -> List[int]: + """Greedy search for a single utterance. + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + max_sym_per_frame: + Maximum number of symbols per frame. If it is set to 0, the WER + would be 100%. + Returns: + Return the decoded result. + """ + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + + device = model.device + + decoder_input = torch.tensor( + [blank_id] * context_size, device=device, dtype=torch.int64 + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + T = encoder_out.size(1) + t = 0 + hyp = [blank_id] * context_size + + # Maximum symbols per utterance. + max_sym_per_utt = 1000 + + # symbols per frame + sym_per_frame = 0 + + # symbols per utterance decoded so far + sym_per_utt = 0 + + while t < T and sym_per_utt < max_sym_per_utt: + if sym_per_frame >= max_sym_per_frame: + sym_per_frame = 0 + t += 1 + continue + + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) + # fmt: on + logits = model.joiner( + current_encoder_out, decoder_out.unsqueeze(1), project_input=False + ) + # logits is (1, 1, 1, vocab_size) + + y = logits.argmax().item() + if y != blank_id: + hyp.append(y) + decoder_input = torch.tensor( + [hyp[-context_size:]], device=device + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + sym_per_utt += 1 + sym_per_frame += 1 + else: + sym_per_frame = 0 + t += 1 + hyp = hyp[context_size:] # remove blanks + + return hyp + + +def greedy_search_batch( + model: Transducer, encoder_out: torch.Tensor +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C), where N >= 1. + Returns: + Return a list-of-list of token IDs containing the decoded results. + len(ans) equals to encoder_out.size(0). + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + device = model.device + + batch_size = encoder_out.size(0) + T = encoder_out.size(1) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + + hyps = [[blank_id] * context_size for _ in range(batch_size)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (batch_size, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + encoder_out = model.joiner.encoder_proj(encoder_out) + + # decoder_out: (batch_size, 1, decoder_out_dim) + for t in range(T): + current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa + # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) + logits = model.joiner( + current_encoder_out, decoder_out.unsqueeze(1), project_input=False + ) + # logits'shape (batch_size, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + ans = [h[context_size:] for h in hyps] + return ans + + +@dataclass +class Hypothesis: + # The predicted tokens so far. + # Newly predicted tokens are appended to `ys`. + ys: List[int] + + # The log prob of ys. + # It contains only one entry. + log_prob: torch.Tensor + + @property + def key(self) -> str: + """Return a string representation of self.ys""" + return "_".join(map(str, self.ys)) + + +class HypothesisList(object): + def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None: + """ + Args: + data: + A dict of Hypotheses. Its key is its `value.key`. + """ + if data is None: + self._data = {} + else: + self._data = data + + @property + def data(self) -> Dict[str, Hypothesis]: + return self._data + + def add(self, hyp: Hypothesis) -> None: + """Add a Hypothesis to `self`. + + If `hyp` already exists in `self`, its probability is updated using + `log-sum-exp` with the existed one. + + Args: + hyp: + The hypothesis to be added. + """ + key = hyp.key + if key in self: + old_hyp = self._data[key] # shallow copy + torch.logaddexp( + old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob + ) + else: + self._data[key] = hyp + + def get_most_probable(self, length_norm: bool = False) -> Hypothesis: + """Get the most probable hypothesis, i.e., the one with + the largest `log_prob`. + + Args: + length_norm: + If True, the `log_prob` of a hypothesis is normalized by the + number of tokens in it. + Returns: + Return the hypothesis that has the largest `log_prob`. + """ + if length_norm: + return max( + self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) + ) + else: + return max(self._data.values(), key=lambda hyp: hyp.log_prob) + + def remove(self, hyp: Hypothesis) -> None: + """Remove a given hypothesis. + + Caution: + `self` is modified **in-place**. + + Args: + hyp: + The hypothesis to be removed from `self`. + Note: It must be contained in `self`. Otherwise, + an exception is raised. + """ + key = hyp.key + assert key in self, f"{key} does not exist" + del self._data[key] + + def filter(self, threshold: torch.Tensor) -> "HypothesisList": + """Remove all Hypotheses whose log_prob is less than threshold. + + Caution: + `self` is not modified. Instead, a new HypothesisList is returned. + + Returns: + Return a new HypothesisList containing all hypotheses from `self` + with `log_prob` being greater than the given `threshold`. + """ + ans = HypothesisList() + for _, hyp in self._data.items(): + if hyp.log_prob > threshold: + ans.add(hyp) # shallow copy + return ans + + def topk(self, k: int) -> "HypothesisList": + """Return the top-k hypothesis.""" + hyps = list(self._data.items()) + + hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] + + ans = HypothesisList(dict(hyps)) + return ans + + def __contains__(self, key: str): + return key in self._data + + def __iter__(self): + return iter(self._data.values()) + + def __len__(self) -> int: + return len(self._data) + + def __str__(self) -> str: + s = [] + for key in self: + s.append(key) + return ", ".join(s) + + +def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: + """Return a ragged shape with axes [utt][num_hyps]. + + Args: + hyps: + len(hyps) == batch_size. It contains the current hypothesis for + each utterance in the batch. + Returns: + Return a ragged shape with 2 axes [utt][num_hyps]. Note that + the shape is on CPU. + """ + num_hyps = [len(h) for h in hyps] + + # torch.cumsum() is inclusive sum, so we put a 0 at the beginning + # to get exclusive sum later. + num_hyps.insert(0, 0) + + num_hyps = torch.tensor(num_hyps) + row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32) + ans = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=row_splits[-1].item() + ) + return ans + + +def modified_beam_search( + model: Transducer, + encoder_out: torch.Tensor, + beam: int = 4, +) -> List[List[int]]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + beam: + Number of active paths during the beam search. + Returns: + Return a list-of-list of token IDs. ans[i] is the decoding results + for the i-th utterance. + """ + assert encoder_out.ndim == 3, encoder_out.shape + + batch_size = encoder_out.size(0) + T = encoder_out.size(1) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = model.device + B = [HypothesisList() for _ in range(batch_size)] + for i in range(batch_size): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + for t in range(T): + current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + + hyps_shape = _get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + if new_token != blank_id: + new_ys.append(new_token) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B[i].add(new_hyp) + + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + ans = [h.ys[context_size:] for h in best_hyps] + + return ans + + +def _deprecated_modified_beam_search( + model: Transducer, + encoder_out: torch.Tensor, + beam: int = 4, +) -> List[int]: + """It limits the maximum number of symbols per frame to 1. + + It decodes only one utterance at a time. We keep it only for reference. + The function :func:`modified_beam_search` should be preferred as it + supports batch decoding. + + + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + beam: + Beam size. + Returns: + Return the decoded result. + """ + + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + + device = model.device + + T = encoder_out.size(1) + + B = HypothesisList() + B.add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + encoder_out = model.joiner.encoder_proj(encoder_out) + + for t in range(T): + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) + # current_encoder_out is of shape (1, 1, 1, encoder_out_dim) + # fmt: on + A = list(B) + B = HypothesisList() + + ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A]) + # ys_log_probs is of shape (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyp in A], + device=device, + dtype=torch.int64, + ) + # decoder_input is of shape (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_output is of shape (num_hyps, 1, 1, joiner_dim) + + current_encoder_out = current_encoder_out.expand( + decoder_out.size(0), 1, 1, -1 + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) + # logits is of shape (num_hyps, 1, 1, vocab_size) + logits = logits.squeeze(1).squeeze(1) + + # now logits is of shape (num_hyps, vocab_size) + log_probs = logits.log_softmax(dim=-1) + + log_probs.add_(ys_log_probs) + + log_probs = log_probs.reshape(-1) + topk_log_probs, topk_indexes = log_probs.topk(beam) + + # topk_hyp_indexes are indexes into `A` + topk_hyp_indexes = topk_indexes // logits.size(-1) + topk_token_indexes = topk_indexes % logits.size(-1) + + topk_hyp_indexes = topk_hyp_indexes.tolist() + topk_token_indexes = topk_token_indexes.tolist() + + for i in range(len(topk_hyp_indexes)): + hyp = A[topk_hyp_indexes[i]] + new_ys = hyp.ys[:] + new_token = topk_token_indexes[i] + if new_token != blank_id: + new_ys.append(new_token) + new_log_prob = topk_log_probs[i] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B.add(new_hyp) + + best_hyp = B.get_most_probable(length_norm=True) + ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks + + return ys + + +def beam_search( + model: Transducer, + encoder_out: torch.Tensor, + beam: int = 4, +) -> List[int]: + """ + It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf + + espnet/nets/beam_search_transducer.py#L247 is used as a reference. + + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + beam: + Beam size. + Returns: + Return the decoded result. + """ + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + + device = model.device + + decoder_input = torch.tensor( + [blank_id] * context_size, + device=device, + dtype=torch.int64, + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + T = encoder_out.size(1) + t = 0 + + B = HypothesisList() + B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0)) + + max_sym_per_utt = 20000 + + sym_per_utt = 0 + + decoder_cache: Dict[str, torch.Tensor] = {} + + while t < T and sym_per_utt < max_sym_per_utt: + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) + # fmt: on + A = B + B = HypothesisList() + + joint_cache: Dict[str, torch.Tensor] = {} + + # TODO(fangjun): Implement prefix search to update the `log_prob` + # of hypotheses in A + + while True: + y_star = A.get_most_probable() + A.remove(y_star) + + cached_key = y_star.key + + if cached_key not in decoder_cache: + decoder_input = torch.tensor( + [y_star.ys[-context_size:]], + device=device, + dtype=torch.int64, + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + decoder_cache[cached_key] = decoder_out + else: + decoder_out = decoder_cache[cached_key] + + cached_key += f"-t-{t}" + if cached_key not in joint_cache: + logits = model.joiner( + current_encoder_out, + decoder_out.unsqueeze(1), + project_input=False, + ) + + # TODO(fangjun): Scale the blank posterior + log_prob = logits.log_softmax(dim=-1) + # log_prob is (1, 1, 1, vocab_size) + log_prob = log_prob.squeeze() + # Now log_prob is (vocab_size,) + joint_cache[cached_key] = log_prob + else: + log_prob = joint_cache[cached_key] + + # First, process the blank symbol + skip_log_prob = log_prob[blank_id] + new_y_star_log_prob = y_star.log_prob + skip_log_prob + + # ys[:] returns a copy of ys + B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) + + # Second, process other non-blank labels + values, indices = log_prob.topk(beam + 1) + for i, v in zip(indices.tolist(), values.tolist()): + if i == blank_id: + continue + new_ys = y_star.ys + [i] + new_log_prob = y_star.log_prob + v + A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob)) + + # Check whether B contains more than "beam" elements more probable + # than the most probable in A + A_most_probable = A.get_most_probable() + + kept_B = B.filter(A_most_probable.log_prob) + + if len(kept_B) >= beam: + B = kept_B.topk(beam) + break + + t += 1 + + best_hyp = B.get_most_probable(length_norm=True) + ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks + return ys diff --git a/egs/librispeech/ASR/pruned2_knowledge/conformer.py b/egs/librispeech/ASR/pruned2_knowledge/conformer.py new file mode 100644 index 000000000..295a35204 --- /dev/null +++ b/egs/librispeech/ASR/pruned2_knowledge/conformer.py @@ -0,0 +1,1071 @@ +#!/usr/bin/env python3 +# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import warnings +from typing import Optional, Tuple +from sampling import create_knowledge_base, KnowledgeBaseLookup + +import torch +from encoder_interface import EncoderInterface +from scaling import ( + ActivationBalancer, + BasicNorm, + DoubleSwish, + ScaledConv1d, + ScaledConv2d, + ScaledLinear, +) +from torch import Tensor, nn + +from icefall.utils import make_pad_mask + + +class Conformer(EncoderInterface): + """ + Args: + num_features (int): Number of input features + subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) + d_model (int): attention dimension, also the output dimension + nhead (int): number of head + dim_feedforward (int): feedforward dimention + num_encoder_layers (int): number of encoder layers + dropout (float): dropout rate + layer_dropout (float): layer-dropout rate. + cnn_module_kernel (int): Kernel size of convolution module + vgg_frontend (bool): whether to use vgg frontend. + """ + + def __init__( + self, + num_features: int, + subsampling_factor: int = 4, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + dropout: float = 0.1, + layer_dropout: float = 0.075, + cnn_module_kernel: int = 31, + knowledge_M: int = 256, + knowledge_N: int = 2, + knowledge_D: int = 512, + knowledge_K: int = 16, + ) -> None: + super(Conformer, self).__init__() + + self.num_features = num_features + self.subsampling_factor = subsampling_factor + if subsampling_factor != 4: + raise NotImplementedError("Support only 'subsampling_factor=4'.") + + + self.knowledge_base = create_knowledge_base(knowledge_M, knowledge_N, + knowledge_D) + + # self.encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, T//subsampling_factor, d_model). + # That is, it does two things simultaneously: + # (1) subsampling: T -> T//subsampling_factor + # (2) embedding: num_features -> d_model + self.encoder_embed = Conv2dSubsampling(num_features, d_model) + + self.encoder_pos = RelPositionalEncoding(d_model, dropout) + + # Pass in a lambda that creates a new ConformerEncoderLayer with these + # args. Don't use deepcopy because we need the knowledge_base + # to be shared. + encoder_layer_fn = lambda: ConformerEncoderLayer( + self.knowledge_base, + d_model, + nhead, + dim_feedforward, + dropout, + layer_dropout, + cnn_module_kernel, + knowledge_M, + knowledge_N, + knowledge_D, + knowledge_K + ) + self.encoder = ConformerEncoder(encoder_layer_fn, num_encoder_layers) + + def forward( + self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + warmup: + A floating point value that gradually increases from 0 throughout + training; when it is >= 1.0 we are "fully warmed up". It is used + to turn modules on sequentially. + Returns: + Return a tuple containing 2 tensors: + - embeddings: its shape is (batch_size, output_seq_len, d_model) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + """ + x = self.encoder_embed(x) + x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # Caution: We assume the subsampling factor is 4! + lengths = ((x_lens - 1) // 2 - 1) // 2 + assert x.size(0) == lengths.max().item() + mask = make_pad_mask(lengths) + + x = self.encoder( + x, pos_emb, src_key_padding_mask=mask, warmup=warmup + ) # (T, N, C) + + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return x, lengths + + +class ConformerEncoderLayer(nn.Module): + """ + ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks. + See: "Conformer: Convolution-augmented Transformer for Speech Recognition" + + Args: + knowledge_base: shared knowledge base parameter matrix, to be passed to constructors + of lookup modules + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module. + knowledge_M, knowledge_N, knowledge_D, knowledge_K: parameters for knowledge-base, + see docs for KnowlegeBaseLookup. + + Examples:: + >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + + def __init__( + self, + knowledge_base: nn.Parameter, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + layer_dropout: float = 0.075, + cnn_module_kernel: int = 31, + knowledge_M: int = 256, + knowledge_N: int = 2, + knowledge_D: int = 512, + knowledge_K: int = 16, + ) -> None: + super(ConformerEncoderLayer, self).__init__() + + self.layer_dropout = layer_dropout + + self.d_model = d_model + + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) + + self.feed_forward = nn.Sequential( + ScaledLinear(d_model, dim_feedforward), + ActivationBalancer(channel_dim=-1), + DoubleSwish(), + nn.Dropout(dropout), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ) + + self.feed_forward_macaron = nn.Sequential( + ScaledLinear(d_model, dim_feedforward), + ActivationBalancer(channel_dim=-1), + DoubleSwish(), + nn.Dropout(dropout), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ) + + self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) + + self.lookup = KnowledgeBaseLookup(knowledge_M, knowledge_N, + knowledge_D, knowledge_K, + d_model, + knowledge_base) + + self.norm_final = BasicNorm(d_model) + + # try to ensure the output is close to zero-mean (or at least, zero-median). + self.balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 + ) + + self.dropout = nn.Dropout(dropout) + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + warmup: float = 1.0, + ) -> Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + warmup: controls selective bypass of of layers; if < 1.0, we will + bypass layers more frequently. + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, N is the batch size, E is the feature number + """ + src_orig = src + + warmup_scale = min(0.1 + warmup, 1.0) + # alpha = 1.0 means fully use this encoder layer, 0.0 would mean + # completely bypass it. + if self.training: + alpha = ( + warmup_scale + if torch.rand(()).item() <= (1.0 - self.layer_dropout) + else 0.1 + ) + else: + alpha = 1.0 + + # macaron style feed forward module + src = src + self.dropout(self.feed_forward_macaron(src)) + + # multi-headed self-attention module + src_att = self.self_attn( + src, + src, + src, + pos_emb=pos_emb, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + )[0] + src = src + self.dropout(src_att) + + # convolution module + src = src + self.dropout(self.conv_module(src)) + + # feed forward module + src = src + self.dropout(self.feed_forward(src)) + + # knowledge-base lookup + src = src + self.dropout(self.lookup(src)) + + src = self.norm_final(self.balancer(src)) + + if alpha != 1.0: + src = alpha * src + (1 - alpha) * src_orig + + return src + + +class ConformerEncoder(nn.Module): + r"""ConformerEncoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the ConformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + + Examples:: + >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) + >>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = conformer_encoder(src, pos_emb) + """ + + def __init__(self, encoder_layer_fn, num_layers: int) -> None: + super().__init__() + self.layers = nn.ModuleList( + [encoder_layer_fn() for i in range(num_layers)] + ) + self.num_layers = num_layers + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + warmup: float = 1.0, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + pos_emb: Positional embedding tensor (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + """ + output = src + + for i, mod in enumerate(self.layers): + output = mod( + output, + pos_emb, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + warmup=warmup, + ) + + return output + + +class RelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module. + + See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py + + Args: + d_model: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length. + + """ + + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: + """Construct an PositionalEncoding object.""" + super(RelPositionalEncoding, self).__init__() + self.d_model = d_model + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x: Tensor) -> None: + """Reset the positional encodings.""" + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x.size(1) * 2 - 1: + # Note: TorchScript doesn't implement operator== for torch.Device + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vecotr and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). + + """ + self.extend_pe(x) + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 + - x.size(1) + + 1 : self.pe.size(1) // 2 # noqa E203 + + x.size(1), + ] + return self.dropout(x), self.dropout(pos_emb) + + +class RelPositionMultiheadAttention(nn.Module): + r"""Multi-Head Attention layer with relative position encoding + + See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + + Examples:: + + >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + ) -> None: + super(RelPositionMultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True) + self.out_proj = ScaledLinear( + embed_dim, embed_dim, bias=True, initial_scale=0.25 + ) + + # linear transformation for positional encoding. + self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach()) + self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach()) + self._reset_parameters() + + def _pos_bias_u(self): + return self.pos_bias_u * self.pos_bias_u_scale.exp() + + def _pos_bias_v(self): + return self.pos_bias_v * self.pos_bias_v_scale.exp() + + def _reset_parameters(self) -> None: + nn.init.normal_(self.pos_bias_u, std=0.01) + nn.init.normal_(self.pos_bias_v, std=0.01) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask and a value is True, + the corresponding value on the attention layer will be ignored. When given + a byte mask and a value is non-zero, the corresponding value on the attention + layer will be ignored + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + - Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the position + with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + - Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + return self.multi_head_attention_forward( + query, + key, + value, + pos_emb, + self.embed_dim, + self.num_heads, + self.in_proj.get_weight(), + self.in_proj.get_bias(), + self.dropout, + self.out_proj.get_weight(), + self.out_proj.get_bias(), + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + ) + + def rel_shift(self, x: Tensor) -> Tensor: + """Compute relative positional encoding. + + Args: + x: Input tensor (batch, head, time1, 2*time1-1). + time1 means the length of query vector. + + Returns: + Tensor: tensor of shape (batch, head, time1, time2) + (note: time2 has the same value as time1, but it is for + the key, while time1 is for the query). + """ + (batch_size, num_heads, time1, n) = x.shape + assert n == 2 * time1 - 1 + # Note: TorchScript requires explicit arg for stride() + batch_stride = x.stride(0) + head_stride = x.stride(1) + time1_stride = x.stride(2) + n_stride = x.stride(3) + return x.as_strided( + (batch_size, num_heads, time1, time1), + (batch_stride, head_stride, time1_stride - n_stride, n_stride), + storage_offset=n_stride * (time1 - 1), + ) + + def multi_head_attention_forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_emb: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Tensor, + in_proj_bias: Tensor, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + embed_dim_to_check: total dimension of the model. + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence + length, N is the batch size, E is the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == embed_dim_to_check + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + + head_dim = embed_dim // num_heads + assert ( + head_dim * num_heads == embed_dim + ), "embed_dim must be divisible by num_heads" + + scaling = float(head_dim) ** -0.5 + + if torch.equal(query, key) and torch.equal(key, value): + # self-attention + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) + + elif torch.equal(key, value): + # encoder-decoder attention + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) + + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = embed_dim * 2 + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + k = nn.functional.linear(key, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim * 2 + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + v = nn.functional.linear(value, _w, _b) + + if attn_mask is not None: + assert ( + attn_mask.dtype == torch.float32 + or attn_mask.dtype == torch.float64 + or attn_mask.dtype == torch.float16 + or attn_mask.dtype == torch.uint8 + or attn_mask.dtype == torch.bool + ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( + attn_mask.dtype + ) + if attn_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for attn_mask is deprecated. Use bool tensor instead." + ) + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [ + bsz * num_heads, + query.size(0), + key.size(0), + ]: + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) + else: + raise RuntimeError( + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) + ) + # attn_mask's dim is 3 now. + + # convert ByteTensor key_padding_mask to bool + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): + warnings.warn( + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + ) + key_padding_mask = key_padding_mask.to(torch.bool) + + q = (q * scaling).contiguous().view(tgt_len, bsz, num_heads, head_dim) + k = k.contiguous().view(-1, bsz, num_heads, head_dim) + v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + + src_len = k.size(0) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz, "{} == {}".format( + key_padding_mask.size(0), bsz + ) + assert key_padding_mask.size(1) == src_len, "{} == {}".format( + key_padding_mask.size(1), src_len + ) + + q = q.transpose(0, 1) # (batch, time1, head, d_k) + + pos_emb_bsz = pos_emb.size(0) + assert pos_emb_bsz in (1, bsz) # actually it is 1 + p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) + p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) + + q_with_bias_u = (q + self._pos_bias_u()).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + q_with_bias_v = (q + self._pos_bias_v()).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + # compute attention score + # first compute matrix a and matrix c + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) + + # compute matrix b and matrix d + matrix_bd = torch.matmul( + q_with_bias_v, p.transpose(-2, -1) + ) # (batch, head, time1, 2*time1-1) + matrix_bd = self.rel_shift(matrix_bd) + + attn_output_weights = ( + matrix_ac + matrix_bd + ) # (batch, head, time1, time2) + + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) + + assert list(attn_output_weights.size()) == [ + bsz * num_heads, + tgt_len, + src_len, + ] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float("-inf")) + else: + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, src_len + ) + + attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) + attn_output_weights = nn.functional.dropout( + attn_output_weights, p=dropout_p, training=training + ) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias + ) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + return attn_output, attn_output_weights.sum(dim=1) / num_heads + else: + return attn_output, None + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Conformer model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__( + self, channels: int, kernel_size: int, bias: bool = True + ) -> None: + """Construct an ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + self.pointwise_conv1 = ScaledConv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + + # after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu). + # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, + # but sometimes, for some reason, for layer 0 the rms ends up being very large, + # between 50 and 100 for different channels. This will cause very peaky and + # sparse derivatives for the sigmoid gating function, which will tend to make + # the loss function not learn effectively. (for most layers the average absolute values + # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, + # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different + # layers, which likely breaks down as 0.5 for the "linear" half and + # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we + # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, + # it will be in a better position to start learning something, i.e. to latch onto + # the correct range. + self.deriv_balancer1 = ActivationBalancer( + channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0 + ) + + self.depthwise_conv = ScaledConv1d( + channels, + channels, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + groups=channels, + bias=bias, + ) + + self.deriv_balancer2 = ActivationBalancer( + channel_dim=1, min_positive=0.05, max_positive=1.0 + ) + + self.activation = DoubleSwish() + + self.pointwise_conv2 = ScaledConv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + initial_scale=0.25, + ) + + def forward(self, x: Tensor) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + + x = self.deriv_balancer1(x) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + + x = self.deriv_balancer2(x) + x = self.activation(x) + + x = self.pointwise_conv2(x) # (batch, channel, time) + + return x.permute(2, 0, 1) + + +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = ((T-1)//2 - 1)//2, which approximates T' == T//4 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + layer1_channels: int = 8, + layer2_channels: int = 32, + layer3_channels: int = 128, + ) -> None: + """ + Args: + in_channels: + Number of channels in. The input shape is (N, T, in_channels). + Caution: It requires: T >=7, in_channels >=7 + out_channels + Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels) + layer1_channels: + Number of channels in layer1 + layer1_channels: + Number of channels in layer2 + """ + assert in_channels >= 7 + super().__init__() + + self.conv = nn.Sequential( + ScaledConv2d( + in_channels=1, + out_channels=layer1_channels, + kernel_size=3, + padding=1, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ScaledConv2d( + in_channels=layer1_channels, + out_channels=layer2_channels, + kernel_size=3, + stride=2, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ScaledConv2d( + in_channels=layer2_channels, + out_channels=layer3_channels, + kernel_size=3, + stride=2, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ) + self.out = ScaledLinear( + layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels + ) + # set learn_eps=False because out_norm is preceded by `out`, and `out` + # itself has learned scale, so the extra degree of freedom is not + # needed. + self.out_norm = BasicNorm(out_channels, learn_eps=False) + # constrain median of output to be close to zero. + self.out_balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + + Returns: + Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + x = self.conv(x) + # Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) + x = self.out_norm(x) + x = self.out_balancer(x) + return x + + +if __name__ == "__main__": + feature_dim = 50 + c = Conformer(num_features=feature_dim, d_model=128, nhead=4) + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + f = c( + torch.randn(batch_size, seq_len, feature_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + warmup=0.5, + ) diff --git a/egs/librispeech/ASR/pruned2_knowledge/decode.py b/egs/librispeech/ASR/pruned2_knowledge/decode.py new file mode 100755 index 000000000..b4a9af55a --- /dev/null +++ b/egs/librispeech/ASR/pruned2_knowledge/decode.py @@ -0,0 +1,547 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned2_knowledge/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned2_knowledge/exp \ + --max-duration 100 \ + --decoding-method greedy_search + +(2) beam search +./pruned2_knowledge/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned2_knowledge/exp \ + --max-duration 100 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned2_knowledge/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned2_knowledge/exp \ + --max-duration 100 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search +./pruned2_knowledge/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned2_knowledge/exp \ + --max-duration 1500 \ + --decoding-method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 +""" + + +import argparse +import logging +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--avg-last-n", + type=int, + default=0, + help="""If positive, --epoch and --avg are ignored and it + will use the last n checkpoints exp_dir/checkpoint-xxx.pt + where xxx is the number of processed batches while + saving that checkpoint. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned2_knowledge/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = model.device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif params.decoding_method == "fast_beam_search": + return { + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps + } + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 100 + else: + log_interval = 2 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for hyp_words, ref_text in zip(hyps, texts): + ref_words = ref_text.split() + this_batch.append((ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if params.avg_last_n > 0: + filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n] + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + model.to(device) + model.eval() + model.device = device + + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned2_knowledge/decoder.py b/egs/librispeech/ASR/pruned2_knowledge/decoder.py new file mode 100644 index 000000000..b6d94aaf1 --- /dev/null +++ b/egs/librispeech/ASR/pruned2_knowledge/decoder.py @@ -0,0 +1,103 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from scaling import ScaledConv1d, ScaledEmbedding + + +class Decoder(nn.Module): + """This class modifies the stateless decoder from the following paper: + + RNN-transducer with stateless prediction network + https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 + + It removes the recurrent connection from the decoder, i.e., the prediction + network. Different from the above paper, it adds an extra Conv1d + right after the embedding layer. + + TODO: Implement https://arxiv.org/pdf/2109.07513.pdf + """ + + def __init__( + self, + vocab_size: int, + decoder_dim: int, + blank_id: int, + context_size: int, + ): + """ + Args: + vocab_size: + Number of tokens of the modeling unit including blank. + decoder_dim: + Dimension of the input embedding, and of the decoder output. + blank_id: + The ID of the blank symbol. + context_size: + Number of previous words to use to predict the next word. + 1 means bigram; 2 means trigram. n means (n+1)-gram. + """ + super().__init__() + + self.embedding = ScaledEmbedding( + num_embeddings=vocab_size, + embedding_dim=decoder_dim, + padding_idx=blank_id, + ) + self.blank_id = blank_id + + assert context_size >= 1, context_size + self.context_size = context_size + self.vocab_size = vocab_size + if context_size > 1: + self.conv = ScaledConv1d( + in_channels=decoder_dim, + out_channels=decoder_dim, + kernel_size=context_size, + padding=0, + groups=decoder_dim, + bias=False, + ) + + def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, U). + need_pad: + True to left pad the input. Should be True during training. + False to not pad the input. Should be False during inference. + Returns: + Return a tensor of shape (N, U, decoder_dim). + """ + y = y.to(torch.int64) + embedding_out = self.embedding(y) + if self.context_size > 1: + embedding_out = embedding_out.permute(0, 2, 1) + if need_pad is True: + embedding_out = F.pad( + embedding_out, pad=(self.context_size - 1, 0) + ) + else: + # During inference time, there is no need to do extra padding + # as we only need one output + assert embedding_out.size(-1) == self.context_size + embedding_out = self.conv(embedding_out) + embedding_out = embedding_out.permute(0, 2, 1) + embedding_out = F.relu(embedding_out) + return embedding_out diff --git a/egs/librispeech/ASR/pruned2_knowledge/decoder2.py b/egs/librispeech/ASR/pruned2_knowledge/decoder2.py new file mode 100644 index 000000000..db51fb1cd --- /dev/null +++ b/egs/librispeech/ASR/pruned2_knowledge/decoder2.py @@ -0,0 +1,238 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from typing import Optional +from subsampling import ScaledConv1d + + +class Decoder(nn.Module): + """This class modifies the stateless decoder from the following paper: + + RNN-transducer with stateless prediction network + https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 + + It removes the recurrent connection from the decoder, i.e., the prediction + network. Different from the above paper, it adds an extra Conv1d + right after the embedding layer. + + TODO: Implement https://arxiv.org/pdf/2109.07513.pdf + """ + + def __init__( + self, + vocab_size: int, + embedding_dim: int, + blank_id: int, + context_size: int, + ): + """ + Args: + vocab_size: + Number of tokens of the modeling unit including blank. + embedding_dim: + Dimension of the input embedding. + blank_id: + The ID of the blank symbol. + context_size: + Number of previous words to use to predict the next word. + 1 means bigram; 2 means trigram. n means (n+1)-gram. + """ + super().__init__() + self.embedding = ScaledEmbedding( + num_embeddings=vocab_size, + embedding_dim=embedding_dim, + padding_idx=blank_id, + ) + self.blank_id = blank_id + + assert context_size >= 1, context_size + self.context_size = context_size + if context_size > 1: + self.conv = ScaledConv1d( + in_channels=embedding_dim, + out_channels=embedding_dim, + kernel_size=context_size, + padding=0, + groups=embedding_dim, + bias=False, + ) + + def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, U). + need_pad: + True to left pad the input. Should be True during training. + False to not pad the input. Should be False during inference. + Returns: + Return a tensor of shape (N, U, embedding_dim). + """ + y = y.to(torch.int64) + embedding_out = self.embedding(y) + if self.context_size > 1: + embedding_out = embedding_out.permute(0, 2, 1) + if need_pad is True: + embedding_out = F.pad( + embedding_out, pad=(self.context_size - 1, 0) + ) + else: + # During inference time, there is no need to do extra padding + # as we only need one output + assert embedding_out.size(-1) == self.context_size + embedding_out = self.conv(embedding_out) + embedding_out = embedding_out.permute(0, 2, 1) + return embedding_out + + + +class ScaledEmbedding(nn.Module): + r"""A simple lookup table that stores embeddings of a fixed dictionary and size. + + This module is often used to store word embeddings and retrieve them using indices. + The input to the module is a list of indices, and the output is the corresponding + word embeddings. + + Args: + num_embeddings (int): size of the dictionary of embeddings + embedding_dim (int): the size of each embedding vector + padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx` + (initialized to zeros) whenever it encounters the index. + max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` + is renormalized to have norm :attr:`max_norm`. + norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. + scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of + the words in the mini-batch. Default ``False``. + sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. + See Notes for more details regarding sparse gradients. + + Attributes: + weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim) + initialized from :math:`\mathcal{N}(0, 1)` + + Shape: + - Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract + - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}` + + .. note:: + Keep in mind that only a limited number of optimizers support + sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`), + :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`) + + .. note:: + With :attr:`padding_idx` set, the embedding vector at + :attr:`padding_idx` is initialized to all zeros. However, note that this + vector can be modified afterwards, e.g., using a customized + initialization method, and thus changing the vector used to pad the + output. The gradient for this vector from :class:`~torch.nn.Embedding` + is always zero. + + Examples:: + + >>> # an Embedding module containing 10 tensors of size 3 + >>> embedding = nn.Embedding(10, 3) + >>> # a batch of 2 samples of 4 indices each + >>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]]) + >>> embedding(input) + tensor([[[-0.0251, -1.6902, 0.7172], + [-0.6431, 0.0748, 0.6969], + [ 1.4970, 1.3448, -0.9685], + [-0.3677, -2.7265, -0.1685]], + + [[ 1.4970, 1.3448, -0.9685], + [ 0.4362, -0.4004, 0.9400], + [-0.6431, 0.0748, 0.6969], + [ 0.9124, -2.3616, 1.1151]]]) + + + >>> # example with padding_idx + >>> embedding = nn.Embedding(10, 3, padding_idx=0) + >>> input = torch.LongTensor([[0,2,0,5]]) + >>> embedding(input) + tensor([[[ 0.0000, 0.0000, 0.0000], + [ 0.1535, -2.0309, 0.9315], + [ 0.0000, 0.0000, 0.0000], + [-0.1655, 0.9897, 0.0635]]]) + """ + __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', + 'scale_grad_by_freq', 'sparse'] + + num_embeddings: int + embedding_dim: int + padding_idx: int + scale_grad_by_freq: bool + weight: Tensor + sparse: bool + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, + scale_grad_by_freq: bool = False, + sparse: bool = False, + scale_speed: float = 5.0) -> None: + super(ScaledEmbedding, self).__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + if padding_idx is not None: + if padding_idx > 0: + assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' + elif padding_idx < 0: + assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' + padding_idx = self.num_embeddings + padding_idx + self.padding_idx = padding_idx + self.scale_grad_by_freq = scale_grad_by_freq + + self.scale_speed = scale_speed + self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() + self.sparse = sparse + + self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) + self.reset_parameters() + + + + def reset_parameters(self) -> None: + nn.init.normal_(self.weight, std=0.05) + nn.init.constant_(self.scale, torch.tensor(1.0/0.05).log() / self.scale_speed) + + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input: Tensor) -> Tensor: + scale = (self.scale * self.scale_speed).exp() + if input.numel() < self.num_embeddings: + return F.embedding( + input, self.weight, self.padding_idx, + None, 2.0, # None, 2.0 relate to normalization + self.scale_grad_by_freq, self.sparse) * scale + else: + return F.embedding( + input, self.weight * scale, self.padding_idx, + None, 2.0, # None, 2.0 relates to normalization + self.scale_grad_by_freq, self.sparse) + + def extra_repr(self) -> str: + s = '{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}, scale={scale}' + if self.padding_idx is not None: + s += ', padding_idx={padding_idx}' + if self.scale_grad_by_freq is not False: + s += ', scale_grad_by_freq={scale_grad_by_freq}' + if self.sparse is not False: + s += ', sparse=True' + return s.format(**self.__dict__) diff --git a/egs/librispeech/ASR/pruned2_knowledge/encoder_interface.py b/egs/librispeech/ASR/pruned2_knowledge/encoder_interface.py new file mode 100644 index 000000000..257facce4 --- /dev/null +++ b/egs/librispeech/ASR/pruned2_knowledge/encoder_interface.py @@ -0,0 +1,43 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import torch +import torch.nn as nn + + +class EncoderInterface(nn.Module): + def forward( + self, x: torch.Tensor, x_lens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A tensor of shape (batch_size, input_seq_len, num_features) + containing the input features. + x_lens: + A tensor of shape (batch_size,) containing the number of frames + in `x` before padding. + Returns: + Return a tuple containing two tensors: + - encoder_out, a tensor of (batch_size, out_seq_len, output_dim) + containing unnormalized probabilities, i.e., the output of a + linear layer. + - encoder_out_lens, a tensor of shape (batch_size,) containing + the number of frames in `encoder_out` before padding. + """ + raise NotImplementedError("Please implement it in a subclass") diff --git a/egs/librispeech/ASR/pruned2_knowledge/export.py b/egs/librispeech/ASR/pruned2_knowledge/export.py new file mode 100755 index 000000000..96d1a30fb --- /dev/null +++ b/egs/librispeech/ASR/pruned2_knowledge/export.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" +Usage: +./pruned2_knowledge/export.py \ + --exp-dir ./pruned2_knowledge/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +It will generate a file exp_dir/pretrained.pt + +To use the generated file with `pruned2_knowledge/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./pruned2_knowledge/decode.py \ + --exp-dir ./pruned2_knowledge/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 100 \ + --bpe-model data/lang_bpe_500/bpe.model +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +from train import get_params, get_transducer_model + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned2_knowledge/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + return parser + + +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + assert args.jit is False, "Support torchscript will be added later" + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + model.eval() + + model.to("cpu") + model.eval() + + if params.jit: + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torch.jit.script") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned2_knowledge/joiner.py b/egs/librispeech/ASR/pruned2_knowledge/joiner.py new file mode 100644 index 000000000..35f75ed2a --- /dev/null +++ b/egs/librispeech/ASR/pruned2_knowledge/joiner.py @@ -0,0 +1,67 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from scaling import ScaledLinear + + +class Joiner(nn.Module): + def __init__( + self, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, + vocab_size: int, + ): + super().__init__() + + self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim) + self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim) + self.output_linear = ScaledLinear(joiner_dim, vocab_size) + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + project_input: bool = True, + ) -> torch.Tensor: + """ + Args: + encoder_out: + Output from the encoder. Its shape is (N, T, s_range, C). + decoder_out: + Output from the decoder. Its shape is (N, T, s_range, C). + project_input: + If true, apply input projections encoder_proj and decoder_proj. + If this is false, it is the user's responsibility to do this + manually. + Returns: + Return a tensor of shape (N, T, s_range, C). + """ + assert encoder_out.ndim == decoder_out.ndim == 4 + assert encoder_out.shape[:-1] == decoder_out.shape[:-1] + + if project_input: + logit = self.encoder_proj(encoder_out) + self.decoder_proj( + decoder_out + ) + else: + logit = encoder_out + decoder_out + + logit = self.output_linear(torch.tanh(logit)) + + return logit diff --git a/egs/librispeech/ASR/pruned2_knowledge/model.py b/egs/librispeech/ASR/pruned2_knowledge/model.py new file mode 100644 index 000000000..599bf2506 --- /dev/null +++ b/egs/librispeech/ASR/pruned2_knowledge/model.py @@ -0,0 +1,193 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import k2 +import torch +import torch.nn as nn +from encoder_interface import EncoderInterface +from scaling import ScaledLinear + +from icefall.utils import add_sos + + +class Transducer(nn.Module): + """It implements https://arxiv.org/pdf/1211.3711.pdf + "Sequence Transduction with Recurrent Neural Networks" + """ + + def __init__( + self, + encoder: EncoderInterface, + decoder: nn.Module, + joiner: nn.Module, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, + vocab_size: int, + ): + """ + Args: + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dm) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. + joiner: + It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output contains + unnormalized probs, i.e., not processed by log-softmax. + """ + super().__init__() + assert isinstance(encoder, EncoderInterface), type(encoder) + assert hasattr(decoder, "blank_id") + + self.encoder = encoder + self.decoder = decoder + self.joiner = joiner + + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_speed=0.5 + ) + self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + warmup: float = 1.0, + ) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + warmup: + A value warmup >= 0 that determines which modules are active, values + warmup > 1 "are fully warmed up" and all modules will be active. + Returns: + Return the transducer loss. + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0 + + encoder_out, x_lens = self.encoder(x, x_lens, warmup=warmup) + assert torch.all(x_lens > 0) + + # Now for the decoder, i.e., the prediction network + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + blank_id = self.decoder.blank_id + sos_y = add_sos(y, sos_id=blank_id) + + # sos_y_padded: [B, S + 1], start with SOS. + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + + # decoder_out: [B, S + 1, decoder_dim] + decoder_out = self.decoder(sos_y_padded) + + # Note: y does not start with SOS + # y_padded : [B, S] + y_padded = y.pad(mode="constant", padding_value=0) + + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) + boundary[:, 2] = y_lens + boundary[:, 3] = x_lens + + lm = self.simple_lm_proj(decoder_out) + am = self.simple_am_proj(encoder_out) + + with torch.cuda.amp.autocast(enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction="sum", + return_grad=True, + ) + + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, + ) + + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=self.joiner.encoder_proj(encoder_out), + lm=self.joiner.decoder_proj(decoder_out), + ranges=ranges, + ) + + # logits : [B, T, prune_range, vocab_size] + + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + logits = self.joiner(am_pruned, lm_pruned, project_input=False) + + with torch.cuda.amp.autocast(enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="sum", + ) + + return (simple_loss, pruned_loss) diff --git a/egs/librispeech/ASR/pruned2_knowledge/optim.py b/egs/librispeech/ASR/pruned2_knowledge/optim.py new file mode 100644 index 000000000..432bf8220 --- /dev/null +++ b/egs/librispeech/ASR/pruned2_knowledge/optim.py @@ -0,0 +1,331 @@ +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Optional, Union + +import torch +from torch.optim import Optimizer + + +class Eve(Optimizer): + r""" + Implements Eve algorithm. This is a modified version of AdamW with a special + way of setting the weight-decay / shrinkage-factor, which is designed to make the + rms of the parameters approach a particular target_rms (default: 0.1). This is + for use with networks with 'scaled' versions of modules (see scaling.py), which + will be close to invariant to the absolute scale on the parameter matrix. + + The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. + The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. + Eve is unpublished so far. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 3e-4; + this value means that the weight would decay significantly after + about 3k minibatches. Is not multiplied by learning rate, but + is conditional on RMS-value of parameter being > target_rms. + target_rms (float, optional): target root-mean-square value of + parameters, if they fall below this we will stop applying weight decay. + + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.98), + eps=1e-8, + weight_decay=1e-3, + target_rms=0.1, + ): + + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError( + "Invalid beta parameter at index 0: {}".format(betas[0]) + ) + if not 0.0 <= betas[1] < 1.0: + raise ValueError( + "Invalid beta parameter at index 1: {}".format(betas[1]) + ) + if not 0 <= weight_decay <= 0.1: + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) + if not 0 < target_rms <= 10.0: + raise ValueError("Invalid target_rms value: {}".format(target_rms)) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + target_rms=target_rms, + ) + super(Eve, self).__init__(params, defaults) + + def __setstate__(self, state): + super(Eve, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + # Perform optimization step + grad = p.grad + if grad.is_sparse: + raise RuntimeError( + "AdamW does not support sparse gradients" + ) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + + beta1, beta2 = group["betas"] + + state["step"] += 1 + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( + group["eps"] + ) + + step_size = group["lr"] / bias_correction1 + target_rms = group["target_rms"] + weight_decay = group["weight_decay"] + + if p.numel() > 1: + # avoid applying this weight-decay on "scaling factors" + # (which are scalar). + is_above_target_rms = p.norm() > ( + target_rms * (p.numel() ** 0.5) + ) + p.mul_(1 - (weight_decay * is_above_target_rms)) + p.addcdiv_(exp_avg, denom, value=-step_size) + + return loss + + +class LRScheduler(object): + """ + Base-class for learning rate schedulers where the learning-rate depends on both the + batch and the epoch. + """ + + def __init__(self, optimizer: Optimizer, verbose: bool = False): + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError( + "{} is not an Optimizer".format(type(optimizer).__name__) + ) + self.optimizer = optimizer + self.verbose = verbose + + for group in optimizer.param_groups: + group.setdefault("initial_lr", group["lr"]) + + self.base_lrs = [ + group["initial_lr"] for group in optimizer.param_groups + ] + + self.epoch = 0 + self.batch = 0 + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + """ + return { + "base_lrs": self.base_lrs, + "epoch": self.epoch, + "batch": self.batch, + } + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_last_lr(self) -> List[float]: + """Return last computed learning rate by current scheduler. Will be a list of float.""" + return self._last_lr + + def get_lr(self): + # Compute list of learning rates from self.epoch and self.batch and + # self.base_lrs; this must be overloaded by the user. + # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ] + raise NotImplementedError + + def step_batch(self, batch: Optional[int] = None) -> None: + # Step the batch index, or just set it. If `batch` is specified, it + # must be the batch index from the start of training, i.e. summed over + # all epochs. + # You can call this in any order; if you don't provide 'batch', it should + # of course be called once per batch. + if batch is not None: + self.batch = batch + else: + self.batch = self.batch + 1 + self._set_lrs() + + def step_epoch(self, epoch: Optional[int] = None): + # Step the epoch index, or just set it. If you provide the 'epoch' arg, + # you should call this at the start of the epoch; if you don't provide the 'epoch' + # arg, you should call it at the end of the epoch. + if epoch is not None: + self.epoch = epoch + else: + self.epoch = self.epoch + 1 + self._set_lrs() + + def _set_lrs(self): + values = self.get_lr() + assert len(values) == len(self.optimizer.param_groups) + + for i, data in enumerate(zip(self.optimizer.param_groups, values)): + param_group, lr = data + param_group["lr"] = lr + self.print_lr(self.verbose, i, lr) + self._last_lr = [group["lr"] for group in self.optimizer.param_groups] + + def print_lr(self, is_verbose, group, lr): + """Display the current learning rate.""" + if is_verbose: + print( + f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate" + f" of group {group} to {lr:.4e}." + ) + + +class Eden(LRScheduler): + """ + Eden scheduler. + lr = initial_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 * + (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) + + E.g. suggest initial-lr = 0.003 (passed to optimizer). + + Args: + optimizer: the optimizer to change the learning rates on + lr_batches: the number of batches after which we start significantly + decreasing the learning rate, suggest 5000. + lr_epochs: the number of epochs after which we start significantly + decreasing the learning rate, suggest 6 if you plan to do e.g. + 20 to 40 epochs, but may need smaller number if dataset is huge + and you will do few epochs. + """ + + def __init__( + self, + optimizer: Optimizer, + lr_batches: Union[int, float], + lr_epochs: Union[int, float], + verbose: bool = False, + ): + super(Eden, self).__init__(optimizer, verbose) + self.lr_batches = lr_batches + self.lr_epochs = lr_epochs + + def get_lr(self): + factor = ( + (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 + ) ** -0.25 * ( + ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) + ** -0.25 + ) + return [x * factor for x in self.base_lrs] + + +def _test_eden(): + m = torch.nn.Linear(100, 100) + optim = Eve(m.parameters(), lr=0.003) + + scheduler = Eden(optim, lr_batches=30, lr_epochs=2, verbose=True) + + for epoch in range(10): + scheduler.step_epoch(epoch) # sets epoch to `epoch` + + for step in range(20): + x = torch.randn(200, 100).detach() + x.requires_grad = True + y = m(x) + dy = torch.randn(200, 100).detach() + f = (y * dy).sum() + f.backward() + + optim.step() + scheduler.step_batch() + optim.zero_grad() + print("last lr = ", scheduler.get_last_lr()) + print("state dict = ", scheduler.state_dict()) + + +if __name__ == "__main__": + _test_eden() diff --git a/egs/librispeech/ASR/pruned2_knowledge/sampling.py b/egs/librispeech/ASR/pruned2_knowledge/sampling.py new file mode 100644 index 000000000..7b05e2f00 --- /dev/null +++ b/egs/librispeech/ASR/pruned2_knowledge/sampling.py @@ -0,0 +1,332 @@ +#!/usr/bin/env python3 + +# This was copied from /ceph-dan/torch-sampling/torch_sampling/sampling_ref.py, +# its git history is there. + +import timeit +import torch +from torch import Tensor +from torch import nn +from torch.cuda.amp import GradScaler, custom_fwd, custom_bwd +from typing import Tuple, Optional +from scaling import ScaledLinear +import random +from torch_scheduled_sampling import sample_combined + +# The main exports of this file are the module KnowledgeBaseLookup and the +# function create_knowledge_base. + + + + + + +def create_knowledge_base(M: int, N: int, D: int) -> nn.Parameter: + std = 0.1 + a = (3 ** 0.5) * std # this sqrt(3) thing is intended to get variance of + # 0.1 from uniform distribution + ans = nn.Parameter(torch.ones(M ** N, D)) + nn.init.uniform_(ans, -a, a) + return ans + +def join_indexes(indexes: Tensor, M: int) -> Tensor: + """ + Combines N-tuples of indexes into single indexes that can be used for + lookup in the knowledge base. Args: + indexes: tensor of torch.int64 of shape (*, K, N), with elements in + {0..M-1} + M: the size of the original softmaxes, is upper bound on elements + in indexes + Returns: + joined_indexes: of shape (*, K), joined_indexes[...,k] equals + joined_indexes[...,0,k] + joined_indexes[...,1,k]*(M**1) ... + joined_indexes[...,1,k]*(M**(N-1))] + """ + N = indexes.shape[-1] + n_powers = M ** torch.arange(N, device=indexes.device) # [ 1, M, ..., M**(N-1) ] + return (indexes * n_powers).sum(dim=-1) + + +# Note, we don't use this, we +def weighted_matrix_lookup(weights: Tensor, + indexes: Tensor, + knowledge_base: Tensor) -> Tensor: + """ + Weighted combination of specified rows of a matrix. + weights: Tensor of shape (*, K), can contain any value but probably in [0..1]. + indexes: Tensor of shape (*, K), with elements in [0..C-1] + knowledge_base: Tensor of shape (C-1, D), whose rows we'll be looking up + Returns: + tensor of shape (*, D), containing weighted sums of rows of + `knowledge_base` + """ + if True: + return WeightedMatrixLookupFunction.apply(weights, indexes, knowledge_base) + else: + # simpler but less memory-efficient implementation + lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) + D = knowledge_base.shape[-1] + weights = weights.unsqueeze(-2) # (*, 1, K) + lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) + ans = torch.matmul(weights, lookup) # ans: (*, 1, D) + ans = ans.squeeze(-2) + assert list(ans.shape) == list(weights.shape[:-2]) + [D] + return ans + + +class WeightedMatrixLookupFunction(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward(ctx, weights: Tensor, indexes: Tensor, knowledge_base: Tensor) -> Tensor: + """ + Weighted combination of specified rows of a matrix. + weights: Tensor of shape (*, K), can contain any value but probably in [0..1]. + indexes: Tensor of shape (*, K), with elements in [0..C-1] + knowledge_base: Tensor of shape (C, D), whose rows we'll be looking up + Returns: + tensor of shape (*, D), containing weighted sums of rows of + `knowledge_base` + """ + if random.random() < 0.001: + print("dtype[1] = ", weights.dtype) + ctx.save_for_backward(weights.detach(), indexes.detach(), + knowledge_base.detach()) + with torch.no_grad(): + lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) + D = knowledge_base.shape[-1] + weights = weights.unsqueeze(-2) # (*, 1, K) + lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) + ans = torch.matmul(weights, lookup) # ans: (*, 1, D) + ans = ans.squeeze(-2) #(*, D) + return ans + + @staticmethod + @custom_bwd + def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, Tensor]: + # ans_grad: (*, D) + weights, indexes, knowledge_base = ctx.saved_tensors + knowledge_base.requires_grad = True + dtype = ans_grad.dtype + ans_grad = ans_grad.to(weights.dtype) + assert weights.requires_grad == False + D = knowledge_base.shape[-1] + with torch.enable_grad(): + # we'll use torch's autograd to differentiate this operation, which + # is nontrivial [and anyway we need `lookup` to compute weight grad. + # We don't save `lookup` because it's large, that is the reason + # we override Torch autograd. + lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) + lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) + weights = weights.unsqueeze(-1) # (*, K, 1) + # forward pass: was: + ## ans = torch.matmul(weights, lookup) + ## ans: (*, 1, D) + ## ans = ans.squeeze(-2) # ans, ans_grad: (*, D) + weights_grad = torch.matmul(lookup, # (*, K, D) + ans_grad.unsqueeze(-1)) # (*, D, 1) + weights_grad = weights_grad.squeeze(-1) # (*, K, 1) -> (*, K) + lookup_grad = weights * ans_grad.unsqueeze(-2) # (*, K, 1) * (*, 1, D) = (*, K, D) + lookup.backward(gradient=lookup_grad) + return weights_grad.to(dtype), None, knowledge_base.grad.to(dtype) + + +class PenalizeNegentropyFunction(torch.autograd.Function): + """ + Function that does nothing in forward pass, but in backprop, it is as + if you had added: `- tot_entropy * alpha` to the loss function, where + tot_entropy is the the entropy of the average of the input distributions, + times the number of input distributions. (We multiply by this because + our overall loss function is proportional to the number of frames). + + This will tend to make the entropy want to become as large as possible, + making (-tot_entropy * alpha) as negative as possible. + + Args: + logprobs: Tensor of shape (*, num_classes), should be the result of + calling some_tensor.log_softmax(dim=-1) + Returns: + logprobs + """ + @staticmethod + def forward(ctx, logprobs: Tensor, alpha: float): + ctx.save_for_backward(logprobs.detach()) + ctx.alpha = alpha + return logprobs + + @staticmethod + def backward(ctx, logprobs_grad: Tensor) -> Tuple[Tensor, None]: + logprobs, = ctx.saved_tensors + with torch.enable_grad(): + logprobs.requires_grad = True + # `negentropy` is the negative entropy of the average distribution. + # distributions. It will be <= 0. + l = logprobs.reshape(-1, logprobs.shape[-1]) + scale = ctx.alpha * l.shape[0] + avg_dist = l.exp().mean(dim=0) + negentropy = (avg_dist * (avg_dist + 1.0e-20).log()).sum() + if random.random() < 0.0005: + negentropy_individual = (l * l.exp()).sum(dim=-1).mean() + print("Negentropy[individual,combined] = ", negentropy_individual.item(), ", ", negentropy.item()) + loss = negentropy * scale + loss.backward() + return logprobs_grad + logprobs.grad, None + + +class KnowledgeBaseLookup(nn.Module): + """ + Create knowledge-base lookup module. (The knowledge-base parameter, which is + large, is shared between these modules). + Args: + M: int, softmax size, e.g. in [32..128] + N: int, number of softmaxes, in [2..3] + D: int, embedding dimension in knowledge base, e.g. 256 + K: number of samples (affects speed/accuracy tradeoff), e.g. 16. + embedding_dim: the dimension to project from and to, e.g. the + d_model of the conformer. + """ + def __init__(self, M: int, N: int, D: int, + K: int, embedding_dim: int, + knowledge_base: nn.Parameter, + negentropy_penalty: float = 0.001): + super(KnowledgeBaseLookup, self).__init__() + self.knowledge_base = knowledge_base # shared! + self.in_proj = ScaledLinear(embedding_dim, M * N, + initial_scale=1.0) + # initial_scale = 4.0 because the knowlege_base activations are + # quite small -- if we use our optimizer they'll have stddev <= 0.1. + self.out_proj = ScaledLinear(D, embedding_dim, + initial_scale = 4.0) + self.M = M + self.N = N + self.K = K + self.negentropy_penalty = negentropy_penalty + + def forward(self, x: Tensor) -> Tensor: + """ + Forward function that does knowledge-base lookup. + Args: + x: input, of shape (*, E) where E is embedding_dim + as passed to constructor + y: output of knowledge-base lookup, of shape (*, E) + + # TODO: later we can try multiplying by a projection of x or something like that. + """ + x = self.in_proj(x) # now (*, M*N) + x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M) + x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M) + x = PenalizeNegentropyFunction.apply(x, self.negentropy_penalty) + + _, indexes, weights = sample_combined(x, self.K, input_is_log=True) + x = weighted_matrix_lookup(weights, indexes, self.knowledge_base) # now (*, D) + x = self.out_proj(x) # now (*, self.embedding_dim) + return x + + +def _test_knowledge_base_lookup(): + K = 16 + N = 2 + M = 128 + D = 256 + E = 255 + + knowledge_base: nn.Parameter = create_knowledge_base(M, N, D) + m = KnowledgeBaseLookup(M, N, D, K, E, knowledge_base) + + B = 30 + T = 40 + x = torch.randn(B, T, E) + x.requires_grad = True + y = m(x) + assert y.shape == x.shape + y.sum().backward() # make sure backward doesn't crash.. + print("y = ", y) + print("x.grad = ", x.grad) + print("knowlege_base.grad norm = ", knowledge_base.grad.norm()) + + dtype = torch.float32 + device = torch.device('cuda') + train_pairs = [ (torch.randn(B, T, E, device=device, dtype=dtype), torch.randn(B, T, E, device=device, dtype=dtype)) for _ in range(10) ] + from optim import Eve + optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04) + m = m.to(device).to(dtype) + + + start = timeit.default_timer() + +# Epoch 0, batch 0, loss 1.0109944343566895 +# Epoch 10, batch 0, loss 1.0146660804748535 +# Epoch 20, batch 0, loss 1.0119813680648804 +# Epoch 30, batch 0, loss 1.0105408430099487 +# Epoch 40, batch 0, loss 1.0077732801437378 +# Epoch 50, batch 0, loss 1.0050103664398193 +# Epoch 60, batch 0, loss 1.0033129453659058 +# Epoch 70, batch 0, loss 1.0014232397079468 +# Epoch 80, batch 0, loss 0.9977912306785583 +# Epoch 90, batch 0, loss 0.8274348974227905 +# Epoch 100, batch 0, loss 0.3368612825870514 +# Epoch 110, batch 0, loss 0.11323091387748718 +# Time taken: 17.591704960912466 + for epoch in range(150): + for n, (x,y) in enumerate(train_pairs): + y_out = m(x) + loss = ((y_out - y)**2).mean() * 100.0 + if n % 10 == 0 and epoch % 10 == 0: + print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") + loss.backward() + optimizer.step() + optimizer.zero_grad() + + stop = timeit.default_timer() + print('Time taken: ', stop - start) + +def _test_knowledge_base_lookup_autocast(): + K = 16 + N = 2 + M = 128 + D = 256 + E = 255 + + knowledge_base: nn.Parameter = create_knowledge_base(M, N, D) + m = KnowledgeBaseLookup(M, N, D, K, E, knowledge_base) + + B = 30 + T = 40 + x = torch.randn(B, T, E) + x.requires_grad = True + y = m(x) + assert y.shape == x.shape + y.sum().backward() # make sure backward doesn't crash.. + print("y = ", y) + print("x.grad = ", x.grad) + print("knowlege_base.grad norm = ", knowledge_base.grad.norm()) + + device = torch.device('cuda') + train_pairs = [ (torch.randn(B, T, E, device=device), torch.randn(B, T, E, device=device)) for _ in range(10) ] + from optim import Eve + optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04) + m = m.to(device) + + scaler = GradScaler(enabled=True) + + start = timeit.default_timer() + + + for epoch in range(150): + for n, (x,y) in enumerate(train_pairs): + y_out = m(x) + with torch.cuda.amp.autocast(enabled=True): + loss = ((y_out - y)**2).mean() * 100.0 + if n % 10 == 0 and epoch % 10 == 0: + print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + stop = timeit.default_timer() + print('Time taken: ', stop - start) + + + +if __name__ == '__main__': + _test_knowledge_base_lookup() + _test_knowledge_base_lookup_autocast() diff --git a/egs/librispeech/ASR/pruned2_knowledge/scaling.py b/egs/librispeech/ASR/pruned2_knowledge/scaling.py new file mode 100644 index 000000000..f726c2583 --- /dev/null +++ b/egs/librispeech/ASR/pruned2_knowledge/scaling.py @@ -0,0 +1,707 @@ +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import collections +from itertools import repeat +from typing import Optional, Tuple +from torch.cuda.amp import custom_fwd, custom_bwd + +import torch +import torch.nn as nn +from torch import Tensor + + +def _ntuple(n): + def parse(x): + if isinstance(x, collections.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +_single = _ntuple(1) +_pair = _ntuple(2) + + +class ActivationBalancerFunction(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward( + ctx, + x: Tensor, + channel_dim: int, + min_positive: float, # e.g. 0.05 + max_positive: float, # e.g. 0.95 + max_factor: float, # e.g. 0.01 + min_abs: float, # e.g. 0.2 + max_abs: float, # e.g. 100.0 + ) -> Tensor: + if x.requires_grad: + if channel_dim < 0: + channel_dim += x.ndim + sum_dims = [d for d in range(x.ndim) if d != channel_dim] + xgt0 = x > 0 + proportion_positive = torch.mean( + xgt0.to(x.dtype), dim=sum_dims, keepdim=True + ) + factor1 = ( + (min_positive - proportion_positive).relu() + * (max_factor / min_positive) + if min_positive != 0.0 + else 0.0 + ) + factor2 = ( + (proportion_positive - max_positive).relu() + * (max_factor / (max_positive - 1.0)) + if max_positive != 1.0 + else 0.0 + ) + factor = factor1 + factor2 + if isinstance(factor, float): + factor = torch.zeros_like(proportion_positive) + + mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True) + below_threshold = mean_abs < min_abs + above_threshold = mean_abs > max_abs + + ctx.save_for_backward( + factor, xgt0, below_threshold, above_threshold + ) + ctx.max_factor = max_factor + ctx.sum_dims = sum_dims + return x + + @staticmethod + @custom_bwd + def backward( + ctx, x_grad: Tensor + ) -> Tuple[Tensor, None, None, None, None, None, None]: + factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors + dtype = x_grad.dtype + scale_factor = ( + (below_threshold.to(dtype) - above_threshold.to(dtype)) + * (xgt0.to(dtype) - 0.5) + * (ctx.max_factor * 2.0) + ) + + neg_delta_grad = x_grad.abs() * (factor + scale_factor) + return x_grad - neg_delta_grad, None, None, None, None, None, None + + +class BasicNorm(torch.nn.Module): + """ + This is intended to be a simpler, and hopefully cheaper, replacement for + LayerNorm. The observation this is based on, is that Transformer-type + networks, especially with pre-norm, sometimes seem to set one of the + feature dimensions to a large constant value (e.g. 50), which "defeats" + the LayerNorm because the output magnitude is then not strongly dependent + on the other (useful) features. Presumably the weight and bias of the + LayerNorm are required to allow it to do this. + + So the idea is to introduce this large constant value as an explicit + parameter, that takes the role of the "eps" in LayerNorm, so the network + doesn't have to do this trick. We make the "eps" learnable. + + Args: + num_channels: the number of channels, e.g. 512. + channel_dim: the axis/dimension corresponding to the channel, + interprted as an offset from the input's ndim if negative. + shis is NOT the num_channels; it should typically be one of + {-2, -1, 0, 1, 2, 3}. + eps: the initial "epsilon" that we add as ballast in: + scale = ((input_vec**2).mean() + epsilon)**-0.5 + Note: our epsilon is actually large, but we keep the name + to indicate the connection with conventional LayerNorm. + learn_eps: if true, we learn epsilon; if false, we keep it + at the initial value. + """ + + def __init__( + self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + eps: float = 0.25, + learn_eps: bool = True, + ) -> None: + super(BasicNorm, self).__init__() + self.num_channels = num_channels + self.channel_dim = channel_dim + if learn_eps: + self.eps = nn.Parameter(torch.tensor(eps).log().detach()) + else: + self.register_buffer("eps", torch.tensor(eps).log().detach()) + + def forward(self, x: Tensor) -> Tensor: + assert x.shape[self.channel_dim] == self.num_channels + scales = ( + torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + + self.eps.exp() + ) ** -0.5 + return x * scales + + +class ScaledLinear(nn.Linear): + """ + A modified version of nn.Linear where the parameters are scaled before + use, via: + weight = self.weight * self.weight_scale.exp() + bias = self.bias * self.bias_scale.exp() + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + initial_speed: this affects how fast the parameter will + learn near the start of training; you can set it to a + value less than one if you suspect that a module + is contributing to instability near the start of training. + Nnote: regardless of the use of this option, it's best to + use schedulers like Noam that have a warm-up period. + Alternatively you can set it to more than 1 if you want it to + initially train faster. Must be greater than 0. + """ + + def __init__( + self, + *args, + initial_scale: float = 1.0, + initial_speed: float = 1.0, + **kwargs + ): + super(ScaledLinear, self).__init__(*args, **kwargs) + initial_scale = torch.tensor(initial_scale).log() + self.weight_scale = nn.Parameter(initial_scale.clone().detach()) + if self.bias is not None: + self.bias_scale = nn.Parameter(initial_scale.clone().detach()) + else: + self.register_parameter("bias_scale", None) + + self._reset_parameters( + initial_speed + ) # Overrides the reset_parameters in nn.Linear + + def _reset_parameters(self, initial_speed: float): + std = 0.1 / initial_speed + a = (3 ** 0.5) * std + nn.init.uniform_(self.weight, -a, a) + if self.bias is not None: + nn.init.constant_(self.bias, 0.0) + fan_in = self.weight.shape[1] * self.weight[0][0].numel() + scale = fan_in ** -0.5 # 1/sqrt(fan_in) + with torch.no_grad(): + self.weight_scale += torch.tensor(scale / std).log() + + def get_weight(self): + return self.weight * self.weight_scale.exp() + + def get_bias(self): + return None if self.bias is None else self.bias * self.bias_scale.exp() + + def forward(self, input: Tensor) -> Tensor: + return torch.nn.functional.linear( + input, self.get_weight(), self.get_bias() + ) + + +class ScaledConv1d(nn.Conv1d): + # See docs for ScaledLinear + def __init__( + self, + *args, + initial_scale: float = 1.0, + initial_speed: float = 1.0, + **kwargs + ): + super(ScaledConv1d, self).__init__(*args, **kwargs) + initial_scale = torch.tensor(initial_scale).log() + self.weight_scale = nn.Parameter(initial_scale.clone().detach()) + if self.bias is not None: + self.bias_scale = nn.Parameter(initial_scale.clone().detach()) + else: + self.register_parameter("bias_scale", None) + self._reset_parameters( + initial_speed + ) # Overrides the reset_parameters in base class + + def _reset_parameters(self, initial_speed: float): + std = 0.1 / initial_speed + a = (3 ** 0.5) * std + nn.init.uniform_(self.weight, -a, a) + if self.bias is not None: + nn.init.constant_(self.bias, 0.0) + fan_in = self.weight.shape[1] * self.weight[0][0].numel() + scale = fan_in ** -0.5 # 1/sqrt(fan_in) + with torch.no_grad(): + self.weight_scale += torch.tensor(scale / std).log() + + def get_weight(self): + return self.weight * self.weight_scale.exp() + + def get_bias(self): + return None if self.bias is None else self.bias * self.bias_scale.exp() + + def forward(self, input: Tensor) -> Tensor: + F = torch.nn.functional + if self.padding_mode != "zeros": + return F.conv1d( + F.pad( + input, + self._reversed_padding_repeated_twice, + mode=self.padding_mode, + ), + self.get_weight(), + self.get_bias(), + self.stride, + _single(0), + self.dilation, + self.groups, + ) + return F.conv1d( + input, + self.get_weight(), + self.get_bias(), + self.stride, + self.padding, + self.dilation, + self.groups, + ) + + +class ScaledConv2d(nn.Conv2d): + # See docs for ScaledLinear + def __init__( + self, + *args, + initial_scale: float = 1.0, + initial_speed: float = 1.0, + **kwargs + ): + super(ScaledConv2d, self).__init__(*args, **kwargs) + initial_scale = torch.tensor(initial_scale).log() + self.weight_scale = nn.Parameter(initial_scale.clone().detach()) + if self.bias is not None: + self.bias_scale = nn.Parameter(initial_scale.clone().detach()) + else: + self.register_parameter("bias_scale", None) + self._reset_parameters( + initial_speed + ) # Overrides the reset_parameters in base class + + def _reset_parameters(self, initial_speed: float): + std = 0.1 / initial_speed + a = (3 ** 0.5) * std + nn.init.uniform_(self.weight, -a, a) + if self.bias is not None: + nn.init.constant_(self.bias, 0.0) + fan_in = self.weight.shape[1] * self.weight[0][0].numel() + scale = fan_in ** -0.5 # 1/sqrt(fan_in) + with torch.no_grad(): + self.weight_scale += torch.tensor(scale / std).log() + + def get_weight(self): + return self.weight * self.weight_scale.exp() + + def get_bias(self): + return None if self.bias is None else self.bias * self.bias_scale.exp() + + def _conv_forward(self, input, weight): + F = torch.nn.functional + if self.padding_mode != "zeros": + return F.conv2d( + F.pad( + input, + self._reversed_padding_repeated_twice, + mode=self.padding_mode, + ), + weight, + self.get_bias(), + self.stride, + _pair(0), + self.dilation, + self.groups, + ) + return F.conv2d( + input, + weight, + self.get_bias(), + self.stride, + self.padding, + self.dilation, + self.groups, + ) + + def forward(self, input: Tensor) -> Tensor: + return self._conv_forward(input, self.get_weight()) + + +class ActivationBalancer(torch.nn.Module): + """ + Modifies the backpropped derivatives of a function to try to encourage, for + each channel, that it is positive at least a proportion `threshold` of the + time. It does this by multiplying negative derivative values by up to + (1+max_factor), and positive derivative values by up to (1-max_factor), + interpolated from 1 at the threshold to those extremal values when none + of the inputs are positive. + + + Args: + channel_dim: the dimension/axis corresponding to the channel, e.g. + -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. + min_positive: the minimum, per channel, of the proportion of the time + that (x > 0), below which we start to modify the derivatives. + max_positive: the maximum, per channel, of the proportion of the time + that (x > 0), above which we start to modify the derivatives. + max_factor: the maximum factor by which we modify the derivatives for + either the sign constraint or the magnitude constraint; + e.g. with max_factor=0.02, the the derivatives would be multiplied by + values in the range [0.98..1.02]. + min_abs: the minimum average-absolute-value per channel, which + we allow, before we start to modify the derivatives to prevent + this. + max_abs: the maximum average-absolute-value per channel, which + we allow, before we start to modify the derivatives to prevent + this. + """ + + def __init__( + self, + channel_dim: int, + min_positive: float = 0.05, + max_positive: float = 0.95, + max_factor: float = 0.01, + min_abs: float = 0.2, + max_abs: float = 100.0, + ): + super(ActivationBalancer, self).__init__() + self.channel_dim = channel_dim + self.min_positive = min_positive + self.max_positive = max_positive + self.max_factor = max_factor + self.min_abs = min_abs + self.max_abs = max_abs + + def forward(self, x: Tensor) -> Tensor: + return ActivationBalancerFunction.apply( + x, + self.channel_dim, + self.min_positive, + self.max_positive, + self.max_factor, + self.min_abs, + self.max_abs, + ) + + +class DoubleSwishFunction(torch.autograd.Function): + """ + double_swish(x) = x * torch.sigmoid(x-1) + This is a definition, originally motivated by its close numerical + similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). + + Memory-efficient derivative computation: + double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) + double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). + Now, s'(x) = s(x) * (1-s(x)). + double_swish'(x) = x * s'(x) + s(x). + = x * s(x) * (1-s(x)) + s(x). + = double_swish(x) * (1-s(x)) + s(x) + ... so we just need to remember s(x) but not x itself. + """ + + @staticmethod + @custom_fwd + def forward(ctx, x: Tensor) -> Tensor: + x = x.detach() + s = torch.sigmoid(x - 1.0) + y = x * s + ctx.save_for_backward(s, y) + return y + + @staticmethod + @custom_bwd + def backward(ctx, y_grad: Tensor) -> Tensor: + s, y = ctx.saved_tensors + return (y * (1 - s) + s) * y_grad + + +class DoubleSwish(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return double-swish activation function which is an approximation to Swish(Swish(x)), + that we approximate closely with x * sigmoid(x-1). + """ + return DoubleSwishFunction.apply(x) + + +class ScaledEmbedding(nn.Module): + r"""This is a modified version of nn.Embedding that introduces a learnable scale + on the parameters. Note: due to how we initialize it, it's best used with + schedulers like Noam that have a warmup period. + + It is a simple lookup table that stores embeddings of a fixed dictionary and size. + + This module is often used to store word embeddings and retrieve them using indices. + The input to the module is a list of indices, and the output is the corresponding + word embeddings. + + Args: + num_embeddings (int): size of the dictionary of embeddings + embedding_dim (int): the size of each embedding vector + padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx` + (initialized to zeros) whenever it encounters the index. + max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` + is renormalized to have norm :attr:`max_norm`. + norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. + scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of + the words in the mini-batch. Default ``False``. + sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. + See Notes for more details regarding sparse gradients. + + initial_speed (float, optional): This affects how fast the parameter will + learn near the start of training; you can set it to a value less than + one if you suspect that a module is contributing to instability near + the start of training. Nnote: regardless of the use of this option, + it's best to use schedulers like Noam that have a warm-up period. + Alternatively you can set it to more than 1 if you want it to + initially train faster. Must be greater than 0. + + + Attributes: + weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim) + initialized from :math:`\mathcal{N}(0, 1)` + + Shape: + - Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract + - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}` + + .. note:: + Keep in mind that only a limited number of optimizers support + sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`), + :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`) + + .. note:: + With :attr:`padding_idx` set, the embedding vector at + :attr:`padding_idx` is initialized to all zeros. However, note that this + vector can be modified afterwards, e.g., using a customized + initialization method, and thus changing the vector used to pad the + output. The gradient for this vector from :class:`~torch.nn.Embedding` + is always zero. + + Examples:: + + >>> # an Embedding module containing 10 tensors of size 3 + >>> embedding = nn.Embedding(10, 3) + >>> # a batch of 2 samples of 4 indices each + >>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]]) + >>> embedding(input) + tensor([[[-0.0251, -1.6902, 0.7172], + [-0.6431, 0.0748, 0.6969], + [ 1.4970, 1.3448, -0.9685], + [-0.3677, -2.7265, -0.1685]], + + [[ 1.4970, 1.3448, -0.9685], + [ 0.4362, -0.4004, 0.9400], + [-0.6431, 0.0748, 0.6969], + [ 0.9124, -2.3616, 1.1151]]]) + + + >>> # example with padding_idx + >>> embedding = nn.Embedding(10, 3, padding_idx=0) + >>> input = torch.LongTensor([[0,2,0,5]]) + >>> embedding(input) + tensor([[[ 0.0000, 0.0000, 0.0000], + [ 0.1535, -2.0309, 0.9315], + [ 0.0000, 0.0000, 0.0000], + [-0.1655, 0.9897, 0.0635]]]) + + """ + __constants__ = [ + "num_embeddings", + "embedding_dim", + "padding_idx", + "scale_grad_by_freq", + "sparse", + ] + + num_embeddings: int + embedding_dim: int + padding_idx: int + scale_grad_by_freq: bool + weight: Tensor + sparse: bool + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + scale_grad_by_freq: bool = False, + sparse: bool = False, + initial_speed: float = 1.0, + ) -> None: + super(ScaledEmbedding, self).__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + if padding_idx is not None: + if padding_idx > 0: + assert ( + padding_idx < self.num_embeddings + ), "Padding_idx must be within num_embeddings" + elif padding_idx < 0: + assert ( + padding_idx >= -self.num_embeddings + ), "Padding_idx must be within num_embeddings" + padding_idx = self.num_embeddings + padding_idx + self.padding_idx = padding_idx + self.scale_grad_by_freq = scale_grad_by_freq + + self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() + self.sparse = sparse + + self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) + self.reset_parameters(initial_speed) + + def reset_parameters(self, initial_speed: float = 1.0) -> None: + std = 0.1 / initial_speed + nn.init.normal_(self.weight, std=std) + nn.init.constant_(self.scale, torch.tensor(1.0 / std).log()) + + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input: Tensor) -> Tensor: + F = torch.nn.functional + scale = self.scale.exp() + if input.numel() < self.num_embeddings: + return ( + F.embedding( + input, + self.weight, + self.padding_idx, + None, + 2.0, # None, 2.0 relate to normalization + self.scale_grad_by_freq, + self.sparse, + ) + * scale + ) + else: + return F.embedding( + input, + self.weight * scale, + self.padding_idx, + None, + 2.0, # None, 2.0 relates to normalization + self.scale_grad_by_freq, + self.sparse, + ) + + def extra_repr(self) -> str: + s = "{num_embeddings}, {embedding_dim}, scale={scale}" + if self.padding_idx is not None: + s += ", padding_idx={padding_idx}" + if self.scale_grad_by_freq is not False: + s += ", scale_grad_by_freq={scale_grad_by_freq}" + if self.sparse is not False: + s += ", sparse=True" + return s.format(**self.__dict__) + + +def _test_activation_balancer_sign(): + probs = torch.arange(0, 1, 0.01) + N = 1000 + x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) + x = x.detach() + x.requires_grad = True + m = ActivationBalancer( + channel_dim=0, + min_positive=0.05, + max_positive=0.95, + max_factor=0.2, + min_abs=0.0, + ) + + y_grad = torch.sign(torch.randn(probs.numel(), N)) + + y = m(x) + y.backward(gradient=y_grad) + print("_test_activation_balancer_sign: x = ", x) + print("_test_activation_balancer_sign: y grad = ", y_grad) + print("_test_activation_balancer_sign: x grad = ", x.grad) + + +def _test_activation_balancer_magnitude(): + magnitudes = torch.arange(0, 1, 0.01) + N = 1000 + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( + -1 + ) + x = x.detach() + x.requires_grad = True + m = ActivationBalancer( + channel_dim=0, + min_positive=0.0, + max_positive=1.0, + max_factor=0.2, + min_abs=0.2, + max_abs=0.8, + ) + + y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) + + y = m(x) + y.backward(gradient=y_grad) + print("_test_activation_balancer_magnitude: x = ", x) + print("_test_activation_balancer_magnitude: y grad = ", y_grad) + print("_test_activation_balancer_magnitude: x grad = ", x.grad) + + +def _test_basic_norm(): + num_channels = 128 + m = BasicNorm(num_channels=num_channels, channel_dim=1) + + x = torch.randn(500, num_channels) + + y = m(x) + + assert y.shape == x.shape + x_rms = (x ** 2).mean().sqrt() + y_rms = (y ** 2).mean().sqrt() + print("x rms = ", x_rms) + print("y rms = ", y_rms) + assert y_rms < x_rms + assert y_rms > 0.5 * x_rms + + +def _test_double_swish_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 0.5 + x.requires_grad = True + m = DoubleSwish() + torch.autograd.gradcheck(m, x) + + +if __name__ == "__main__": + _test_activation_balancer_sign() + _test_activation_balancer_magnitude() + _test_basic_norm() + _test_double_swish_deriv() diff --git a/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py b/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py new file mode 100644 index 000000000..6293e081a --- /dev/null +++ b/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py @@ -0,0 +1,628 @@ +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn +from torch import Tensor +from typing import Tuple, Optional + + + +def _activation_balancer_loss(mean_pos: Tensor, + mean_neg: Tensor, + min_positive: float, # e.g. 0.05 + max_positive: float, # e.g. 0.95 + max_factor: float, # e.g. 0.01 + min_abs: float, # e.g. 0.2 + max_abs: float, # e.g. 100.0 + eps: float = 1.0e-10): + """ + Returns a loss-function for the ActivationBalancer module. This loss + function is not exposed to the user but is used internally, and eventually + its derivatives are scaled by some heuristic related to derivative magnitudes, + and added to the backpropped deriv. + + Args: + mean_pos: a Tensor of arbitrary dimension, probably something like (1, num_channels, 1, 1), + containing the mean of only the positive parts of the input features, i.e. + of x.relu(). + mean_neg: a Tensor of arbitrary dimension, probably something like (1, num_channels, 1, 1), + containing the mean of only the negative parts of the input features, i.e. + of (-x).relu(). + min_positive: the minimum allowed value of mean_pos / (mean_pos + mean_neg) before we + start penalizing. + max_positive: the maximum allowed value of mean_pos / (mean_pos + mean_neg) before we + start penalizing. + """ + loss_parts = [] + + x_mean = mean_positive - mean_negative + x_mean_abs = (mean_positive + mean_negative + eps).detach() + x_rel_mean= x_mean / x_mean_abs + + if min_positive != 0.0: + # e.g. x_mean_floor = -0.95 + 0.05 = -0.9 + x_rel_mean_floor = (-(1-min_positive) + min_positive) + min_positive_loss = (x_rel_mean_floor - x_rel_mean).relu().sum() * (1.0 / (2*min_positive)) + # this part of the loss would be 1.0 * num_channels if all these constraints were + # 100% violated. + loss_parts.append(min_positive_loss) + + if max_positive != 1.0: + # e.g. x_mean_floor = -0.05 + 0.95 = 0.8 + x_rel_mean_ceil = - (1.0-max_positive) + max_positive + max_positive_loss = (x_rel_mean - x_rel_mean_ceil).relu().sum() * (1.0 / (1 - x_rel_mean_ceil)) + # this part of the loss would be 1.0 * num_channels if all these constraints were + # 100% violated. + loss_parts.append(max_positive_loss) + + if min_abs != 0.0: + min_abs_loss = min_abs - x_mean_abs).relu().sum() / min_abs + # this part of the loss would be 1.0 * num_channels if all these constraints were + # 100% violated. + loss_parts.append(min_abs_loss) + + if max_abs != 0.0: + max_abs_loss = (x_mean_abs / max_abs).log().relu() + # this part of the loss would be [something logarithmic] * num_channels if all these constraints were + # 100% violated. + loss_parts.append(max_abs_loss) + + + # the min_positive and 1 - max_positive are "ballast" added to the + denom = mean_pos + mean_neg + (min_positive + (1 - max_positive)) + num + + if min_positive != 0.0: + + + + +class ActivationBalancerFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, + channel_dim: int, + min_positive: float, # e.g. 0.05 + max_positive: float, # e.g. 0.95 + max_factor: float, # e.g. 0.01 + min_abs: float, # e.g. 0.2 + max_abs: float, # e.g. 100.0 + ) -> Tensor: + if x.requires_grad: + if channel_dim < 0: + channel_dim += x.ndim + sum_dims = [d for d in range(x.ndim) if d != channel_dim] + xgt0 = x > 0 + proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True) + factor1 = ((min_positive - proportion_positive).relu() * (max_factor / min_positive) + if min_positive != 0.0 else 0.0) + factor2 = ((proportion_positive - max_positive).relu() * (max_factor / (max_positive - 1.0)) + if max_positive != 1.0 else 0.0) + factor = factor1 + factor2 + if isinstance(factor, float): + factor = torch.zeros_like(proportion_positive) + + mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True) + below_threshold = (mean_abs < min_abs) + above_threshold = (mean_abs > max_abs) + + ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) + ctx.max_factor = max_factor + ctx.sum_dims = sum_dims + return x + + @staticmethod + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None, None]: + factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors + dtype = x_grad.dtype + scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) * + (xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0)) + + neg_delta_grad = x_grad.abs() * (factor + scale_factor) + return x_grad - neg_delta_grad, None, None, None, None, None, None + + +class BasicNorm(torch.nn.Module): + """ + This is intended to be a simpler, and hopefully cheaper, replacement for + LayerNorm. The observation this is based on, is that Transformer-type + networks, especially with pre-norm, sometimes seem to set one of the + feature dimensions to a large constant value (e.g. 50), which "defeats" + the LayerNorm because the output magnitude is then not strongly dependent + on the other (useful) features. Presumably the weight and bias of the + LayerNorm are required to allow it to do this. + + So the idea is to introduce this large constant value as an explicit + parameter, that takes the role of the "eps" in LayerNorm, so the network + doesn't have to do this trick. We make the "eps" learnable. + + Args: + num_channels: the number of channels, e.g. 512. + channel_dim: the axis/dimension corresponding to the channel, + interprted as an offset from the input's ndim if negative. + shis is NOT the num_channels; it should typically be one of + {-2, -1, 0, 1, 2, 3}. + eps: the initial "epsilon" that we add as ballast in: + scale = ((input_vec**2).mean() + epsilon)**-0.5 + Note: our epsilon is actually large, but we keep the name + to indicate the connection with conventional LayerNorm. + learn_eps: if true, we learn epsilon; if false, we keep it + at the initial value. + """ + def __init__(self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + eps: float = 0.25, + learn_eps: bool = True) -> None: + super(BasicNorm, self).__init__() + self.num_channels = num_channels + self.channel_dim = channel_dim + if learn_eps: + self.eps = nn.Parameter(torch.tensor(eps).log().detach()) + else: + self.register_buffer('eps', torch.tensor(eps).log().detach()) + + + def forward(self, x: Tensor) -> Tensor: + assert x.shape[self.channel_dim] == self.num_channels + scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) + + self.eps.exp()) ** -0.5 + return x * scales + + + + +class ScaledLinear(nn.Linear): + """ + A modified version of nn.Linear where the parameters are scaled before + use, via: + weight = self.weight * self.weight_scale.exp() + bias = self.bias * self.bias_scale.exp() + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + + Note: it uses the default initialization for the weight and bias, + inherited from nn.Linear. For modules with small fan-in, this + may be larger than optimal. + """ + def __init__(self, *args, + initial_scale: float = 1.0, + **kwargs): + super(ScaledLinear, self).__init__(*args, **kwargs) + initial_scale = torch.tensor(initial_scale).log() + self.weight_scale = nn.Parameter(initial_scale.clone().detach()) + if self.bias is not None: + self.bias_scale = nn.Parameter(initial_scale.clone().detach()) + else: + self.register_parameter('bias_scale', None) + + self._reset_parameters() # Overrides the reset_parameters in nn.Linear + + def _reset_parameters(self): + std = 0.01 + a = (3 ** 0.5) * std + nn.init.uniform_(self.weight, -a, a) + if self.bias is not None: + nn.init.constant_(self.bias, 0.0) + fan_in = self.weight.shape[1] * self.weight[0][0].numel() + scale = fan_in ** -0.5 # 1/sqrt(fan_in) + with torch.no_grad(): + self.weight_scale += torch.tensor(scale / std).log() + if self.bias is not None: + self.bias_scale += torch.tensor(scale / std).log() + + def get_weight(self): + return self.weight * self.weight_scale.exp() + + def get_bias(self): + return (None if self.bias is None else + self.bias * self.bias_scale.exp()) + + def forward(self, input: Tensor) -> Tensor: + return torch.nn.functional.linear(input, self.get_weight(), + self.get_bias()) + + +class ScaledConv1d(nn.Conv1d): + def __init__(self, *args, + initial_scale=1.0, **kwargs): + super(ScaledConv1d, self).__init__(*args, **kwargs) + initial_scale = torch.tensor(initial_scale).log() + self.weight_scale = nn.Parameter(initial_scale.clone().detach()) + if self.bias is not None: + self.bias_scale = nn.Parameter(initial_scale.clone().detach()) + else: + self.register_parameter('bias_scale', None) + self._reset_parameters() # Overrides the reset_parameters in base class + + def _reset_parameters(self): + std = 0.01 + a = (3 ** 0.5) * std + nn.init.uniform_(self.weight, -a, a) + if self.bias is not None: + nn.init.constant_(self.bias, 0.0) + fan_in = self.weight.shape[1] * self.weight[0][0].numel() + scale = fan_in ** -0.5 # 1/sqrt(fan_in) + with torch.no_grad(): + self.weight_scale += torch.tensor(scale / std).log() + if self.bias is not None: + self.bias_scale += torch.tensor(scale / std).log() + + + def get_weight(self): + return self.weight * self.weight_scale.exp() + + def get_bias(self): + return (None if self.bias is None else + self.bias * self.bias_scale.exp()) + + def forward(self, input: Tensor) -> Tensor: + F = torch.nn.functional + if self.padding_mode != 'zeros': + return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), + self.get_weight(), self.get_bias(), self.stride, + _single(0), self.dilation, self.groups) + return F.conv1d(input, self.get_weight(), self.get_bias(), self.stride, + self.padding, self.dilation, self.groups) + + + +class ScaledConv2d(nn.Conv2d): + def __init__(self, *args, initial_scale=1.0, **kwargs): + super(ScaledConv2d, self).__init__(*args, **kwargs) + initial_scale = torch.tensor(initial_scale).log() + self.weight_scale = nn.Parameter(initial_scale.clone().detach()) + if self.bias is not None: + self.bias_scale = nn.Parameter(initial_scale.clone().detach()) + else: + self.register_parameter('bias_scale', None) + self._reset_parameters() # Overrides the reset_parameters in base class + + def _reset_parameters(self): + std = 0.01 + a = (3 ** 0.5) * std + nn.init.uniform_(self.weight, -a, a) + if self.bias is not None: + nn.init.constant_(self.bias, 0.0) + fan_in = self.weight.shape[1] * self.weight[0][0].numel() + scale = fan_in ** -0.5 # 1/sqrt(fan_in) + with torch.no_grad(): + self.weight_scale += torch.tensor(scale / std).log() + if self.bias is not None: + self.bias_scale += torch.tensor(scale / std).log() + + + def get_weight(self): + return self.weight * self.weight_scale.exp() + + def get_bias(self): + return (None if self.bias is None else + self.bias * self.bias_scale.exp()) + + def _conv_forward(self, input, weight): + F = torch.nn.functional + if self.padding_mode != 'zeros': + return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), + weight, self.get_bias(), self.stride, + _pair(0), self.dilation, self.groups) + return F.conv2d(input, weight, self.get_bias(), self.stride, + self.padding, self.dilation, self.groups) + + def forward(self, input: Tensor) -> Tensor: + return self._conv_forward(input, self.get_weight()) + + + + +class ActivationBalancer(torch.nn.Module): + """ + Modifies the backpropped derivatives of a function to try to encourage, for + each channel, that it is positive at least a proportion `threshold` of the + time. It does this by multiplying negative derivative values by up to + (1+max_factor), and positive derivative values by up to (1-max_factor), + interpolated from 1 at the threshold to those extremal values when none + of the inputs are positive. + + + Args: + channel_dim: the dimension/axis corresponding to the channel, e.g. + -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. + min_positive: the minimum, per channel, of the proportion of the time + that (x > 0), below which we start to modify the derivatives. + max_positive: the maximum, per channel, of the proportion of the time + that (x > 0), below which we start to modify the derivatives. + max_factor: the maximum factor by which we modify the derivatives for + either the sign constraint or the magnitude constraint; + e.g. with max_factor=0.02, the the derivatives would be multiplied by + values in the range [0.98..1.02]. + min_abs: the minimum average-absolute-value per channel, which + we allow, before we start to modify the derivatives to prevent + this. + max_abs: the maximum average-absolute-value per channel, which + we allow, before we start to modify the derivatives to prevent + this. + """ + def __init__(self, channel_dim: int, + min_positive: float = 0.05, + max_positive: float = 0.95, + max_factor: float = 0.01, + min_abs: float = 0.2, + max_abs: float = 100.0): + super(ActivationBalancer, self).__init__() + self.channel_dim = channel_dim + self.min_positive = min_positive + self.max_positive = max_positive + self.max_factor = max_factor + self.min_abs = min_abs + self.max_abs = max_abs + + def forward(self, x: Tensor) -> Tensor: + return ActivationBalancerFunction.apply(x, self.channel_dim, + self.min_positive, self.max_positive, + self.max_factor, self.min_abs, + self.max_abs) + + +class DoubleSwishFunction(torch.autograd.Function): + """ + double_swish(x) = x * torch.sigmoid(x-1) + This is a definition, originally motivated by its close numerical + similarity to swish(swish(x), where swish(x) = x * sigmoid(x). + + Memory-efficient derivative computation: + double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) + double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). + Now, s'(x) = s(x) * (1-s(x)). + double_swish'(x) = x * s'(x) + s(x). + = x * s(x) * (1-s(x)) + s(x). + = double_swish(x) * (1-s(x)) + s(x) + ... so we just need to remember s(x) but not x itself. + """ + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + x = x.detach() + s = torch.sigmoid(x - 1.0) + y = x * s + ctx.save_for_backward(s, y) + return y + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + s, y = ctx.saved_tensors + return (y * (1-s) + s) * y_grad + +class DoubleSwish(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return double-swish activation function which is an approximation to Swish(Swish(x)), + that we approximate closely with x * sigmoid(x-1). + """ + return DoubleSwishFunction.apply(x) + + + + +class ScaledEmbedding(nn.Module): + r"""A simple lookup table that stores embeddings of a fixed dictionary and size. + + This module is often used to store word embeddings and retrieve them using indices. + The input to the module is a list of indices, and the output is the corresponding + word embeddings. + + Args: + num_embeddings (int): size of the dictionary of embeddings + embedding_dim (int): the size of each embedding vector + padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx` + (initialized to zeros) whenever it encounters the index. + max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` + is renormalized to have norm :attr:`max_norm`. + norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. + scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of + the words in the mini-batch. Default ``False``. + sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. + See Notes for more details regarding sparse gradients. + + Attributes: + weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim) + initialized from :math:`\mathcal{N}(0, 1)` + + Shape: + - Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract + - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}` + + .. note:: + Keep in mind that only a limited number of optimizers support + sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`), + :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`) + + .. note:: + With :attr:`padding_idx` set, the embedding vector at + :attr:`padding_idx` is initialized to all zeros. However, note that this + vector can be modified afterwards, e.g., using a customized + initialization method, and thus changing the vector used to pad the + output. The gradient for this vector from :class:`~torch.nn.Embedding` + is always zero. + + Examples:: + + >>> # an Embedding module containing 10 tensors of size 3 + >>> embedding = nn.Embedding(10, 3) + >>> # a batch of 2 samples of 4 indices each + >>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]]) + >>> embedding(input) + tensor([[[-0.0251, -1.6902, 0.7172], + [-0.6431, 0.0748, 0.6969], + [ 1.4970, 1.3448, -0.9685], + [-0.3677, -2.7265, -0.1685]], + + [[ 1.4970, 1.3448, -0.9685], + [ 0.4362, -0.4004, 0.9400], + [-0.6431, 0.0748, 0.6969], + [ 0.9124, -2.3616, 1.1151]]]) + + + >>> # example with padding_idx + >>> embedding = nn.Embedding(10, 3, padding_idx=0) + >>> input = torch.LongTensor([[0,2,0,5]]) + >>> embedding(input) + tensor([[[ 0.0000, 0.0000, 0.0000], + [ 0.1535, -2.0309, 0.9315], + [ 0.0000, 0.0000, 0.0000], + [-0.1655, 0.9897, 0.0635]]]) + """ + __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', + 'scale_grad_by_freq', 'sparse'] + + num_embeddings: int + embedding_dim: int + padding_idx: int + scale_grad_by_freq: bool + weight: Tensor + sparse: bool + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, + scale_grad_by_freq: bool = False, + sparse: bool = False) -> None: + super(ScaledEmbedding, self).__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + if padding_idx is not None: + if padding_idx > 0: + assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' + elif padding_idx < 0: + assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' + padding_idx = self.num_embeddings + padding_idx + self.padding_idx = padding_idx + self.scale_grad_by_freq = scale_grad_by_freq + + self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() + self.sparse = sparse + + self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) + self.reset_parameters() + + + + def reset_parameters(self) -> None: + std = 0.01 + nn.init.normal_(self.weight, std=std) + nn.init.constant_(self.scale, torch.tensor(1.0/std).log()) + + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input: Tensor) -> Tensor: + F = torch.nn.functional + scale = self.scale.exp() + if input.numel() < self.num_embeddings: + return F.embedding( + input, self.weight, self.padding_idx, + None, 2.0, # None, 2.0 relate to normalization + self.scale_grad_by_freq, self.sparse) * scale + else: + return F.embedding( + input, self.weight * scale, self.padding_idx, + None, 2.0, # None, 2.0 relates to normalization + self.scale_grad_by_freq, self.sparse) + + def extra_repr(self) -> str: + s = '{num_embeddings}, {embedding_dim}, scale={scale}' + if self.padding_idx is not None: + s += ', padding_idx={padding_idx}' + if self.scale_grad_by_freq is not False: + s += ', scale_grad_by_freq={scale_grad_by_freq}' + if self.sparse is not False: + s += ', sparse=True' + return s.format(**self.__dict__) + + +def _test_activation_balancer_sign(): + channel_dim = 0 + probs = torch.arange(0, 1, 0.01) + N = 1000 + x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) + x = x.detach() + x.requires_grad = True + m = ActivationBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95, + max_factor=0.2, min_abs=0.0) + + y_grad = torch.sign(torch.randn(probs.numel(), N)) + + y = m(x) + y.backward(gradient=y_grad) + print("_test_activation_balancer_sign: x = ", x) + print("_test_activation_balancer_sign: y grad = ", y_grad) + print("_test_activation_balancer_sign: x grad = ", x.grad) + +def _test_activation_balancer_magnitude(): + channel_dim = 0 + magnitudes = torch.arange(0, 1, 0.01) + N = 1000 + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) + x = x.detach() + x.requires_grad = True + m = ActivationBalancer(channel_dim=0, + min_positive=0.0, max_positive=1.0, + max_factor=0.2, + min_abs=0.2, max_abs=0.8) + + y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) + + y = m(x) + y.backward(gradient=y_grad) + print("_test_activation_balancer_magnitude: x = ", x) + print("_test_activation_balancer_magnitude: y grad = ", y_grad) + print("_test_activation_balancer_magnitude: x grad = ", x.grad) + + +def _test_basic_norm(): + num_channels = 128 + m = BasicNorm(num_channels=num_channels, channel_dim=1) + + x = torch.randn(500, num_channels) + + y = m(x) + + assert y.shape == x.shape + x_rms = (x**2).mean().sqrt() + y_rms = (y**2).mean().sqrt() + print("x rms = ", x_rms) + print("y rms = ", y_rms) + assert y_rms < x_rms + assert y_rms > 0.5 * x_rms + + +def _test_double_swish_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 0.5 + x.requires_grad = True + m = DoubleSwish() + torch.autograd.gradcheck(m, x) + + +if __name__ == '__main__': + _test_activation_balancer_sign() + _test_activation_balancer_magnitude() + _test_basic_norm() + _test_double_swish_deriv() diff --git a/egs/librispeech/ASR/pruned2_knowledge/train.py b/egs/librispeech/ASR/pruned2_knowledge/train.py new file mode 100755 index 000000000..2f6840166 --- /dev/null +++ b/egs/librispeech/ASR/pruned2_knowledge/train.py @@ -0,0 +1,997 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned2_knowledge/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 0 \ + --exp-dir pruned2_knowledge/exp \ + --full-libri 1 \ + --max-duration 300 + +# For mix precision training: + +./pruned2_knowledge/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 0 \ + --use_fp16 1 \ + --exp-dir pruned2_knowledge/exp \ + --full-libri 1 \ + --max-duration 550 + +""" + + +import argparse +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from conformer import Conformer +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, Eve +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import save_checkpoint_with_global_batch_idx +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=0, + help="""Resume training from from this epoch. + If it is positive, it will load checkpoint from + transducer_stateless2/exp/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned2_knowledge/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--initial-lr", + type=float, + default=0.003, + help="The initial learning rate. This value should not need to be changed.", + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate decreases. + We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=6, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" + "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=8000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=20, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for conformer + "feature_dim": 80, + "subsampling_factor": 4, + "encoder_dim": 512, + "nhead": 8, + "dim_feedforward": 2048, + "num_encoder_layers": 18, + # parameters for decoder + "decoder_dim": 512, + # parameters for joiner + "joiner_dim": 512, + # parameters for Noam + "model_warm_step": 3000, # arg given to model, not for lrate + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Conformer and Transformer + encoder = Conformer( + num_features=params.feature_dim, + subsampling_factor=params.subsampling_factor, + d_model=params.encoder_dim, + nhead=params.nhead, + dim_feedforward=params.dim_feedforward, + num_encoder_layers=params.num_encoder_layers, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=params.encoder_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=params.encoder_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 0: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, + warmup: float = 1.0, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + warmup=warmup, + ) + # after the main warmup step, we keep pruned_loss_scale small + # for the same amount of time (model_warm_step), to avoid + # overwhelming the simple_loss and causing it to diverge, + # in case it had not fully learned the alignment yet. + pruned_loss_scale = ( + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss + ) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup=(params.batch_idx_train / params.model_warm_step), + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}" + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + if params.full_libri is False: + params.valid_interval = 1600 + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank]) + model.device = device + + optimizer = Eve(model.parameters(), lr=params.initial_lr) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2 ** 22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + librispeech = LibriSpeechAsrDataModule(args) + + train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 1.0 <= c.duration <= 20.0 + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs): + scheduler.step_epoch(epoch) + fix_random_seed(params.seed + epoch) + train_dl.sampler.set_epoch(epoch) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def scan_pessimistic_batches_for_oom( + model: nn.Module, + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 0 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + # warmup = 0.0 is so that the derivs for the pruned loss stay zero + # (i.e. are not remembered by the decaying-average in adam), because + # we want to avoid these params being subject to shrinkage in adam. + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup=0.0, + ) + loss.backward() + optimizer.step() + optimizer.zero_grad() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index a3fa6cc7c..da88e257b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -537,7 +537,7 @@ def greedy_search( device = next(model.parameters()).device decoder_input = torch.tensor( - [blank_id] * context_size, device=device, dtype=torch.int64 + [-1] * (context_size - 1) + [blank_id], device=device, dtype=torch.int64 ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -646,7 +646,7 @@ def greedy_search_batch( assert torch.all(encoder_out_lens > 0), encoder_out_lens assert N == batch_size_list[0], (N, batch_size_list) - hyps = [[blank_id] * context_size for _ in range(N)] + hyps = [[-1] * (context_size - 1) + [blank_id] for _ in range(N)] # timestamp[n][i] is the frame index after subsampling # on which hyp[n][i] is decoded diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index b04a74a19..bc273d33b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -1621,6 +1621,8 @@ class Conv2dSubsampling(nn.Module): if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) feature_dim = 50 c = Conformer(num_features=feature_dim, d_model=128, nhead=4) batch_size = 5 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/__init__.py b/egs/librispeech/ASR/pruned_transducer_stateless7/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless7/asr_datamodule.py new file mode 120000 index 000000000..a074d6085 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/asr_datamodule.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless7/beam_search.py new file mode 120000 index 000000000..8554e44cc --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py new file mode 100755 index 000000000..06c5863f1 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py @@ -0,0 +1,854 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--simulate-streaming", + type=str2bool, + default=False, + help="""Whether to simulate streaming in decoding, this is a good way to + test a streaming model. + """, + ) + + parser.add_argument( + "--decode-chunk-size", + type=int, + default=16, + help="The chunk size for decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--left-context", + type=int, + default=64, + help="left context can be seen during decoding (in frames after subsampling)", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.simulate_streaming: + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) + encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( + x=feature, + x_lens=feature_lens, + chunk_size=params.decode_chunk_size, + left_context=params.left_context, + simulate_streaming=True, + ) + else: + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.simulate_streaming: + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" + params.suffix += f"-left-context-{params.left_context}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if params.simulate_streaming: + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py new file mode 100644 index 000000000..712dc8ce1 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py @@ -0,0 +1,104 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Decoder(nn.Module): + """This class modifies the stateless decoder from the following paper: + + RNN-transducer with stateless prediction network + https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 + + It removes the recurrent connection from the decoder, i.e., the prediction + network. Different from the above paper, it adds an extra Conv1d + right after the embedding layer. + + TODO: Implement https://arxiv.org/pdf/2109.07513.pdf + """ + + def __init__( + self, + vocab_size: int, + decoder_dim: int, + blank_id: int, + context_size: int, + ): + """ + Args: + vocab_size: + Number of tokens of the modeling unit including blank. + decoder_dim: + Dimension of the input embedding, and of the decoder output. + blank_id: + The ID of the blank symbol. + context_size: + Number of previous words to use to predict the next word. + 1 means bigram; 2 means trigram. n means (n+1)-gram. + """ + super().__init__() + + self.embedding = nn.Embedding( + num_embeddings=vocab_size, + embedding_dim=decoder_dim, + padding_idx=blank_id, + ) + self.blank_id = blank_id + + assert context_size >= 1, context_size + self.context_size = context_size + self.vocab_size = vocab_size + if context_size > 1: + self.conv = nn.Conv1d( + in_channels=decoder_dim, + out_channels=decoder_dim, + kernel_size=context_size, + padding=0, + groups=decoder_dim//4, # group size == 4 + bias=False, + ) + + def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, U). + need_pad: + True to left pad the input. Should be True during training. + False to not pad the input. Should be False during inference. + Returns: + Return a tensor of shape (N, U, decoder_dim). + """ + y = y.to(torch.int64) + # this stuff about clamp() is a temporary fix for a mismatch + # at utterance start, we use negative ids in beam_search.py + embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) + if self.context_size > 1: + embedding_out = embedding_out.permute(0, 2, 1) + if need_pad is True: + embedding_out = F.pad( + embedding_out, pad=(self.context_size - 1, 0) + ) + else: + # During inference time, there is no need to do extra padding + # as we only need one output + assert embedding_out.size(-1) == self.context_size + embedding_out = self.conv(embedding_out) + embedding_out = embedding_out.permute(0, 2, 1) + embedding_out = F.relu(embedding_out) + return embedding_out diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/encoder_interface.py b/egs/librispeech/ASR/pruned_transducer_stateless7/encoder_interface.py new file mode 120000 index 000000000..b9aa0ae08 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/encoder_interface.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py new file mode 100755 index 000000000..5744ea3ea --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py @@ -0,0 +1,324 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +(1) Export to torchscript model using torch.jit.script() + +./pruned_transducer_stateless7/export.py \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("cpu_jit.pt")`. + +Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python +are on CPU. You can use `to("cuda")` to move them to a CUDA device. + +Check +https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +(2) Export `model.state_dict()` + +./pruned_transducer_stateless7/export.py \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +To use the generated file with `pruned_transducer_stateless7/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./pruned_transducer_stateless7/decode.py \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 + # You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +import torch.nn as nn +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named cpu_jit.pt + + Check ./jit_pretrained.py for how to use it. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + if params.jit is True: + convert_scaled_to_non_scaled(model, inplace=True) + logging.info("Using torch.jit.script()") + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py new file mode 100755 index 000000000..81b0deba3 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py @@ -0,0 +1,274 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless7/export.py \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --jit 1 + +Usage of this script: + +./pruned_transducer_stateless7/jit_pretrained.py \ + --nn-model-filename ./pruned_transducer_stateless7/exp/cpu_jit.pt \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model-filename", + type=str, + required=True, + help="Path to the torchscript model cpu_jit.pt", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + model: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + Returns: + Return the decoded results for each utterance. + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + device = encoder_out.device + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + context_size = model.decoder.context_size + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ).squeeze(1) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + current_encoder_out = current_encoder_out + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = model.joiner( + current_encoder_out, + decoder_out, + ) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ) + decoder_out = decoder_out.squeeze(1) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = torch.jit.load(args.nn_model_filename) + + model.eval() + + model.to(device) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + x=features, + x_lens=feature_lengths, + ) + + hyps = greedy_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = sp.decode(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py new file mode 100644 index 000000000..7d8de5afe --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py @@ -0,0 +1,67 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + + +class Joiner(nn.Module): + def __init__( + self, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, + vocab_size: int, + ): + super().__init__() + + self.encoder_proj = nn.Linear(encoder_dim, joiner_dim) + self.decoder_proj = nn.Linear(decoder_dim, joiner_dim) + self.output_linear = nn.Linear(joiner_dim, vocab_size) + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + project_input: bool = True, + ) -> torch.Tensor: + """ + Args: + encoder_out: + Output from the encoder. Its shape is (N, T, s_range, C). + decoder_out: + Output from the decoder. Its shape is (N, T, s_range, C). + project_input: + If true, apply input projections encoder_proj and decoder_proj. + If this is false, it is the user's responsibility to do this + manually. + Returns: + Return a tensor of shape (N, T, s_range, C). + """ + assert encoder_out.ndim == decoder_out.ndim + assert encoder_out.ndim in (2, 4) + assert encoder_out.shape[:-1] == decoder_out.shape[:-1] + + if project_input: + logit = self.encoder_proj(encoder_out) + self.decoder_proj( + decoder_out + ) + else: + logit = encoder_out + decoder_out + + logit = self.output_linear(torch.tanh(logit)) + + return logit diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py new file mode 100644 index 000000000..53cde6c6f --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -0,0 +1,195 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import k2 +import torch +import torch.nn as nn +import random +from encoder_interface import EncoderInterface + +from icefall.utils import add_sos +from scaling import penalize_abs_values_gt + + +class Transducer(nn.Module): + """It implements https://arxiv.org/pdf/1211.3711.pdf + "Sequence Transduction with Recurrent Neural Networks" + """ + + def __init__( + self, + encoder: EncoderInterface, + decoder: nn.Module, + joiner: nn.Module, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, + vocab_size: int, + ): + """ + Args: + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dm) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. + joiner: + It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output contains + unnormalized probs, i.e., not processed by log-softmax. + """ + super().__init__() + assert isinstance(encoder, EncoderInterface), type(encoder) + assert hasattr(decoder, "blank_id") + + self.encoder = encoder + self.decoder = decoder + self.joiner = joiner + + self.simple_am_proj = nn.Linear( + encoder_dim, vocab_size, + ) + self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size) + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + ) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + Returns: + Return the transducer loss. + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0 + + encoder_out, x_lens = self.encoder(x, x_lens) + assert torch.all(x_lens > 0) + + # Now for the decoder, i.e., the prediction network + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + blank_id = self.decoder.blank_id + sos_y = add_sos(y, sos_id=blank_id) + + # sos_y_padded: [B, S + 1], start with SOS. + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + + # decoder_out: [B, S + 1, decoder_dim] + decoder_out = self.decoder(sos_y_padded) + + # Note: y does not start with SOS + # y_padded : [B, S] + y_padded = y.pad(mode="constant", padding_value=0) + + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) + boundary[:, 2] = y_lens + boundary[:, 3] = x_lens + + lm = self.simple_lm_proj(decoder_out) + am = self.simple_am_proj(encoder_out) + + #if self.training and random.random() < 0.25: + # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) + #if self.training and random.random() < 0.25: + # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) + + with torch.cuda.amp.autocast(enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction="sum", + return_grad=True, + ) + + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, + ) + + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=self.joiner.encoder_proj(encoder_out), + lm=self.joiner.decoder_proj(decoder_out), + ranges=ranges, + ) + + # logits : [B, T, prune_range, vocab_size] + + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + logits = self.joiner(am_pruned, lm_pruned, project_input=False) + + with torch.cuda.amp.autocast(enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="sum", + ) + + return (simple_loss, pruned_loss) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py new file mode 100644 index 000000000..bb8b0a0e3 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -0,0 +1,971 @@ +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from typing import List, Optional, Union, Tuple, List +from lhotse.utils import fix_random_seed +import torch +from scaling import ActivationBalancer +import random +from torch import Tensor +from torch.optim import Optimizer +import logging +import contextlib + + + +class BatchedOptimizer(Optimizer): + """ + This class adds to class Optimizer the capability to optimize parameters in batches: + it will stack the parameters and their grads for you so the optimizer can work + on tensors with an extra leading dimension. This is intended for speed with GPUs, + as it reduces the number of kernels launched in the optimizer. + + Args: + params: + """ + def __init__(self, params, defaults): + super(BatchedOptimizer, self).__init__(params, defaults) + + + + @contextlib.contextmanager + def batched_params(self, param_group): + """ + This function returns (technically, yields) a list of + of tuples (p, state), where + p is a `fake` parameter that is stacked (over axis 0) from real parameters + that share the same shape, and its gradient is also stacked; + `state` is the state corresponding to this batch of parameters + (it will be physically located in the "state" for one of the real + parameters, the last one that has any particular shape and dtype). + + This function is decorated as a context manager so that it can + write parameters back to their "real" locations. + + The idea is, instead of doing: + + for p in group["params"]: + state = self.state[p] + ... + + you can do: + + with self.batched_params(group["params"]) as batches: + for p, state in batches: + ... + + + Args: + group: a parameter group, which is a list of parameters; should be + one of self.groups. + """ + batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter + + for p in param_group: + key = (str(p.dtype), *p.shape) + batches[key].append(p) + + stacked_params_dict = dict() + + # turn batches into a list, in deterministic order. + batches = [ batches[key] for key in sorted(batches.keys()) ] + # pairs will contain pairs of (stacked_param, state), one for each batch + # in `batches`. + pairs = [] + + for batch in batches: + p = batch[0] + # we arbitrarily store the state in the + # state corresponding to the 1st parameter in the + # group. class Optimizer will take care of saving/loading state. + state = self.state[p] + p_stacked = torch.stack(batch) + grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch ]) + p_stacked.grad = grad + stacked_params_dict[key] = p_stacked + pairs.append((p_stacked, state)) + + yield pairs # <-- calling code will do the actual optimization here! + + for ((stacked_params, _state), batch) in zip(pairs, batches): + for i, p in enumerate(batch): # batch is list of Parameter + p.copy_(stacked_params[i]) + + + +class ScaledAdam(BatchedOptimizer): + """ + Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update + proportional to the norm of that parameter; and also learn the scale of the parameter, + in log space, subject to upper and lower limits (as if we had factored each parameter as + param = underlying_param * log_scale.exp()) + + + Args: + params: The parameters or param_groups to optimize (like other Optimizer subclasses) + lr: The learning rate. We will typically use a learning rate schedule that starts + at 0.03 and decreases over time, i.e. much higher than other common + optimizers. + clipping_scale: (e.g. 2.0) + A scale for gradient-clipping: if specified, the normalized gradients + over the whole model will be clipped to have 2-norm equal to + `clipping_scale` times the median 2-norm over the most recent period + of `clipping_update_period` minibatches. By "normalized gradients", + we mean after multiplying by the rms parameter value for this tensor + [for non-scalars]; this is appropriate because our update is scaled + by this quantity. + betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. + Must satisfy 0 < beta <= beta2 < 1. + scalar_lr_scale: A scaling factor on the learning rate, that we use to update the + scale of each parameter tensor and scalar parameters of the mode.. + If each parameter were decomposed + as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale + would be a the scaling factor on the learning rate of p_scale. + eps: A general-purpose epsilon to prevent division by zero + param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be >= this value) + param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be <= this value) + scalar_max: Maximum absolute value for scalar parameters (applicable if your + model has any parameters with numel() == 1). + size_update_period: The periodicity, in steps, with which we update the size (scale) + of the parameter tensor. This is provided to save a little time + in the update. + clipping_update_period: if clipping_scale is specified, this is the period + """ + def __init__( + self, + params, + lr=3e-02, + clipping_scale=None, + betas=(0.9, 0.98), + scalar_lr_scale=0.1, + eps=1.0e-08, + param_min_rms=1.0e-05, + param_max_rms=3.0, + scalar_max=10.0, + size_update_period=4, + clipping_update_period=100, + ): + + + defaults = dict( + lr=lr, + clipping_scale=clipping_scale, + betas=betas, + scalar_lr_scale=scalar_lr_scale, + eps=eps, + param_min_rms=param_min_rms, + param_max_rms=param_max_rms, + scalar_max=scalar_max, + size_update_period=size_update_period, + clipping_update_period=clipping_update_period, + ) + + super(ScaledAdam, self).__init__(params, defaults) + + def __setstate__(self, state): + super(ScaledAdam, self).__setstate__(state) + + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + batch = True + for group in self.param_groups: + + with self.batched_params(group["params"]) as batches: + + # batches is list of pairs (stacked_param, state). stacked_param is like + # a regular parameter, and will have a .grad, but the 1st dim corresponds to + # a stacking dim, it is not a real dim. + + if len(batches[0][1]) == 0: # if len(first state) == 0: not yet initialized + clipping_scale = 1 + else: + clipping_scale = self._get_clipping_scale(group, batches) + + for p, state in batches: + # Perform optimization step. + # grad is not going to be None, we handled that when creating the batches. + grad = p.grad + if grad.is_sparse: + raise RuntimeError( + "ScaledAdam optimizer does not support sparse gradients" + ) + # State initialization + if len(state) == 0: + self._init_state(group, p, state) + + self._step_one_batch(group, p, state, clipping_scale) + + + return loss + + def _init_state(self, + group: dict, + p: Tensor, + state: dict): + """ + Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p + is actually the batch dimension, corresponding to batched-together + parameters of a given shape. + + + Args: + group: Dict to look up configuration values. + p: The parameter that we are initializing the state for + state: Dict from string to whatever state we are initializing + """ + size_update_period = group["size_update_period"] + + state["step"] = 0 + + kwargs = {'device':p.device, 'dtype':p.dtype} + + # 'delta' implements conventional momentum. There are + # several different kinds of update going on, so rather than + # compute "exp_avg" like in Adam, we store and decay a + # parameter-change "delta", which combines all forms of + # update. this is equivalent to how it's done in Adam, + # except for the first few steps. + state["delta"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + batch_size = p.shape[0] + numel = p.numel() // batch_size + numel = p.numel() + + + if numel > 1: + # "param_rms" just periodically records the scalar root-mean-square value of + # the parameter tensor. + # it has a shape like (batch_size, 1, 1, 1, 1) + param_rms = (p**2).mean(dim=list(range(1, p.ndim)), + keepdim=True).sqrt() + state["param_rms"] = param_rms + + state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) + state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape, + **kwargs) + + + # exp_avg_sq is the weighted sum of scaled gradients. as in Adam. + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + def _get_clipping_scale(self, + group: dict, + pairs: List[Tuple[Tensor, dict]]) -> float: + """ + Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients + by this amount before applying the rest of the update. + + Args: + group: the parameter group, an item in self.param_groups + pairs: a list of pairs of (param, state) where param is a batched set of parameters, with a .grad + (1st dim is batch dim) and state is the state-dict where optimization parameters are kept. + """ + assert len(pairs) >= 1 + clipping_scale = group["clipping_scale"] + (first_p, first_state) = pairs[0] + step = first_state["step"] + if clipping_scale is None or step == 0: + # no clipping. return early on step == 0 because the other + # parameters' state won't have been initialized yet. + return 1.0 + clipping_update_period = group["clipping_update_period"] + + tot_sumsq = torch.tensor(0.0, device=first_p.device) + for (p, state) in pairs: + grad = p.grad + if grad.is_sparse: + raise RuntimeError( + "ScaledAdam optimizer does not support sparse gradients" + ) + if p.numel() == p.shape[0]: # a batch of scalars + tot_sumsq += (grad**2).sum() # sum() to change shape [1] to [] + else: + tot_sumsq += ((grad * state["param_rms"])**2).sum() + + tot_norm = tot_sumsq.sqrt() + if not "model_norms" in first_state: + first_state["model_norms"] = torch.zeros(clipping_update_period, + device=p.device) + first_state["model_norms"][step % clipping_update_period] = tot_norm + + if step % clipping_update_period == 0: + # Print some stats. + # We don't reach here if step == 0 because we would have returned + # above. + sorted_norms = first_state["model_norms"].sort()[0].to('cpu') + quartiles = [] + for n in range(0, 5): + index = min(clipping_update_period - 1, + (clipping_update_period // 4) * n) + quartiles.append(sorted_norms[index].item()) + + median = quartiles[2] + threshold = clipping_scale * median + first_state["model_norm_threshold"] = threshold + percent_clipped = (first_state["num_clipped"] * 100.0 / clipping_update_period + if "num_clipped" in first_state else 0.0) + first_state["num_clipped"] = 0 + quartiles = ' '.join([ '%.3e' % x for x in quartiles ]) + logging.info(f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " + f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}") + + if step < clipping_update_period: + return 1.0 # We have not yet estimated a norm to clip to. + else: + try: + model_norm_threshold = first_state["model_norm_threshold"] + except: + logging.info("Warning: model_norm_threshold not in state: possibly " + "you changed config when restarting, adding clipping_scale option?") + return 1.0 + ans = min(1.0,(model_norm_threshold / (tot_norm + 1.0e-20)).item()) + if ans < 1.0: + first_state["num_clipped"] += 1 + if ans < 0.1: + logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}") + return ans + + + def _step_one_batch(self, + group: dict, + p: Tensor, + state: dict, + clipping_scale: float): + """ + Do the step for one parameter, which is actually going to be a batch of + `real` parameters, with dim 0 as the batch dim. + Args: + group: dict to look up configuration values + p: parameter to update (actually multiple parameters stacked together + as a batch) + state: state-dict for p, to look up the optimizer state + """ + lr = group["lr"] + size_update_period = group["size_update_period"] + beta1 = group["betas"][0] + + grad = p.grad + if clipping_scale != 1.0: + grad = grad * clipping_scale + step = state["step"] + delta = state["delta"] + + delta.mul_(beta1) + batch_size = p.shape[0] + numel = p.numel() // batch_size + if numel > 1: + # Update the size/scale of p, and set param_rms + scale_grads = state["scale_grads"] + scale_grads[step % size_update_period] = (p * grad).sum( + dim=list(range(1, p.ndim)), keepdim=True) + if step % size_update_period == size_update_period - 1: + param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..) + param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), + keepdim=True).sqrt()) + if step > 0: + # self._size_update() learns the overall scale on the + # parameter, by shrinking or expanding it. + self._size_update(group, scale_grads, p, state) + + + if numel == 1: + # For parameters with 1 element we just use regular Adam. + # Updates delta. + self._step_scalar(group, p, state) + else: + self._step(group, p, state) + + state["step"] = step + 1 + + + def _size_update(self, + group: dict, + scale_grads: Tensor, + p: Tensor, + state: dict) -> None: + """ + Called only where p.numel() > 1, this updates the scale of the parameter. + If we imagine: p = underlying_param * scale.exp(), and we are doing + gradient descent on underlying param and on scale, this function does the update + on `scale`. + + Args: + group: dict to look up configuration values + scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing + grads w.r.t. the scales. + p: The parameter to update + state: The state-dict of p + """ + + param_rms = state["param_rms"] + beta1, beta2 = group["betas"] + size_lr = group["lr"] * group["scalar_lr_scale"] + param_min_rms = group["param_min_rms"] + param_max_rms = group["param_max_rms"] + eps = group["eps"] + step = state["step"] + batch_size = p.shape[0] + + size_update_period = scale_grads.shape[0] + # correct beta2 for the size update period: we will have + # faster decay at this level. + beta2_corr = beta2 ** size_update_period + + scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..) + scale_exp_avg_sq.mul_(beta2_corr).add_( + (scale_grads ** 2).mean(dim=0), # mean over dim `size_update_period` + alpha=1-beta2_corr) # shape is (batch_size, 1, 1, ...) + + # The 1st time we reach here is when size_step == 1. + size_step = (step + 1) // size_update_period + bias_correction2 = 1 - beta2_corr ** size_step + # we don't bother with bias_correction1; this will help prevent divergence + # at the start of training. + + denom = scale_exp_avg_sq.sqrt() + eps + + scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum(dim=0) / denom + + is_too_small = (param_rms < param_min_rms) + is_too_large = (param_rms > param_max_rms) + + # when the param gets too small, just don't shrink it any further. + scale_step.masked_fill_(is_too_small, 0.0) + # when it gets too large, stop it from getting any larger. + scale_step.masked_fill_(is_too_large, -size_lr * size_update_period) + delta = state["delta"] + # the factor of (1-beta1) relates to momentum. + delta.add_(p * scale_step, alpha=(1-beta1)) + + + def _step(self, + group: dict, + p: Tensor, + state: dict): + """ + This function does the core update of self.step(), in the case where the members of + the batch have more than 1 element. + + Args: + group: A dict which will be used to look up configuration values + p: The parameter to be updated + grad: The grad of p + state: The state-dict corresponding to parameter p + + This function modifies p. + """ + grad = p.grad + lr = group["lr"] + beta1, beta2 = group["betas"] + eps = group["eps"] + param_min_rms = group["param_min_rms"] + step = state["step"] + + exp_avg_sq = state["exp_avg_sq"] + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, + value=(1-beta2)) + + this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0) + bias_correction2 = 1 - beta2 ** (this_step + 1) + if bias_correction2 < 0.99: + # note: not in-place. + exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) + + denom = exp_avg_sq.sqrt() + denom += eps + grad = grad / denom + + alpha = -lr * (1-beta1) * state["param_rms"].clamp(min=param_min_rms) + + delta = state["delta"] + delta.add_(grad * alpha) + p.add_(delta) + + + def _step_scalar(self, + group: dict, + p: Tensor, + state: dict): + """ + A simplified form of the core update for scalar tensors, where we cannot get a good + estimate of the parameter rms. + """ + beta1, beta2 = group["betas"] + scalar_max = group["scalar_max"] + eps = group["eps"] + lr = group["lr"] * group["scalar_lr_scale"] + grad = p.grad + + exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, + value=1-beta2) + + # bias_correction2 is like in Adam. Don't bother with bias_correction1; + # slower update at the start will help stability anyway. + bias_correction2 = 1 - beta2 ** (state["step"] + 1) + denom = (exp_avg_sq / bias_correction2).sqrt() + eps + + delta = state["delta"] + delta.add_(grad / denom, alpha=-lr*(1-beta1)) + p.clamp_(min=-scalar_max, max=scalar_max) + p.add_(delta) + + + +class LRScheduler(object): + """ + Base-class for learning rate schedulers where the learning-rate depends on both the + batch and the epoch. + """ + + def __init__(self, optimizer: Optimizer, verbose: bool = False): + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError( + "{} is not an Optimizer".format(type(optimizer).__name__) + ) + self.optimizer = optimizer + self.verbose = verbose + + for group in optimizer.param_groups: + group.setdefault("base_lr", group["lr"]) + + self.base_lrs = [ + group["base_lr"] for group in optimizer.param_groups + ] + + self.epoch = 0 + self.batch = 0 + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + """ + return { + "base_lrs": self.base_lrs, + "epoch": self.epoch, + "batch": self.batch, + } + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_last_lr(self) -> List[float]: + """Return last computed learning rate by current scheduler. Will be a list of float.""" + return self._last_lr + + def get_lr(self): + # Compute list of learning rates from self.epoch and self.batch and + # self.base_lrs; this must be overloaded by the user. + # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ] + raise NotImplementedError + + def step_batch(self, batch: Optional[int] = None) -> None: + # Step the batch index, or just set it. If `batch` is specified, it + # must be the batch index from the start of training, i.e. summed over + # all epochs. + # You can call this in any order; if you don't provide 'batch', it should + # of course be called once per batch. + if batch is not None: + self.batch = batch + else: + self.batch = self.batch + 1 + self._set_lrs() + + def step_epoch(self, epoch: Optional[int] = None): + # Step the epoch index, or just set it. If you provide the 'epoch' arg, + # you should call this at the start of the epoch; if you don't provide the 'epoch' + # arg, you should call it at the end of the epoch. + if epoch is not None: + self.epoch = epoch + else: + self.epoch = self.epoch + 1 + self._set_lrs() + + def _set_lrs(self): + values = self.get_lr() + assert len(values) == len(self.optimizer.param_groups) + + for i, data in enumerate(zip(self.optimizer.param_groups, values)): + param_group, lr = data + param_group["lr"] = lr + self.print_lr(self.verbose, i, lr) + self._last_lr = [group["lr"] for group in self.optimizer.param_groups] + + def print_lr(self, is_verbose, group, lr): + """Display the current learning rate.""" + if is_verbose: + logging.info( + f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate" + f" of group {group} to {lr:.4e}." + ) + + +class Eden(LRScheduler): + """ + Eden scheduler. + The basic formula (before warmup) is: + lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 * + (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup + where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches + and then stays constant at 1. + + + E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam + + Args: + optimizer: the optimizer to change the learning rates on + lr_batches: the number of batches after which we start significantly + decreasing the learning rate, suggest 5000. + lr_epochs: the number of epochs after which we start significantly + decreasing the learning rate, suggest 6 if you plan to do e.g. + 20 to 40 epochs, but may need smaller number if dataset is huge + and you will do few epochs. + """ + + def __init__( + self, + optimizer: Optimizer, + lr_batches: Union[int, float], + lr_epochs: Union[int, float], + warmup_batches: Union[int, float] = 500.0, + verbose: bool = False, + ): + super(Eden, self).__init__(optimizer, verbose) + self.lr_batches = lr_batches + self.lr_epochs = lr_epochs + self.warmup_batches = warmup_batches + + def get_lr(self): + factor = ( + (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 + ) ** -0.25 * ( + ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) + ** -0.25 + ) + warmup_factor = (1.0 if self.batch >= self.warmup_batches + else 0.5 + 0.5 * (self.batch / self.warmup_batches)) + + return [x * factor * warmup_factor for x in self.base_lrs] + + +def _test_eden(): + m = torch.nn.Linear(100, 100) + optim = ScaledAdam(m.parameters(), lr=0.03) + + scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True) + + for epoch in range(10): + scheduler.step_epoch(epoch) # sets epoch to `epoch` + + for step in range(20): + x = torch.randn(200, 100).detach() + x.requires_grad = True + y = m(x) + dy = torch.randn(200, 100).detach() + f = (y * dy).sum() + f.backward() + + optim.step() + scheduler.step_batch() + optim.zero_grad() + + logging.info(f"last lr = {scheduler.get_last_lr()}") + logging.info(f"state dict = {scheduler.state_dict()}") + + +# This is included mostly as a baseline for ScaledAdam. +class Eve(Optimizer): + """ + Implements Eve algorithm. This is a modified version of AdamW with a special + way of setting the weight-decay / shrinkage-factor, which is designed to make the + rms of the parameters approach a particular target_rms (default: 0.1). This is + for use with networks with 'scaled' versions of modules (see scaling.py), which + will be close to invariant to the absolute scale on the parameter matrix. + + The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. + The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. + Eve is unpublished so far. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 3e-4; + this value means that the weight would decay significantly after + about 3k minibatches. Is not multiplied by learning rate, but + is conditional on RMS-value of parameter being > target_rms. + target_rms (float, optional): target root-mean-square value of + parameters, if they fall below this we will stop applying weight decay. + + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.98), + eps=1e-8, + weight_decay=1e-3, + target_rms=0.1, + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError( + "Invalid beta parameter at index 0: {}".format(betas[0]) + ) + if not 0.0 <= betas[1] < 1.0: + raise ValueError( + "Invalid beta parameter at index 1: {}".format(betas[1]) + ) + if not 0 <= weight_decay <= 0.1: + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) + if not 0 < target_rms <= 10.0: + raise ValueError("Invalid target_rms value: {}".format(target_rms)) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + target_rms=target_rms, + ) + super(Eve, self).__init__(params, defaults) + + def __setstate__(self, state): + super(Eve, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + # Perform optimization step + grad = p.grad + if grad.is_sparse: + raise RuntimeError( + "AdamW does not support sparse gradients" + ) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + + beta1, beta2 = group["betas"] + + state["step"] += 1 + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( + group["eps"] + ) + + step_size = group["lr"] / bias_correction1 + target_rms = group["target_rms"] + weight_decay = group["weight_decay"] + + if p.numel() > 1: + # avoid applying this weight-decay on "scaling factors" + # (which are scalar). + is_above_target_rms = p.norm() > ( + target_rms * (p.numel() ** 0.5) + ) + p.mul_(1 - (weight_decay * is_above_target_rms)) + + p.addcdiv_(exp_avg, denom, value=-step_size) + + if random.random() < 0.0005: + step = (exp_avg/denom) * step_size + logging.info(f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}") + + + return loss + + +def _test_scaled_adam(hidden_dim: int): + import timeit + from scaling import ScaledLinear + E = 100 + B = 4 + T = 2 + logging.info("in test_eve_cain") + #device = torch.device('cuda') + device = torch.device('cpu') + dtype = torch.float32 + + fix_random_seed(42) + # these input_magnitudes and output_magnitudes are to test that + # Abel is working as we expect and is able to adjust scales of + # different dims differently. + input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() + output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() + + for iter in [1, 0]: + fix_random_seed(42) + Linear = torch.nn.Linear if iter == 0 else ScaledLinear + + m = torch.nn.Sequential(Linear(E, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, E), + ).to(device) + + train_pairs = [ (100.0 * torch.randn(B, T, E, device=device, dtype=dtype) * input_magnitudes, + torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes) for _ in range(20) ] + + if iter == 0: optim = Eve(m.parameters(), lr=0.003) + elif iter == 1: optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) + scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) + + + start = timeit.default_timer() + avg_loss = 0.0 + for epoch in range(180): + scheduler.step_epoch() + #if epoch == 100 and iter in [2,3]: + # optim.reset_speedup() # check it doesn't crash. + + #if epoch == 130: + # opts = diagnostics.TensorDiagnosticOptions( + # 2 ** 22 + # ) # allow 4 megabytes per sub-module + # diagnostic = diagnostics.attach_diagnostics(m, opts) + + + for n, (x,y) in enumerate(train_pairs): + y_out = m(x) + loss = ((y_out - y)**2).mean() * 100.0 + if epoch == 0 and n == 0: + avg_loss = loss.item() + else: + avg_loss = 0.98 * avg_loss + 0.02 * loss.item() + if n == 0 and epoch % 5 == 0: + #norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() + #norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() + #norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() + #norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() + #scale1 = '%.2e' % (m[0].weight_scale.exp().item()) + #scale1b = '%.2e' % (m[0].bias_scale.exp().item()) + #scale2 = '%.2e' % (m[2].weight_scale.exp().item()) + #scale2b = '%.2e' % (m[2].bias_scale.exp().item()) + lr = scheduler.get_last_lr()[0] + logging.info(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}") #, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} + loss.log().backward() + optim.step() + optim.zero_grad() + scheduler.step_batch() + + #diagnostic.print_diagnostics() + + stop = timeit.default_timer() + logging.info(f"Iter={iter}, Time taken: {stop - start}") + + logging.info(f"last lr = {scheduler.get_last_lr()}") + #logging.info("state dict = ", scheduler.state_dict()) + #logging.info("optim state_dict = ", optim.state_dict()) + logging.info(f"input_magnitudes = {input_magnitudes}") + logging.info(f"output_magnitudes = {output_magnitudes}") + + + +if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + logging.getLogger().setLevel(logging.INFO) + import subprocess + s = subprocess.check_output("git status -uno .; git log -1; git diff HEAD .", shell=True) + logging.info(s) + import sys + if len(sys.argv) > 1: + hidden_dim = int(sys.argv[1]) + else: + hidden_dim = 200 + + _test_scaled_adam(hidden_dim) + _test_eden() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py new file mode 100755 index 000000000..7fe1e681a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py @@ -0,0 +1,363 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +./pruned_transducer_stateless7/export.py \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +Usage of this script: + +(1) greedy search +./pruned_transducer_stateless7/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) beam search +./pruned_transducer_stateless7/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) modified beam search +./pruned_transducer_stateless7/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(4) fast beam search +./pruned_transducer_stateless7/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method fast_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +You can also use `./pruned_transducer_stateless7/exp/epoch-xx.pt`. + +Note: ./pruned_transducer_stateless7/exp/pretrained.pt is generated by +./pruned_transducer_stateless7/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lengths + ) + + num_waves = encoder_out.size(0) + hyps = [] + msg = f"Using {params.method}" + if params.method == "beam_search": + msg += f" with beam size {params.beam_size}" + logging.info(msg) + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(sp.decode(hyp).split()) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py new file mode 100644 index 000000000..50cedba56 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -0,0 +1,1161 @@ +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import collections +from itertools import repeat +from typing import Optional, Tuple, Union +from functools import reduce +import logging + +import random +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.nn import Embedding as ScaledEmbedding + + +class ActivationBalancerFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + scale_factor: Tensor, + sign_factor: Optional[Tensor], + channel_dim: int, + ) -> Tensor: + if channel_dim < 0: + channel_dim += x.ndim + ctx.channel_dim = channel_dim + xgt0 = (x > 0) + if sign_factor is None: + ctx.save_for_backward(xgt0, scale_factor) + else: + ctx.save_for_backward(xgt0, scale_factor, sign_factor) + return x + + + @staticmethod + def backward( + ctx, x_grad: Tensor + ) -> Tuple[Tensor, None, None, None]: + if len(ctx.saved_tensors) == 3: + xgt0, scale_factor, sign_factor = ctx.saved_tensors + for _ in range(ctx.channel_dim, x_grad.ndim - 1): + scale_factor = scale_factor.unsqueeze(-1) + sign_factor = sign_factor.unsqueeze(-1) + factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) + else: + xgt0, scale_factor = ctx.saved_tensors + for _ in range(ctx.channel_dim, x_grad.ndim - 1): + scale_factor = scale_factor.unsqueeze(-1) + factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5) + neg_delta_grad = x_grad.abs() * factor + return x_grad - neg_delta_grad, None, None, None, + +def _compute_scale_factor(x: Tensor, + channel_dim: int, + min_abs: float, + max_abs: float, + gain_factor: float, + max_factor: float) -> Tensor: + if channel_dim < 0: + channel_dim += x.ndim + sum_dims = [d for d in range(x.ndim) if d != channel_dim] + x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32) + + if min_abs == 0.0: + below_threshold = 0.0 + else: + # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if + # x_abs)_mean , min_abs. + below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(min=0, max=max_factor) + + above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(min=0, max=max_factor) + + return below_threshold - above_threshold + +def _compute_sign_factor(x: Tensor, + channel_dim: int, + min_positive: float, + max_positive: float, + gain_factor: float, + max_factor: float) -> Tensor: + if channel_dim < 0: + channel_dim += x.ndim + sum_dims = [d for d in range(x.ndim) if d != channel_dim] + proportion_positive = torch.mean((x > 0).to(torch.float32), + dim=sum_dims) + if min_positive == 0.0: + factor1 = 0.0 + else: + # 0 if proportion_positive >= min_positive, else can be + # as large as max_factor. + factor1 = ((min_positive - proportion_positive) * + (gain_factor / min_positive)).clamp_(min=0, max=max_factor) + + if max_positive == 1.0: + factor2 = 0.0 + else: + # 0 if self.proportion_positive <= max_positive, else can be + # as large as -max_factor. + factor2 = ((proportion_positive - max_positive) * + (gain_factor / (1.0 - max_positive))).clamp_(min=0, max=max_factor) + sign_factor = factor1 - factor2 + # require min_positive != 0 or max_positive != 1: + assert not isinstance(sign_factor, float) + return sign_factor + + + +class ActivationScaleBalancerFunction(torch.autograd.Function): + """ + This object is used in class ActivationBalancer when the user specified + min_positive=0, max_positive=1, so there are no constraints on the signs + of the activations and only the absolute value has a constraint. + """ + @staticmethod + def forward( + ctx, + x: Tensor, + sign_factor: Tensor, + scale_factor: Tensor, + channel_dim: int, + ) -> Tensor: + if channel_dim < 0: + channel_dim += x.ndim + ctx.channel_dim = channel_dim + xgt0 = (x > 0) + ctx.save_for_backward(xgt0, sign_factor, scale_factor) + return x + + + @staticmethod + def backward( + ctx, x_grad: Tensor + ) -> Tuple[Tensor, None, None, None]: + xgt0, sign_factor, scale_factor = ctx.saved_tensors + for _ in range(ctx.channel_dim, x_grad.ndim - 1): + sign_factor = sign_factor.unsqueeze(-1) + scale_factor = scale_factor.unsqueeze(-1) + + factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) + neg_delta_grad = x_grad.abs() * factor + return x_grad - neg_delta_grad, None, None, None, + + +class RandomClampFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + min: Optional[float], + max: Optional[float], + prob: float, + reflect: float) -> Tensor: + x_clamped = torch.clamp(x, min=min, max=max) + mask = torch.rand_like(x) < prob + ans = torch.where(mask, x_clamped, x) + if x.requires_grad: + ctx.save_for_backward(ans == x) + ctx.reflect = reflect + if reflect != 0.0: + ans = ans * (1.0 + reflect) - (x * reflect) + return ans + + @staticmethod + def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None, None]: + is_same, = ctx.saved_tensors + x_grad = ans_grad * is_same.to(ans_grad.dtype) + reflect = ctx.reflect + if reflect != 0.0: + x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect) + return x_grad, None, None, None, None + +def random_clamp(x: Tensor, + min: Optional[float] = None, + max: Optional[float] = None, + prob: float = 0.5, + reflect: float = 0.0): + return RandomClampFunction.apply(x, min, max, prob, reflect) + + +def random_cast_to_half(x: Tensor, + min_abs: float = 5.0e-06) -> Tensor: + """ + A randomized way of casting a floating point value to half precision. + """ + if x.dtype == torch.float16: + return x + x_abs = x.abs() + is_too_small = (x_abs < min_abs) + # for elements where is_too_small is true, random_val will contain +-min_abs with + # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, + # for those elements]. + random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs) + return torch.where(is_too_small, random_val, x).to(torch.float16) + + +class RandomGradFunction(torch.autograd.Function): + """ + Does nothing in forward pass; in backward pass, gets rid of very small grads using + randomized approach that preserves expectations (intended to reduce roundoff). + """ + @staticmethod + def forward(ctx, x: Tensor, min_abs: float) -> Tensor: + ctx.min_abs = min_abs + return x + + @staticmethod + def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]: + if ans_grad.dtype == torch.float16: + return random_cast_to_half(ans_grad.to(torch.float32), + min_abs=ctx.min_abs), None + else: + return ans_grad, None + +class RandomGrad(torch.nn.Module): + """ + Gets rid of very small gradients using an expectation-preserving method, intended to increase + accuracy of training when using amp (automatic mixed precision) + """ + def __init__(self, + min_abs: float = 5.0e-06): + super(RandomGrad, self).__init__() + self.min_abs = min_abs + + def forward(self, + x: Tensor): + if torch.jit.is_scripting() or not self.training: + return x + else: + return RandomGradFunction.apply(x, self.min_abs) + + + +class SoftmaxFunction(torch.autograd.Function): + """ + Tries to handle half-precision derivatives in a randomized way that should + be more accurate for training than the default behavior. + """ + @staticmethod + def forward(ctx, x: Tensor, dim: int): + ans = x.softmax(dim=dim) + # if x dtype is float16, x.softmax() returns a float32 because + # (presumably) that op does not support float16, and autocast + # is enabled. + if torch.is_autocast_enabled(): + ans = ans.to(torch.float16) + ctx.save_for_backward(ans) + ctx.x_dtype = x.dtype + ctx.dim = dim + return ans + + @staticmethod + def backward(ctx, ans_grad: Tensor): + ans, = ctx.saved_tensors + with torch.cuda.amp.autocast(enabled=False): + ans_grad = ans_grad.to(torch.float32) + ans = ans.to(torch.float32) + x_grad = ans_grad * ans + x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) + return x_grad, None + + + +def softmax(x: Tensor, + dim: int): + if torch.jit.is_scripting(): + return x.softmax(dim) + + return SoftmaxFunction.apply(x, dim) + + +class MaxEigLimiterFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + coeffs: Tensor, + direction: Tensor, + channel_dim: int, + grad_scale: float) -> Tensor: + ctx.channel_dim = channel_dim + ctx.grad_scale = grad_scale + ctx.save_for_backward(x.detach(), + coeffs.detach(), + direction.detach()) + return x + + + @staticmethod + def backward(ctx, x_grad, *args): + with torch.enable_grad(): + (x_orig, coeffs, new_direction) = ctx.saved_tensors + x_orig.requires_grad = True + num_channels = x_orig.shape[ctx.channel_dim] + x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) + new_direction.requires_grad = False + x = x - x.mean(dim=0) + x_var = (x ** 2).mean() + x_residual = x - coeffs * new_direction + x_residual_var = (x_residual ** 2).mean() + # `variance_proportion` is the proportion of the variance accounted for + # by the top eigen-direction. This is to be minimized. + variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) + variance_proportion.backward() + x_orig_grad = x_orig.grad + x_extra_grad = x_orig.grad * ctx.grad_scale * x_grad.norm() / (x_orig_grad.norm() + 1.0e-20) + return x_grad + x_extra_grad.detach(), None, None, None, None + + +class BasicNorm(torch.nn.Module): + """ + This is intended to be a simpler, and hopefully cheaper, replacement for + LayerNorm. The observation this is based on, is that Transformer-type + networks, especially with pre-norm, sometimes seem to set one of the + feature dimensions to a large constant value (e.g. 50), which "defeats" + the LayerNorm because the output magnitude is then not strongly dependent + on the other (useful) features. Presumably the weight and bias of the + LayerNorm are required to allow it to do this. + + So the idea is to introduce this large constant value as an explicit + parameter, that takes the role of the "eps" in LayerNorm, so the network + doesn't have to do this trick. We make the "eps" learnable. + + Args: + num_channels: the number of channels, e.g. 512. + channel_dim: the axis/dimension corresponding to the channel, + interprted as an offset from the input's ndim if negative. + shis is NOT the num_channels; it should typically be one of + {-2, -1, 0, 1, 2, 3}. + eps: the initial "epsilon" that we add as ballast in: + scale = ((input_vec**2).mean() + epsilon)**-0.5 + Note: our epsilon is actually large, but we keep the name + to indicate the connection with conventional LayerNorm. + learn_eps: if true, we learn epsilon; if false, we keep it + at the initial value. + eps_min: float + eps_max: float + """ + + def __init__( + self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + eps: float = 0.25, + learn_eps: bool = True, + eps_min: float = -3.0, + eps_max: float = 3.0, + ) -> None: + super(BasicNorm, self).__init__() + self.num_channels = num_channels + self.channel_dim = channel_dim + if learn_eps: + self.eps = nn.Parameter(torch.tensor(eps).log().detach()) + else: + self.register_buffer("eps", torch.tensor(eps).log().detach()) + self.eps_min = eps_min + self.eps_max = eps_max + + def forward(self, x: Tensor) -> Tensor: + assert x.shape[self.channel_dim] == self.num_channels + eps = self.eps + if self.training and random.random() < 0.25: + # with probability 0.25, in training mode, clamp eps between the min + # and max; this will encourage it to learn parameters within the + # allowed range by making parameters that are outside the allowed + # range noisy. + + # gradients to allow the parameter to get back into the allowed + # region if it happens to exit it. + eps = eps.clamp(min=self.eps_min, max=self.eps_max) + scales = ( + torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp() + ) ** -0.5 + return x * scales + + + +def ScaledLinear(*args, + initial_scale: float = 1.0, + **kwargs ) -> nn.Linear: + """ + Behaves like a constructor of a modified version of nn.Linear + that gives an easy way to set the default initial parameter scale. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Linear(*args, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_(ans.bias, + -0.1 * initial_scale, + 0.1 * initial_scale) + return ans + + + +def ScaledConv1d(*args, + initial_scale: float = 1.0, + **kwargs ) -> nn.Conv1d: + """ + Behaves like a constructor of a modified version of nn.Conv1d + that gives an easy way to set the default initial parameter scale. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Conv1d(*args, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_(ans.bias, + -0.1 * initial_scale, + 0.1 * initial_scale) + return ans + + + +class ActivationBalancer(torch.nn.Module): + """ + Modifies the backpropped derivatives of a function to try to encourage, for + each channel, that it is positive at least a proportion `threshold` of the + time. It does this by multiplying negative derivative values by up to + (1+max_factor), and positive derivative values by up to (1-max_factor), + interpolated from 1 at the threshold to those extremal values when none + of the inputs are positive. + + Args: + num_channels: the number of channels + channel_dim: the dimension/axis corresponding to the channel, e.g. + -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. + min_positive: the minimum, per channel, of the proportion of the time + that (x > 0), below which we start to modify the derivatives. + max_positive: the maximum, per channel, of the proportion of the time + that (x > 0), above which we start to modify the derivatives. + max_factor: the maximum factor by which we modify the derivatives for + either the sign constraint or the magnitude constraint; + e.g. with max_factor=0.02, the the derivatives would be multiplied by + values in the range [0.98..1.02]. + sign_gain_factor: determines the 'gain' with which we increase the + change in gradient once the constraints on min_positive and max_positive + are violated. + scale_gain_factor: determines the 'gain' with which we increase the + change in gradient once the constraints on min_abs and max_abs + are violated. + min_abs: the minimum average-absolute-value difference from the mean + value per channel, which we allow, before we start to modify + the derivatives to prevent this. + max_abs: the maximum average-absolute-value difference from the mean + value per channel, which we allow, before we start to modify + the derivatives to prevent this. + min_prob: determines the minimum probability with which we modify the + gradients for the {min,max}_positive and {min,max}_abs constraints, + on each forward(). This is done randomly to prevent all layers + from doing it at the same time. Early in training we may use + higher probabilities than this; it will decay to this value. + """ + def __init__( + self, + num_channels: int, + channel_dim: int, + min_positive: float = 0.05, + max_positive: float = 0.95, + max_factor: float = 0.04, + sign_gain_factor: float = 0.01, + scale_gain_factor: float = 0.02, + min_abs: float = 0.2, + max_abs: float = 100.0, + min_prob: float = 0.1, + ): + super(ActivationBalancer, self).__init__() + self.num_channels = num_channels + self.channel_dim = channel_dim + self.min_positive = min_positive + self.max_positive = max_positive + self.max_factor = max_factor + self.min_abs = min_abs + self.max_abs = max_abs + self.min_prob = min_prob + self.sign_gain_factor = sign_gain_factor + self.scale_gain_factor = scale_gain_factor + + # count measures how many times the forward() function has been called. + # We occasionally sync this to a tensor called `count`, that exists to + # make sure it is synced to disk when we load and save the model. + self.cpu_count = 0 + self.register_buffer('count', torch.tensor(0, dtype=torch.int64)) + + + + def forward(self, x: Tensor) -> Tensor: + if torch.jit.is_scripting() or not x.requires_grad: + return _no_op(x) + + count = self.cpu_count + self.cpu_count += 1 + + if random.random() < 0.01: + # Occasionally sync self.cpu_count with self.count. + # count affects the decay of 'prob'. don't do this on every iter, + # because syncing with the GPU is slow. + self.cpu_count = max(self.cpu_count, self.count.item()) + self.count.fill_(self.cpu_count) + + # the prob of doing some work exponentially decreases from 0.5 till it hits + # a floor at min_prob (==0.1, by default) + prob = max(self.min_prob, 0.5 ** (1 + (count/4000.0))) + + if random.random() < prob: + sign_gain_factor = 0.5 + if self.min_positive != 0.0 or self.max_positive != 1.0: + sign_factor = _compute_sign_factor(x, self.channel_dim, + self.min_positive, self.max_positive, + gain_factor=self.sign_gain_factor / prob, + max_factor=self.max_factor) + else: + sign_factor = None + + + scale_factor = _compute_scale_factor(x, self.channel_dim, + min_abs=self.min_abs, + max_abs=self.max_abs, + gain_factor=self.scale_gain_factor / prob, + max_factor=self.max_factor) + return ActivationBalancerFunction.apply( + x, scale_factor, sign_factor, self.channel_dim, + ) + else: + return _no_op(x) + + +def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor: + """ + Returns x unmodified, but in backprop will put a penalty for the excess of + the absolute values of elements of x over the limit "limit". E.g. if + limit == 10.0, then if x has any values over 10 it will get a penalty. + + Caution: the value of this penalty will be affected by grad scaling used + in automatic mixed precision training. For this reasons we use this, + it shouldn't really matter, or may even be helpful; we just use this + to disallow really implausible values of scores to be given to softmax. + """ + x_sign = x.sign() + over_limit = (x.abs() - limit) > 0 + # The following is a memory efficient way to penalize the absolute values of + # x that's over the limit. (The memory efficiency comes when you think + # about which items torch needs to cache for the autograd, and which ones it + # can throw away). The numerical value of aux_loss as computed here will + # actually be larger than it should be, by limit * over_limit.sum(), but it + # has the same derivative as the real aux_loss which is penalty * (x.abs() - + # limit).relu(). + aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x) + # note: we don't do sum() here on aux)_loss, but it's as if we had done + # sum() due to how with_loss() works. + x = with_loss(x, aux_loss) + # you must use x for something, or this will be ineffective. + return x + + +def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. + if x.ndim == 2: + return x.diag() + else: + (batch, dim, dim) = x.shape + x = x.reshape(batch, dim * dim) + x = x[:, ::dim+1] + assert x.shape == (batch, dim) + return x + + +def _whitening_metric(x: Tensor, + num_groups: int): + """ + Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of + of the centered feature covariance are the same within each group's covariance matrix + and also between groups. + Args: + x: a Tensor of shape (*, num_channels) + num_groups: the number of groups of channels, a number >=1 that divides num_channels + Returns: + Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and + greater than 1.0 otherwise. + """ + assert x.dtype != torch.float16 + x = x.reshape(-1, x.shape[-1]) + (num_frames, num_channels) = x.shape + assert num_channels % num_groups == 0 + channels_per_group = num_channels // num_groups + x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1) + # x now has shape (num_groups, num_frames, channels_per_group) + # subtract the mean so we use the centered, not uncentered, covariance. + # My experience has been that when we "mess with the gradients" like this, + # it's better not do anything that tries to move the mean around, because + # that can easily cause instability. + x = x - x.mean(dim=1, keepdim=True) + # x_covar: (num_groups, channels_per_group, channels_per_group) + x_covar = torch.matmul(x.transpose(1, 2), x) + x_covar_mean_diag = _diag(x_covar).mean() + # the following expression is what we'd get if we took the matrix product + # of each covariance and measured the mean of its trace, i.e. + # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). + x_covarsq_mean_diag = (x_covar ** 2).sum() / (num_groups * channels_per_group) + # this metric will be >= 1.0; the larger it is, the less 'white' the data was. + metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20) + return metric + + +class WhiteningPenaltyFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, + x: Tensor, + num_groups: int, + whitening_limit: float, + grad_scale: float) -> Tensor: + ctx.save_for_backward(x) + ctx.num_groups = num_groups + ctx.whitening_limit = whitening_limit + ctx.grad_scale = grad_scale + return x + + @staticmethod + def backward(ctx, + x_grad: Tensor): + x_orig, = ctx.saved_tensors + with torch.enable_grad(): + with torch.cuda.amp.autocast(enabled=False): + x_detached = x_orig.to(torch.float32).detach() + x_detached.requires_grad = True + + metric = _whitening_metric(x_detached, ctx.num_groups) + + if random.random() < 0.005 or __name__ == "__main__": + logging.info(f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, " + f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}") + + (metric - ctx.whitening_limit).relu().backward() + penalty_grad = x_detached.grad + scale = ctx.grad_scale * (x_grad.to(torch.float32).norm() / + (penalty_grad.norm() + 1.0e-20)) + penalty_grad = penalty_grad * scale + return x_grad + penalty_grad.to(x_grad.dtype), None, None, None + + + +class Whiten(nn.Module): + def __init__( + self, + num_groups: int, + whitening_limit: float, + prob: Union[float, Tuple[float,float]], + grad_scale: float): + """ + Args: + num_groups: the number of groups to divide the channel dim into before + whitening. We will attempt to make the feature covariance + within each group, after mean subtraction, as "white" as possible, + while having the same trace across all groups. + whitening_limit: a value greater than 1.0, that dictates how much + freedom we have to violate the constraints. 1.0 would mean perfectly + white, with exactly the same trace across groups; larger values + give more freedom. E.g. 2.0. + prob: the probability with which we apply the gradient modification + (also affects the grad scale). May be supplied as a float, + or as a pair (min_prob, max_prob) + + grad_scale: determines the scale on the gradient term from this object, + relative to the rest of the gradient on the attention weights. + E.g. 0.02 (you may want to use smaller values than this if prob is large) + """ + super(Whiten, self).__init__() + assert num_groups >= 1 + assert whitening_limit >= 1 + assert grad_scale >= 0 + self.num_groups = num_groups + self.whitening_limit = whitening_limit + if isinstance(prob, float): + assert 0 < prob <= 1 + self.prob = prob + else: + (self.min_prob, self.max_prob) = prob + assert 0 < self.min_prob < self.max_prob <= 1 + self.prob = self.max_prob + + self.grad_scale = grad_scale + + def forward(self, + x: Tensor) -> Tensor: + """ + In the forward pass, this function just returns the input unmodified. + In the backward pass, it will modify the gradients to ensure that the + distribution in each group has close to (lambda times I) as the covariance + after mean subtraction, with the same lambda across groups. + For whitening_limit > 1, there will be more freedom to violate this + constraint. + + Args: + x: the input of shape (*, num_channels) + + Returns: + x, unmodified. You should make sure + you use the returned value, or the graph will be freed + and nothing will happen in backprop. + """ + if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0: + return _no_op(x) + else: + if hasattr(self, 'min_prob') and random.random() < 0.25: + # occasionally switch between min_prob and max_prob, based on whether + # we are above or below the threshold. + if _whitening_metric(x.to(torch.float32), self.num_groups) > self.whitening_limit: + # there would be a change to the grad. + self.prob = self.max_prob + else: + self.prob = self.min_prob + + return WhiteningPenaltyFunction.apply(x, + self.num_groups, + self.whitening_limit, + self.grad_scale) + + +class WithLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, y: Tensor): + ctx.y_shape = y.shape + return x + @staticmethod + def backward(ctx, ans_grad: Tensor): + return ans_grad, torch.ones(ctx.y_shape, + dtype=ans_grad.dtype, + device=ans_grad.device) +def with_loss(x, y): + if torch.jit.is_scripting(): + return x + # returns x but adds y.sum() to the loss function. + return WithLoss.apply(x, y) + + +def _no_op(x: Tensor) -> Tensor: + if (torch.jit.is_scripting()): + return x + else: + # a no-op function that will have a node in the autograd graph, + # to avoid certain bugs relating to backward hooks + return x.chunk(1, dim=-1)[0] + + +class Identity(torch.nn.Module): + def __init__(self): + super(Identity, self).__init__() + + def forward(self, x): + return _no_op(x) + +class MaxEig(torch.nn.Module): + """ + Modifies the backpropped derivatives of a function to try to discourage + that any given direction in activation space accounts for more than + a specified proportion of the covariance (e.g. 0.2). + + + Args: + num_channels: the number of channels + channel_dim: the dimension/axis corresponding to the channel, e.g. + -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. + max_var_per_eig: the maximum proportion of the variance of the + features/channels, after mean subtraction, that can come from + any given eigenvalue. + min_prob: the minimum probability with which we apply this during any invocation + of forward(), assuming last time we applied the constraint it was + not active; supplied for speed. + scale: determines the scale with which we modify the gradients, relative + to the existing / unmodified gradients + """ + def __init__( + self, + num_channels: int, + channel_dim: int, + max_var_per_eig: float = 0.2, + min_prob: float = 0.01, + scale: float = 0.01, + ): + super(MaxEig, self).__init__() + self.num_channels = num_channels + self.channel_dim = channel_dim + self.scale = scale + assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels + self.max_var_per_eig = max_var_per_eig + + # we figure out the dominant direction using the power method: starting with + # a random vector, keep multiplying by the covariance and renormalizing. + with torch.no_grad(): + # arbitrary.. would use randn() but want to leave the rest of the model's + # random parameters unchanged for comparison + direction = torch.arange(num_channels).to(torch.float) + direction = direction / direction.norm() + self.register_buffer('max_eig_direction', direction) + + self.min_prob = min_prob + # cur_prob is the current probability we'll use to apply the ActivationBalancer. + # We'll regress this towards prob, each tiem we try to apply it and it is not + # active. + self.cur_prob = 1.0 + + + + def forward(self, x: Tensor) -> Tensor: + if (torch.jit.is_scripting() or + self.max_var_per_eig <= 0 or + random.random() > self.cur_prob): + return _no_op(x) + + with torch.cuda.amp.autocast(enabled=False): + eps = 1.0e-20 + orig_x = x + x = x.to(torch.float32) + with torch.no_grad(): + x = x.transpose(self.channel_dim, -1).reshape(-1, self.num_channels) + x = x - x.mean(dim=0) + new_direction, coeffs = self._find_direction_coeffs(x, self.max_eig_direction) + x_var = (x**2).mean() + x_residual = x - coeffs * new_direction + x_residual_var = (x_residual**2).mean() + + # `variance_proportion` is the proportion of the variance accounted for + # by the top eigen-direction. + variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) + + # ensure new direction is nonzero even if x == 0, by including `direction`. + self._set_direction(0.1 * self.max_eig_direction + new_direction) + + if random.random() < 0.01 or __name__ == "__main__": + logging.info(f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}") + + if variance_proportion >= self.max_var_per_eig: + # The constraint is active. Note, we should quite rarely + # reach here, only near the beginning of training if we are + # starting to diverge, should this constraint be active. + cur_prob = self.cur_prob + self.cur_prob = 1.0 # next time, do the update with probability 1.0. + return MaxEigLimiterFunction.apply(orig_x, coeffs, new_direction, + self.channel_dim, self.scale) + else: + # let self.cur_prob exponentially approach self.min_prob, as + # long as the constraint is inactive. + self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob + return orig_x + + + def _set_direction(self, + direction: Tensor): + """ + Sets self.max_eig_direction to a normalized version of `direction` + """ + direction = direction.detach() + direction = direction / direction.norm() + direction_sum = direction.sum().item() + if direction_sum - direction_sum == 0: # no inf/nan + self.max_eig_direction[:] = direction + else: + logging.info(f"Warning: sum of direction in MaxEig is {direction_sum}, " + "num_channels={self.num_channels}, channel_dim={self.channel_dim}") + + + def _find_direction_coeffs(self, + x: Tensor, + prev_direction: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + """ + Figure out (an approximation to) the proportion of the variance of a set of + feature vectors that can be attributed to the top eigen-direction. + Args: + x: a Tensor of shape (num_frames, num_channels), with num_frames > 1. + prev_direction: a Tensor of shape (num_channels,), that is our previous estimate + of the top eigen-direction, or a random direction if this is the first + iteration. Does not have to be normalized, but should be nonzero. + + Returns: (cur_direction, coeffs), where: + cur_direction: a Tensor of shape (num_channels,) that is the current + estimate of the top eigen-direction. + coeffs: a Tensor of shape (num_frames, 1) that minimizes, or + approximately minimizes, (x - coeffs * cur_direction).norm() + """ + (num_frames, num_channels) = x.shape + assert num_channels > 1 and num_frames > 1 + assert prev_direction.shape == (num_channels,) + # `coeffs` are the coefficients of `prev_direction` in x. + # actually represent the coeffs up to a constant positive factor. + coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10 + cur_direction = (x * coeffs).sum(dim=0) / ((coeffs ** 2).sum() + 1.0e-20) + return cur_direction, coeffs + + + + +class DoubleSwishFunction(torch.autograd.Function): + """ + double_swish(x) = x * torch.sigmoid(x-1) + This is a definition, originally motivated by its close numerical + similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). + + Memory-efficient derivative computation: + double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) + double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). + Now, s'(x) = s(x) * (1-s(x)). + double_swish'(x) = x * s'(x) + s(x). + = x * s(x) * (1-s(x)) + s(x). + = double_swish(x) * (1-s(x)) + s(x) + ... so we just need to remember s(x) but not x itself. + """ + + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + requires_grad = x.requires_grad + x_dtype = x.dtype + if x.dtype == torch.float16: + x = x.to(torch.float32) + + s = torch.sigmoid(x - 1.0) + y = x * s + + if requires_grad: + deriv = (y * (1 - s) + s) + # notes on derivative of x * sigmoid(x - 1): + # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 + # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund + # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound. + # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which + # floors), should be expectation-preserving. + floor = -0.043637 + ceil = 1.2 + d_scaled = ((deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv)) + if __name__ == "__main__": + # for self-testing only. + assert d_scaled.min() >= 0.0 + assert d_scaled.max() < 256.0 + d_int = d_scaled.to(torch.uint8) + ctx.save_for_backward(d_int) + if x.dtype == torch.float16 or torch.is_autocast_enabled(): + y = y.to(torch.float16) + return y + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + d, = ctx.saved_tensors + # the same constants as used in forward pass. + floor = -0.043637 + ceil = 1.2 + d = (d * ((ceil - floor) / 255.0) + floor) + return (y_grad * d) + + +class DoubleSwish(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return double-swish activation function which is an approximation to Swish(Swish(x)), + that we approximate closely with x * sigmoid(x-1). + """ + if torch.jit.is_scripting(): + return x * torch.sigmoid(x - 1.0) + return DoubleSwishFunction.apply(x) + + + +def _test_max_eig(): + for proportion in [0.1, 0.5, 10.0]: + logging.info(f"proportion = {proportion}") + x = torch.randn(100, 128) + direction = torch.randn(128) + coeffs = torch.randn(100, 1) + x += proportion * direction * coeffs + + x.requires_grad = True + + num_channels = 128 + m = MaxEig(num_channels, + 1, # channel_dim + 0.5, # max_var_per_eig + scale=0.1) # grad_scale + + + for _ in range(4): + y = m(x) + + y_grad = torch.randn_like(x) + y.backward(gradient=y_grad) + + if proportion < 0.2: + assert torch.allclose(x.grad, y_grad, atol=1.0e-02) + elif proportion > 1.0: + assert not torch.allclose(x.grad, y_grad) + + +def _test_whiten(): + for proportion in [0.1, 0.5, 10.0]: + logging.info(f"_test_whiten(): proportion = {proportion}") + x = torch.randn(100, 128) + direction = torch.randn(128) + coeffs = torch.randn(100, 1) + x += proportion * direction * coeffs + + x.requires_grad = True + + num_channels = 128 + m = Whiten(1, # num_groups + 5.0, # whitening_limit, + prob=1.0, + grad_scale=0.1) # grad_scale + + + for _ in range(4): + y = m(x) + + y_grad = torch.randn_like(x) + y.backward(gradient=y_grad) + + if proportion < 0.2: + assert torch.allclose(x.grad, y_grad) + elif proportion > 1.0: + assert not torch.allclose(x.grad, y_grad) + + + +def _test_activation_balancer_sign(): + probs = torch.arange(0, 1, 0.01) + N = 1000 + x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0) + x = x.detach() + x.requires_grad = True + m = ActivationBalancer( + probs.numel(), + channel_dim=0, + min_positive=0.05, + max_positive=0.95, + max_factor=0.2, + min_abs=0.0, + ) + + y_grad = torch.sign(torch.randn(probs.numel(), N)) + + y = m(x) + y.backward(gradient=y_grad) + print("_test_activation_balancer_sign: x = ", x) + print("_test_activation_balancer_sign: y grad = ", y_grad) + print("_test_activation_balancer_sign: x grad = ", x.grad) + + +def _test_activation_balancer_magnitude(): + magnitudes = torch.arange(0, 1, 0.01) + N = 1000 + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( + -1 + ) + x = x.detach() + x.requires_grad = True + m = ActivationBalancer( + magnitudes.numel(), + channel_dim=0, + min_positive=0.0, + max_positive=1.0, + max_factor=0.2, + min_abs=0.2, + max_abs=0.8, + min_prob=1.0, + ) + + y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) + + y = m(x) + y.backward(gradient=y_grad) + print("_test_activation_balancer_magnitude: x = ", x) + print("_test_activation_balancer_magnitude: y grad = ", y_grad) + print("_test_activation_balancer_magnitude: x grad = ", x.grad) + + +def _test_basic_norm(): + num_channels = 128 + m = BasicNorm(num_channels=num_channels, channel_dim=1) + + x = torch.randn(500, num_channels) + + y = m(x) + + assert y.shape == x.shape + x_rms = (x ** 2).mean().sqrt() + y_rms = (y ** 2).mean().sqrt() + print("x rms = ", x_rms) + print("y rms = ", y_rms) + assert y_rms < x_rms + assert y_rms > 0.5 * x_rms + + +def _test_double_swish_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 3.0 + x.requires_grad = True + m = DoubleSwish() + + tol = ((1.2-(-0.043637))/255.0) + torch.autograd.gradcheck(m, x, atol=tol) + + + # for self-test. + x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 + x.requires_grad = True + y = m(x) + + + +def _test_softmax(): + a = torch.randn(2, 10, dtype=torch.float64) + b = a.clone() + a.requires_grad = True + b.requires_grad = True + a.softmax(dim=1)[:,0].sum().backward() + print("a grad = ", a.grad) + softmax(b, dim=1)[:,0].sum().backward() + print("b grad = ", b.grad) + assert torch.allclose(a.grad, b.grad) + + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_softmax() + _test_whiten() + _test_max_eig() + _test_activation_balancer_sign() + _test_activation_balancer_magnitude() + _test_basic_norm() + _test_double_swish_deriv() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py new file mode 100644 index 000000000..8d357b15f --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py @@ -0,0 +1,118 @@ +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file replaces various modules in a model. +Specifically, ActivationBalancer is replaced with an identity operator; +Whiten is also replaced with an identity operator; +BasicNorm is replaced by a module with `exp` removed. +""" + +import copy +from typing import List + +import torch +import torch.nn as nn +from scaling import ( + ActivationBalancer, + BasicNorm, + Whiten, +) + + +class NonScaledNorm(nn.Module): + """See BasicNorm for doc""" + + def __init__( + self, + num_channels: int, + eps_exp: float, + channel_dim: int = -1, # CAUTION: see documentation. + ): + super().__init__() + self.num_channels = num_channels + self.channel_dim = channel_dim + self.eps_exp = eps_exp + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if not torch.jit.is_tracing(): + assert x.shape[self.channel_dim] == self.num_channels + scales = ( + torch.mean(x * x, dim=self.channel_dim, keepdim=True) + self.eps_exp + ).pow(-0.5) + return x * scales + + +def convert_basic_norm(basic_norm: BasicNorm) -> NonScaledNorm: + assert isinstance(basic_norm, BasicNorm), type(BasicNorm) + norm = NonScaledNorm( + num_channels=basic_norm.num_channels, + eps_exp=basic_norm.eps.data.exp().item(), + channel_dim=basic_norm.channel_dim, + ) + return norm + + +# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa +# get_submodule was added to nn.Module at v1.9.0 +def get_submodule(model, target): + if target == "": + return model + atoms: List[str] = target.split(".") + mod: torch.nn.Module = model + for item in atoms: + if not hasattr(mod, item): + raise AttributeError( + mod._get_name() + " has no " "attribute `" + item + "`" + ) + mod = getattr(mod, item) + if not isinstance(mod, torch.nn.Module): + raise AttributeError("`" + item + "` is not " "an nn.Module") + return mod + + +def convert_scaled_to_non_scaled( + model: nn.Module, + inplace: bool = False, +): + """ + Args: + model: + The model to be converted. + inplace: + If True, the input model is modified inplace. + If False, the input model is copied and we modify the copied version. + Return: + Return a model without scaled layers. + """ + if not inplace: + model = copy.deepcopy(model) + + d = {} + for name, m in model.named_modules(): + if isinstance(m, BasicNorm): + d[name] = convert_basic_norm(m) + elif isinstance(m, (ActivationBalancer, Whiten)): + d[name] = nn.Identity() + + for k, v in d.items(): + if "." in k: + parent, child = k.rsplit(".", maxsplit=1) + setattr(get_submodule(model, parent), child, v) + else: + setattr(model, k, v) + + return model diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/test_model.py new file mode 100755 index 000000000..db7fb7b3e --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/test_model.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./pruned_transducer_stateless4/test_model.py +""" + +from train import get_params, get_transducer_model + + +def test_model_1(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.num_encoder_layers = "2,4,3,2,4" + # params.feedforward_dims = "1024,1024,1536,1536,1024" + params.feedforward_dims = "1024,1024,2048,2048,1024" + params.nhead = "8,8,8,8,8" + params.encoder_dims = "384,384,384,384,384" + params.attention_dims = "192,192,192,192,192" + params.encoder_unmasked_dims = "256,256,256,256,256" + params.zipformer_downsampling_factors = "1,2,4,8,2" + params.cnn_module_kernels = "31,31,31,31,31" + params.decoder_dim = 512 + params.joiner_dim = 512 + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + +def main(): + test_model_1() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py new file mode 100755 index 000000000..8927be227 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -0,0 +1,1217 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7/exp \ + --full-libri 1 \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless7/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7/exp \ + --full-libri 1 \ + --max-duration 550 + +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from zipformer import Zipformer +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.hooks import register_inf_check_hooks +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] + + +def set_batch_count( + model: Union[nn.Module, DDP], batch_count: float +) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, 'batch_count'): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated" + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""" + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse." + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", + type=float, + default=0.05, + help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" + "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(','))) + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple(params.zipformer_downsampling_factors), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(',')[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(',')[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = ( + simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss + ) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError(f"grad_scale is too small, exiting: {cur_grad_scale}") + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + if params.full_libri is False: + params.valid_interval = 1600 + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], + find_unused_parameters=True) + + optimizer = ScaledAdam(model.parameters(), + lr=params.base_lr, + clipping_scale=2.0) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2 ** 22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeechAsrDataModule(args) + + train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 1.0 <= c.duration <= 20.0 + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, + init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB") + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py new file mode 100644 index 000000000..c14066d38 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -0,0 +1,1858 @@ +#!/usr/bin/env python3 +# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import math +import warnings +import itertools +from typing import List, Optional, Tuple, Union +import logging +import torch +import random +from encoder_interface import EncoderInterface +from scaling import ( + ActivationBalancer, + BasicNorm, + MaxEig, + DoubleSwish, + ScaledConv1d, + ScaledLinear, # not as in other dirs.. just scales down initial parameter values. + Whiten, + Identity, + _diag, + random_clamp, + penalize_abs_values_gt, + softmax, +) +from torch import Tensor, nn + +from icefall.utils import make_pad_mask +from icefall.dist import get_rank + + +class Zipformer(EncoderInterface): + """ + Args: + num_features (int): Number of input features + d_model: (int,int): embedding dimension of 2 encoder stacks + attention_dim: (int,int): attention dimension of 2 encoder stacks + nhead (int, int): number of heads + dim_feedforward (int, int): feedforward dimension in 2 encoder stacks + num_encoder_layers (int): number of encoder layers + dropout (float): dropout rate + cnn_module_kernel (int): Kernel size of convolution module + vgg_frontend (bool): whether to use vgg frontend. + warmup_batches (float): number of batches to warm up over + """ + + def __init__( + self, + num_features: int, + output_downsampling_factor: int = 2, + encoder_dims: Tuple[int] = (384, 384), + attention_dim: Tuple[int] = (256, 256), + encoder_unmasked_dims: Tuple[int] = (256, 256), + zipformer_downsampling_factors: Tuple[int] = (2, 4), + nhead: Tuple[int] = (8, 8), + feedforward_dim: Tuple[int] = (1536, 2048), + num_encoder_layers: Tuple[int] = (12, 12), + dropout: float = 0.1, + cnn_module_kernels: Tuple[int] = (31, 31), + pos_dim: int = 4, + warmup_batches: float = 4000.0, + ) -> None: + super(Zipformer, self).__init__() + + self.num_features = num_features + self.encoder_unmasked_dims = encoder_unmasked_dims + assert 0 < encoder_dims[0] <= encoder_dims[1] + self.encoder_dims = encoder_dims + self.encoder_unmasked_dims = encoder_unmasked_dims + self.zipformer_downsampling_factors = zipformer_downsampling_factors + self.output_downsampling_factor = output_downsampling_factor + + # will be written to, see set_batch_count() + self.batch_count = 0 + self.warmup_end = warmup_batches + + for u,d in zip(encoder_unmasked_dims, encoder_dims): + assert u <= d, (u, d) + + # self.encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7)//2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7)//2 + # (2) embedding: num_features -> encoder_dims + self.encoder_embed = Conv2dSubsampling(num_features, encoder_dims[0], + dropout=dropout) + + + # each one will be ZipformerEncoder or DownsampledZipformerEncoder + encoders = [] + + num_encoders = len(encoder_dims) + for i in range(num_encoders): + encoder_layer = ZipformerEncoderLayer( + encoder_dims[i], + attention_dim[i], + nhead[i], + feedforward_dim[i], + dropout, + cnn_module_kernels[i], + pos_dim, + ) + + # For the segment of the warmup period, we let the Conv2dSubsampling + # layer learn something. Then we start to warm up the other encoders. + encoder = ZipformerEncoder( + encoder_layer, + num_encoder_layers[i], + dropout, + warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), + warmup_end=warmup_batches * (i + 2) / (num_encoders + 1) + ) + + if zipformer_downsampling_factors[i] != 1: + encoder = DownsampledZipformerEncoder( + encoder, + input_dim=encoder_dims[i-1] if i > 0 else encoder_dims[0], + output_dim=encoder_dims[i], + downsample=zipformer_downsampling_factors[i], + ) + encoders.append(encoder) + self.encoders = nn.ModuleList(encoders) + + # initializes self.skip_layers and self.skip_modules + self._init_skip_modules() + + self.downsample_output = AttentionDownsample(encoder_dims[-1], + encoder_dims[-1], + downsample=output_downsampling_factor) + + + def _get_layer_skip_dropout_prob(self): + if not self.training: + return 0.0 + batch_count = self.batch_count + min_dropout_prob = 0.025 + + if batch_count > self.warmup_end: + return min_dropout_prob + else: + return 0.5 - (batch_count / self.warmup_end) * (0.5 - min_dropout_prob) + + def _init_skip_modules(self): + """ + If self.zipformer_downampling_factors = (1, 2, 4, 8, 4, 2), then at the input of layer + indexed 4 (in zero indexing), with has subsapling_factor=4, we combine the output of + layers 2 and 3; and at the input of layer indexed 5, which which has subsampling_factor=2, + we combine the outputs of layers 1 and 5. + """ + skip_layers = [] + skip_modules = [] + z = self.zipformer_downsampling_factors + for i in range(len(z)): + if i <= 1 or z[i-1] <= z[i]: + skip_layers.append(None) + skip_modules.append(SimpleCombinerIdentity()) + else: + # TEMP + for j in range(i-2, -1, -1): + if z[j] <= z[i] or j == 0: + # TEMP logging statement. + logging.info(f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will " + f"combine the outputs of layers {j} and {i-1}, with downsampling_factors={z[j]} and {z[i-1]}.") + skip_layers.append(j) + skip_modules.append(SimpleCombiner(self.encoder_dims[j], + self.encoder_dims[i-1], + min_weight=(0.0,0.25))) + break + self.skip_layers = skip_layers + self.skip_modules = nn.ModuleList(skip_modules) + + def get_feature_masks( + self, + x: torch.Tensor) -> List[float]: + # Note: The actual return type is Union[List[float], List[Tensor]], + # but to make torch.jit.script() work, we use List[float] + """ + In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of + randomized feature masks, one per encoder. + On e.g. 15% of frames, these masks will zero out all enocder dims larger than + some supplied number, e.g. >256, so in effect on those frames we are using + a smaller encoer dim. + + We generate the random masks at this level because we want the 2 masks to 'agree' + all the way up the encoder stack. This will mean that the 1st mask will have + mask values repeated self.zipformer_subsampling_factor times. + + Args: + x: the embeddings (needed for the shape and dtype and device), of shape + (num_frames, batch_size, encoder_dims0) + """ + num_encoders = len(self.encoder_dims) + if torch.jit.is_scripting() or not self.training: + return [ 1.0 ] * num_encoders + + (num_frames0, batch_size, _encoder_dims0) = x.shape + + + assert self.encoder_dims[0] == _encoder_dims0, (self.encoder_dims, _encoder_dims0) + + max_downsampling_factor = max(self.zipformer_downsampling_factors) + + num_frames_max = (num_frames0 + max_downsampling_factor - 1) + + + feature_mask_dropout_prob = 0.15 + + # frame_mask_max shape: (num_frames_max, batch_size, 1) + frame_mask_max = (torch.rand(num_frames_max, batch_size, 1, + device=x.device) > + feature_mask_dropout_prob).to(x.dtype) + + feature_masks = [] + for i in range(num_encoders): + ds = self.zipformer_downsampling_factors[i] + upsample_factor = (max_downsampling_factor // ds) + + frame_mask = (frame_mask_max.unsqueeze(1).expand(num_frames_max, upsample_factor, + batch_size, 1) + .reshape(num_frames_max * upsample_factor, batch_size, 1)) + num_frames = (num_frames0 + ds - 1) // ds + frame_mask = frame_mask[:num_frames] + feature_mask = torch.ones(num_frames, batch_size, self.encoder_dims[i], + dtype=x.dtype, device=x.device) + u = self.encoder_unmasked_dims[i] + feature_mask[:, :, u:] *= frame_mask + feature_masks.append(feature_mask) + + return feature_masks + + + def forward( + self, x: torch.Tensor, x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + Returns: + Return a tuple containing 2 tensors: + - embeddings: its shape is (batch_size, output_seq_len, encoder_dims[-1]) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + """ + x = self.encoder_embed(x) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + lengths = (x_lens - 7) >> 1 + assert x.size(0) == lengths.max().item(), (x.shape, lengths, lengths.max()) + mask = make_pad_mask(lengths) + + outputs = [] + feature_masks = self.get_feature_masks(x) + + for i, (module, skip_module) in enumerate(zip(self.encoders, self.skip_modules)): + ds = self.zipformer_downsampling_factors[i] + k = self.skip_layers[i] + if isinstance(k, int): + layer_skip_dropout_prob = self._get_layer_skip_dropout_prob() + if torch.jit.is_scripting(): + x = skip_module(outputs[k], x) + elif (not self.training) or random.random() > layer_skip_dropout_prob: + x = skip_module(outputs[k], x) + x = module(x, + feature_mask=feature_masks[i], + src_key_padding_mask=None if mask is None else mask[...,::ds]) + outputs.append(x) + + x = self.downsample_output(x) + # class Downsample has this rounding behavior.. + assert self.output_downsampling_factor == 2, self.output_downsampling_factor + lengths = (lengths + 1) >> 1 + + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return x, lengths + + +class ZipformerEncoderLayer(nn.Module): + """ + ZipformerEncoderLayer is made up of self-attn, feedforward and convolution networks. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + feedforward_dim: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module. + + Examples:: + >>> encoder_layer = ZipformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + def __init__( + self, + d_model: int, + attention_dim: int, + nhead: int, + feedforward_dim: int = 2048, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + pos_dim: int = 4, + ) -> None: + super(ZipformerEncoderLayer, self).__init__() + + self.d_model = d_model + + # will be written to, see set_batch_count() + self.batch_count = 0 + + self.self_attn = RelPositionMultiheadAttention( + d_model, attention_dim, nhead, pos_dim, dropout=0.0, + ) + + self.pooling = PoolingModule(d_model) + + self.feed_forward1 = FeedforwardModule(d_model, + feedforward_dim, + dropout) + + self.feed_forward2 = FeedforwardModule(d_model, + feedforward_dim, + dropout) + + self.feed_forward3 = FeedforwardModule(d_model, + feedforward_dim, + dropout) + + + self.conv_module1 = ConvolutionModule(d_model, + cnn_module_kernel) + + self.conv_module2 = ConvolutionModule(d_model, + cnn_module_kernel) + + self.norm_final = BasicNorm(d_model) + + self.bypass_scale = nn.Parameter(torch.tensor(0.5)) + + # try to ensure the output is close to zero-mean (or at least, zero-median). + self.balancer = ActivationBalancer( + d_model, channel_dim=-1, + min_positive=0.45, max_positive=0.55, + max_abs=6.0, + ) + self.whiten = Whiten(num_groups=1, + whitening_limit=5.0, + prob=(0.025, 0.25), + grad_scale=0.01) + + def get_bypass_scale(self): + if torch.jit.is_scripting() or not self.training: + return self.bypass_scale + if random.random() < 0.1: + # ensure we get grads if self.bypass_scale becomes out of range + return self.bypass_scale + # hardcode warmup period for bypass scale + warmup_period = 20000.0 + initial_clamp_min = 0.75 + final_clamp_min = 0.25 + if self.batch_count > warmup_period: + clamp_min = final_clamp_min + else: + clamp_min = (initial_clamp_min - + (self.batch_count / warmup_period) * (initial_clamp_min - final_clamp_min)) + return self.bypass_scale.clamp(min=clamp_min, max=1.0) + + def get_dynamic_dropout_rate(self): + # return dropout rate for the dynamic modules (self_attn, pooling, convolution); this + # starts at 0.2 and rapidly decreases to 0. Its purpose is to keep the training stable + # at the beginning, by making the network focus on the feedforward modules. + if torch.jit.is_scripting() or not self.training: + return 0.0 + warmup_period = 2000.0 + initial_dropout_rate = 0.2 + final_dropout_rate = 0.0 + if self.batch_count > warmup_period: + return final_dropout_rate + else: + return (initial_dropout_rate - + (initial_dropout_rate * final_dropout_rate) * (self.batch_count / warmup_period)) + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + batch_split: if not None, this layer will only be applied to + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, N is the batch size, E is the feature number + """ + src_orig = src + + # macaron style feed forward module + src = src + self.feed_forward1(src) + + # dropout rate for submodules that interact with time. + dynamic_dropout = self.get_dynamic_dropout_rate() + + # pooling module + if torch.jit.is_scripting(): + src = src + self.pooling(src, key_padding_mask=src_key_padding_mask) + elif random.random() > dynamic_dropout: + src = src + self.pooling(src, key_padding_mask=src_key_padding_mask) + + if torch.jit.is_scripting(): + src_att, attn_weights = self.self_attn( + src, + pos_emb=pos_emb, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + ) + src = src + src_att + + src = src + self.conv_module1( + src, src_key_padding_mask=src_key_padding_mask + ) + + src = src + self.feed_forward2(src) + + src = src + self.self_attn.forward2(src, attn_weights) + + src = src + self.conv_module2( + src, src_key_padding_mask=src_key_padding_mask + ) + else: + use_self_attn = random.random() > dynamic_dropout + if use_self_attn: + src_att, attn_weights = self.self_attn( + src, + pos_emb=pos_emb, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + ) + src = src + src_att + + if random.random() > dynamic_dropout: + src = src + self.conv_module1( + src, src_key_padding_mask=src_key_padding_mask + ) + + src = src + self.feed_forward2(src) + if use_self_attn: + src = src + self.self_attn.forward2(src, attn_weights) + + if random.random() > dynamic_dropout: + src = src + self.conv_module2( + src, src_key_padding_mask=src_key_padding_mask + ) + + src = src + self.feed_forward3(src) + + src = self.norm_final(self.balancer(src)) + + delta = src - src_orig + + src = src_orig + delta * self.get_bypass_scale() + + return self.whiten(src) + + +class ZipformerEncoder(nn.Module): + r"""ZipformerEncoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the ZipformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + + Examples:: + >>> encoder_layer = ZipformerEncoderLayer(d_model=512, nhead=8) + >>> zipformer_encoder = ZipformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = zipformer_encoder(src) + """ + def __init__( + self, + encoder_layer: nn.Module, + num_layers: int, + dropout: float, + warmup_begin: float, + warmup_end: float + ) -> None: + super().__init__() + # will be written to, see set_batch_count() Note: in inference time this + # may be zero but should be treated as large, we can check if + # self.training is true. + self.batch_count = 0 + self.warmup_begin = warmup_begin + self.warmup_end = warmup_end + # module_seed is for when we need a random number that is unique to the module but + # shared across jobs. It's used to randomly select how many layers to drop, + # so that we can keep this consistent across worker tasks (for efficiency). + self.module_seed = torch.randint(0, 1000, ()).item() + + self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model, + dropout) + + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + assert 0 <= warmup_begin <= warmup_end, (warmup_begin, warmup_end) + + + delta = (1. / num_layers) * (warmup_end - warmup_begin) + cur_begin = warmup_begin + for i in range(num_layers): + self.layers[i].warmup_begin = cur_begin + cur_begin += delta + self.layers[i].warmup_end = cur_begin + + + def get_layers_to_drop(self, rnd_seed: int): + ans = set() + if not self.training: + return ans + + batch_count = self.batch_count + num_layers = len(self.layers) + + def get_layerdrop_prob(layer: int) -> float: + layer_warmup_begin = self.layers[layer].warmup_begin + layer_warmup_end = self.layers[layer].warmup_end + + initial_layerdrop_prob = 0.5 + final_layerdrop_prob = 0.05 + + if batch_count == 0: + # As a special case, if batch_count == 0, return 0 (drop no + # layers). This is rather ugly, I'm afraid; it is intended to + # enable our scan_pessimistic_batches_for_oom() code to work correctly + # so if we are going to get OOM it will happen early. + # also search for 'batch_count' with quotes in this file to see + # how we initialize the warmup count to a random number between + # 0 and 10. + return 0.0 + elif batch_count < layer_warmup_begin: + return initial_layerdrop_prob + elif batch_count > layer_warmup_end: + return final_layerdrop_prob + else: + # linearly interpolate + t = (batch_count - layer_warmup_begin) / layer_warmup_end + assert 0.0 <= t < 1.001, t + return initial_layerdrop_prob + t * (final_layerdrop_prob - initial_layerdrop_prob) + + shared_rng = random.Random(batch_count + self.module_seed) + independent_rng = random.Random(rnd_seed) + + layerdrop_probs = [ get_layerdrop_prob(i) for i in range(num_layers) ] + tot = sum(layerdrop_probs) + # Instead of drawing the samples independently, we first randomly decide + # how many layers to drop out, using the same random number generator between + # jobs so that all jobs drop out the same number (this is for speed). + # Then we use an approximate approach to drop out the individual layers + # with their specified probs while reaching this exact target. + num_to_drop = int(tot) + int(shared_rng.random() < (tot - int(tot))) + + layers = list(range(num_layers)) + independent_rng.shuffle(layers) + + # go through the shuffled layers until we get the required number of samples. + if num_to_drop > 0: + for layer in itertools.cycle(layers): + if independent_rng.random() < layerdrop_probs[layer]: + ans.add(layer) + if len(ans) == num_to_drop: + break + if shared_rng.random() < 0.005 or __name__ == "__main__": + logging.info(f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, " + f"batch_count={batch_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}") + return ans + + + def forward( + self, + src: Tensor, + # Note: The type of feature_mask should be Union[float, Tensor], + # but to make torch.jit.script() work, we use `float` here + feature_mask: float = 1.0, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer. + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + Returns: (x, x_no_combine), both of shape (S, N, E) + """ + pos_emb = self.encoder_pos(src) + output = src + + + if torch.jit.is_scripting(): + layers_to_drop = [] + else: + rnd_seed = src.numel() + random.randint(0, 1000) + layers_to_drop = self.get_layers_to_drop(rnd_seed) + + output = output * feature_mask + + for i, mod in enumerate(self.layers): + if not torch.jit.is_scripting(): + if i in layers_to_drop: + continue + output = mod( + output, + pos_emb, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + ) + + output = output * feature_mask + + return output + + +class DownsampledZipformerEncoder(nn.Module): + r""" + DownsampledZipformerEncoder is a zipformer encoder evaluated at a reduced frame rate, + after convolutional downsampling, and then upsampled again at the output, and combined + with the origin input, so that the output has the same shape as the input. + """ + def __init__(self, + encoder: nn.Module, + input_dim: int, + output_dim: int, + downsample: int): + super(DownsampledZipformerEncoder, self).__init__() + self.downsample_factor = downsample + self.downsample = AttentionDownsample(input_dim, output_dim, downsample) + self.encoder = encoder + self.upsample = SimpleUpsample(output_dim, downsample) + self.out_combiner = SimpleCombiner(input_dim, + output_dim, + min_weight=(0.0, 0.25)) + + + def forward(self, + src: Tensor, + # Note: the type of feature_mask should be Unino[float, Tensor], + # but to make torch.jit.script() happ, we use float here + feature_mask: float = 1.0, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Downsample, go through encoder, upsample. + + Args: + src: the sequence to the encoder (required). + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer. feature_mask is expected to be already downsampled by + self.downsample_factor. + mask: the mask for the src sequence (optional). CAUTION: we need to downsample + this, if we are to support it. Won't work correctly yet. + src_key_padding_mask: the mask for the src keys per batch (optional). Should + be downsampled already. + + Shape: + src: (S, N, E). + mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + Returns: output of shape (S, N, F) where F is the number of output features + (output_dim to constructor) + """ + src_orig = src + src = self.downsample(src) + ds = self.downsample_factor + if mask is not None: + mask = mask[::ds,::ds] + + src = self.encoder( + src, feature_mask=feature_mask, mask=mask, src_key_padding_mask=mask, + ) + src = self.upsample(src) + # remove any extra frames that are not a multiple of downsample_factor + src = src[:src_orig.shape[0]] + + return self.out_combiner(src_orig, src) + +class AttentionDownsample(torch.nn.Module): + """ + Does downsampling with attention, by weighted sum, and a projection.. + """ + def __init__(self, + in_channels: int, + out_channels: int, + downsample: int): + """ + Require out_channels > in_channels. + """ + super(AttentionDownsample, self).__init__() + self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5)) + + # fill in the extra dimensions with a projection of the input + if out_channels > in_channels: + self.extra_proj = nn.Linear(in_channels * downsample, + out_channels - in_channels, + bias=False) + else: + self.extra_proj = None + self.downsample = downsample + + def forward(self, + src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, in_channels) + Returns a tensor of shape + ( (seq_len+downsample-1)//downsample, batch_size, out_channels) + """ + (seq_len, batch_size, in_channels) = src.shape + ds = self.downsample + d_seq_len = (seq_len + ds - 1) // ds + + # Pad to an exact multiple of self.downsample + if seq_len != d_seq_len * ds: + # right-pad src, repeating the last element. + pad = d_seq_len * ds - seq_len + src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2]) + src = torch.cat((src, src_extra), dim=0) + assert src.shape[0] == d_seq_len * ds, (src.shape[0], d_seq_len, ds) + + src = src.reshape(d_seq_len, ds, batch_size, in_channels) + scores = (src * self.query).sum(dim=-1, keepdim=True) + + scores = penalize_abs_values_gt(scores, + limit=10.0, + penalty=1.0e-04) + + weights = scores.softmax(dim=1) + + # ans1 is the first `in_channels` channels of the output + ans = (src * weights).sum(dim=1) + src = src.permute(0, 2, 1, 3).reshape(d_seq_len, batch_size, ds * in_channels) + + if self.extra_proj is not None: + ans2 = self.extra_proj(src) + ans = torch.cat((ans, ans2), dim=2) + return ans + + +class SimpleUpsample(torch.nn.Module): + """ + A very simple form of upsampling that mostly just repeats the input, but + also adds a position-specific bias. + """ + def __init__(self, + num_channels: int, + upsample: int): + super(SimpleUpsample, self).__init__() + self.bias = nn.Parameter(torch.randn(upsample, num_channels) * 0.01) + + def forward(self, + src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, num_channels) + Returns a tensor of shape + ( (seq_len*upsample), batch_size, num_channels) + """ + upsample = self.bias.shape[0] + (seq_len, batch_size, num_channels) = src.shape + src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) + src = src + self.bias.unsqueeze(1) + src = src.reshape(seq_len * upsample, batch_size, num_channels) + return src + +class SimpleCombinerIdentity(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + + def forward(self, src1: Tensor, src2: Tensor) -> Tensor: + return src1 + +class SimpleCombiner(torch.nn.Module): + """ + A very simple way of combining 2 vectors of 2 different dims, via a + learned weighted combination in the shared part of the dim. + Args: + dim1: the dimension of the first input, e.g. 256 + dim2: the dimension of the second input, e.g. 384. + The output will have the same dimension as dim2. + """ + def __init__(self, + dim1: int, + dim2: int, + min_weight: Tuple[float] = (0., 0.)): + super(SimpleCombiner, self).__init__() + assert dim2 >= dim1, (dim2, dim1) + self.weight1 = nn.Parameter(torch.zeros(())) + self.min_weight = min_weight + + def forward(self, + src1: Tensor, + src2: Tensor) -> Tensor: + """ + src1: (*, dim1) + src2: (*, dim2) + + Returns: a tensor of shape (*, dim2) + """ + assert src1.shape[:-1] == src2.shape[:-1], (src1.shape, src2.shape) + + weight1 = self.weight1 + if not torch.jit.is_scripting(): + if self.training and random.random() < 0.25 and self.min_weight != (0., 0.): + weight1 = weight1.clamp(min=self.min_weight[0], + max=1.0-self.min_weight[1]) + + + src1 = src1 * weight1 + src2 = src2 * (1.0 - weight1) + + src1_dim = src1.shape[-1] + src2_dim = src2.shape[-1] + if src1_dim != src2_dim: + if src1_dim < src2_dim: + src1 = torch.nn.functional.pad(src1, (0, src2_dim - src1_dim)) + else: + src1 = src1[:src2_dim] + + + return src1 + src2 + + + + +class RelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module. + + See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py + + Args: + d_model: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length. + + """ + + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: + """Construct a PositionalEncoding object.""" + super(RelPositionalEncoding, self).__init__() + self.d_model = d_model + self.dropout = torch.nn.Dropout(dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x: Tensor) -> None: + """Reset the positional encodings.""" + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x.size(0) * 2 - 1: + # Note: TorchScript doesn't implement operator== for torch.Device + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vecotr and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i Tensor: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (time, batch, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). + + """ + self.extend_pe(x) + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 + - x.size(0) + + 1 : self.pe.size(1) // 2 # noqa E203 + + x.size(0), + ] + return self.dropout(pos_emb) + + + +class RelPositionMultiheadAttention(nn.Module): + r"""Multi-Head Attention layer with relative position encoding + + This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", + we have to write up the differences. + + + Args: + embed_dim: total dimension of the model. + attention_dim: dimension in the attention module, may be less or more than embed_dim + but must be a multiple of num_heads. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + + Examples:: + + >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) + """ + + def __init__( + self, + embed_dim: int, + attention_dim: int, + num_heads: int, + pos_dim: int, + dropout: float = 0.0, + ) -> None: + super(RelPositionMultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.attention_dim = attention_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = attention_dim // num_heads + self.pos_dim = pos_dim + assert self.head_dim % 2 == 0, self.head_dim + assert ( + self.head_dim * num_heads == attention_dim + ), (self.head_dim, num_heads, attention_dim) + + # the initial_scale is supposed to take over the "scaling" factor of + # head_dim ** -0.5, dividing it between the query and key. + in_proj_dim = (2 * attention_dim + # query, key + attention_dim // 2 + # value + pos_dim * num_heads) # positional encoding query + + self.in_proj = ScaledLinear(embed_dim, in_proj_dim, bias=True, + initial_scale=self.head_dim**-0.25) + + # self.whiten_values is applied on the values in forward(); + # it just copies the keys but prevents low-rank distribution by modifying grads. + self.whiten_values = Whiten(num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025) + self.whiten_keys = Whiten(num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025) + + + # linear transformation for positional encoding. + self.linear_pos = ScaledLinear(embed_dim, num_heads * pos_dim, bias=False, + initial_scale=0.05) + + # the following are for diagnosics only, see --print-diagnostics option. + # they only copy their inputs. + self.copy_pos_query = Identity() + self.copy_query = Identity() + + self.out_proj = ScaledLinear( + attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 + ) + + self.in_proj2 = nn.Linear(embed_dim, attention_dim // 2, bias=False) + self.out_proj2 = ScaledLinear(attention_dim // 2, embed_dim, bias=True, + initial_scale=0.05) + # self.whiten_values2 is applied on the values in forward2() + self.whiten_values2 = Whiten(num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025) + + + def forward( + self, + x: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + r""" + Args: + x: input to be projected to query, key, value + pos_emb: Positional embedding tensor + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask and a value is True, + the corresponding value on the attention layer will be ignored. When given + a byte mask and a value is non-zero, the corresponding value on the attention + layer will be ignored + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + - Inputs: + - x: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the position + with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + - Returns: (attn_output, attn_weights) + + - attn_output: :math:`(S, N, E)` where S is the sequence length, N is the batch size, + E is the embedding dimension. + - attn_weights: :math:`(N * N, S, S)` where N is the batch size, H is the num-heads + and S is the sequence length. + """ + x, weights = self.multi_head_attention_forward( + self.in_proj(x), + self.linear_pos(pos_emb), + self.attention_dim, + self.num_heads, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + attn_mask=attn_mask, + ) + return x, weights + + + def multi_head_attention_forward( + self, + x_proj: Tensor, + pos: Tensor, + attention_dim: int, + num_heads: int, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + r""" + Args: + x_proj: the projected input, to be split into query, key, value. + pos: head-specific biases arising from the positional embeddings. + attention_dim: dimension inside attention mechanism + num_heads: parallel attention heads. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + Inputs: + - x: :math:`(L, N, 7 * A // 2)` where L is the target sequence length, N is the batch size, A is + the attention dimension. Will be split into (query, key, value, pos). + - pos: :math:`(N, 2*L-1, A//2)` or :math:`(1, 2*L-1, A//2)` where L is the sequence + length, N is the batch size, and A is the attention dim. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_weights: :math:`(N * H, S, S)` where N is the batch size, + H is the num-heads, S is the sequence length. + """ + + seq_len, bsz, _ = x_proj.size() + + head_dim = attention_dim // num_heads + pos_dim = self.pos_dim # positional-encoding dim per head + assert ( + head_dim * num_heads == attention_dim + ), f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}, {attention_dim}" + + + # self-attention + q = x_proj[...,0:attention_dim] + k = x_proj[...,attention_dim:2*attention_dim] + value_dim = attention_dim // 2 + v = x_proj[...,2*attention_dim:2*attention_dim+value_dim] + # p is the position-encoding query, its dimension is num_heads*pos_dim.. + p = x_proj[...,2*attention_dim+value_dim:] + + + k = self.whiten_keys(k) # does nothing in the forward pass. + v = self.whiten_values(v) # does nothing in the forward pass. + q = self.copy_query(q) # for diagnostics only, does nothing. + p = self.copy_pos_query(p) # for diagnostics only, does nothing. + + + if attn_mask is not None: + assert ( + attn_mask.dtype == torch.float32 + or attn_mask.dtype == torch.float64 + or attn_mask.dtype == torch.float16 + or attn_mask.dtype == torch.uint8 + or attn_mask.dtype == torch.bool + ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( + attn_mask.dtype + ) + if attn_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for attn_mask is deprecated. Use bool tensor instead." + ) + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, seq_len, seq_len]: + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [ + bsz * num_heads, + seq_len, + seq_len, + ]: + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) + else: + raise RuntimeError( + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) + ) + # attn_mask's dim is 3 now. + + # convert ByteTensor key_padding_mask to bool + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): + warnings.warn( + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + ) + key_padding_mask = key_padding_mask.to(torch.bool) + + q = q.reshape(seq_len, bsz, num_heads, head_dim) + p = p.reshape(seq_len, bsz, num_heads, pos_dim) + k = k.reshape(seq_len, bsz, num_heads, head_dim) + v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) + + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz, "{} == {}".format( + key_padding_mask.size(0), bsz + ) + assert key_padding_mask.size(1) == seq_len, "{} == {}".format( + key_padding_mask.size(1), seq_len + ) + + + + q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim) + p = p.permute(1, 2, 0, 3) # (batch, head, time1, pos_dim) + k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) + + + seq_len2 = 2 * seq_len - 1 + pos = pos.reshape(1, seq_len2, num_heads, pos_dim).permute(0, 2, 3, 1) + # pos shape now: (batch, head, pos_dim, seq_len2) + + # (batch, head, time1, pos_dim) x (1, head, pos_dim, seq_len2) -> (batch, head, time1, seq_len2) + # [where seq_len2 represents relative position.] + pos_weights = torch.matmul(p, pos) + # the following .as_strided() expression converts the last axis of pos_weights from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + pos_weights = pos_weights.as_strided((bsz, num_heads, seq_len, seq_len), + (pos_weights.stride(0), + pos_weights.stride(1), + pos_weights.stride(2)-pos_weights.stride(3), + pos_weights.stride(3)), + storage_offset=pos_weights.stride(3) * (seq_len - 1)) + + + # caution: they are really scores at this point. + attn_output_weights = torch.matmul(q, k) + pos_weights + + if not torch.jit.is_scripting(): + if training and random.random() < 0.1: + # This is a harder way of limiting the attention scores to not be too large. + # It incurs a penalty if any of them has an absolute value greater than 50.0. + # this should be outside the normal range of the attention scores. We use + # this mechanism instead of, say, a limit on entropy, because once the entropy + # gets very small gradients through the softmax can become very small, and + # some mechanisms like that become ineffective. + attn_output_weights = penalize_abs_values_gt(attn_output_weights, + limit=25.0, + penalty=1.0e-04) + + + # attn_output_weights: (batch, head, time1, time2) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, seq_len, seq_len + ) + + assert list(attn_output_weights.size()) == [ + bsz * num_heads, + seq_len, + seq_len, + ] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float("-inf")) + else: + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view( + bsz, num_heads, seq_len, seq_len + ) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, seq_len, seq_len + ) + + # Using this version of softmax, defined in scaling.py, + # should save a little of the memory used in backprop by, if + # we are in automatic mixed precision mode (amp) == autocast, + # only storing the half-precision output for backprop purposes. + attn_output_weights = softmax(attn_output_weights, dim=-1) + + attn_output_weights = nn.functional.dropout( + attn_output_weights, p=dropout_p, training=training + ) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * num_heads, seq_len, + head_dim // 2] + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(seq_len, bsz, attention_dim // 2) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias + ) + + return attn_output, attn_output_weights + + + def forward2( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """ + Second forward function, where we re-use the attn_weights returned by the first forward function + but with different input. + Args: + x: input, of shape (seq_len, batch_size, embed_dim) + attn_weights: attention weights returned by forward(), of shape (batch_size * num_heads, seq_len, seq_len) + Returns: + output of the same shape as x, i.e. (seq_len, batch_size, embed_dim) + """ + num_heads = self.num_heads + (seq_len, bsz, embed_dim) = x.shape + head_dim = self.attention_dim // num_heads + # v: (tgt_len, bsz, embed_dim // 2) + v = self.in_proj2(x) + v = self.whiten_values2(v) # does nothing in the forward pass. + v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) + + # now v: (bsz * num_heads, seq_len, head_dim // 2) + attn_output = torch.bmm(attn_weights, v) + + if not torch.jit.is_scripting(): + if random.random() < 0.001 or __name__ == "__main__": + self._print_attn_stats(attn_weights, attn_output) + + # attn_output: (bsz * num_heads, seq_len, head_dim) + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(seq_len, bsz, self.attention_dim // 2) + ) + # returned value is of shape (seq_len, bsz, embed_dim), like x. + return self.out_proj2(attn_output) + + + def _print_attn_stats( + self, + attn_weights: Tensor, + attn_output: Tensor): + # attn_weights: (batch_size * num_heads, seq_len, seq_len) + # attn_output: (bsz * num_heads, seq_len, head_dim) + (n, seq_len, head_dim) = attn_output.shape + num_heads = self.num_heads + bsz = n // num_heads + + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + attn_weights = attn_weights.to(torch.float32) + attn_output = attn_output.to(torch.float32) + attn_weights_entropy = -((attn_weights + 1.0e-20).log() * attn_weights).sum( + dim=-1).reshape(bsz, num_heads, seq_len).mean(dim=(0,2)) + attn_output = attn_output.reshape(bsz, num_heads, seq_len, head_dim) + attn_output = attn_output.permute(1, 0, 2, 3).reshape(num_heads, bsz * seq_len, head_dim) + attn_output_mean = attn_output.mean(dim=1, keepdim=True) + attn_output = attn_output - attn_output_mean + attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / (bsz * seq_len) + # attn_covar: (num_heads, head_dim, head_dim) + #eigs, _ = torch.symeig(attn_covar) + #logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") + + attn_covar = _diag(attn_covar).mean(dim=1) # (num_heads,) + embed_dim = self.in_proj2.weight.shape[1] + in_proj_covar = (self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2).mean(dim=(1,2)) + out_proj_covar = (self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2).mean(dim=(0,2)) + logging.info(f"attn_weights_entropy = {attn_weights_entropy}, covar={attn_covar}, in_proj_covar={in_proj_covar}, out_proj_covar={out_proj_covar}") + + + + +class PoolingModule(nn.Module): + """ + Averages the input over the time dimension and project with a square matrix. + """ + def __init__(self, + d_model: int): + super().__init__() + self.proj = ScaledLinear(d_model, d_model, + initial_scale=0.1, bias=False) + + def forward(self, + x: Tensor, + key_padding_mask: Optional[Tensor] = None): + """ + Args: + x: a Tensor of shape (T, N, C) + key_padding_mask: a Tensor of bool, of shape (N, T), with True in masked + positions. + Returns: + a Tensor of shape (1, N, C) + """ + if key_padding_mask is not None: + pooling_mask = key_padding_mask.logical_not().to(x.dtype) # (N, T) + pooling_mask = (pooling_mask / pooling_mask.sum(dim=1, keepdim=True)) + pooling_mask = pooling_mask.transpose(0, 1).contiguous().unsqueeze(-1) + # now pooling_mask: (T, N, 1) + x = (x * pooling_mask).sum(dim=0, keepdim=True) + else: + num_frames = x.shape[0] + pooling_mask = 1.0 / num_frames + x = (x * pooling_mask).sum(dim=0, keepdim=True) + + x = self.proj(x) + return x + + +class FeedforwardModule(nn.Module): + """Feedforward module in Zipformer model. + """ + def __init__(self, + d_model: int, + feedforward_dim: int, + dropout: float): + super(FeedforwardModule, self).__init__() + self.in_proj = nn.Linear(d_model, feedforward_dim) + self.balancer = ActivationBalancer(feedforward_dim, + channel_dim=-1, max_abs=10.0, + min_prob=0.25) + self.activation = DoubleSwish() + self.dropout = nn.Dropout(dropout) + self.out_proj = ScaledLinear(feedforward_dim, d_model, + initial_scale=0.01) + + def forward(self, + x: Tensor): + x = self.in_proj(x) + x = self.balancer(x) + x = self.activation(x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Zipformer model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__( + self, channels: int, kernel_size: int, bias: bool = True + ) -> None: + """Construct an ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0, kernel_size + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + + # after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu). + # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, + # but sometimes, for some reason, for layer 0 the rms ends up being very large, + # between 50 and 100 for different channels. This will cause very peaky and + # sparse derivatives for the sigmoid gating function, which will tend to make + # the loss function not learn effectively. (for most layers the average absolute values + # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, + # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different + # layers, which likely breaks down as 0.5 for the "linear" half and + # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we + # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, + # it will be in a better position to start learning something, i.e. to latch onto + # the correct range. + self.deriv_balancer1 = ActivationBalancer( + 2 * channels, + channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0 + ) + + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + groups=channels, + bias=bias, + ) + + self.deriv_balancer2 = ActivationBalancer( + channels, channel_dim=1, + min_positive=0.05, max_positive=1.0, + max_abs=20.0, + ) + + self.activation = DoubleSwish() + + self.pointwise_conv2 = ScaledConv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + initial_scale=0.05, + ) + + def forward(self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains bool in masked positions. + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + + x = self.deriv_balancer1(x) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + if src_key_padding_mask is not None: + x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + + x = self.deriv_balancer2(x) + x = self.activation(x) + + x = self.pointwise_conv2(x) # (batch, channel, time) + + return x.permute(2, 0, 1) + + +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = (T-3)//2 - 2 == (T-7)//2 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + layer1_channels: int = 8, + layer2_channels: int = 32, + layer3_channels: int = 128, + dropout: float = 0.1, + ) -> None: + """ + Args: + in_channels: + Number of channels in. The input shape is (N, T, in_channels). + Caution: It requires: T >=7, in_channels >=7 + out_channels + Output dim. The output shape is (N, (T-7)//2, out_channels) + layer1_channels: + Number of channels in layer1 + layer2_channels: + Number of channels in layer2 + layer3_channels: + Number of channels in layer3 + """ + assert in_channels >= 7, in_channels + super().__init__() + + self.conv = nn.Sequential( + nn.Conv2d( + in_channels=1, + out_channels=layer1_channels, + kernel_size=3, + padding=(0, 1), # (time, freq) + ), + ActivationBalancer(layer1_channels, + channel_dim=1), + DoubleSwish(), + nn.Conv2d( + in_channels=layer1_channels, + out_channels=layer2_channels, + kernel_size=3, + stride=2, + padding=0, + ), + ActivationBalancer(layer2_channels, + channel_dim=1), + DoubleSwish(), + nn.Conv2d( + in_channels=layer2_channels, + out_channels=layer3_channels, + kernel_size=3, + stride=(1, 2), # (time, freq) + ), + ActivationBalancer(layer3_channels, + channel_dim=1), + DoubleSwish(), + ) + out_height = (((in_channels - 1) // 2) - 1) // 2 + self.out = ScaledLinear(out_height * layer3_channels, out_channels) + self.dropout = nn.Dropout(dropout) + + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + + Returns: + Return a tensor of shape (N, (T-7)//2, odim) + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + x = self.conv(x) + # Now x is of shape (N, odim, (T-7)//2, ((idim-1)//2 - 1)//2) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).reshape(b, t, c * f)) + # Now x is of shape (N, (T-7)//2, odim) + x = self.dropout(x) + return x + +class AttentionCombine(nn.Module): + """ + This module combines a list of Tensors, all with the same shape, to + produce a single output of that same shape which, in training time, + is a random combination of all the inputs; but which in test time + will be just the last input. + + All but the last input will have a linear transform before we + randomly combine them; these linear transforms will be initialized + to the identity transform. + + The idea is that the list of Tensors will be a list of outputs of multiple + zipformer layers. This has a similar effect as iterated loss. (See: + DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER + NETWORKS). + """ + + def __init__( + self, + num_channels: int, + num_inputs: int, + random_prob: float = 0.25, + single_prob: float = 0.333, + ) -> None: + """ + Args: + num_channels: + the number of channels + num_inputs: + The number of tensor inputs, which equals the number of layers' + outputs that are fed into this module. E.g. in an 18-layer neural + net if we output layers 16, 12, 18, num_inputs would be 3. + random_prob: + the probability with which we apply a nontrivial mask, in training + mode. + single_prob: + the probability with which we mask to allow just a single + module's output (in training) + """ + super().__init__() + + self.random_prob = random_prob + self.single_prob = single_prob + self.weight = torch.nn.Parameter(torch.zeros(num_channels, + num_inputs)) + self.bias = torch.nn.Parameter(torch.zeros(num_inputs)) + + assert 0 <= random_prob <= 1, random_prob + assert 0 <= single_prob <= 1, single_prob + + + + def forward(self, inputs: List[Tensor]) -> Tensor: + """Forward function. + Args: + inputs: + A list of Tensor, e.g. from various layers of a transformer. + All must be the same shape, of (*, num_channels) + Returns: + A Tensor of shape (*, num_channels). In test mode + this is just the final input. + """ + num_inputs = self.weight.shape[1] + assert len(inputs) == num_inputs + + # Shape of weights: (*, num_inputs) + num_channels = inputs[0].shape[-1] + num_frames = inputs[0].numel() // num_channels + + ndim = inputs[0].ndim + # stacked_inputs: (num_frames, num_channels, num_inputs) + stacked_inputs = torch.stack(inputs, dim=ndim).reshape( + (num_frames, num_channels, num_inputs) + ) + + scores = (stacked_inputs * self.weight).sum(dim=(1,)) + self.bias + + if random.random() < 0.002: + logging.info(f"Average scores are {scores.softmax(dim=1).mean(dim=0)}") + + if self.training: + # random masking.. + mask_start = torch.randint(low=1, high=int(num_inputs / self.random_prob), + size=(num_frames,), device=scores.device).unsqueeze(1) + # mask will have rows like: [ False, False, False, True, True, .. ] + arange = torch.arange(num_inputs, device=scores.device).unsqueeze(0).expand( + num_frames, num_inputs) + mask = arange >= mask_start + + apply_single_prob = torch.logical_and(torch.rand(size=(num_frames, 1), + device=scores.device) < self.single_prob, + mask_start < num_inputs) + single_prob_mask = torch.logical_and(apply_single_prob, + arange < mask_start - 1) + + mask = torch.logical_or(mask, + single_prob_mask) + + scores = scores.masked_fill(mask, float('-inf')) + + if self.training and random.random() < 0.1: + scores = penalize_abs_values_gt(scores, + limit=10.0, + penalty=1.0e-04) + + weights = scores.softmax(dim=1) + + # (num_frames, num_channels, num_inputs) * (num_frames, num_inputs, 1) -> (num_frames, num_channels, 1), + ans = torch.matmul(stacked_inputs, weights.unsqueeze(2)) + # ans: (*, num_channels) + ans = ans.reshape(*tuple(inputs[0].shape[:-1]), num_channels) + + if __name__ == "__main__": + # for testing only... + print("Weights = ", weights.reshape(num_frames, num_inputs)) + return ans + + + +def _test_random_combine(): + print("_test_random_combine()") + num_inputs = 3 + num_channels = 50 + m = AttentionCombine( + num_channels=num_channels, + num_inputs=num_inputs, + random_prob=0.5, + single_prob=0.0) + + + x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)] + + y = m(x) + assert y.shape == x[0].shape + assert torch.allclose(y, x[0]) # .. since actually all ones. + + +def _test_zipformer_main(): + feature_dim = 50 + batch_size = 5 + seq_len = 20 + feature_dim = 50 + # Just make sure the forward pass runs. + + c = Zipformer( + num_features=feature_dim, encoder_dims=(64,96), encoder_unmasked_dims=(48,64), nhead=(4,4) + ) + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + f = c( + torch.randn(batch_size, seq_len, feature_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + f[0].sum().backward() + c.eval() + f = c( + torch.randn(batch_size, seq_len, feature_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + f # to remove flake8 warnings + +def _test_conv2d_subsampling(): + num_features = 80 + encoder_dims = 384 + dropout = 0.1 + encoder_embed = Conv2dSubsampling(num_features, encoder_dims, + dropout=dropout) + for i in range(20, 40): + x = torch.rand(2, i, num_features) + y = encoder_embed(x) + assert (x.shape[1] - 7) // 2 == y.shape[1], (x.shape[1], y.shape[1]) + + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_random_combine() + _test_zipformer_main() + _test_conv2d_subsampling() diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index 170586455..5069b78e8 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -86,7 +86,7 @@ def save_checkpoint( } if model_avg is not None: - checkpoint["model_avg"] = model_avg.state_dict() + checkpoint["model_avg"] = model_avg.to(torch.float32).state_dict() if params: for k, v in params.items(): @@ -466,8 +466,10 @@ def average_state_dict( uniqued_names = list(uniqued.values()) for k in uniqued_names: - state_dict_1[k] *= weight_1 - state_dict_1[k] += ( - state_dict_2[k].to(device=state_dict_1[k].device) * weight_2 - ) - state_dict_1[k] *= scaling_factor + v = state_dict_1[k] + if torch.is_floating_point(v): + v *= weight_1 + v += ( + state_dict_2[k].to(device=state_dict_1[k].device) * weight_2 + ) + v *= scaling_factor diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index 609e25626..b075aceac 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -19,7 +19,7 @@ import random from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional, Tuple, List import torch from torch import Tensor, nn @@ -82,11 +82,18 @@ def get_tensor_stats( elif stats_type == "positive": x = (x > 0).to(dtype=torch.float) else: - assert stats_type == "value" + assert stats_type in [ "value", "max", "min" ] sum_dims = [d for d in range(x.ndim) if d != dim] if len(sum_dims) > 0: - x = torch.sum(x, dim=sum_dims) + if stats_type == "max": + for dim in reversed(sum_dims): + x = torch.max(x, dim=dim)[0] + elif stats_type == "min": + for dim in reversed(sum_dims): + x = torch.min(x, dim=dim)[0] + else: + x = torch.sum(x, dim=sum_dims) x = x.flatten() return x, count @@ -105,17 +112,19 @@ class TensorDiagnostic(object): opts: Options object. name: - The tensor name. + The name associated with this diagnostics object, will probably be {module_name}.X + where X is "output" or "grad", or {parameter_name}.Y where Y is param_value or param_grad. """ def __init__(self, opts: TensorDiagnosticOptions, name: str): - self.name = name self.opts = opts + self.name = name + self.class_name = None # will assign in accumulate() self.stats = None # we'll later assign a list to this data member. It's a list of dict. # the keys into self.stats[dim] are strings, whose values can be - # "abs", "value", "positive", "rms", "value". + # "abs", "max", "min" ,"value", "positive", "rms", "value". # The values e.g. self.stats[dim]["rms"] are lists of dataclass TensorAndCount, # containing a tensor and its associated count (which is the sum of the other dims # that we aggregated over, e.g. the number of frames and/or batch elements and/or @@ -124,8 +133,13 @@ class TensorDiagnostic(object): # only adding a new element to the list if there was a different dim. # if the string in the key is "eigs", if we detect a length mismatch we put None as the value. - def accumulate(self, x): - """Accumulate tensors.""" + + def accumulate(self, x, class_name: Optional[str] = None): + """ + Accumulate tensors. + """ + if class_name is not None: + self.class_name = class_name if isinstance(x, Tuple): x = x[0] if not isinstance(x, Tensor): @@ -142,11 +156,11 @@ class TensorDiagnostic(object): for dim in range(ndim): this_dim_stats = self.stats[dim] if ndim > 1: - stats_types = ["abs", "positive", "value", "rms"] + stats_types = ["abs", "max", "min", "positive", "value", "rms"] if x.shape[dim] <= self.opts.max_eig_dim: stats_types.append("eigs") else: - stats_types = ["value", "abs"] + stats_types = ["value", "abs", "max", "min"] for stats_type in stats_types: stats, count = get_tensor_stats(x, dim, stats_type) @@ -161,7 +175,12 @@ class TensorDiagnostic(object): continue for s in this_dim_stats[stats_type]: if s.tensor.shape == stats.shape: - s.tensor += stats + if stats_type == "max": + s.tensor = torch.maximum(s.tensor, stats) + elif stats_type == "min": + s.tensor = torch.minimum(s.tensor, stats) + else: + s.tensor += stats s.count += count done = True break @@ -186,14 +205,26 @@ class TensorDiagnostic(object): for dim, this_dim_stats in enumerate(self.stats): for stats_type, stats_list in this_dim_stats.items(): # stats_type could be "rms", "value", "abs", "eigs", "positive". - # "value" could be a list of TensorAndCount, or None + # "stats_list" could be a list of TensorAndCount (one list per distinct tensor + # shape of the stats), or None if stats_list is None: assert stats_type == "eigs" continue + + def get_count(count): + return 1 if stats_type in ["max", "min"] else count + + if len(stats_list) == 1: + stats = stats_list[0].tensor / get_count(stats_list[0].count) + else: + # a dimension that has variable size in different nnet + # forwards, e.g. a time dimension in an ASR model. + stats = torch.cat( + [x.tensor / get_count(x.count) for x in stats_list], dim=0 + ) + if stats_type == "eigs": - assert len(stats_list) == 1 - stats = stats_list[0].tensor / stats_list[0].count try: eigs, _ = torch.symeig(stats) stats = eigs.abs().sqrt() @@ -201,15 +232,9 @@ class TensorDiagnostic(object): print( "Error getting eigenvalues, trying another method." ) - eigs = torch.linalg.eigvals(stats) + eigs, _ = torch.eig(stats) stats = eigs.abs().sqrt() # sqrt so it reflects data magnitude, like stddev- not variance - elif len(stats_list) == 1: - stats = stats_list[0].tensor / stats_list[0].count - else: - stats = torch.cat( - [x.tensor / x.count for x in stats_list], dim=0 - ) if stats_type == "rms": # we stored the square; after aggregation we need to take sqrt. @@ -236,7 +261,7 @@ class TensorDiagnostic(object): ans = stats.tolist() ans = ["%.2g" % x for x in ans] ans = "[" + " ".join(ans) + "]" - if stats_type == "value": + if stats_type in [ "value", "rms", "eigs" ]: # This norm is useful because it is strictly less than the largest # sqrt(eigenvalue) of the variance, which we print out, and shows, # speaking in an approximate way, how much of that largest eigenvalue @@ -245,7 +270,7 @@ class TensorDiagnostic(object): ans += f", norm={norm:.2g}" mean = stats.mean().item() rms = (stats ** 2).mean().sqrt().item() - ans += f", mean={mean:.2g}, rms={rms:.2g}" + ans += f", mean={mean:.3g}, rms={rms:.3g}" # OK, "ans" contains the actual stats, e.g. # ans = "percentiles: [0.43 0.46 0.48 0.49 0.49 0.5 0.51 0.52 0.53 0.54 0.59], mean=0.5, rms=0.5" @@ -256,11 +281,13 @@ class TensorDiagnostic(object): if len(sizes) == 1 else f"{min(sizes)}..{max(sizes)}" ) + maybe_class_name = f" type={self.class_name}," if self.class_name is not None else "" print( - f"module={self.name}, dim={dim}, size={size_str}, {stats_type} {ans}" + f"module={self.name},{maybe_class_name} dim={dim}, size={size_str}, {stats_type} {ans}" ) + class ModelDiagnostic(object): """This class stores diagnostics for all tensors in the torch.nn.Module. @@ -321,20 +348,29 @@ def attach_diagnostics( def forward_hook( _module, _input, _output, _model_diagnostic=ans, _name=name ): + if isinstance(_output, tuple) and len(_output) == 1: + _output = _output[0] + if isinstance(_output, Tensor): - _model_diagnostic[f"{_name}.output"].accumulate(_output) + _model_diagnostic[f"{_name}.output"].accumulate(_output, + class_name=type(_module).__name__) elif isinstance(_output, tuple): for i, o in enumerate(_output): - _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o) + _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o, + class_name=type(_module).__name__) def backward_hook( _module, _input, _output, _model_diagnostic=ans, _name=name ): + if isinstance(_output, tuple) and len(_output) == 1: + _output = _output[0] if isinstance(_output, Tensor): - _model_diagnostic[f"{_name}.grad"].accumulate(_output) + _model_diagnostic[f"{_name}.grad"].accumulate(_output, + class_name=type(_module).__name__) elif isinstance(_output, tuple): for i, o in enumerate(_output): - _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o) + _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o, + class_name=type(_module).__name__) module.register_forward_hook(forward_hook) module.register_backward_hook(backward_hook) diff --git a/icefall/hooks.py b/icefall/hooks.py new file mode 100644 index 000000000..fbcf5e148 --- /dev/null +++ b/icefall/hooks.py @@ -0,0 +1,102 @@ +# Copyright 2021-2022 Xiaomi Corporation (authors: Zengwei Yao, Daniel Povey) +# +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import torch +from torch import Tensor, nn +import logging + + +def register_inf_check_hooks(model: nn.Module) -> None: + """Registering forward hook on each module, to check + whether its output tensors is not finite. + + Args: + model: + the model to be analyzed. + """ + + for name, module in model.named_modules(): + if name == "": + name = "" + + # default param _name is a way to capture the current value of the variable "name". + def forward_hook(_module, _input, _output, _name=name): + if isinstance(_output, Tensor): + if not torch.isfinite(_output.to(torch.float32).sum()): + raise ValueError( + f"The sum of {_name}.output is not finite: {_output}" + ) + elif isinstance(_output, tuple): + for i, o in enumerate(_output): + if isinstance(o, tuple): + o = o[0] + if not isinstance(o, Tensor): + continue + if not torch.isfinite(o.to(torch.float32).sum()): + raise ValueError( + f"The sum of {_name}.output[{i}] is not finite: {_output}" + ) + + # default param _name is a way to capture the current value of the variable "name". + def backward_hook(_module, _input, _output, _name=name): + if isinstance(_output, Tensor): + if not torch.isfinite(_output.to(torch.float32).sum()): + logging.warning( + f"The sum of {_name}.grad is not finite" # ": {_output}" + ) + elif isinstance(_output, tuple): + for i, o in enumerate(_output): + if isinstance(o, tuple): + o = o[0] + if not isinstance(o, Tensor): + continue + if not torch.isfinite(o.to(torch.float32).sum()): + logging.warning( + f"The sum of {_name}.grad[{i}] is not finite" + ) + + module.register_forward_hook(forward_hook) + module.register_backward_hook(backward_hook) + + + for name, parameter in model.named_parameters(): + + def param_backward_hook( + grad, _name=name + ): + if not torch.isfinite(grad.to(torch.float32).sum()): + logging.warning( + f"The sum of {_name}.param_grad is not finite" + ) + + parameter.register_hook(param_backward_hook) + + + +def _test_inf_check_hooks(): + model = nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 80)) + + register_inf_check_hooks(model) + for _ in range(10): + T = random.randint(200, 300) + x = torch.randn(T, 100) + float("inf") * (T % 2) + y = model(x) + y.sum().backward() + + +if __name__ == "__main__": + _test_inf_check_hooks() From cedf9aa24f01288f24079fb919743675a25d0832 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 13 Nov 2022 11:51:00 +0800 Subject: [PATCH 027/120] Fix shallow fusion and add CI tests for it (#676) * Fix shallow fusion and add CI tests for it * Fix -1 index in embedding introduced in the zipformer PR --- ...-lstm-transducer-stateless2-2022-09-03.yml | 30 +++++++++++++++++++ ...-lstm-transducer-stateless2-2022-09-03.yml | 15 ++++++++-- .../beam_search.py | 2 +- .../pruned_transducer_stateless2/decoder.py | 10 ++++++- 4 files changed, 53 insertions(+), 4 deletions(-) diff --git a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml index b89055c72..6ce92d022 100755 --- a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml +++ b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml @@ -174,6 +174,36 @@ done echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" + +if [[ x"${GITHUB_EVENT_LABEL_NAME}" == x"shallow-fusion" ]]; then + lm_repo_url=https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm + log "Download pre-trained RNN-LM model from ${lm_repo_url}" + git clone $lm_repo_url + lm_repo=$(basename $lm_repo_url) + pushd $lm_repo + git lfs pull --include "exp/pretrained.pt" + cd exp + ln -s pretrained.pt epoch-88.pt + popd + + ./lstm_transducer_stateless2/decode.py \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --lang-dir $repo/data/lang_bpe_500 \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --max-duration 600 \ + --decoding-method modified_beam_search_rnnlm_shallow_fusion \ + --beam 4 \ + --rnn-lm-scale 0.3 \ + --rnn-lm-exp-dir $lm_repo/exp \ + --rnn-lm-epoch 88 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 +fi + if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then mkdir -p lstm_transducer_stateless2/exp ln -s $PWD/$repo/exp/pretrained.pt lstm_transducer_stateless2/exp/epoch-999.pt diff --git a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml index dd67771ba..a90841fb6 100644 --- a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml +++ b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml @@ -18,7 +18,7 @@ on: jobs: run_librispeech_lstm_transducer_stateless2_2022_09_03: - if: github.event.label.name == 'ready' || github.event.label.name == 'ncnn' || github.event.label.name == 'onnx' || github.event_name == 'push' || github.event_name == 'schedule' + if: github.event.label.name == 'ready' || github.event.label.name == 'shallow-fusion' || github.event.label.name == 'ncnn' || github.event.label.name == 'onnx' || github.event_name == 'push' || github.event_name == 'schedule' runs-on: ${{ matrix.os }} strategy: matrix: @@ -128,9 +128,20 @@ jobs: find modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 find modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + - name: Display decoding results for lstm_transducer_stateless2 + if: github.event.label.name == 'shallow-fusion' + shell: bash + run: | + cd egs/librispeech/ASR + tree lstm_transducer_stateless2/exp + cd lstm_transducer_stateless2/exp + echo "===modified_beam_search_rnnlm_shallow_fusion===" + find modified_beam_search_rnnlm_shallow_fusion -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find modified_beam_search_rnnlm_shallow_fusion -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + - name: Upload decoding results for lstm_transducer_stateless2 uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' + if: github.event_name == 'schedule' || github.event.label.name == 'shallow-fusion' with: name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-lstm_transducer_stateless2-2022-09-03 path: egs/librispeech/ASR/lstm_transducer_stateless2/exp/ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index da88e257b..b7c2010f7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -2083,7 +2083,7 @@ def modified_beam_search_rnnlm_shallow_fusion( log_prob=hyp_log_prob, state=state, lm_score=lm_score, - timestampe=new_timestamp, + timestamp=new_timestamp, ) B[i].add(new_hyp) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index e01167285..ba91302ce 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -101,7 +101,15 @@ class Decoder(nn.Module): need_pad = bool(need_pad) y = y.to(torch.int64) - embedding_out = self.embedding(y) + # this stuff about clamp() is a temporary fix for a mismatch + # at utterance start, we use negative ids in beam_search.py + if torch.jit.is_tracing(): + # This is for exporting to PNNX via ONNX + embedding_out = self.embedding(y) + else: + embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze( + -1 + ) if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad: From 62302259d093517341fc833bab6e6dcd9cb9bc9e Mon Sep 17 00:00:00 2001 From: ahmedalbahnasawy <50653875+ahmedalbahnasawy@users.noreply.github.com> Date: Mon, 14 Nov 2022 20:11:42 +0400 Subject: [PATCH 028/120] add kaldifeat (#680) --- docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile b/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile index 524303fb8..3637d2f11 100644 --- a/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile +++ b/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile @@ -68,6 +68,7 @@ RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ cd /workspace/icefall && \ pip install -r requirements.txt +RUN pip install kaldifeat ENV PYTHONPATH /workspace/icefall:$PYTHONPATH -WORKDIR /workspace/icefall \ No newline at end of file +WORKDIR /workspace/icefall From 952a7b3fcc7dc1ad0f87f431edecdcb4f2c6fd3b Mon Sep 17 00:00:00 2001 From: Tiance Wang Date: Tue, 15 Nov 2022 10:45:48 +0800 Subject: [PATCH 029/120] Fix typo (#681) * Update add_alignment_librispeech.py * Update scaling_converter.py --- egs/librispeech/ASR/local/add_alignment_librispeech.py | 2 +- .../ASR/pruned_transducer_stateless3/scaling_converter.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/local/add_alignment_librispeech.py b/egs/librispeech/ASR/local/add_alignment_librispeech.py index cd1bcea67..fe6a26c51 100755 --- a/egs/librispeech/ASR/local/add_alignment_librispeech.py +++ b/egs/librispeech/ASR/local/add_alignment_librispeech.py @@ -171,7 +171,7 @@ def add_alignment( ali = alignments[origin_id] else: logging.info( - f"Warning: {origin_id} does not has alignment." + f"Warning: {origin_id} does not have alignment." ) ali = [] subcut.alignment = {"word": ali} diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py index 1e7e808c7..1e6022b57 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py @@ -87,7 +87,7 @@ def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear: in_features=scaled_linear.in_features, out_features=scaled_linear.out_features, bias=True, # otherwise, it throws errors when converting to PNNX format - # device=weight.device, # Pytorch version before v1.9.0 does not has + # device=weight.device, # Pytorch version before v1.9.0 does not have # this argument. Comment out for now, we will # see if it will raise error for versions # after v1.9.0 From 855c76655b49de61a5c6d054a7ff2158a639e6f7 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 15 Nov 2022 16:56:05 +0800 Subject: [PATCH 030/120] Add zipformer from Dan using multi-dataset setup (#675) * Bug fix * Change subsamplling factor from 1 to 2 * Implement AttentionCombine as replacement for RandomCombine * Decrease random_prob from 0.5 to 0.333 * Add print statement * Apply single_prob mask, so sometimes we just get one layer as output. * Introduce feature mask per frame * Include changes from Liyong about padding conformer module. * Reduce single_prob from 0.5 to 0.25 * Reduce feature_mask_dropout_prob from 0.25 to 0.15. * Remove dropout from inside ConformerEncoderLayer, for adding to residuals * Increase feature_mask_dropout_prob from 0.15 to 0.2. * Swap random_prob and single_prob, to reduce prob of being randomized. * Decrease feature_mask_dropout_prob back from 0.2 to 0.15, i.e. revert the 43->48 change. * Randomize order of some modules * Bug fix * Stop backprop bug * Introduce a scale dependent on the masking value * Implement efficient layer dropout * Simplify the learned scaling factor on the modules * Compute valid loss on batch 0. * Make the scaling factors more global and the randomness of dropout more random * Bug fix * Introduce offset in layerdrop_scaleS * Remove final combination; implement layer drop that drops the final layers. * Bug fices * Fix bug RE self.training * Fix bug setting layerdrop mask * Fix eigs call * Add debug info * Remove warmup * Remove layer dropout and model-level warmup * Don't always apply the frame mask * Slight code cleanup/simplification * Various fixes, finish implementating frame masking * Remove debug info * Don't compute validation if printing diagnostics. * Apply layer bypass during warmup in a new way, including 2s and 4s of layers. * Update checkpoint.py to deal with int params * Revert initial_scale to previous values. * Remove the feature where it was bypassing groups of layers. * Implement layer dropout with probability 0.075 * Fix issue with warmup in test time * Add warmup schedule where dropout disappears from earlier layers first. * Have warmup that gradually removes dropout from layers; multiply initialization scales by 0.1. * Do dropout a different way * Fix bug in warmup * Remove debug print * Make the warmup mask per frame. * Implement layer dropout (in a relatively efficient way) * Decrease initial keep_prob to 0.25. * Make it start warming up from the very start, and increase warmup_batches to 6k * Change warmup schedule and increase warmup_batches from 4k to 6k * Make the bypass scale trainable. * Change the initial keep-prob back from 0.25 to 0.5 * Bug fix * Limit bypass scale to >= 0.1 * Revert "Change warmup schedule and increase warmup_batches from 4k to 6k" This reverts commit 86845bd5d859ceb6f83cd83f3719c3e6641de987. * Do warmup by dropping out whole layers. * Decrease frequency of logging variance_proportion * Make layerdrop different in different processes. * For speed, drop the same num layers per job. * Decrease initial_layerdrop_prob from 0.75 to 0.5 * Revert also the changes in scaled_adam_exp85 regarding warmup schedule * Remove unused code LearnedScale. * Reintroduce batching to the optimizer * Various fixes from debugging with nvtx, but removed the NVTX annotations. * Only apply ActivationBalancer with prob 0.25. * Fix s -> scaling for import. * Increase final layerdrop prob from 0.05 to 0.075 * Fix bug where fewer layers were dropped than should be; remove unnecesary print statement. * Fix bug in choosing layers to drop * Refactor RelPosMultiheadAttention to have 2nd forward function and introduce more modules in conformer encoder layer * Reduce final layerdrop_prob from 0.075 to 0.05. * Fix issue with diagnostics if stats is None * Remove persistent attention scores. * Make ActivationBalancer and MaxEig more efficient. * Cosmetic improvements * Change scale_factor_scale from 0.5 to 0.8 * Make the ActivationBalancer regress to the data mean, not zero, when enforcing abs constraint. * Remove unused config value * Fix bug when channel_dim < 0 * Fix bug when channel_dim < 0 * Simplify how the positional-embedding scores work in attention (thanks to Zengwei for this concept) * Revert dropout on attention scores to 0.0. * This should just be a cosmetic change, regularizing how we get the warmup times from the layers. * Reduce beta from 0.75 to 0.0. * Reduce stats period from 10 to 4. * Reworking of ActivationBalancer code to hopefully balance speed and effectiveness. * Add debug code for attention weihts and eigs * Remove debug statement * Add different debug info. * Penalize attention-weight entropies above a limit. * Remove debug statements * use larger delta but only penalize if small grad norm * Bug fixes; change debug freq * Change cutoff for small_grad_norm * Implement whitening of values in conformer. * Also whiten the keys in conformer. * Fix an issue with scaling of grad. * Decrease whitening limit from 2.0 to 1.1. * Fix debug stats. * Reorganize Whiten() code; configs are not the same as before. Also remove MaxEig for self_attn module * Bug fix RE float16 * Revert whitening_limit from 1.1 to 2.2. * Replace MaxEig with Whiten with limit=5.0, and move it to end of ConformerEncoderLayer * Change LR schedule to start off higher * Simplify the dropout mask, no non-dropped-out sequences * Make attention dims configurable, not embed_dim//2, trying 256. * Reduce attention_dim to 192; cherry-pick scaled_adam_exp130 which is linear_pos interacting with query * Use half the dim for values, vs. keys and queries. * Increase initial-lr from 0.04 to 0.05, plus changes for diagnostics * Cosmetic changes * Changes to avoid bug in backward hooks, affecting diagnostics. * Random clip attention scores to -5..5. * Add some random clamping in model.py * Add reflect=0.1 to invocations of random_clamp() * Remove in_balancer. * Revert model.py so there are no constraints on the output. * Implement randomized backprop for softmax. * Reduce min_abs from 1e-03 to 1e-04 * Add RandomGrad with min_abs=1.0e-04 * Use full precision to do softmax and store ans. * Fix bug in backprop of random_clamp() * Get the randomized backprop for softmax in autocast mode working. * Remove debug print * Reduce min_abs from 1.0e-04 to 5.0e-06 * Add hard limit of attention weights to +- 50 * Use normal implementation of softmax. * Remove use of RandomGrad * Remove the use of random_clamp in conformer.py. * Reduce the limit on attention weights from 50 to 25. * Reduce min_prob of ActivationBalancer from 0.1 to 0.05. * Penalize too large weights in softmax of AttentionDownsample() * Also apply limit on logit in SimpleCombiner * Increase limit on logit for SimpleCombiner to 25.0 * Add more diagnostics to debug gradient scale problems * Changes to grad scale logging; increase grad scale more frequently if less than one. * Add logging * Remove comparison diagnostics, which were not that useful. * Configuration changes: scores limit 5->10, min_prob 0.05->0.1, cur_grad_scale more aggressive increase * Reset optimizer state when we change loss function definition. * Make warmup period decrease scale on simple loss, leaving pruned loss scale constant. * Cosmetic change * Increase initial-lr from 0.05 to 0.06. * Increase initial-lr from 0.06 to 0.075 and decrease lr-epochs from 3.5 to 3. * Fixes to logging statements. * Introduce warmup schedule in optimizer * Increase grad_scale to Whiten module * Add inf check hooks * Renaming in optim.py; remove step() from scan_pessimistic_batches_for_oom in train.py * Change base lr to 0.1, also rename from initial lr in train.py * Adding activation balancers after simple_am_prob and simple_lm_prob * Reduce max_abs on am_balancer * Increase max_factor in final lm_balancer and am_balancer * Use penalize_abs_values_gt, not ActivationBalancer. * Trying to reduce grad_scale of Whiten() from 0.02 to 0.01. * Add hooks.py, had negleted to git add it. * don't do penalize_values_gt on simple_lm_proj and simple_am_proj; reduce --base-lr from 0.1 to 0.075 * Increase probs of activation balancer and make it decay slower. * Dont print out full non-finite tensor * Increase default max_factor for ActivationBalancer from 0.02 to 0.04; decrease max_abs in ConvolutionModule.deriv_balancer2 from 100.0 to 20.0 * reduce initial scale in GradScaler * Increase max_abs in ActivationBalancer of conv module from 20 to 50 * --base-lr0.075->0.5; --lr-epochs 3->3.5 * Revert 179->180 change, i.e. change max_abs for deriv_balancer2 back from 50.0 20.0 * Save some memory in the autograd of DoubleSwish. * Change the discretization of the sigmoid to be expectation preserving. * Fix randn to rand * Try a more exact way to round to uint8 that should prevent ever wrapping around to zero * Make it use float16 if in amp but use clamp to avoid wrapping error * Store only half precision output for softmax. * More memory efficient backprop for DoubleSwish. * Change to warmup schedule. * Changes to more accurately estimate OOM conditions * Reduce cutoff from 100 to 5 for estimating OOM with warmup * Make 20 the limit for warmup_count * Cast to float16 in DoubleSwish forward * Hopefully make penalize_abs_values_gt more memory efficient. * Add logging about memory used. * Change scalar_max in optim.py from 2.0 to 5.0 * Regularize how we apply the min and max to the eps of BasicNorm * Fix clamping of bypass scale; remove a couple unused variables. * Increase floor on bypass_scale from 0.1 to 0.2. * Increase bypass_scale from 0.2 to 0.4. * Increase bypass_scale min from 0.4 to 0.5 * Rename conformer.py to zipformer.py * Rename Conformer to Zipformer * Update decode.py by copying from pruned_transducer_stateless5 and changing directory name * Remove some unused variables. * Fix clamping of epsilon * Refactor zipformer for more flexibility so we can change number of encoder layers. * Have a 3rd encoder, at downsampling factor of 8. * Refactor how the downsampling is done so that it happens later, but the 1st encoder stack still operates after a subsampling of 2. * Fix bug RE seq lengths * Have 4 encoder stacks * Have 6 different encoder stacks, U-shaped network. * Reduce dim of linear positional encoding in attention layers. * Reduce min of bypass_scale from 0.5 to 0.3, and make it not applied in test mode. * Tuning change to num encoder layers, inspired by relative param importance. * Make decoder group size equal to 4. * Add skip connections as in normal U-net * Avoid falling off the loop for weird inputs * Apply layer-skip dropout prob * Have warmup schedule for layer-skipping * Rework how warmup count is produced; should not affect results. * Add warmup schedule for zipformer encoder layer, from 1.0 -> 0.2. * Reduce initial clamp_min for bypass_scale from 1.0 to 0.5. * Restore the changes from scaled_adam_219 and scaled_adam_exp220, accidentally lost, re layer skipping * Change to schedule of bypass_scale min: make it larger, decrease slower. * Change schedule after initial loss not promising * Implement pooling module, add it after initial feedforward. * Bug fix * Introduce dropout rate to dynamic submodules of conformer. * Introduce minimum probs in the SimpleCombiner * Add bias in weight module * Remove dynamic weights in SimpleCombine * Remove the 5th of 6 encoder stacks * Fix some typos * small fixes * small fixes * Copy files * Update decode.py * Add changes from the master * Add changes from the master * update results * Add CI * Small fixes * Small fixes Co-authored-by: Daniel Povey --- ...pruned-transducer-stateless7-2022-11-11.sh | 1 + ...pruned-transducer-stateless8-2022-11-14.sh | 116 ++ .../run-librispeech-2022-11-14-stateless8.yml | 155 ++ egs/librispeech/ASR/README.md | 1 + egs/librispeech/ASR/RESULTS.md | 58 + .../jit_pretrained.py | 1 + .../pruned_transducer_stateless8/__init__.py | 0 .../asr_datamodule.py | 1 + .../beam_search.py | 1 + .../pruned_transducer_stateless8/decode.py | 863 +++++++++++ .../pruned_transducer_stateless8/decoder.py | 1 + .../encoder_interface.py | 1 + .../pruned_transducer_stateless8/export.py | 334 ++++ .../gigaspeech.py | 1 + .../jit_pretrained.py | 275 ++++ .../pruned_transducer_stateless8/joiner.py | 1 + .../librispeech.py | 1 + .../ASR/pruned_transducer_stateless8/model.py | 222 +++ .../ASR/pruned_transducer_stateless8/optim.py | 1 + .../pretrained.py | 363 +++++ .../pruned_transducer_stateless8/scaling.py | 1 + .../scaling_converter.py | 1 + .../ASR/pruned_transducer_stateless8/train.py | 1367 +++++++++++++++++ .../pruned_transducer_stateless8/zipformer.py | 1 + 24 files changed, 3767 insertions(+) create mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh create mode 100644 .github/workflows/run-librispeech-2022-11-14-stateless8.yml create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless8/__init__.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless8/asr_datamodule.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless8/beam_search.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless8/decode.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless8/decoder.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless8/encoder_interface.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless8/export.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless8/gigaspeech.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless8/joiner.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless8/librispeech.py create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless8/model.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless8/optim.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless8/scaling.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless8/scaling_converter.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless8/train.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless8/zipformer.py diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh index 75861bbc7..8e485d2e6 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh @@ -33,6 +33,7 @@ popd log "Export to torchscript model" ./pruned_transducer_stateless7/export.py \ --exp-dir $repo/exp \ + --use-averaged-model false \ --bpe-model $repo/data/lang_bpe_500/bpe.model \ --epoch 99 \ --avg 1 \ diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh new file mode 100755 index 000000000..e782b8425 --- /dev/null +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh @@ -0,0 +1,116 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless8-2022-11-14 + +log "Downloading pre-trained model from $repo_url" +git lfs install +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +soxi $repo/test_wavs/*.wav +ls -lh $repo/test_wavs/*.wav + +pushd $repo/exp +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/cpu_jit.pt" +git lfs pull --include "exp/pretrained.pt" +ln -s pretrained.pt epoch-99.pt +ls -lh *.pt +popd + +log "Decode with models exported by torch.jit.script()" + +./pruned_transducer_stateless8/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --nn-model-filename $repo/exp/cpu_jit.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +log "Export to torchscript model" +./pruned_transducer_stateless8/export.py \ + --exp-dir $repo/exp \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model false \ + --epoch 99 \ + --avg 1 \ + --jit 1 + +ls -lh $repo/exp/*.pt + +log "Decode with models exported by torch.jit.script()" + +./pruned_transducer_stateless8/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --nn-model-filename $repo/exp/cpu_jit.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless8/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./pruned_transducer_stateless8/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" +echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" +if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then + mkdir -p pruned_transducer_stateless8/exp + ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless8/exp/epoch-999.pt + ln -s $PWD/$repo/data/lang_bpe_500 data/ + + ls -lh data + ls -lh pruned_transducer_stateless8/exp + + log "Decoding test-clean and test-other" + + # use a small value for decoding with CPU + max_duration=100 + + for method in greedy_search fast_beam_search modified_beam_search; do + log "Decoding with $method" + + ./pruned_transducer_stateless8/decode.py \ + --decoding-method $method \ + --epoch 999 \ + --avg 1 \ + --use-averaged-model 0 \ + --max-duration $max_duration \ + --exp-dir pruned_transducer_stateless8/exp + done + + rm pruned_transducer_stateless8/exp/*.pt +fi diff --git a/.github/workflows/run-librispeech-2022-11-14-stateless8.yml b/.github/workflows/run-librispeech-2022-11-14-stateless8.yml new file mode 100644 index 000000000..eaab35189 --- /dev/null +++ b/.github/workflows/run-librispeech-2022-11-14-stateless8.yml @@ -0,0 +1,155 @@ +# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: run-librispeech-2022-11-14-stateless8 +# zipformer + +on: + push: + branches: + - master + pull_request: + types: [labeled] + + schedule: + # minute (0-59) + # hour (0-23) + # day of the month (1-31) + # month (1-12) + # day of the week (0-6) + # nightly build at 15:50 UTC time every day + - cron: "50 15 * * *" + +jobs: + run_librispeech_2022_11_14_zipformer_stateless8: + if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: [3.8] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }}-2022-09-25 + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: Cache LibriSpeech test-clean and test-other datasets + id: libri-test-clean-and-test-other-data + uses: actions/cache@v2 + with: + path: | + ~/tmp/download + key: cache-libri-test-clean-and-test-other + + - name: Download LibriSpeech test-clean and test-other + if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh + + - name: Prepare manifests for LibriSpeech test-clean and test-other + shell: bash + run: | + .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh + + - name: Cache LibriSpeech test-clean and test-other fbank features + id: libri-test-clean-and-test-other-fbank + uses: actions/cache@v2 + with: + path: | + ~/tmp/fbank-libri + key: cache-libri-fbank-test-clean-and-test-other-v2 + + - name: Compute fbank for LibriSpeech test-clean and test-other + if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh + + - name: Inference with pre-trained model + shell: bash + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} + run: | + mkdir -p egs/librispeech/ASR/data + ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank + ls -lh egs/librispeech/ASR/data/* + + sudo apt-get -qq install git-lfs tree sox + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + .github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh + + - name: Display decoding results for librispeech pruned_transducer_stateless8 + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + shell: bash + run: | + cd egs/librispeech/ASR/ + tree ./pruned_transducer_stateless8/exp + + cd pruned_transducer_stateless8 + echo "results for pruned_transducer_stateless8" + echo "===greedy search===" + find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===fast_beam_search===" + find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===modified beam search===" + find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + - name: Upload decoding results for librispeech pruned_transducer_stateless8 + uses: actions/upload-artifact@v2 + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + with: + name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless8-2022-11-14 + path: egs/librispeech/ASR/pruned_transducer_stateless8/exp/ diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md index c366650bb..e737d68bd 100644 --- a/egs/librispeech/ASR/README.md +++ b/egs/librispeech/ASR/README.md @@ -23,6 +23,7 @@ The following table lists the differences among them. | `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless4 + more layers + random combiner| | `pruned_transducer_stateless6` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless4 + distillation with hubert| | `pruned_transducer_stateless7` | Zipformer | Embedding + Conv1d | First experiment with Zipformer from Dan| +| `pruned_transducer_stateless8` | Zipformer | Embedding + Conv1d | Same as pruned_transducer_stateless7, but using extra data from GigaSpeech| | `pruned_stateless_emformer_rnnt2` | Emformer(from torchaudio) | Embedding + Conv1d | Using Emformer from torchaudio for streaming ASR| | `conv_emformer_transducer_stateless` | ConvEmformer | Embedding + Conv1d | Using ConvEmformer for streaming ASR + mechanisms in reworked model | | `conv_emformer_transducer_stateless2` | ConvEmformer | Embedding + Conv1d | Using ConvEmformer with simplified memory for streaming ASR + mechanisms in reworked model | diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 43cd67c85..030e47b86 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -1,5 +1,63 @@ ## Results +### pruned_transducer_stateless8 (zipformer + multidataset) + +See for more details. + +[pruned_transducer_stateless8](./pruned_transducer_stateless8) + +The tensorboard log can be found at + + +You can find a pretrained model, training logs, decoding logs, and decoding +results at: + + +You can use to deploy it. + +Number of model parameters: 70369391, i.e., 70.37 M + +| | test-clean | test-other | comment | +|----------------------|------------|-------------|----------------------------------------| +| greedy search | 1.87 | 4.38 | --epoch 16 --avg 2 --max-duration 600 | +| modified beam search | 1.81 | 4.34 | --epoch 16 --avg 2 --max-duration 600 | +| fast beam search | 1.91 | 4.33 | --epoch 16 --avg 2 --max-duration 600 | + +The training commands are: +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + +./pruned_transducer_stateless8/train.py \ + --world-size 8 \ + --num-epochs 20 \ + --full-libri 1 \ + --use-fp16 1 \ + --max-duration 750 \ + --exp-dir pruned_transducer_stateless8/exp \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --master-port 12535 \ + --giga-prob 0.9 +``` + +The decoding commands are: +```bash +for m in greedy_search fast_beam_search modified_beam_search ; do + for epoch in 16; do + for avg in 2; do + ./pruned_transducer_stateless8/decode.py \ + --epoch $epoch \ + --avg $avg \ + --use-averaged-model 1 \ + --exp-dir ./pruned_transducer_stateless8/exp \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --max-duration 600 \ + --decoding-method $m + done + done +done +``` + + ### pruned_transducer_stateless7 (zipformer) See for more details. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py index 81b0deba3..e2405d5ef 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py @@ -30,6 +30,7 @@ Usage of this script: ./pruned_transducer_stateless7/jit_pretrained.py \ --nn-model-filename ./pruned_transducer_stateless7/exp/cpu_jit.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ /path/to/foo.wav \ /path/to/bar.wav """ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/__init__.py b/egs/librispeech/ASR/pruned_transducer_stateless8/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless8/asr_datamodule.py new file mode 120000 index 000000000..3ba9ada4f --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/asr_datamodule.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless8/beam_search.py new file mode 120000 index 000000000..8554e44cc --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py new file mode 100755 index 000000000..9d7335e77 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py @@ -0,0 +1,863 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless8/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless8/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless8/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless8/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless8/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless8/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./pruned_transducer_stateless8/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless8/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless8/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless8/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless8/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless8/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless8/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless8/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import AsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from librispeech import LibriSpeech +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless8/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--simulate-streaming", + type=str2bool, + default=False, + help="""Whether to simulate streaming in decoding, this is a good way to + test a streaming model. + """, + ) + + parser.add_argument( + "--decode-chunk-size", + type=int, + default=16, + help="The chunk size for decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--left-context", + type=int, + default=64, + help="left context can be seen during decoding (in frames after subsampling)", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.simulate_streaming: + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) + encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( + x=feature, + x_lens=feature_lens, + chunk_size=params.decode_chunk_size, + left_context=params.left_context, + simulate_streaming=True, + ) + else: + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.simulate_streaming: + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" + params.suffix += f"-left-context-{params.left_context}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if params.simulate_streaming: + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params, enable_giga=False) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), + strict=False, + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + else: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + asr_datamodule = AsrDataModule(args) + librispeech = LibriSpeech(manifest_dir=args.manifest_dir) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = asr_datamodule.test_dataloaders(test_clean_cuts) + test_other_dl = asr_datamodule.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless8/decoder.py new file mode 120000 index 000000000..33944d0d2 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/decoder.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/encoder_interface.py b/egs/librispeech/ASR/pruned_transducer_stateless8/encoder_interface.py new file mode 120000 index 000000000..b9aa0ae08 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/encoder_interface.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/export.py b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py new file mode 100755 index 000000000..49f469e29 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py @@ -0,0 +1,334 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +(1) Export to torchscript model using torch.jit.script() + +./pruned_transducer_stateless8/export.py \ + --exp-dir ./pruned_transducer_stateless8/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("cpu_jit.pt")`. + +Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python +are on CPU. You can use `to("cuda")` to move them to a CUDA device. + +Check +https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +(2) Export `model.state_dict()` + +./pruned_transducer_stateless8/export.py \ + --exp-dir ./pruned_transducer_stateless8/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +To use the generated file with `pruned_transducer_stateless8/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./pruned_transducer_stateless8/decode.py \ + --exp-dir ./pruned_transducer_stateless8/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless8-2022-11-14 + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless8-2022-11-14 + # You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless8-2022-11-14/exp +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +import torch.nn as nn +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless8/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named cpu_jit.pt + + Check ./jit_pretrained.py for how to use it. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params, enable_giga=False) + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), + strict=False, + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), + strict=False, + ) + else: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + + model.to("cpu") + model.eval() + + if params.jit is True: + convert_scaled_to_non_scaled(model, inplace=True) + logging.info("Using torch.jit.script()") + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless8/gigaspeech.py new file mode 120000 index 000000000..5242c652a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/gigaspeech.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/gigaspeech.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py new file mode 100755 index 000000000..e79a3a3aa --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py @@ -0,0 +1,275 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless8/export.py \ + --exp-dir ./pruned_transducer_stateless8/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --jit 1 + +Usage of this script: + +./pruned_transducer_stateless8/jit_pretrained.py \ + --nn-model-filename ./pruned_transducer_stateless8/exp/cpu_jit.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model-filename", + type=str, + required=True, + help="Path to the torchscript model cpu_jit.pt", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + model: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + Returns: + Return the decoded results for each utterance. + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + device = encoder_out.device + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + context_size = model.decoder.context_size + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ).squeeze(1) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + current_encoder_out = current_encoder_out + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = model.joiner( + current_encoder_out, + decoder_out, + ) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ) + decoder_out = decoder_out.squeeze(1) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = torch.jit.load(args.nn_model_filename) + + model.eval() + + model.to(device) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + x=features, + x_lens=feature_lengths, + ) + + hyps = greedy_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = sp.decode(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless8/joiner.py new file mode 120000 index 000000000..ecfb6dd8a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/joiner.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/joiner.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/librispeech.py b/egs/librispeech/ASR/pruned_transducer_stateless8/librispeech.py new file mode 120000 index 000000000..b76723bf5 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/librispeech.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/librispeech.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/model.py b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py new file mode 100644 index 000000000..497b89136 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py @@ -0,0 +1,222 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import random +from typing import Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from encoder_interface import EncoderInterface +from scaling import penalize_abs_values_gt + +from icefall.utils import add_sos + + +class Transducer(nn.Module): + """It implements https://arxiv.org/pdf/1211.3711.pdf + "Sequence Transduction with Recurrent Neural Networks" + """ + + def __init__( + self, + encoder: EncoderInterface, + decoder: nn.Module, + joiner: nn.Module, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, + vocab_size: int, + decoder_giga: Optional[nn.Module] = None, + joiner_giga: Optional[nn.Module] = None, + ): + """ + Args: + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dm) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. + joiner: + It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output contains + unnormalized probs, i.e., not processed by log-softmax. + """ + super().__init__() + assert isinstance(encoder, EncoderInterface), type(encoder) + assert hasattr(decoder, "blank_id") + + self.encoder = encoder + self.decoder = decoder + self.joiner = joiner + + self.decoder_giga = decoder_giga + self.joiner_giga = joiner_giga + + self.simple_am_proj = nn.Linear( + encoder_dim, + vocab_size, + ) + self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size) + + if decoder_giga is not None: + self.simple_am_proj_giga = nn.Linear(encoder_dim, vocab_size) + self.simple_lm_proj_giga = nn.Linear(decoder_dim, vocab_size) + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + libri: bool = True, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + ) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + libri: + True to use the decoder and joiner for the LibriSpeech dataset. + False to use the decoder and joiner for the GigaSpeech dataset. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + Returns: + Return the transducer loss. + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0 + + encoder_out, x_lens = self.encoder(x, x_lens) + assert torch.all(x_lens > 0) + + if libri: + decoder = self.decoder + simple_lm_proj = self.simple_lm_proj + simple_am_proj = self.simple_am_proj + joiner = self.joiner + else: + decoder = self.decoder_giga + simple_lm_proj = self.simple_lm_proj_giga + simple_am_proj = self.simple_am_proj_giga + joiner = self.joiner_giga + + # Now for the decoder, i.e., the prediction network + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + blank_id = decoder.blank_id + sos_y = add_sos(y, sos_id=blank_id) + + # sos_y_padded: [B, S + 1], start with SOS. + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + + # decoder_out: [B, S + 1, decoder_dim] + decoder_out = decoder(sos_y_padded) + + # Note: y does not start with SOS + # y_padded : [B, S] + y_padded = y.pad(mode="constant", padding_value=0) + + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) + boundary[:, 2] = y_lens + boundary[:, 3] = x_lens + + lm = simple_lm_proj(decoder_out) + am = simple_am_proj(encoder_out) + + # if self.training and random.random() < 0.25: + # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) + # if self.training and random.random() < 0.25: + # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) + + with torch.cuda.amp.autocast(enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction="sum", + return_grad=True, + ) + + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, + ) + + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=joiner.encoder_proj(encoder_out), + lm=joiner.decoder_proj(decoder_out), + ranges=ranges, + ) + + # logits : [B, T, prune_range, vocab_size] + + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + logits = joiner(am_pruned, lm_pruned, project_input=False) + + with torch.cuda.amp.autocast(enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="sum", + ) + + return (simple_loss, pruned_loss) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless8/optim.py new file mode 120000 index 000000000..81ac4a89a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/optim.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/optim.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py new file mode 100755 index 000000000..373a48fc1 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py @@ -0,0 +1,363 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +./pruned_transducer_stateless8/export.py \ + --exp-dir ./pruned_transducer_stateless8/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +Usage of this script: + +(1) greedy search +./pruned_transducer_stateless8/pretrained.py \ + --checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) beam search +./pruned_transducer_stateless8/pretrained.py \ + --checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) modified beam search +./pruned_transducer_stateless8/pretrained.py \ + --checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(4) fast beam search +./pruned_transducer_stateless8/pretrained.py \ + --checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method fast_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +You can also use `./pruned_transducer_stateless8/exp/epoch-xx.pt`. + +Note: ./pruned_transducer_stateless8/exp/pretrained.pt is generated by +./pruned_transducer_stateless8/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params, enable_giga=False) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lengths + ) + + num_waves = encoder_out.size(0) + hyps = [] + msg = f"Using {params.method}" + if params.method == "beam_search": + msg += f" with beam size {params.beam_size}" + logging.info(msg) + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(sp.decode(hyp).split()) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless8/scaling.py new file mode 120000 index 000000000..2428b74b9 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/scaling.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless8/scaling_converter.py new file mode 120000 index 000000000..b8b8ba432 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/scaling_converter.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py new file mode 100755 index 000000000..b4177d3f0 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py @@ -0,0 +1,1367 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +cd egs/librispeech/ASR/ +./prepare.sh +./prepare_giga_speech.sh + +./pruned_transducer_stateless8/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless8/exp \ + --full-libri 1 \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless8/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless8/exp \ + --full-libri 1 \ + --max-duration 550 + +""" + + +import argparse +import copy +import logging +import random +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import AsrDataModule +from decoder import Decoder +from gigaspeech import GigaSpeech +from joiner import Joiner +from lhotse import CutSet, load_manifest +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from librispeech import LibriSpeech +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="When enabled, use 960h LibriSpeech. " + "Otherwise, use 100h subset.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless8/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" + "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--giga-prob", + type=float, + default=0.5, + help="The probability to select a batch from the GigaSpeech dataset", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model( + params: AttributeDict, + enable_giga: bool = True, +) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + if enable_giga: + logging.info("Use giga") + decoder_giga = get_decoder_model(params) + joiner_giga = get_joiner_model(params) + else: + logging.info("Disable giga") + decoder_giga = None + joiner_giga = None + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + decoder_giga=decoder_giga, + joiner_giga=joiner_giga, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def is_libri(c: Cut) -> bool: + """Return True if this cut is from the LibriSpeech dataset. + + Note: + During data preparation, we set the custom field in + the supervision segment of GigaSpeech to dict(origin='giga') + See ../local/preprocess_gigaspeech.py. + """ + return c.supervisions[0].custom is None + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + libri = is_libri(supervisions["cut"][0]) + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + libri=libri, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + giga_train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + rng: random.Random, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + giga_train_dl: + Dataloader for the GigaSpeech training dataset. + valid_dl: + Dataloader for the validation dataset. + rng: + For selecting which dataset to use. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + libri_tot_loss = MetricsTracker() + giga_tot_loss = MetricsTracker() + tot_loss = MetricsTracker() + + # index 0: for LibriSpeech + # index 1: for GigaSpeech + # This sets the probabilities for choosing which datasets + dl_weights = [1 - params.giga_prob, params.giga_prob] + + iter_libri = iter(train_dl) + iter_giga = iter(giga_train_dl) + + batch_idx = 0 + + while True: + idx = rng.choices((0, 1), weights=dl_weights, k=1)[0] + dl = iter_libri if idx == 0 else iter_giga + + try: + batch = next(dl) + except StopIteration: + name = "libri" if idx == 0 else "giga" + logging.info(f"{name} reaches end of dataloader") + break + + batch_idx += 1 + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + libri = is_libri(batch["supervisions"]["cut"][0]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + if libri: + libri_tot_loss = ( + libri_tot_loss * (1 - 1 / params.reset_interval) + ) + loss_info + prefix = "libri" # for logging only + else: + giga_tot_loss = ( + giga_tot_loss * (1 - 1 / params.reset_interval) + ) + loss_info + prefix = "giga" + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or ( + cur_grad_scale < 8.0 and batch_idx % 400 == 0 + ): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, {prefix}_loss[{loss_info}], " + f"tot_loss[{tot_loss}], " + f"libri_tot_loss[{libri_tot_loss}], " + f"giga_tot_loss[{giga_tot_loss}], " + f"batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + ( + f"grad_scale: {scaler._scale.item()}" + if params.use_fp16 + else "" + ) + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, + f"train/current_{prefix}_", + params.batch_idx_train, + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + libri_tot_loss.write_summary( + tb_writer, "train/libri_tot_", params.batch_idx_train + ) + giga_tot_loss.write_summary( + tb_writer, "train/giga_tot_", params.batch_idx_train + ) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if ( + batch_idx % params.valid_interval == 0 + and not params.print_diagnostics + ): + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def filter_short_and_long_utterances(cuts: CutSet) -> CutSet: + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 1.0 <= c.duration <= 20.0 + + cuts = cuts.filter(remove_short_and_long_utt) + + return cuts + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + if params.full_libri is False: + params.valid_interval = 1600 + + fix_random_seed(params.seed) + rng = random.Random(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params, enable_giga=True) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + model.parameters(), lr=params.base_lr, clipping_scale=2.0 + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2 ** 22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeech(manifest_dir=args.manifest_dir) + + train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() + + train_cuts = filter_short_and_long_utterances(train_cuts) + + gigaspeech = GigaSpeech(manifest_dir=args.manifest_dir) + # XL 10k hours + # L 2.5k hours + # M 1k hours + # S 250 hours + # XS 10 hours + # DEV 12 hours + # Test 40 hours + if params.full_libri: + logging.info("Using the XL subset of GigaSpeech (10k hours)") + train_giga_cuts = gigaspeech.train_XL_cuts() + else: + logging.info("Using the S subset of GigaSpeech (250 hours)") + train_giga_cuts = gigaspeech.train_S_cuts() + + train_giga_cuts = filter_short_and_long_utterances(train_giga_cuts) + train_giga_cuts = train_giga_cuts.repeat(times=None) + + if args.enable_musan: + cuts_musan = load_manifest( + Path(args.manifest_dir) / "musan_cuts.jsonl.gz" + ) + else: + cuts_musan = None + + asr_datamodule = AsrDataModule(args) + + train_dl = asr_datamodule.train_dataloaders( + train_cuts, + on_the_fly_feats=False, + cuts_musan=cuts_musan, + ) + + giga_train_dl = asr_datamodule.train_dataloaders( + train_giga_cuts, + on_the_fly_feats=False, + cuts_musan=cuts_musan, + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = asr_datamodule.valid_dataloaders(valid_cuts) + + if False and not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + giga_train_dl=giga_train_dl, + valid_dl=valid_dl, + rng=rng, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + assert 0 <= args.giga_prob < 1, args.giga_prob + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless8/zipformer.py new file mode 120000 index 000000000..79b076556 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/zipformer.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/zipformer.py \ No newline at end of file From c8ce243255d7f18dc4485c3367ef470234670e92 Mon Sep 17 00:00:00 2001 From: Desh Raj Date: Tue, 15 Nov 2022 22:29:45 -0500 Subject: [PATCH 031/120] Zipformer output length (#686) * add assertion for output length * add comment in filter_cuts * add length filter to Zipformer recipes --- egs/librispeech/ASR/local/filter_cuts.py | 3 + .../ASR/pruned_transducer_stateless7/train.py | 119 ++++++++++++------ .../pruned_transducer_stateless7/zipformer.py | 1 + .../ASR/pruned_transducer_stateless8/train.py | 44 +++++-- 4 files changed, 116 insertions(+), 51 deletions(-) diff --git a/egs/librispeech/ASR/local/filter_cuts.py b/egs/librispeech/ASR/local/filter_cuts.py index 53dbb8211..dff98a954 100644 --- a/egs/librispeech/ASR/local/filter_cuts.py +++ b/egs/librispeech/ASR/local/filter_cuts.py @@ -101,6 +101,9 @@ def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor): # Note: for ./lstm_transducer_stateless/lstm.py, the formula is # T = ((num_frames - 3) // 2 - 1) // 2 + # Note: for ./pruned_transducer_stateless7/zipformer.py, the formula is + # T = ((num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) if T < len(tokens): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 8927be227..3f27736b3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -59,7 +59,6 @@ import torch import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from zipformer import Zipformer from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut @@ -71,6 +70,7 @@ from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer from icefall import diagnostics from icefall.checkpoint import load_checkpoint, remove_checkpoints @@ -79,9 +79,9 @@ from icefall.checkpoint import ( save_checkpoint_with_global_batch_idx, update_averaged_model, ) -from icefall.hooks import register_inf_check_hooks from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool LRSchedulerType = Union[ @@ -89,14 +89,12 @@ LRSchedulerType = Union[ ] -def set_batch_count( - model: Union[nn.Module, DDP], batch_count: float -) -> None: +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: if isinstance(model, DDP): # get underlying nn.Module model = model.module for module in model.modules(): - if hasattr(module, 'batch_count'): + if hasattr(module, "batch_count"): module.batch_count = batch_count @@ -126,7 +124,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-dims", type=str, default="384,384,384,384,384", - help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated" + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", ) parser.add_argument( @@ -134,7 +132,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): type=str, default="192,192,192,192,192", help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; - not the same as embedding dimension.""" + not the same as embedding dimension.""", ) parser.add_argument( @@ -143,7 +141,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): default="256,256,256,256,256", help="Unmasked dimensions in the encoders, relates to augmentation during training. " "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " - " worse." + " worse.", ) parser.add_argument( @@ -248,10 +246,7 @@ def get_parser(): ) parser.add_argument( - "--base-lr", - type=float, - default=0.05, - help="The base learning rate." + "--base-lr", type=float, default=0.05, help="The base learning rate." ) parser.add_argument( @@ -451,11 +446,14 @@ def get_params() -> AttributeDict: def get_encoder_model(params: AttributeDict) -> nn.Module: # TODO: We can add an option to switch between Zipformer and Transformer def to_int_tuple(s: str): - return tuple(map(int, s.split(','))) + return tuple(map(int, s.split(","))) + encoder = Zipformer( num_features=params.feature_dim, output_downsampling_factor=2, - zipformer_downsampling_factors=to_int_tuple(params.zipformer_downsampling_factors), + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), encoder_dims=to_int_tuple(params.encoder_dims), attention_dim=to_int_tuple(params.attention_dims), encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), @@ -479,7 +477,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module: def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( - encoder_dim=int(params.encoder_dims.split(',')[-1]), + encoder_dim=int(params.encoder_dims.split(",")[-1]), decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, @@ -496,7 +494,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: encoder=encoder, decoder=decoder, joiner=joiner, - encoder_dim=int(params.encoder_dims.split(',')[-1]), + encoder_dim=int(params.encoder_dims.split(",")[-1]), decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, @@ -682,18 +680,17 @@ def compute_loss( # take down the scale on the simple loss from 1.0 at the start # to params.simple_loss scale by warm_step. simple_loss_scale = ( - s if batch_idx_train >= warm_step + s + if batch_idx_train >= warm_step else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) ) pruned_loss_scale = ( - 1.0 if batch_idx_train >= warm_step + 1.0 + if batch_idx_train >= warm_step else 0.1 + 0.9 * (batch_idx_train / warm_step) ) - loss = ( - simple_loss_scale * simple_loss + - pruned_loss_scale * pruned_loss - ) + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -873,12 +870,16 @@ def train_one_epoch( # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + if cur_grad_scale < 1.0 or ( + cur_grad_scale < 8.0 and batch_idx % 400 == 0 + ): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: - raise RuntimeError(f"grad_scale is too small, exiting: {cur_grad_scale}") + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] @@ -888,8 +889,12 @@ def train_one_epoch( f"Epoch {params.cur_epoch}, " f"batch {batch_idx}, loss[{loss_info}], " f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " + - (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + f"lr: {cur_lr:.2e}, " + + ( + f"grad_scale: {scaler._scale.item()}" + if params.use_fp16 + else "" + ) ) if tb_writer is not None: @@ -905,12 +910,15 @@ def train_one_epoch( ) if params.use_fp16: tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, ) - - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + if ( + batch_idx % params.valid_interval == 0 + and not params.print_diagnostics + ): logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, @@ -921,7 +929,9 @@ def train_one_epoch( ) model.train() logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) if tb_writer is not None: valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train @@ -997,12 +1007,11 @@ def run(rank, world_size, args): model.to(device) if world_size > 1: logging.info("Using DDP") - model = DDP(model, device_ids=[rank], - find_unused_parameters=True) + model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = ScaledAdam(model.parameters(), - lr=params.base_lr, - clipping_scale=2.0) + optimizer = ScaledAdam( + model.parameters(), lr=params.base_lr, clipping_scale=2.0 + ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) @@ -1043,7 +1052,34 @@ def run(rank, world_size, args): # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - return 1.0 <= c.duration <= 20.0 + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True train_cuts = train_cuts.filter(remove_short_and_long_utt) @@ -1071,8 +1107,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, - init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1193,7 +1228,9 @@ def scan_pessimistic_batches_for_oom( ) display_and_save_batch(batch, params=params, sp=sp) raise - logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) def main(): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index c14066d38..023dec97d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1828,6 +1828,7 @@ def _test_zipformer_main(): torch.randn(batch_size, seq_len, feature_dim), torch.full((batch_size,), seq_len, dtype=torch.int64), ) + assert ((seq_len - 7) // 2 + 1) // 2 == f[0].shape[1], (seq_len, f.shape[1]) f[0].sum().backward() c.eval() f = c( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py index b4177d3f0..2603bb854 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py @@ -90,12 +90,7 @@ from icefall.checkpoint import ( from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - MetricsTracker, - setup_logger, - str2bool, -) +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool LRSchedulerType = Union[ torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler @@ -1045,7 +1040,9 @@ def train_one_epoch( params.best_train_loss = params.train_loss -def filter_short_and_long_utterances(cuts: CutSet) -> CutSet: +def filter_short_and_long_utterances( + cuts: CutSet, sp: spm.SentencePieceProcessor +) -> CutSet: def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds # @@ -1055,7 +1052,34 @@ def filter_short_and_long_utterances(cuts: CutSet) -> CutSet: # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - return 1.0 <= c.duration <= 20.0 + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True cuts = cuts.filter(remove_short_and_long_utt) @@ -1162,7 +1186,7 @@ def run(rank, world_size, args): train_cuts += librispeech.train_clean_360_cuts() train_cuts += librispeech.train_other_500_cuts() - train_cuts = filter_short_and_long_utterances(train_cuts) + train_cuts = filter_short_and_long_utterances(train_cuts, sp) gigaspeech = GigaSpeech(manifest_dir=args.manifest_dir) # XL 10k hours @@ -1179,7 +1203,7 @@ def run(rank, world_size, args): logging.info("Using the S subset of GigaSpeech (250 hours)") train_giga_cuts = gigaspeech.train_S_cuts() - train_giga_cuts = filter_short_and_long_utterances(train_giga_cuts) + train_giga_cuts = filter_short_and_long_utterances(train_giga_cuts, sp) train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: From aa7bae1ecd520e06804721c3b13a8c3c2eb06bcc Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 16 Nov 2022 19:58:28 +0800 Subject: [PATCH 032/120] fix decode.py for conformer_ctc in gigaspeech (#688) --- egs/gigaspeech/ASR/conformer_ctc/decode.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/gigaspeech/ASR/conformer_ctc/decode.py b/egs/gigaspeech/ASR/conformer_ctc/decode.py index 51406667e..9c1418baa 100755 --- a/egs/gigaspeech/ASR/conformer_ctc/decode.py +++ b/egs/gigaspeech/ASR/conformer_ctc/decode.py @@ -481,9 +481,9 @@ def decode_dataset( ), "It should not decode to empty in the first batch!" this_batch = [] hyp_words = [] - for ref_text in texts: + for cut_id, ref_text in zip(cut_ids, texts): ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words)) for lm_scale in results.keys(): results[lm_scale].extend(this_batch) From d110b04ad389134c82fa314e3aafc7b40043efb0 Mon Sep 17 00:00:00 2001 From: Desh Raj Date: Wed, 16 Nov 2022 13:06:43 -0500 Subject: [PATCH 033/120] apply new black formatting to all files --- .github/workflows/style_check.yml | 11 +- .pre-commit-config.yaml | 26 +- docker/README.md | 24 +- .../Dockerfile | 14 +- .../Dockerfile | 17 +- .../images/k2-gt-v1.9-blueviolet.svg | 2 +- .../images/python-gt-v3.6-blue.svg | 2 +- .../images/torch-gt-v1.6.0-green.svg | 2 +- docs/source/recipes/aishell/index.rst | 1 - docs/source/recipes/timit/index.rst | 1 - docs/source/recipes/timit/tdnn_ligru_ctc.rst | 28 +- docs/source/recipes/timit/tdnn_lstm_ctc.rst | 24 +- .../local/compute_fbank_aidatatang_200zh.py | 8 +- .../ASR/local/prepare_char.py | 8 +- .../ASR/local/prepare_lang.py | 4 +- .../ASR/local/test_prepare_lang.py | 4 +- egs/aidatatang_200zh/ASR/local/text2token.py | 21 +- egs/aidatatang_200zh/ASR/prepare.sh | 3 +- .../asr_datamodule.py | 110 +- .../pruned_transducer_stateless2/decode.py | 50 +- .../pruned_transducer_stateless2/export.py | 20 +- .../pretrained.py | 41 +- .../ASR/pruned_transducer_stateless2/train.py | 50 +- egs/aishell/ASR/conformer_ctc/conformer.py | 70 +- egs/aishell/ASR/conformer_ctc/decode.py | 29 +- egs/aishell/ASR/conformer_ctc/export.py | 17 +- egs/aishell/ASR/conformer_ctc/pretrained.py | 39 +- egs/aishell/ASR/conformer_ctc/subsampling.py | 16 +- .../ASR/conformer_ctc/test_subsampling.py | 3 +- egs/aishell/ASR/conformer_ctc/train.py | 12 +- egs/aishell/ASR/conformer_ctc/transformer.py | 44 +- egs/aishell/ASR/conformer_mmi/conformer.py | 70 +- egs/aishell/ASR/conformer_mmi/decode.py | 33 +- egs/aishell/ASR/conformer_mmi/subsampling.py | 16 +- egs/aishell/ASR/conformer_mmi/train.py | 8 +- egs/aishell/ASR/conformer_mmi/transformer.py | 44 +- .../local/compute_fbank_aidatatang_200zh.py | 8 +- .../ASR/local/compute_fbank_aishell.py | 8 +- egs/aishell/ASR/local/prepare_char.py | 8 +- egs/aishell/ASR/local/prepare_lang.py | 4 +- egs/aishell/ASR/local/test_prepare_lang.py | 4 +- .../pruned_transducer_stateless2/decode.py | 50 +- .../pruned_transducer_stateless2/export.py | 31 +- .../pretrained.py | 50 +- .../ASR/pruned_transducer_stateless2/train.py | 64 +- .../pruned_transducer_stateless3/decode.py | 73 +- .../pruned_transducer_stateless3/export.py | 54 +- .../ASR/pruned_transducer_stateless3/model.py | 8 +- .../pretrained.py | 50 +- .../ASR/pruned_transducer_stateless3/train.py | 79 +- .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 118 +- egs/aishell/ASR/tdnn_lstm_ctc/decode.py | 33 +- egs/aishell/ASR/tdnn_lstm_ctc/model.py | 5 +- egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py | 37 +- egs/aishell/ASR/tdnn_lstm_ctc/train.py | 7 +- .../ASR/transducer_stateless/beam_search.py | 22 +- .../ASR/transducer_stateless/conformer.py | 70 +- .../ASR/transducer_stateless/decode.py | 39 +- .../ASR/transducer_stateless/decoder.py | 4 +- .../ASR/transducer_stateless/export.py | 20 +- egs/aishell/ASR/transducer_stateless/model.py | 4 +- .../ASR/transducer_stateless/pretrained.py | 36 +- egs/aishell/ASR/transducer_stateless/train.py | 15 +- .../ASR/transducer_stateless/transformer.py | 4 +- .../asr_datamodule.py | 85 +- .../transducer_stateless_modified-2/decode.py | 46 +- .../transducer_stateless_modified-2/export.py | 20 +- .../pretrained.py | 50 +- .../transducer_stateless_modified-2/train.py | 22 +- .../transducer_stateless_modified/decode.py | 46 +- .../transducer_stateless_modified/export.py | 20 +- .../pretrained.py | 50 +- .../transducer_stateless_modified/train.py | 15 +- egs/aishell2/ASR/local/__init__.py | 0 .../ASR/local/compute_fbank_aishell2.py | 8 +- .../pruned_transducer_stateless5/__init__.py | 0 .../asr_datamodule.py | 114 +- .../pruned_transducer_stateless5/decode.py | 67 +- .../pruned_transducer_stateless5/export.py | 47 +- .../pretrained.py | 40 +- .../ASR/pruned_transducer_stateless5/train.py | 67 +- .../ASR/local/compute_fbank_aishell4.py | 8 +- egs/aishell4/ASR/local/prepare_char.py | 8 +- egs/aishell4/ASR/local/prepare_lang.py | 4 +- egs/aishell4/ASR/local/test_prepare_lang.py | 4 +- egs/aishell4/ASR/local/text2token.py | 21 +- .../asr_datamodule.py | 110 +- .../pruned_transducer_stateless5/decode.py | 69 +- .../pruned_transducer_stateless5/export.py | 47 +- .../pretrained.py | 45 +- .../ASR/pruned_transducer_stateless5/train.py | 59 +- .../ASR/local/compute_fbank_alimeeting.py | 8 +- egs/alimeeting/ASR/local/prepare_char.py | 8 +- egs/alimeeting/ASR/local/prepare_lang.py | 4 +- egs/alimeeting/ASR/local/test_prepare_lang.py | 4 +- egs/alimeeting/ASR/local/text2segments.py | 2 +- egs/alimeeting/ASR/local/text2token.py | 21 +- .../asr_datamodule.py | 110 +- .../pruned_transducer_stateless2/decode.py | 60 +- .../pruned_transducer_stateless2/export.py | 20 +- .../pretrained.py | 41 +- .../ASR/pruned_transducer_stateless2/train.py | 50 +- egs/csj/ASR/.gitignore | 2 +- egs/csj/ASR/local/compute_fbank_csj.py | 38 +- egs/csj/ASR/local/compute_fbank_musan.py | 17 +- egs/csj/ASR/local/conf/disfluent.ini | 55 +- egs/csj/ASR/local/conf/fluent.ini | 55 +- egs/csj/ASR/local/conf/number.ini | 55 +- egs/csj/ASR/local/conf/symbol.ini | 55 +- .../ASR/local/display_manifest_statistics.py | 4 +- egs/csj/ASR/local/prepare_lang_char.py | 17 +- egs/csj/ASR/local/validate_manifest.py | 7 +- .../ASR/conformer_ctc/asr_datamodule.py | 117 +- egs/gigaspeech/ASR/conformer_ctc/conformer.py | 66 +- egs/gigaspeech/ASR/conformer_ctc/decode.py | 29 +- .../ASR/conformer_ctc/gigaspeech_scoring.py | 3 +- .../ASR/conformer_ctc/label_smoothing.py | 7 +- .../ASR/conformer_ctc/subsampling.py | 16 +- egs/gigaspeech/ASR/conformer_ctc/train.py | 12 +- .../ASR/conformer_ctc/transformer.py | 49 +- .../compute_fbank_gigaspeech_dev_test.py | 4 +- .../local/compute_fbank_gigaspeech_splits.py | 10 +- .../ASR/local/preprocess_gigaspeech.py | 10 +- .../asr_datamodule.py | 117 +- .../pruned_transducer_stateless2/decode.py | 42 +- .../pruned_transducer_stateless2/export.py | 24 +- .../ASR/pruned_transducer_stateless2/train.py | 48 +- egs/librispeech/ASR/conformer_ctc/ali.py | 25 +- .../ASR/conformer_ctc/conformer.py | 66 +- egs/librispeech/ASR/conformer_ctc/decode.py | 29 +- egs/librispeech/ASR/conformer_ctc/export.py | 17 +- .../ASR/conformer_ctc/label_smoothing.py | 7 +- .../ASR/conformer_ctc/pretrained.py | 33 +- .../ASR/conformer_ctc/subsampling.py | 16 +- egs/librispeech/ASR/conformer_ctc/train.py | 22 +- .../ASR/conformer_ctc/transformer.py | 49 +- .../ASR/conformer_ctc2/attention.py | 19 +- .../ASR/conformer_ctc2/conformer.py | 65 +- egs/librispeech/ASR/conformer_ctc2/decode.py | 56 +- egs/librispeech/ASR/conformer_ctc2/export.py | 49 +- egs/librispeech/ASR/conformer_ctc2/train.py | 39 +- .../ASR/conformer_ctc2/transformer.py | 50 +- .../ASR/conformer_mmi/conformer.py | 70 +- egs/librispeech/ASR/conformer_mmi/decode.py | 29 +- .../ASR/conformer_mmi/subsampling.py | 16 +- .../ASR/conformer_mmi/test_subsampling.py | 3 +- .../ASR/conformer_mmi/test_transformer.py | 9 +- .../ASR/conformer_mmi/train-with-attention.py | 27 +- egs/librispeech/ASR/conformer_mmi/train.py | 27 +- .../ASR/conformer_mmi/transformer.py | 28 +- .../decode.py | 69 +- .../emformer.py | 119 +- .../export.py | 47 +- .../stream.py | 8 +- .../streaming_decode.py | 75 +- .../train.py | 56 +- .../decode.py | 69 +- .../emformer.py | 108 +- .../export.py | 47 +- .../streaming_decode.py | 75 +- .../train.py | 56 +- .../ASR/local/add_alignment_librispeech.py | 12 +- egs/librispeech/ASR/local/compile_hlg.py | 4 +- egs/librispeech/ASR/local/compile_lg.py | 4 +- .../compute_fbank_gigaspeech_dev_test.py | 4 +- .../local/compute_fbank_gigaspeech_splits.py | 10 +- .../ASR/local/compute_fbank_librispeech.py | 8 +- .../ASR/local/compute_fbank_musan.py | 8 +- .../convert_transcript_words_to_tokens.py | 16 +- egs/librispeech/ASR/local/download_lm.py | 4 +- egs/librispeech/ASR/local/filter_cuts.py | 10 +- .../ASR/local/generate_unique_lexicon.py | 4 +- egs/librispeech/ASR/local/prepare_lang_bpe.py | 4 +- .../ASR/local/prepare_lm_training_data.py | 11 +- .../ASR/local/preprocess_gigaspeech.py | 4 +- .../ASR/local/test_prepare_lang.py | 4 +- .../ASR/local/validate_manifest.py | 7 +- .../ASR/lstm_transducer_stateless/decode.py | 818 ------------ .../ASR/lstm_transducer_stateless/export.py | 388 ------ .../jit_pretrained.py | 322 ----- .../ASR/lstm_transducer_stateless/lstm.py | 871 ------------- .../ASR/lstm_transducer_stateless/model.py | 210 --- .../lstm_transducer_stateless/pretrained.py | 352 ----- .../ASR/lstm_transducer_stateless/stream.py | 148 --- .../streaming_decode.py | 968 -------------- .../ASR/lstm_transducer_stateless/train.py | 1157 ----------------- .../ASR/lstm_transducer_stateless2/decode.py | 67 +- .../ASR/lstm_transducer_stateless2/export.py | 59 +- .../jit_pretrained.py | 21 +- .../ASR/lstm_transducer_stateless2/model.py | 8 +- .../lstm_transducer_stateless2/ncnn-decode.py | 15 +- .../lstm_transducer_stateless2/pretrained.py | 40 +- .../streaming-ncnn-decode.py | 27 +- .../streaming-onnx-decode.py | 45 +- .../ASR/lstm_transducer_stateless2/train.py | 68 +- .../ASR/lstm_transducer_stateless3/decode.py | 79 +- .../ASR/lstm_transducer_stateless3/export.py | 47 +- .../jit_pretrained.py | 21 +- .../ASR/lstm_transducer_stateless3/lstm.py | 14 +- .../lstm_transducer_stateless3/pretrained.py | 40 +- .../streaming_decode.py | 74 +- .../ASR/lstm_transducer_stateless3/train.py | 66 +- .../ASR/pruned2_knowledge/asr_datamodule.py | 125 +- .../ASR/pruned2_knowledge/beam_search.py | 18 +- .../ASR/pruned2_knowledge/conformer.py | 90 +- .../ASR/pruned2_knowledge/decode.py | 44 +- .../ASR/pruned2_knowledge/decoder.py | 4 +- .../ASR/pruned2_knowledge/decoder2.py | 84 +- .../ASR/pruned2_knowledge/export.py | 20 +- .../ASR/pruned2_knowledge/joiner.py | 4 +- .../ASR/pruned2_knowledge/model.py | 8 +- .../ASR/pruned2_knowledge/optim.py | 35 +- .../ASR/pruned2_knowledge/sampling.py | 184 +-- .../ASR/pruned2_knowledge/scaling.py | 51 +- .../ASR/pruned2_knowledge/scaling_tmp.py | 355 +++-- .../ASR/pruned2_knowledge/train.py | 50 +- .../pruned_stateless_emformer_rnnt2/decode.py | 69 +- .../emformer.py | 8 +- .../pruned_stateless_emformer_rnnt2/export.py | 47 +- .../pruned_stateless_emformer_rnnt2/model.py | 4 +- .../pruned_stateless_emformer_rnnt2/train.py | 44 +- .../beam_search.py | 26 +- .../ASR/pruned_transducer_stateless/decode.py | 44 +- .../decode_stream.py | 19 +- .../pruned_transducer_stateless/decoder.py | 4 +- .../ASR/pruned_transducer_stateless/export.py | 20 +- .../ASR/pruned_transducer_stateless/model.py | 4 +- .../pruned_transducer_stateless/pretrained.py | 36 +- .../streaming_beam_search.py | 8 +- .../streaming_decode.py | 39 +- .../ASR/pruned_transducer_stateless/train.py | 46 +- .../beam_search.py | 51 +- .../pruned_transducer_stateless2/conformer.py | 97 +- .../pruned_transducer_stateless2/decode.py | 50 +- .../pruned_transducer_stateless2/decoder.py | 8 +- .../pruned_transducer_stateless2/export.py | 24 +- .../pruned_transducer_stateless2/joiner.py | 4 +- .../ASR/pruned_transducer_stateless2/model.py | 8 +- .../ASR/pruned_transducer_stateless2/optim.py | 35 +- .../pretrained.py | 36 +- .../pruned_transducer_stateless2/scaling.py | 56 +- .../streaming_beam_search.py | 12 +- .../streaming_decode.py | 39 +- .../ASR/pruned_transducer_stateless2/train.py | 58 +- .../asr_datamodule.py | 85 +- .../decode-giga.py | 54 +- .../pruned_transducer_stateless3/decode.py | 74 +- .../pruned_transducer_stateless3/export.py | 32 +- .../gigaspeech.py | 8 +- .../jit_pretrained.py | 21 +- .../ASR/pruned_transducer_stateless3/model.py | 8 +- .../onnx_check.py | 24 +- .../onnx_pretrained.py | 27 +- .../pretrained.py | 36 +- .../scaling_converter.py | 10 +- .../streaming_decode.py | 39 +- .../pruned_transducer_stateless3/test_onnx.py | 24 +- .../ASR/pruned_transducer_stateless3/train.py | 65 +- .../pruned_transducer_stateless4/decode.py | 79 +- .../pruned_transducer_stateless4/export.py | 47 +- .../streaming_decode.py | 62 +- .../ASR/pruned_transducer_stateless4/train.py | 61 +- .../pruned_transducer_stateless5/conformer.py | 118 +- .../pruned_transducer_stateless5/decode.py | 67 +- .../pruned_transducer_stateless5/export.py | 47 +- .../pretrained.py | 40 +- .../streaming_decode.py | 62 +- .../ASR/pruned_transducer_stateless5/train.py | 66 +- .../pruned_transducer_stateless6/conformer.py | 67 +- .../pruned_transducer_stateless6/decode.py | 69 +- .../pruned_transducer_stateless6/export.py | 24 +- .../extract_codebook_index.py | 3 +- .../hubert_decode.py | 17 +- .../hubert_xlarge.py | 22 +- .../ASR/pruned_transducer_stateless6/model.py | 12 +- .../ASR/pruned_transducer_stateless6/train.py | 65 +- .../pruned_transducer_stateless6/vq_utils.py | 31 +- .../pruned_transducer_stateless7/decode.py | 67 +- .../pruned_transducer_stateless7/decoder.py | 6 +- .../pruned_transducer_stateless7/export.py | 47 +- .../jit_pretrained.py | 21 +- .../pruned_transducer_stateless7/joiner.py | 4 +- .../ASR/pruned_transducer_stateless7/model.py | 16 +- .../ASR/pruned_transducer_stateless7/optim.py | 439 ++++--- .../pretrained.py | 40 +- .../pruned_transducer_stateless7/scaling.py | 487 +++---- .../scaling_converter.py | 12 +- .../ASR/pruned_transducer_stateless7/train.py | 88 +- .../pruned_transducer_stateless7/zipformer.py | 660 +++++----- .../pruned_transducer_stateless8/decode.py | 67 +- .../pruned_transducer_stateless8/export.py | 47 +- .../jit_pretrained.py | 21 +- .../ASR/pruned_transducer_stateless8/model.py | 4 +- .../pretrained.py | 40 +- .../ASR/pruned_transducer_stateless8/train.py | 99 +- .../ASR/streaming_conformer_ctc/README.md | 16 +- .../ASR/streaming_conformer_ctc/conformer.py | 116 +- .../streaming_decode.py | 68 +- .../ASR/streaming_conformer_ctc/train.py | 16 +- .../streaming_conformer_ctc/transformer.py | 40 +- .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 113 +- egs/librispeech/ASR/tdnn_lstm_ctc/decode.py | 29 +- egs/librispeech/ASR/tdnn_lstm_ctc/model.py | 5 +- .../ASR/tdnn_lstm_ctc/pretrained.py | 43 +- egs/librispeech/ASR/tdnn_lstm_ctc/train.py | 8 +- egs/librispeech/ASR/transducer/beam_search.py | 14 +- egs/librispeech/ASR/transducer/decode.py | 28 +- egs/librispeech/ASR/transducer/export.py | 17 +- egs/librispeech/ASR/transducer/pretrained.py | 33 +- egs/librispeech/ASR/transducer/rnn.py | 24 +- egs/librispeech/ASR/transducer/test_rnn.py | 16 +- egs/librispeech/ASR/transducer/train.py | 12 +- .../ASR/transducer_lstm/beam_search.py | 14 +- egs/librispeech/ASR/transducer_lstm/decode.py | 28 +- .../ASR/transducer_lstm/encoder.py | 4 +- egs/librispeech/ASR/transducer_lstm/train.py | 12 +- .../ASR/transducer_stateless/alignment.py | 4 +- .../ASR/transducer_stateless/beam_search.py | 28 +- .../ASR/transducer_stateless/compute_ali.py | 24 +- .../ASR/transducer_stateless/conformer.py | 107 +- .../ASR/transducer_stateless/decode.py | 42 +- .../ASR/transducer_stateless/decoder.py | 4 +- .../ASR/transducer_stateless/export.py | 20 +- .../ASR/transducer_stateless/joiner.py | 8 +- .../ASR/transducer_stateless/pretrained.py | 36 +- .../transducer_stateless/test_compute_ali.py | 11 +- .../transducer_stateless/test_conformer.py | 4 +- .../ASR/transducer_stateless/train.py | 23 +- .../ASR/transducer_stateless/transformer.py | 4 +- .../ASR/transducer_stateless2/decode.py | 42 +- .../ASR/transducer_stateless2/export.py | 20 +- .../ASR/transducer_stateless2/pretrained.py | 36 +- .../ASR/transducer_stateless2/train.py | 23 +- .../decode.py | 42 +- .../export.py | 20 +- .../pretrained.py | 36 +- .../test_asr_datamodule.py | 4 +- .../train.py | 22 +- egs/ptb/LM/local/sort_lm_training_data.py | 4 +- .../LM/local/test_prepare_lm_training_data.py | 4 +- .../ASR/local/compute_fbank_musan.py | 8 +- .../ASR/local/compute_fbank_spgispeech.py | 14 +- egs/spgispeech/ASR/local/prepare_splits.py | 8 +- .../asr_datamodule.py | 100 +- .../pruned_transducer_stateless2/decode.py | 66 +- .../pruned_transducer_stateless2/export.py | 26 +- .../ASR/pruned_transducer_stateless2/train.py | 51 +- .../ASR/local/compute_fbank_tal_csasr.py | 8 +- egs/tal_csasr/ASR/local/prepare_char.py | 4 +- egs/tal_csasr/ASR/local/prepare_lang.py | 4 +- egs/tal_csasr/ASR/local/test_prepare_lang.py | 4 +- egs/tal_csasr/ASR/local/text2token.py | 21 +- .../asr_datamodule.py | 110 +- .../pruned_transducer_stateless5/decode.py | 77 +- .../pruned_transducer_stateless5/export.py | 47 +- .../pretrained.py | 40 +- .../ASR/pruned_transducer_stateless5/train.py | 59 +- .../ASR/local/compute_fbank_tedlium.py | 8 +- .../convert_transcript_words_to_bpe_ids.py | 4 +- egs/tedlium3/ASR/local/prepare_lexicon.py | 11 +- egs/tedlium3/ASR/local/prepare_transcripts.py | 11 +- .../ASR/pruned_transducer_stateless/decode.py | 38 +- .../ASR/pruned_transducer_stateless/export.py | 20 +- .../pruned_transducer_stateless/pretrained.py | 41 +- .../ASR/pruned_transducer_stateless/train.py | 35 +- .../transducer_stateless/asr_datamodule.py | 118 +- .../ASR/transducer_stateless/beam_search.py | 30 +- .../ASR/transducer_stateless/decode.py | 31 +- .../ASR/transducer_stateless/decoder.py | 4 +- .../ASR/transducer_stateless/export.py | 20 +- .../ASR/transducer_stateless/pretrained.py | 36 +- .../ASR/transducer_stateless/train.py | 11 +- egs/timit/ASR/RESULTS.md | 2 +- egs/timit/ASR/local/compile_hlg.py | 4 +- egs/timit/ASR/local/compute_fbank_timit.py | 8 +- egs/timit/ASR/local/prepare_lexicon.py | 8 +- egs/timit/ASR/prepare.sh | 4 +- egs/timit/ASR/tdnn_ligru_ctc/decode.py | 29 +- egs/timit/ASR/tdnn_ligru_ctc/model.py | 12 +- egs/timit/ASR/tdnn_ligru_ctc/pretrained.py | 43 +- egs/timit/ASR/tdnn_ligru_ctc/train.py | 4 +- egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py | 104 +- egs/timit/ASR/tdnn_lstm_ctc/decode.py | 29 +- egs/timit/ASR/tdnn_lstm_ctc/model.py | 5 +- egs/timit/ASR/tdnn_lstm_ctc/pretrained.py | 43 +- egs/timit/ASR/tdnn_lstm_ctc/train.py | 4 +- .../compute_fbank_wenetspeech_dev_test.py | 11 +- .../local/compute_fbank_wenetspeech_splits.py | 10 +- egs/wenetspeech/ASR/local/prepare_char.py | 8 +- .../ASR/local/preprocess_wenetspeech.py | 6 +- egs/wenetspeech/ASR/local/text2token.py | 21 +- egs/wenetspeech/ASR/prepare.sh | 2 +- .../asr_datamodule.py | 121 +- .../pruned_transducer_stateless2/decode.py | 64 +- .../pruned_transducer_stateless2/export.py | 28 +- .../jit_pretrained.py | 21 +- .../onnx_check.py | 24 +- .../onnx_pretrained.py | 27 +- .../pretrained.py | 41 +- .../ASR/pruned_transducer_stateless2/train.py | 50 +- .../pruned_transducer_stateless5/conformer.py | 97 +- .../pruned_transducer_stateless5/decode.py | 75 +- .../decode_stream.py | 19 +- .../pruned_transducer_stateless5/export.py | 20 +- .../pretrained.py | 41 +- .../streaming_beam_search.py | 8 +- .../streaming_decode.py | 62 +- .../ASR/pruned_transducer_stateless5/train.py | 67 +- egs/yesno/ASR/local/compile_hlg.py | 4 +- egs/yesno/ASR/local/compute_fbank_yesno.py | 12 +- egs/yesno/ASR/tdnn/asr_datamodule.py | 74 +- egs/yesno/ASR/tdnn/decode.py | 29 +- egs/yesno/ASR/tdnn/pretrained.py | 37 +- egs/yesno/ASR/tdnn/train.py | 4 +- egs/yesno/ASR/transducer/decode.py | 25 +- egs/yesno/ASR/transducer/train.py | 4 +- icefall/char_graph_compiler.py | 8 +- icefall/checkpoint.py | 12 +- icefall/decode.py | 36 +- icefall/diagnostics.py | 80 +- icefall/dist.py | 4 +- icefall/env.py | 4 +- icefall/graph_compiler.py | 4 +- icefall/hooks.py | 19 +- icefall/lexicon.py | 16 +- icefall/mmi.py | 29 +- icefall/mmi_graph_compiler.py | 8 +- icefall/rnn_lm/compute_perplexity.py | 15 +- icefall/rnn_lm/dataset.py | 8 +- icefall/rnn_lm/export.py | 17 +- icefall/rnn_lm/model.py | 28 +- icefall/rnn_lm/train.py | 11 +- icefall/shared/make_kn_lm.py | 184 ++- icefall/utils.py | 64 +- pyproject.toml | 2 +- setup.py | 3 +- test/test_checkpoint.py | 6 +- test/test_decode.py | 1 + test/test_graph_compiler.py | 4 +- test/test_utils.py | 4 +- 440 files changed, 6789 insertions(+), 14532 deletions(-) mode change 100755 => 100644 egs/aishell2/ASR/local/__init__.py mode change 100755 => 100644 egs/aishell2/ASR/pruned_transducer_stateless5/__init__.py mode change 100755 => 100644 egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py mode change 100755 => 100644 egs/librispeech/ASR/lstm_transducer_stateless/decode.py mode change 100755 => 100644 egs/librispeech/ASR/lstm_transducer_stateless/export.py mode change 100755 => 100644 egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py mode change 100755 => 100644 egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py mode change 100755 => 100644 egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py mode change 100755 => 100644 egs/librispeech/ASR/lstm_transducer_stateless/train.py diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml index 90459bc1c..45d261ccc 100644 --- a/.github/workflows/style_check.yml +++ b/.github/workflows/style_check.yml @@ -45,17 +45,18 @@ jobs: - name: Install Python dependencies run: | - python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 click==8.0.4 - # See https://github.com/psf/black/issues/2964 - # The version of click should be selected from 8.0.0, 8.0.1, 8.0.2, 8.0.3, and 8.0.4 + python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0 + # Click issue fixed in https://github.com/psf/black/pull/2966 - name: Run flake8 shell: bash working-directory: ${{github.workspace}} run: | # stop the build if there are Python syntax errors or undefined names - flake8 . --count --show-source --statistics - flake8 . + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 \ + --statistics --extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503 - name: Run black shell: bash diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 446ba0fe7..e2055801b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,26 +1,38 @@ repos: - repo: https://github.com/psf/black - rev: 21.6b0 + rev: 22.3.0 hooks: - id: black - args: [--line-length=80] + args: ["--line-length=88"] additional_dependencies: ['click==8.0.1'] exclude: icefall\/__init__\.py - repo: https://github.com/PyCQA/flake8 - rev: 3.9.2 + rev: 5.0.4 hooks: - id: flake8 - args: [--max-line-length=80] + args: ["--max-line-length=88", "--extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503"] + + # What are we ignoring here? + # E203: whitespace before ':' + # E266: too many leading '#' for block comment + # E501: line too long + # F401: module imported but unused + # E402: module level import not at top of file + # F403: 'from module import *' used; unable to detect undefined names + # F841: local variable is assigned to but never used + # W503: line break before binary operator + # In addition, the default ignore list is: + # E121,E123,E126,E226,E24,E704,W503,W504 - repo: https://github.com/pycqa/isort - rev: 5.9.2 + rev: 5.10.1 hooks: - id: isort - args: [--profile=black, --line-length=80] + args: ["--profile=black"] - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.0.1 + rev: v4.2.0 hooks: - id: check-executables-have-shebangs - id: end-of-file-fixer diff --git a/docker/README.md b/docker/README.md index 6f2314e96..c14b9bf75 100644 --- a/docker/README.md +++ b/docker/README.md @@ -2,7 +2,7 @@ 2 sets of configuration are provided - (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8, and (b) Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8. -If your NVIDIA driver supports CUDA Version: 11.3, please go for case (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8. +If your NVIDIA driver supports CUDA Version: 11.3, please go for case (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8. Otherwise, since the older PyTorch images are not updated with the [apt-key rotation by NVIDIA](https://developer.nvidia.com/blog/updating-the-cuda-linux-gpg-repository-key), you have to go for case (b) Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8. Ensure that your NVDIA driver supports at least CUDA 11.0. @@ -10,7 +10,7 @@ You can check the highest CUDA version within your NVIDIA driver's support with ```bash $ nvidia-smi -Tue Sep 20 00:26:13 2022 +Tue Sep 20 00:26:13 2022 +-----------------------------------------------------------------------------+ | NVIDIA-SMI 450.119.03 Driver Version: 450.119.03 CUDA Version: 11.0 | |-------------------------------+----------------------+----------------------+ @@ -26,7 +26,7 @@ Tue Sep 20 00:26:13 2022 | 41% 30C P8 11W / 280W | 6MiB / 24220MiB | 0% Default | | | | N/A | +-------------------------------+----------------------+----------------------+ - + +-----------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | @@ -40,15 +40,15 @@ Tue Sep 20 00:26:13 2022 ``` ## Building images locally -If your environment requires a proxy to access the Internet, remember to add those information into the Dockerfile directly. -For most cases, you can uncomment these lines in the Dockerfile and add in your proxy details. +If your environment requires a proxy to access the Internet, remember to add those information into the Dockerfile directly. +For most cases, you can uncomment these lines in the Dockerfile and add in your proxy details. ```dockerfile ENV http_proxy=http://aaa.bb.cc.net:8080 \ https_proxy=http://aaa.bb.cc.net:8080 ``` -Then, proceed with these commands. +Then, proceed with these commands. ### If you are case (a), i.e. your NVIDIA driver supports CUDA version >= 11.3: @@ -72,11 +72,11 @@ docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all icefall ``` ### Tips: -1. Since your data and models most probably won't be in the docker, you must use the -v flag to access the host machine. Do this by specifying `-v {/path/in/host/machine}:{/path/in/docker}`. +1. Since your data and models most probably won't be in the docker, you must use the -v flag to access the host machine. Do this by specifying `-v {/path/in/host/machine}:{/path/in/docker}`. 2. Also, if your environment requires a proxy, this would be a good time to add it in too: `-e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080`. -Overall, your docker run command should look like this. +Overall, your docker run command should look like this. ```bash docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all -v {/path/in/host/machine}:{/path/in/docker} -e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080 icefall/pytorch1.12.1 @@ -86,9 +86,9 @@ You can explore more docker run options [here](https://docs.docker.com/engine/re ### Linking to icefall in your host machine -If you already have icefall downloaded onto your host machine, you can use that repository instead so that changes in your code are visible inside and outside of the container. +If you already have icefall downloaded onto your host machine, you can use that repository instead so that changes in your code are visible inside and outside of the container. -Note: Remember to set the -v flag above during the first run of the container, as that is the only way for your container to access your host machine. +Note: Remember to set the -v flag above during the first run of the container, as that is the only way for your container to access your host machine. Warning: Check that the icefall in your host machine is visible from within your container before proceeding to the commands below. Use these commands once you are inside the container. @@ -103,7 +103,7 @@ ln -s {/path/in/docker/to/icefall} /workspace/icefall docker exec -it icefall /bin/bash ``` -## Restarting a killed container that has been run before. +## Restarting a killed container that has been run before. ```bash docker start -ai icefall ``` @@ -111,4 +111,4 @@ docker start -ai icefall ## Sample usage of the CPU based images: ```bash docker run -it icefall /bin/bash -``` +``` diff --git a/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile b/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile index 3637d2f11..ff9e40604 100644 --- a/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile +++ b/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile @@ -1,7 +1,7 @@ FROM pytorch/pytorch:1.12.1-cuda11.3-cudnn8-devel # ENV http_proxy=http://aaa.bbb.cc.net:8080 \ -# https_proxy=http://aaa.bbb.cc.net:8080 +# https_proxy=http://aaa.bbb.cc.net:8080 # install normal source RUN apt-get update && \ @@ -38,10 +38,10 @@ RUN wget -P /opt https://cmake.org/files/v3.18/cmake-3.18.0.tar.gz && \ rm -rf cmake-3.18.0.tar.gz && \ find /opt/cmake-3.18.0 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \ cd - - -# flac + +# flac RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz && \ - cd /opt && \ + cd /opt && \ xz -d flac-1.3.2.tar.xz && \ tar -xvf flac-1.3.2.tar && \ cd flac-1.3.2 && \ @@ -49,11 +49,11 @@ RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz && make && make install && \ rm -rf flac-1.3.2.tar && \ find /opt/flac-1.3.2 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \ - cd - + cd - RUN conda install -y -c pytorch torchaudio=0.12 && \ pip install graphviz - + #install k2 from source RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \ @@ -68,7 +68,7 @@ RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ cd /workspace/icefall && \ pip install -r requirements.txt -RUN pip install kaldifeat +RUN pip install kaldifeat ENV PYTHONPATH /workspace/icefall:$PYTHONPATH WORKDIR /workspace/icefall diff --git a/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile b/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile index 17a8215f9..5c7423fa5 100644 --- a/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile +++ b/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile @@ -1,12 +1,12 @@ FROM pytorch/pytorch:1.7.1-cuda11.0-cudnn8-devel # ENV http_proxy=http://aaa.bbb.cc.net:8080 \ -# https_proxy=http://aaa.bbb.cc.net:8080 +# https_proxy=http://aaa.bbb.cc.net:8080 RUN rm /etc/apt/sources.list.d/cuda.list && \ rm /etc/apt/sources.list.d/nvidia-ml.list && \ apt-key del 7fa2af80 - + # install normal source RUN apt-get update && \ apt-get install -y --no-install-recommends \ @@ -36,7 +36,7 @@ RUN curl -fsSL https://developer.download.nvidia.com/compute/cuda/repos/ubuntu18 curl -fsSL https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub | apt-key add - && \ echo "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/cuda.list && \ echo "deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/nvidia-ml.list && \ - rm -rf /var/lib/apt/lists/* && \ + rm -rf /var/lib/apt/lists/* && \ mv /opt/conda/lib/libcufft.so.10 /opt/libcufft.so.10.bak && \ mv /opt/conda/lib/libcurand.so.10 /opt/libcurand.so.10.bak && \ mv /opt/conda/lib/libcublas.so.11 /opt/libcublas.so.11.bak && \ @@ -56,10 +56,10 @@ RUN wget -P /opt https://cmake.org/files/v3.18/cmake-3.18.0.tar.gz && \ rm -rf cmake-3.18.0.tar.gz && \ find /opt/cmake-3.18.0 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \ cd - - -# flac + +# flac RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz && \ - cd /opt && \ + cd /opt && \ xz -d flac-1.3.2.tar.xz && \ tar -xvf flac-1.3.2.tar && \ cd flac-1.3.2 && \ @@ -67,7 +67,7 @@ RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz && make && make install && \ rm -rf flac-1.3.2.tar && \ find /opt/flac-1.3.2 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \ - cd - + cd - RUN conda install -y -c pytorch torchaudio=0.7.1 && \ pip install graphviz @@ -79,7 +79,7 @@ RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \ cd - # install lhotse -RUN pip install git+https://github.com/lhotse-speech/lhotse +RUN pip install git+https://github.com/lhotse-speech/lhotse RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ cd /workspace/icefall && \ @@ -88,4 +88,3 @@ RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ ENV PYTHONPATH /workspace/icefall:$PYTHONPATH WORKDIR /workspace/icefall - diff --git a/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg b/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg index 534b2e534..3019ff03d 100644 --- a/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg +++ b/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg @@ -1 +1 @@ -k2: >= v1.9k2>= v1.9 \ No newline at end of file +k2: >= v1.9k2>= v1.9 diff --git a/docs/source/installation/images/python-gt-v3.6-blue.svg b/docs/source/installation/images/python-gt-v3.6-blue.svg index 4254dc58a..df677ad09 100644 --- a/docs/source/installation/images/python-gt-v3.6-blue.svg +++ b/docs/source/installation/images/python-gt-v3.6-blue.svg @@ -1 +1 @@ -python: >= 3.6python>= 3.6 \ No newline at end of file +python: >= 3.6python>= 3.6 diff --git a/docs/source/installation/images/torch-gt-v1.6.0-green.svg b/docs/source/installation/images/torch-gt-v1.6.0-green.svg index d3ece9a17..d7007d742 100644 --- a/docs/source/installation/images/torch-gt-v1.6.0-green.svg +++ b/docs/source/installation/images/torch-gt-v1.6.0-green.svg @@ -1 +1 @@ -torch: >= 1.6.0torch>= 1.6.0 \ No newline at end of file +torch: >= 1.6.0torch>= 1.6.0 diff --git a/docs/source/recipes/aishell/index.rst b/docs/source/recipes/aishell/index.rst index d072d6e9c..b77d59bca 100644 --- a/docs/source/recipes/aishell/index.rst +++ b/docs/source/recipes/aishell/index.rst @@ -19,4 +19,3 @@ It can be downloaded from ``_ tdnn_lstm_ctc conformer_ctc stateless_transducer - diff --git a/docs/source/recipes/timit/index.rst b/docs/source/recipes/timit/index.rst index 17f40cdb7..5ee147be7 100644 --- a/docs/source/recipes/timit/index.rst +++ b/docs/source/recipes/timit/index.rst @@ -6,4 +6,3 @@ TIMIT tdnn_ligru_ctc tdnn_lstm_ctc - diff --git a/docs/source/recipes/timit/tdnn_ligru_ctc.rst b/docs/source/recipes/timit/tdnn_ligru_ctc.rst index 186420ee7..3d7aefe02 100644 --- a/docs/source/recipes/timit/tdnn_ligru_ctc.rst +++ b/docs/source/recipes/timit/tdnn_ligru_ctc.rst @@ -148,10 +148,10 @@ Some commonly used options are: $ ./tdnn_ligru_ctc/decode.py --epoch 25 --avg 17 - uses the average of ``epoch-9.pt``, ``epoch-10.pt``, ``epoch-11.pt``, - ``epoch-12.pt``, ``epoch-13.pt``, ``epoch-14.pt``, ``epoch-15.pt``, - ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, ``epoch-19.pt``, - ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, ``epoch-23.pt``, + uses the average of ``epoch-9.pt``, ``epoch-10.pt``, ``epoch-11.pt``, + ``epoch-12.pt``, ``epoch-13.pt``, ``epoch-14.pt``, ``epoch-15.pt``, + ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, ``epoch-19.pt``, + ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, ``epoch-23.pt``, ``epoch-24.pt`` and ``epoch-25.pt`` for decoding. @@ -317,13 +317,13 @@ To decode with ``1best`` method, we can use: .. code-block:: bash - ./tdnn_ligru_ctc/pretrained.py + ./tdnn_ligru_ctc/pretrained.py --method 1best - --checkpoint ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/exp/pretrained_average_9_25.pt - --words-file ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/words.txt - --HLG ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt - ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV - ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV + --checkpoint ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/exp/pretrained_average_9_25.pt + --words-file ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/words.txt + --HLG ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt + ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV + ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV The output is: @@ -337,7 +337,7 @@ The output is: 2021-11-08 20:41:38,697 INFO [pretrained.py:210] Reading sound files: ['./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV'] 2021-11-08 20:41:38,704 INFO [pretrained.py:216] Decoding started 2021-11-08 20:41:39,819 INFO [pretrained.py:246] Use HLG decoding - 2021-11-08 20:41:39,829 INFO [pretrained.py:267] + 2021-11-08 20:41:39,829 INFO [pretrained.py:267] ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV: sil dh ih sh uw ah l iy v iy z ih sil p r aa sil k s ih m ey dx ih sil d w uh dx ih w ih s f iy l ih ng w ih th ih n ih m s eh l f sil jh @@ -362,8 +362,8 @@ To decode with ``whole-lattice-rescoring`` methond, you can use --HLG ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt \ --G ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lm/G_4_gram.pt \ --ngram-lm-scale 0.1 \ - ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV - ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV + ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV + ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV The decoding output is: @@ -378,7 +378,7 @@ The decoding output is: 2021-11-08 20:37:54,715 INFO [pretrained.py:210] Reading sound files: ['./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV'] 2021-11-08 20:37:54,720 INFO [pretrained.py:216] Decoding started 2021-11-08 20:37:55,808 INFO [pretrained.py:251] Use HLG decoding + LM rescoring - 2021-11-08 20:37:56,348 INFO [pretrained.py:267] + 2021-11-08 20:37:56,348 INFO [pretrained.py:267] ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV: sil dh ih sh uw ah l iy v iy z ah sil p r aa sil k s ih m ey dx ih sil d w uh dx iy w ih s f iy l iy ng w ih th ih n ih m s eh l f sil jh diff --git a/docs/source/recipes/timit/tdnn_lstm_ctc.rst b/docs/source/recipes/timit/tdnn_lstm_ctc.rst index 6f760a9ce..ee67a6edc 100644 --- a/docs/source/recipes/timit/tdnn_lstm_ctc.rst +++ b/docs/source/recipes/timit/tdnn_lstm_ctc.rst @@ -148,8 +148,8 @@ Some commonly used options are: $ ./tdnn_lstm_ctc/decode.py --epoch 25 --avg 10 - uses the average of ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, - ``epoch-19.pt``, ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, + uses the average of ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, + ``epoch-19.pt``, ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, ``epoch-23.pt``, ``epoch-24.pt`` and ``epoch-25.pt`` for decoding. @@ -315,13 +315,13 @@ To decode with ``1best`` method, we can use: .. code-block:: bash - ./tdnn_lstm_ctc/pretrained.py + ./tdnn_lstm_ctc/pretrained.py --method 1best - --checkpoint ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/exp/pretrained_average_16_25.pt - --words-file ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/words.txt - --HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt - ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV - ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV + --checkpoint ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/exp/pretrained_average_16_25.pt + --words-file ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/words.txt + --HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt + ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV + ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV The output is: @@ -335,7 +335,7 @@ The output is: 2021-11-08 21:02:53,827 INFO [pretrained.py:210] Reading sound files: ['./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV'] 2021-11-08 21:02:53,831 INFO [pretrained.py:216] Decoding started 2021-11-08 21:02:54,380 INFO [pretrained.py:246] Use HLG decoding - 2021-11-08 21:02:54,387 INFO [pretrained.py:267] + 2021-11-08 21:02:54,387 INFO [pretrained.py:267] ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV: sil dh ih sh uw ah l iy v iy z ih sil p r aa sil k s ih m ey dx ih sil d w uh dx iy w ih s f iy l iy w ih th ih n ih m s eh l f sil jh @@ -360,8 +360,8 @@ To decode with ``whole-lattice-rescoring`` methond, you can use --HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt \ --G ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lm/G_4_gram.pt \ --ngram-lm-scale 0.08 \ - ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV - ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV + ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV + ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV The decoding output is: @@ -376,7 +376,7 @@ The decoding output is: 2021-11-08 20:05:26,978 INFO [pretrained.py:210] Reading sound files: ['./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV'] 2021-11-08 20:05:26,981 INFO [pretrained.py:216] Decoding started 2021-11-08 20:05:27,519 INFO [pretrained.py:251] Use HLG decoding + LM rescoring - 2021-11-08 20:05:27,878 INFO [pretrained.py:267] + 2021-11-08 20:05:27,878 INFO [pretrained.py:267] ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV: sil dh ih sh uw l iy v iy z ih sil p r aa sil k s ah m ey dx ih sil w uh dx iy w ih s f iy l ih ng w ih th ih n ih m s eh l f sil jh diff --git a/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py b/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py index fb2751c0f..387c14acf 100755 --- a/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py +++ b/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py @@ -87,9 +87,7 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80): ) if "train" in partition: cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -116,9 +114,7 @@ def get_args(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/aidatatang_200zh/ASR/local/prepare_char.py b/egs/aidatatang_200zh/ASR/local/prepare_char.py index d9e47d17a..6b440dfb3 100755 --- a/egs/aidatatang_200zh/ASR/local/prepare_char.py +++ b/egs/aidatatang_200zh/ASR/local/prepare_char.py @@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil( cur_state = loop_state word = word2id[word] - pieces = [ - token2id[i] if i in token2id else token2id[""] for i in pieces - ] + pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces] for i in range(len(pieces) - 1): w = word if i == 0 else eps @@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool: return False -def generate_lexicon( - token_sym_table: Dict[str, int], words: List[str] -) -> Lexicon: +def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon: """Generate a lexicon from a word list and token_sym_table. Args: diff --git a/egs/aidatatang_200zh/ASR/local/prepare_lang.py b/egs/aidatatang_200zh/ASR/local/prepare_lang.py index e5ae89ec4..c8cf9b881 100755 --- a/egs/aidatatang_200zh/ASR/local/prepare_lang.py +++ b/egs/aidatatang_200zh/ASR/local/prepare_lang.py @@ -317,9 +317,7 @@ def lexicon_to_fst( def get_args(): parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", type=str, help="The lang dir, data/lang_phone" - ) + parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone") return parser.parse_args() diff --git a/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py b/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py index d4cf62bba..74e025ad7 100755 --- a/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py +++ b/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py @@ -88,9 +88,7 @@ def test_read_lexicon(filename: str): fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.draw("L.pdf", title="L") - fsa_disambig = lexicon_to_fst( - lexicon_disambig, phone2id=phone2id, word2id=word2id - ) + fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.draw("L_disambig.pdf", title="L_disambig") diff --git a/egs/aidatatang_200zh/ASR/local/text2token.py b/egs/aidatatang_200zh/ASR/local/text2token.py index 71be2a613..2be639b7a 100755 --- a/egs/aidatatang_200zh/ASR/local/text2token.py +++ b/egs/aidatatang_200zh/ASR/local/text2token.py @@ -50,15 +50,15 @@ def get_parser(): "-n", default=1, type=int, - help="number of characters to split, i.e., \ - aabb -> a a b b with -n 1 and aa bb with -n 2", + help=( + "number of characters to split, i.e., aabb -> a a b" + " b with -n 1 and aa bb with -n 2" + ), ) parser.add_argument( "--skip-ncols", "-s", default=0, type=int, help="skip first n columns" ) - parser.add_argument( - "--space", default="", type=str, help="space symbol" - ) + parser.add_argument("--space", default="", type=str, help="space symbol") parser.add_argument( "--non-lang-syms", "-l", @@ -66,9 +66,7 @@ def get_parser(): type=str, help="list of non-linguistic symobles, e.g., etc.", ) - parser.add_argument( - "text", type=str, default=False, nargs="?", help="input text" - ) + parser.add_argument("text", type=str, default=False, nargs="?", help="input text") parser.add_argument( "--trans_type", "-t", @@ -108,8 +106,7 @@ def token2id( if token_type == "lazy_pinyin": text = lazy_pinyin(chars_list) sub_ids = [ - token_table[txt] if txt in token_table else oov_id - for txt in text + token_table[txt] if txt in token_table else oov_id for txt in text ] ids.append(sub_ids) else: # token_type = "pinyin" @@ -135,9 +132,7 @@ def main(): if args.text: f = codecs.open(args.text, encoding="utf-8") else: - f = codecs.getreader("utf-8")( - sys.stdin if is_python2 else sys.stdin.buffer - ) + f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer) sys.stdout = codecs.getwriter("utf-8")( sys.stdout if is_python2 else sys.stdout.buffer diff --git a/egs/aidatatang_200zh/ASR/prepare.sh b/egs/aidatatang_200zh/ASR/prepare.sh index 039951354..4749e1b7f 100755 --- a/egs/aidatatang_200zh/ASR/prepare.sh +++ b/egs/aidatatang_200zh/ASR/prepare.sh @@ -106,11 +106,10 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then if [ ! -f $lang_char_dir/words.txt ]; then ./local/prepare_words.py \ --input-file $lang_char_dir/words_no_ids.txt \ - --output-file $lang_char_dir/words.txt + --output-file $lang_char_dir/words.txt fi if [ ! -f $lang_char_dir/L_disambig.pt ]; then ./local/prepare_char.py fi fi - diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py index 6a5b57e24..8c94f5bea 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -81,10 +81,12 @@ class Aidatatang_200zhAsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", + description=( + "These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc." + ), ) group.add_argument( "--manifest-dir", @@ -96,75 +98,91 @@ class Aidatatang_200zhAsrDataModule: "--max-duration", type=int, default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", + help=( + "Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM." + ), ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", + help=( + "When enabled, the batches will come from buckets of " + "similar duration (saves padding frames)." + ), ) group.add_argument( "--num-buckets", type=int, default=300, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", + help=( + "The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets)." + ), ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", + help=( + "When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding." + ), ) group.add_argument( "--duration-factor", type=float, default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", + help=( + "Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch." + ), ) group.add_argument( "--gap", type=float, default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", + help=( + "The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used." + ), ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", + help=( + "When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available." + ), ) group.add_argument( "--shuffle", type=str2bool, default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", + help=( + "When enabled (=default), the examples will be shuffled for each epoch." + ), ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", + help=( + "When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it." + ), ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that " - "collect the batches.", + help="The number of training dataloader workers that collect the batches.", ) group.add_argument( @@ -178,18 +196,22 @@ class Aidatatang_200zhAsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", + help=( + "Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp." + ), ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", + help=( + "When enabled, select noise from MUSAN and mix it" + "with training dataset. " + ), ) def train_dataloaders( @@ -205,24 +227,20 @@ class Aidatatang_200zhAsrDataModule: The state dict for the training sampler. """ logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - f"Using cut concatenation with duration factor " + "Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -237,9 +255,7 @@ class Aidatatang_200zhAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -282,9 +298,7 @@ class Aidatatang_200zhAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -340,9 +354,7 @@ class Aidatatang_200zhAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py index f0407f429..3f582ef04 100755 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py @@ -69,11 +69,7 @@ from beam_search import ( ) from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -92,25 +88,30 @@ def get_parser(): "--epoch", type=int, default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--batch", type=int, default=None, - help="It specifies the batch checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the batch checkpoint to use for decoding." + "Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -192,8 +193,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -249,9 +249,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -266,10 +264,7 @@ def decode_one_batch( ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -315,11 +310,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -390,9 +381,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -425,8 +414,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py index 00b54c39f..34f4d3ddf 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py @@ -62,17 +62,20 @@ def get_parser(): "--epoch", type=int, default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -103,8 +106,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) return parser @@ -173,9 +175,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py index eb5e6b0d4..3c96ed07b 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py @@ -85,9 +85,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -112,10 +114,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -162,8 +166,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -193,10 +196,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -257,9 +259,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -284,10 +284,7 @@ def main(): ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -339,9 +336,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py index d46838b68..c7b1a4266 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py @@ -81,9 +81,7 @@ from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] os.environ["CUDA_LAUNCH_BLOCKING"] = "1" @@ -187,42 +185,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -542,22 +543,15 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -711,9 +705,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -813,7 +805,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/aishell/ASR/conformer_ctc/conformer.py b/egs/aishell/ASR/conformer_ctc/conformer.py index cb7205e51..f5b5873b4 100644 --- a/egs/aishell/ASR/conformer_ctc/conformer.py +++ b/egs/aishell/ASR/conformer_ctc/conformer.py @@ -157,9 +157,7 @@ class ConformerEncoderLayer(nn.Module): normalize_before: bool = True, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -177,18 +175,14 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -222,9 +216,7 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) + src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -343,9 +335,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -361,9 +351,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -633,9 +621,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -703,33 +691,25 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor" + " instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -766,9 +746,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -780,9 +758,7 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -816,13 +792,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -845,9 +817,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/aishell/ASR/conformer_ctc/decode.py b/egs/aishell/ASR/conformer_ctc/decode.py index 751b7d5b5..a30fa52df 100755 --- a/egs/aishell/ASR/conformer_ctc/decode.py +++ b/egs/aishell/ASR/conformer_ctc/decode.py @@ -58,16 +58,19 @@ def get_parser(): "--epoch", type=int, default=49, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=20, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -401,9 +404,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -431,9 +432,7 @@ def save_results( # we compute CER for aishell dataset. results_char = [] for res in results: - results_char.append( - (res[0], list("".join(res[1])), list("".join(res[2]))) - ) + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=enable_log @@ -441,9 +440,7 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt" @@ -562,9 +559,7 @@ def main(): eos_id=eos_id, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/aishell/ASR/conformer_ctc/export.py b/egs/aishell/ASR/conformer_ctc/export.py index 42b8c29e7..9ee405e8b 100644 --- a/egs/aishell/ASR/conformer_ctc/export.py +++ b/egs/aishell/ASR/conformer_ctc/export.py @@ -40,17 +40,20 @@ def get_parser(): "--epoch", type=int, default=84, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=25, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -157,9 +160,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/conformer_ctc/pretrained.py b/egs/aishell/ASR/conformer_ctc/pretrained.py index 27776bc24..e3d5a20e3 100755 --- a/egs/aishell/ASR/conformer_ctc/pretrained.py +++ b/egs/aishell/ASR/conformer_ctc/pretrained.py @@ -46,27 +46,29 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( "--tokens-file", type=str, - help="Path to tokens.txt" "Used only when method is ctc-decoding", + help="Path to tokens.txtUsed only when method is ctc-decoding", ) parser.add_argument( "--words-file", type=str, - help="Path to words.txt" "Used when method is NOT ctc-decoding", + help="Path to words.txtUsed when method is NOT ctc-decoding", ) parser.add_argument( "--HLG", type=str, - help="Path to HLG.pt." "Used when method is NOT ctc-decoding", + help="Path to HLG.pt.Used when method is NOT ctc-decoding", ) parser.add_argument( @@ -163,10 +165,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) return parser @@ -210,10 +214,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -274,9 +277,7 @@ def main(): logging.info("Decoding started") features = fbank(waves) - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) # Note: We don't use key padding mask for attention during decoding with torch.no_grad(): @@ -371,9 +372,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/conformer_ctc/subsampling.py b/egs/aishell/ASR/conformer_ctc/subsampling.py index 542fb0364..8e0f73d05 100644 --- a/egs/aishell/ASR/conformer_ctc/subsampling.py +++ b/egs/aishell/ASR/conformer_ctc/subsampling.py @@ -42,13 +42,9 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), - nn.Conv2d( - in_channels=odim, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -132,17 +128,13 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d( - kernel_size=2, stride=2, padding=0, ceil_mode=True - ) + torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear( - block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim - ) + self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/aishell/ASR/conformer_ctc/test_subsampling.py b/egs/aishell/ASR/conformer_ctc/test_subsampling.py index e3361d0c9..81fa234dd 100755 --- a/egs/aishell/ASR/conformer_ctc/test_subsampling.py +++ b/egs/aishell/ASR/conformer_ctc/test_subsampling.py @@ -16,9 +16,8 @@ # limitations under the License. -from subsampling import Conv2dSubsampling -from subsampling import VggSubsampling import torch +from subsampling import Conv2dSubsampling, VggSubsampling def test_conv2d_subsampling(): diff --git a/egs/aishell/ASR/conformer_ctc/train.py b/egs/aishell/ASR/conformer_ctc/train.py index a228cc1fe..c2cbe6e3b 100755 --- a/egs/aishell/ASR/conformer_ctc/train.py +++ b/egs/aishell/ASR/conformer_ctc/train.py @@ -382,9 +382,7 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids( - supervisions["text"] - ) + unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -520,9 +518,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -630,9 +626,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/aishell/ASR/conformer_ctc/transformer.py b/egs/aishell/ASR/conformer_ctc/transformer.py index f93914aaa..a3e50e385 100644 --- a/egs/aishell/ASR/conformer_ctc/transformer.py +++ b/egs/aishell/ASR/conformer_ctc/transformer.py @@ -149,9 +149,7 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear( - d_model, self.decoder_num_class - ) + self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) self.decoder_criterion = LabelSmoothingLoss() else: @@ -183,9 +181,7 @@ class Transformer(nn.Module): x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) - encoder_memory, memory_key_padding_mask = self.run_encoder( - x, supervision - ) + encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -266,23 +262,17 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -343,23 +333,17 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -632,9 +616,7 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class PositionalEncoding(nn.Module): @@ -836,9 +818,7 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -859,9 +839,7 @@ def encoder_padding_mask( return mask -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/aishell/ASR/conformer_mmi/conformer.py b/egs/aishell/ASR/conformer_mmi/conformer.py index cb7205e51..f5b5873b4 100644 --- a/egs/aishell/ASR/conformer_mmi/conformer.py +++ b/egs/aishell/ASR/conformer_mmi/conformer.py @@ -157,9 +157,7 @@ class ConformerEncoderLayer(nn.Module): normalize_before: bool = True, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -177,18 +175,14 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -222,9 +216,7 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) + src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -343,9 +335,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -361,9 +351,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -633,9 +621,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -703,33 +691,25 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor" + " instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -766,9 +746,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -780,9 +758,7 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -816,13 +792,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -845,9 +817,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/aishell/ASR/conformer_mmi/decode.py b/egs/aishell/ASR/conformer_mmi/decode.py index 4db367e36..a43183063 100755 --- a/egs/aishell/ASR/conformer_mmi/decode.py +++ b/egs/aishell/ASR/conformer_mmi/decode.py @@ -59,16 +59,19 @@ def get_parser(): "--epoch", type=int, default=49, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=20, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -413,9 +416,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -443,9 +444,7 @@ def save_results( # we compute CER for aishell dataset. results_char = [] for res in results: - results_char.append( - (res[0], list("".join(res[1])), list("".join(res[2]))) - ) + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=enable_log @@ -453,9 +452,7 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt" @@ -550,9 +547,7 @@ def main(): if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save( - {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" - ) + torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") return model.to(device) @@ -581,9 +576,7 @@ def main(): eos_id=eos_id, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/aishell/ASR/conformer_mmi/subsampling.py b/egs/aishell/ASR/conformer_mmi/subsampling.py index 720ed6c22..398837a46 100644 --- a/egs/aishell/ASR/conformer_mmi/subsampling.py +++ b/egs/aishell/ASR/conformer_mmi/subsampling.py @@ -42,13 +42,9 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), - nn.Conv2d( - in_channels=odim, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -132,17 +128,13 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d( - kernel_size=2, stride=2, padding=0, ceil_mode=True - ) + torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear( - block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim - ) + self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/aishell/ASR/conformer_mmi/train.py b/egs/aishell/ASR/conformer_mmi/train.py index 685831d09..09cd6e60c 100755 --- a/egs/aishell/ASR/conformer_mmi/train.py +++ b/egs/aishell/ASR/conformer_mmi/train.py @@ -511,9 +511,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -625,9 +623,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/aishell/ASR/conformer_mmi/transformer.py b/egs/aishell/ASR/conformer_mmi/transformer.py index f93914aaa..a3e50e385 100644 --- a/egs/aishell/ASR/conformer_mmi/transformer.py +++ b/egs/aishell/ASR/conformer_mmi/transformer.py @@ -149,9 +149,7 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear( - d_model, self.decoder_num_class - ) + self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) self.decoder_criterion = LabelSmoothingLoss() else: @@ -183,9 +181,7 @@ class Transformer(nn.Module): x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) - encoder_memory, memory_key_padding_mask = self.run_encoder( - x, supervision - ) + encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -266,23 +262,17 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -343,23 +333,17 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -632,9 +616,7 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class PositionalEncoding(nn.Module): @@ -836,9 +818,7 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -859,9 +839,7 @@ def encoder_padding_mask( return mask -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py b/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py index 42700a972..037971927 100755 --- a/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py +++ b/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py @@ -87,9 +87,7 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80): ) if "train" in partition: cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -116,9 +114,7 @@ def get_args(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/aishell/ASR/local/compute_fbank_aishell.py b/egs/aishell/ASR/local/compute_fbank_aishell.py index deab6c809..115ca1031 100755 --- a/egs/aishell/ASR/local/compute_fbank_aishell.py +++ b/egs/aishell/ASR/local/compute_fbank_aishell.py @@ -83,9 +83,7 @@ def compute_fbank_aishell(num_mel_bins: int = 80): ) if "train" in partition: cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -111,9 +109,7 @@ def get_args(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/aishell/ASR/local/prepare_char.py b/egs/aishell/ASR/local/prepare_char.py index d9e47d17a..6b440dfb3 100755 --- a/egs/aishell/ASR/local/prepare_char.py +++ b/egs/aishell/ASR/local/prepare_char.py @@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil( cur_state = loop_state word = word2id[word] - pieces = [ - token2id[i] if i in token2id else token2id[""] for i in pieces - ] + pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces] for i in range(len(pieces) - 1): w = word if i == 0 else eps @@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool: return False -def generate_lexicon( - token_sym_table: Dict[str, int], words: List[str] -) -> Lexicon: +def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon: """Generate a lexicon from a word list and token_sym_table. Args: diff --git a/egs/aishell/ASR/local/prepare_lang.py b/egs/aishell/ASR/local/prepare_lang.py index e5ae89ec4..c8cf9b881 100755 --- a/egs/aishell/ASR/local/prepare_lang.py +++ b/egs/aishell/ASR/local/prepare_lang.py @@ -317,9 +317,7 @@ def lexicon_to_fst( def get_args(): parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", type=str, help="The lang dir, data/lang_phone" - ) + parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone") return parser.parse_args() diff --git a/egs/aishell/ASR/local/test_prepare_lang.py b/egs/aishell/ASR/local/test_prepare_lang.py index d4cf62bba..74e025ad7 100755 --- a/egs/aishell/ASR/local/test_prepare_lang.py +++ b/egs/aishell/ASR/local/test_prepare_lang.py @@ -88,9 +88,7 @@ def test_read_lexicon(filename: str): fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.draw("L.pdf", title="L") - fsa_disambig = lexicon_to_fst( - lexicon_disambig, phone2id=phone2id, word2id=word2id - ) + fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.draw("L_disambig.pdf", title="L_disambig") diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py index a12934d55..ae926ec66 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py @@ -76,11 +76,7 @@ from beam_search import ( ) from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -118,9 +114,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -188,8 +186,7 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -249,9 +246,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) if params.decoding_method == "fast_beam_search": hyp_tokens = fast_beam_search_one_best( @@ -263,10 +258,7 @@ def decode_one_batch( max_contexts=params.max_contexts, max_states=params.max_states, ) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -310,11 +302,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -387,9 +375,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -415,9 +401,7 @@ def save_results( # we compute CER for aishell dataset. results_char = [] for res in results: - results_char.append( - (res[0], list("".join(res[1])), list("".join(res[2]))) - ) + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=True @@ -428,8 +412,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -473,9 +456,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -504,8 +485,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/export.py b/egs/aishell/ASR/pruned_transducer_stateless2/export.py index feababdd2..5f6888db4 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/export.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/export.py @@ -50,11 +50,7 @@ from pathlib import Path import torch from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.utils import str2bool @@ -87,9 +83,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -120,8 +118,7 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) add_model_arguments(parser) @@ -157,8 +154,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -191,9 +187,7 @@ def main(): model.__class__.forward = torch.jit.ignore(model.__class__.forward) logging.info("Using torch.jit.script") model = torch.jit.script(model) - filename = ( - params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt" - ) + filename = params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt" model.save(str(filename)) logging.info(f"Saved to {filename}") else: @@ -201,17 +195,14 @@ def main(): # Save it using a format so that it can be loaded # by :func:`load_checkpoint` filename = ( - params.exp_dir - / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt" + params.exp_dir / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt" ) torch.save({"model": model.state_dict()}, str(filename)) logging.info(f"Saved to {filename}") if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py index 3c38e5db7..f754a7b9e 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py @@ -87,9 +87,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -115,10 +117,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -165,15 +169,16 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", type=int, default=1, - help="Maximum number of symbols per frame. " - "Use only when --method is greedy_search", + help=( + "Maximum number of symbols per frame. " + "Use only when --method is greedy_search" + ), ) add_model_arguments(parser) @@ -196,10 +201,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -256,13 +260,9 @@ def main(): feature_lens = [f.size(0) for f in features] feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens) num_waves = encoder_out.size(0) hyp_list = [] @@ -310,9 +310,7 @@ def main(): beam=params.beam_size, ) else: - raise ValueError( - f"Unsupported decoding method: {params.method}" - ) + raise ValueError(f"Unsupported decoding method: {params.method}") hyp_list.append(hyp) hyps = [] @@ -329,9 +327,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/train.py b/egs/aishell/ASR/pruned_transducer_stateless2/train.py index 97d892754..66ca23035 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/train.py @@ -49,7 +49,6 @@ import optim import torch import torch.multiprocessing as mp import torch.nn as nn - from asr_datamodule import AishellAsrDataModule from conformer import Conformer from decoder import Decoder @@ -75,9 +74,7 @@ from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -203,8 +200,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need " - "to be changed.", + help="The initial learning rate. This value should not need to be changed.", ) parser.add_argument( @@ -227,42 +223,45 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -561,11 +560,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -593,23 +588,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -725,9 +713,7 @@ def train_one_epoch( scaler.update() optimizer.zero_grad() except: # noqa - display_and_save_batch( - batch, params=params, graph_compiler=graph_compiler - ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) raise if params.print_diagnostics and batch_idx == 5: @@ -891,7 +877,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1029,9 +1015,7 @@ def scan_pessimistic_batches_for_oom( f"Failing criterion: {criterion} " f"(={crit_values[criterion]}) ..." ) - display_and_save_batch( - batch, params=params, graph_compiler=graph_compiler - ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) raise diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py index d159e420b..6c505940d 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py @@ -121,20 +121,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -202,8 +206,7 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -263,9 +266,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) if params.decoding_method == "fast_beam_search": hyp_tokens = fast_beam_search_one_best( @@ -277,10 +278,7 @@ def decode_one_batch( max_contexts=params.max_contexts, max_states=params.max_states, ) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -324,11 +322,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -401,9 +395,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -429,9 +421,7 @@ def save_results( # we compute CER for aishell dataset. results_char = [] for res in results: - results_char.append( - (res[0], list("".join(res[1])), list("".join(res[2]))) - ) + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=True @@ -442,8 +432,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tCER", file=f) @@ -488,9 +477,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -518,13 +505,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -551,13 +537,12 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -586,7 +571,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/export.py b/egs/aishell/ASR/pruned_transducer_stateless3/export.py index 566902a85..e5a5d7c77 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/export.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/export.py @@ -88,20 +88,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -132,8 +136,7 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) add_model_arguments(parser) @@ -166,13 +169,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -195,13 +197,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -229,7 +230,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -252,9 +253,7 @@ def main(): model.__class__.forward = torch.jit.ignore(model.__class__.forward) logging.info("Using torch.jit.script") model = torch.jit.script(model) - filename = ( - params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt" - ) + filename = params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt" model.save(str(filename)) logging.info(f"Saved to {filename}") else: @@ -262,17 +261,14 @@ def main(): # Save it using a format so that it can be loaded # by :func:`load_checkpoint` filename = ( - params.exp_dir - / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt" + params.exp_dir / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt" ) torch.save({"model": model.state_dict()}, str(filename)) logging.info(f"Saved to {filename}") if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/model.py b/egs/aishell/ASR/pruned_transducer_stateless3/model.py index e150e8230..a4dda0d6d 100644 --- a/egs/aishell/ASR/pruned_transducer_stateless3/model.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/model.py @@ -84,9 +84,7 @@ class Transducer(nn.Module): self.decoder_datatang = decoder_datatang self.joiner_datatang = joiner_datatang - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) if decoder_datatang is not None: @@ -179,9 +177,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = encoder_out_lens diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py index 04a0a882a..109879952 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py @@ -87,9 +87,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -115,10 +117,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -165,15 +169,16 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", type=int, default=1, - help="Maximum number of symbols per frame. " - "Use only when --method is greedy_search", + help=( + "Maximum number of symbols per frame. " + "Use only when --method is greedy_search" + ), ) add_model_arguments(parser) @@ -196,10 +201,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -257,13 +261,9 @@ def main(): feature_lens = [f.size(0) for f in features] feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens) num_waves = encoder_out.size(0) hyp_list = [] @@ -311,9 +311,7 @@ def main(): beam=params.beam_size, ) else: - raise ValueError( - f"Unsupported decoding method: {params.method}" - ) + raise ValueError(f"Unsupported decoding method: {params.method}") hyp_list.append(hyp) hyps = [] @@ -330,9 +328,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/train.py b/egs/aishell/ASR/pruned_transducer_stateless3/train.py index feaef5cf6..b24f533ff 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/train.py @@ -96,9 +96,7 @@ from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -224,8 +222,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need " - "to be changed.", + help="The initial learning rate. This value should not need to be changed.", ) parser.add_argument( @@ -248,42 +245,45 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -635,11 +635,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -670,23 +666,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -824,9 +813,7 @@ def train_one_epoch( ) # summary stats if datatang_train_dl is not None: - tot_loss = ( - tot_loss * (1 - 1 / params.reset_interval) - ) + loss_info + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info if aishell: aishell_tot_loss = ( @@ -847,9 +834,7 @@ def train_one_epoch( scaler.update() optimizer.zero_grad() except: # noqa - display_and_save_batch( - batch, params=params, graph_compiler=graph_compiler - ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) raise if params.print_diagnostics and batch_idx == 5: @@ -892,9 +877,7 @@ def train_one_epoch( cur_lr = scheduler.get_last_lr()[0] if datatang_train_dl is not None: datatang_str = f"datatang_tot_loss[{datatang_tot_loss}], " - tot_loss_str = ( - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - ) + tot_loss_str = f"tot_loss[{tot_loss}], batch size: {batch_size}, " else: tot_loss_str = "" datatang_str = "" @@ -1067,7 +1050,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1076,9 +1059,7 @@ def run(rank, world_size, args): train_cuts = filter_short_and_long_utterances(train_cuts) if args.enable_musan: - cuts_musan = load_manifest( - Path(args.manifest_dir) / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") else: cuts_musan = None @@ -1093,9 +1074,7 @@ def run(rank, world_size, args): if params.datatang_prob > 0: datatang = AIDatatang200zh(manifest_dir=args.manifest_dir) train_datatang_cuts = datatang.train_cuts() - train_datatang_cuts = filter_short_and_long_utterances( - train_datatang_cuts - ) + train_datatang_cuts = filter_short_and_long_utterances(train_datatang_cuts) train_datatang_cuts = train_datatang_cuts.repeat(times=None) datatang_train_dl = asr_datamodule.train_dataloaders( train_datatang_cuts, @@ -1249,9 +1228,7 @@ def scan_pessimistic_batches_for_oom( f"Failing criterion: {criterion} " f"(={crit_values[criterion]}) ..." ) - display_and_save_batch( - batch, params=params, graph_compiler=graph_compiler - ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) raise diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py index d24ba6bb7..12ae6e7d4 100644 --- a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -64,10 +64,12 @@ class AishellAsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", + description=( + "These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc." + ), ) group.add_argument( "--manifest-dir", @@ -79,59 +81,74 @@ class AishellAsrDataModule: "--max-duration", type=int, default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", + help=( + "Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM." + ), ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", + help=( + "When enabled, the batches will come from buckets of " + "similar duration (saves padding frames)." + ), ) group.add_argument( "--num-buckets", type=int, default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", + help=( + "The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets)." + ), ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", + help=( + "When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding." + ), ) group.add_argument( "--duration-factor", type=float, default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", + help=( + "Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch." + ), ) group.add_argument( "--gap", type=float, default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", + help=( + "The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used." + ), ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", + help=( + "When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available." + ), ) group.add_argument( "--shuffle", type=str2bool, default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", + help=( + "When enabled (=default), the examples will be shuffled for each epoch." + ), ) group.add_argument( "--drop-last", @@ -143,17 +160,18 @@ class AishellAsrDataModule: "--return-cuts", type=str2bool, default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", + help=( + "When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it." + ), ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that " - "collect the batches.", + help="The number of training dataloader workers that collect the batches.", ) group.add_argument( @@ -167,40 +185,40 @@ class AishellAsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", + help=( + "Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp." + ), ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", + help=( + "When enabled, select noise from MUSAN and mix it" + "with training dataset. " + ), ) def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - f"Using cut concatenation with duration factor " + "Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -215,9 +233,7 @@ class AishellAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -260,9 +276,7 @@ class AishellAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -308,9 +322,7 @@ class AishellAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: @@ -366,13 +378,9 @@ class AishellAsrDataModule: @lru_cache() def valid_cuts(self) -> CutSet: logging.info("About to get dev cuts") - return load_manifest_lazy( - self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz" - ) + return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz") @lru_cache() def test_cuts(self) -> List[CutSet]: logging.info("About to get test cuts") - return load_manifest_lazy( - self.args.manifest_dir / "aishell_cuts_test.jsonl.gz" - ) + return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_test.jsonl.gz") diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py index 66b734fc4..8ef247438 100755 --- a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py @@ -49,16 +49,19 @@ def get_parser(): "--epoch", type=int, default=19, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=5, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( "--method", @@ -265,9 +268,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -289,9 +290,7 @@ def save_results( # We compute CER for aishell dataset. results_char = [] for res in results: - results_char.append( - (res[0], list("".join(res[1])), list("".join(res[2]))) - ) + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats(f, f"{test_set_name}-{key}", results_char) test_set_wers[key] = wer @@ -335,9 +334,7 @@ def main(): logging.info(f"device: {device}") - HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") - ) + HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")) HLG = HLG.to(device) assert HLG.requires_grad is False @@ -362,9 +359,7 @@ def main(): if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save( - {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" - ) + torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") model.to(device) model.eval() @@ -392,9 +387,7 @@ def main(): lexicon=lexicon, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/model.py b/egs/aishell/ASR/tdnn_lstm_ctc/model.py index 5e04c11b4..1731e1ebe 100644 --- a/egs/aishell/ASR/tdnn_lstm_ctc/model.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/model.py @@ -66,10 +66,7 @@ class TdnnLstm(nn.Module): nn.BatchNorm1d(num_features=500, affine=False), ) self.lstms = nn.ModuleList( - [ - nn.LSTM(input_size=500, hidden_size=500, num_layers=1) - for _ in range(5) - ] + [nn.LSTM(input_size=500, hidden_size=500, num_layers=1) for _ in range(5)] ) self.lstm_bnorms = nn.ModuleList( [nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)] diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py index 9bd810809..52f9410cf 100644 --- a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py @@ -41,9 +41,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -53,9 +55,7 @@ def get_parser(): help="Path to words.txt", ) - parser.add_argument( - "--HLG", type=str, required=True, help="Path to HLG.pt." - ) + parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.") parser.add_argument( "--method", @@ -71,10 +71,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) return parser @@ -112,10 +114,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -173,9 +174,7 @@ def main(): logging.info("Decoding started") features = fbank(waves) - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) features = features.permute(0, 2, 1) # now features is [N, C, T] with torch.no_grad(): @@ -219,9 +218,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/train.py b/egs/aishell/ASR/tdnn_lstm_ctc/train.py index 7619b0551..e574cf89b 100755 --- a/egs/aishell/ASR/tdnn_lstm_ctc/train.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/train.py @@ -49,12 +49,7 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - encode_supervisions, - setup_logger, - str2bool, -) +from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool def get_parser(): diff --git a/egs/aishell/ASR/transducer_stateless/beam_search.py b/egs/aishell/ASR/transducer_stateless/beam_search.py index 9ed9b2ad1..de0a8d0f5 100644 --- a/egs/aishell/ASR/transducer_stateless/beam_search.py +++ b/egs/aishell/ASR/transducer_stateless/beam_search.py @@ -47,9 +47,9 @@ def greedy_search( device = model.device - decoder_input = torch.tensor( - [blank_id] * context_size, device=device - ).reshape(1, context_size) + decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape( + 1, context_size + ) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -81,9 +81,9 @@ def greedy_search( y = logits.argmax().item() if y != blank_id: hyp.append(y) - decoder_input = torch.tensor( - [hyp[-context_size:]], device=device - ).reshape(1, context_size) + decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( + 1, context_size + ) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -157,9 +157,7 @@ class HypothesisList(object): """ if length_norm: - return max( - self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) - ) + return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -246,9 +244,9 @@ def beam_search( device = model.device - decoder_input = torch.tensor( - [blank_id] * context_size, device=device - ).reshape(1, context_size) + decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape( + 1, context_size + ) decoder_out = model.decoder(decoder_input, need_pad=False) diff --git a/egs/aishell/ASR/transducer_stateless/conformer.py b/egs/aishell/ASR/transducer_stateless/conformer.py index 64114253d..e26c6c385 100644 --- a/egs/aishell/ASR/transducer_stateless/conformer.py +++ b/egs/aishell/ASR/transducer_stateless/conformer.py @@ -155,9 +155,7 @@ class ConformerEncoderLayer(nn.Module): normalize_before: bool = True, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -175,18 +173,14 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -220,9 +214,7 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) + src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -341,9 +333,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -359,9 +349,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -631,9 +619,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -701,33 +689,25 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor" + " instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -764,9 +744,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -778,9 +756,7 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -814,13 +790,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -843,9 +815,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/aishell/ASR/transducer_stateless/decode.py b/egs/aishell/ASR/transducer_stateless/decode.py index 780b0c4bb..1f7bb14e1 100755 --- a/egs/aishell/ASR/transducer_stateless/decode.py +++ b/egs/aishell/ASR/transducer_stateless/decode.py @@ -52,16 +52,19 @@ def get_parser(): "--epoch", type=int, default=30, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=10, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -99,8 +102,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -227,9 +229,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] batch_size = encoder_out.size(0) @@ -248,9 +248,7 @@ def decode_one_batch( model=model, encoder_out=encoder_out_i, beam=params.beam_size ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") hyps.append([lexicon.token_table[i] for i in hyp]) if params.decoding_method == "greedy_search": @@ -319,9 +317,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -346,9 +342,7 @@ def save_results( # we compute CER for aishell dataset. results_char = [] for res in results: - results_char.append( - (res[0], list("".join(res[1])), list("".join(res[2]))) - ) + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=True @@ -359,8 +353,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tCER", file=f) @@ -430,9 +423,7 @@ def main(): if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save( - {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" - ) + torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") return model.to(device) diff --git a/egs/aishell/ASR/transducer_stateless/decoder.py b/egs/aishell/ASR/transducer_stateless/decoder.py index c2c6552a9..70e9e6c96 100644 --- a/egs/aishell/ASR/transducer_stateless/decoder.py +++ b/egs/aishell/ASR/transducer_stateless/decoder.py @@ -86,9 +86,7 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/aishell/ASR/transducer_stateless/export.py b/egs/aishell/ASR/transducer_stateless/export.py index 4c6519b96..e35b26fe0 100755 --- a/egs/aishell/ASR/transducer_stateless/export.py +++ b/egs/aishell/ASR/transducer_stateless/export.py @@ -69,17 +69,20 @@ def get_parser(): "--epoch", type=int, default=20, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=10, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -110,8 +113,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) return parser @@ -243,9 +245,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/transducer_stateless/model.py b/egs/aishell/ASR/transducer_stateless/model.py index 994305fc1..591bbe44f 100644 --- a/egs/aishell/ASR/transducer_stateless/model.py +++ b/egs/aishell/ASR/transducer_stateless/model.py @@ -103,9 +103,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/aishell/ASR/transducer_stateless/pretrained.py b/egs/aishell/ASR/transducer_stateless/pretrained.py index db89c4d67..8effc9815 100755 --- a/egs/aishell/ASR/transducer_stateless/pretrained.py +++ b/egs/aishell/ASR/transducer_stateless/pretrained.py @@ -73,9 +73,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -100,10 +102,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -117,8 +121,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -211,10 +214,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -273,9 +275,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -319,9 +319,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/transducer_stateless/train.py b/egs/aishell/ASR/transducer_stateless/train.py index d54157709..62ffff473 100755 --- a/egs/aishell/ASR/transducer_stateless/train.py +++ b/egs/aishell/ASR/transducer_stateless/train.py @@ -126,8 +126,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -389,9 +388,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -504,9 +501,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -625,9 +620,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/aishell/ASR/transducer_stateless/transformer.py b/egs/aishell/ASR/transducer_stateless/transformer.py index e851dcc32..b3ff153c1 100644 --- a/egs/aishell/ASR/transducer_stateless/transformer.py +++ b/egs/aishell/ASR/transducer_stateless/transformer.py @@ -250,9 +250,7 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class PositionalEncoding(nn.Module): diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py b/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py index 838e53658..76e209f06 100644 --- a/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py @@ -29,10 +29,7 @@ from lhotse.dataset import ( K2SpeechRecognitionDataset, SpecAugment, ) -from lhotse.dataset.input_strategies import ( - OnTheFlyFeatures, - PrecomputedFeatures, -) +from lhotse.dataset.input_strategies import OnTheFlyFeatures, PrecomputedFeatures from torch.utils.data import DataLoader from icefall.utils import str2bool @@ -46,59 +43,69 @@ class AsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", + description=( + "These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc." + ), ) group.add_argument( "--max-duration", type=int, default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", + help=( + "Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM." + ), ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", + help=( + "When enabled, the batches will come from buckets of " + "similar duration (saves padding frames)." + ), ) group.add_argument( "--num-buckets", type=int, default=30, - help="The number of buckets for the DynamicBucketingSampler " - "(you might want to increase it for larger datasets).", + help=( + "The number of buckets for the DynamicBucketingSampler " + "(you might want to increase it for larger datasets)." + ), ) group.add_argument( "--shuffle", type=str2bool, default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", + help=( + "When enabled (=default), the examples will be shuffled for each epoch." + ), ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", + help=( + "When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it." + ), ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that " - "collect the batches.", + help="The number of training dataloader workers that collect the batches.", ) group.add_argument( @@ -112,18 +119,22 @@ class AsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", + help=( + "Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp." + ), ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", + help=( + "When enabled, select noise from MUSAN and mix it" + "with training dataset. " + ), ) group.add_argument( @@ -137,9 +148,11 @@ class AsrDataModule: "--on-the-fly-feats", type=str2bool, default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available. Used only in dev/test CutSet", + help=( + "When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available. Used only in dev/test CutSet" + ), ) def train_dataloaders( @@ -162,9 +175,7 @@ class AsrDataModule: if cuts_musan is not None: logging.info("Enable MUSAN") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -173,9 +184,7 @@ class AsrDataModule: if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -252,9 +261,7 @@ class AsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py index ea3f94fd8..fd4cb8385 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py @@ -93,16 +93,19 @@ def get_parser(): "--epoch", type=int, default=30, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=10, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -170,8 +173,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -227,9 +229,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) if params.decoding_method == "fast_beam_search": hyp_tokens = fast_beam_search_one_best( @@ -241,10 +241,7 @@ def decode_one_batch( max_contexts=params.max_contexts, max_states=params.max_states, ) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -288,11 +285,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -365,9 +358,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -393,9 +384,7 @@ def save_results( # we compute CER for aishell dataset. results_char = [] for res in results: - results_char.append( - (res[0], list("".join(res[1])), list("".join(res[2]))) - ) + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=True @@ -406,8 +395,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tCER", file=f) @@ -448,9 +436,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/export.py b/egs/aishell/ASR/transducer_stateless_modified-2/export.py index 3bd2ceb11..32481829c 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/export.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/export.py @@ -68,17 +68,20 @@ def get_parser(): "--epoch", type=int, default=20, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=10, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -109,8 +112,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) return parser @@ -241,9 +243,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py index a95a4bc52..55701a007 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py @@ -87,9 +87,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -115,10 +117,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -165,15 +169,16 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", type=int, default=1, - help="Maximum number of symbols per frame. " - "Use only when --method is greedy_search", + help=( + "Maximum number of symbols per frame. " + "Use only when --method is greedy_search" + ), ) return parser @@ -194,10 +199,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -254,13 +258,9 @@ def main(): feature_lens = [f.size(0) for f in features] feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens) num_waves = encoder_out.size(0) hyp_list = [] @@ -308,9 +308,7 @@ def main(): beam=params.beam_size, ) else: - raise ValueError( - f"Unsupported decoding method: {params.method}" - ) + raise ValueError(f"Unsupported decoding method: {params.method}") hyp_list.append(hyp) hyps = [] @@ -327,9 +325,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/train.py b/egs/aishell/ASR/transducer_stateless_modified-2/train.py index 225d0d709..8fb7d1e49 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/train.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/train.py @@ -149,8 +149,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -168,8 +167,7 @@ def get_parser(): "--datatang-prob", type=float, default=0.2, - help="The probability to select a batch from the " - "aidatatang_200zh dataset", + help="The probability to select a batch from the aidatatang_200zh dataset", ) return parser @@ -449,9 +447,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -605,9 +601,7 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) aishell_tot_loss.write_summary( tb_writer, "train/aishell_tot_", params.batch_idx_train ) @@ -735,9 +729,7 @@ def run(rank, world_size, args): train_datatang_cuts = train_datatang_cuts.repeat(times=None) if args.enable_musan: - cuts_musan = load_manifest( - Path(args.manifest_dir) / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") else: cuts_musan = None @@ -776,9 +768,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/aishell/ASR/transducer_stateless_modified/decode.py b/egs/aishell/ASR/transducer_stateless_modified/decode.py index 65fcda873..1e41942da 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/decode.py +++ b/egs/aishell/ASR/transducer_stateless_modified/decode.py @@ -94,16 +94,19 @@ def get_parser(): "--epoch", type=int, default=30, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=10, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -171,8 +174,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -231,9 +233,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) if params.decoding_method == "fast_beam_search": hyp_tokens = fast_beam_search_one_best( @@ -245,10 +245,7 @@ def decode_one_batch( max_contexts=params.max_contexts, max_states=params.max_states, ) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -292,11 +289,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -369,9 +362,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -397,9 +388,7 @@ def save_results( # we compute CER for aishell dataset. results_char = [] for res in results: - results_char.append( - (res[0], list("".join(res[1])), list("".join(res[2]))) - ) + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=True @@ -410,8 +399,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tCER", file=f) @@ -452,9 +440,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" diff --git a/egs/aishell/ASR/transducer_stateless_modified/export.py b/egs/aishell/ASR/transducer_stateless_modified/export.py index 11335a834..ca1d4bd4a 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/export.py +++ b/egs/aishell/ASR/transducer_stateless_modified/export.py @@ -68,17 +68,20 @@ def get_parser(): "--epoch", type=int, default=20, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=10, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -109,8 +112,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) return parser @@ -241,9 +243,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py index 262e822c2..038090461 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py +++ b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py @@ -87,9 +87,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -115,10 +117,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -165,15 +169,16 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", type=int, default=1, - help="Maximum number of symbols per frame. " - "Use only when --method is greedy_search", + help=( + "Maximum number of symbols per frame. " + "Use only when --method is greedy_search" + ), ) return parser @@ -194,10 +199,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -254,13 +258,9 @@ def main(): feature_lens = [f.size(0) for f in features] feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens) num_waves = encoder_out.size(0) hyp_list = [] @@ -308,9 +308,7 @@ def main(): beam=params.beam_size, ) else: - raise ValueError( - f"Unsupported decoding method: {params.method}" - ) + raise ValueError(f"Unsupported decoding method: {params.method}") hyp_list.append(hyp) hyps = [] @@ -327,9 +325,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/transducer_stateless_modified/train.py b/egs/aishell/ASR/transducer_stateless_modified/train.py index d3ffccafa..5f116f2bd 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/train.py +++ b/egs/aishell/ASR/transducer_stateless_modified/train.py @@ -142,8 +142,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -414,9 +413,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -529,9 +526,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -657,9 +652,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/aishell2/ASR/local/__init__.py b/egs/aishell2/ASR/local/__init__.py old mode 100755 new mode 100644 diff --git a/egs/aishell2/ASR/local/compute_fbank_aishell2.py b/egs/aishell2/ASR/local/compute_fbank_aishell2.py index d8d3622bd..ec0c584ca 100755 --- a/egs/aishell2/ASR/local/compute_fbank_aishell2.py +++ b/egs/aishell2/ASR/local/compute_fbank_aishell2.py @@ -83,9 +83,7 @@ def compute_fbank_aishell2(num_mel_bins: int = 80): ) if "train" in partition: cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -111,9 +109,7 @@ def get_args(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/__init__.py b/egs/aishell2/ASR/pruned_transducer_stateless5/__init__.py old mode 100755 new mode 100644 diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py old mode 100755 new mode 100644 index b7a21f579..e8966b554 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -76,10 +76,12 @@ class AiShell2AsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", + description=( + "These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc." + ), ) group.add_argument( "--manifest-dir", @@ -91,59 +93,74 @@ class AiShell2AsrDataModule: "--max-duration", type=int, default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", + help=( + "Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM." + ), ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", + help=( + "When enabled, the batches will come from buckets of " + "similar duration (saves padding frames)." + ), ) group.add_argument( "--num-buckets", type=int, default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", + help=( + "The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets)." + ), ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", + help=( + "When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding." + ), ) group.add_argument( "--duration-factor", type=float, default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", + help=( + "Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch." + ), ) group.add_argument( "--gap", type=float, default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", + help=( + "The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used." + ), ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", + help=( + "When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available." + ), ) group.add_argument( "--shuffle", type=str2bool, default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", + help=( + "When enabled (=default), the examples will be shuffled for each epoch." + ), ) group.add_argument( "--drop-last", @@ -155,17 +172,18 @@ class AiShell2AsrDataModule: "--return-cuts", type=str2bool, default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", + help=( + "When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it." + ), ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that " - "collect the batches.", + help="The number of training dataloader workers that collect the batches.", ) group.add_argument( @@ -179,18 +197,22 @@ class AiShell2AsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", + help=( + "Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp." + ), ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", + help=( + "When enabled, select noise from MUSAN and mix it" + "with training dataset. " + ), ) group.add_argument( @@ -216,20 +238,16 @@ class AiShell2AsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - f"Using cut concatenation with duration factor " + "Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -244,9 +262,7 @@ class AiShell2AsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -290,9 +306,7 @@ class AiShell2AsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -348,9 +362,7 @@ class AiShell2AsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: @@ -406,9 +418,7 @@ class AiShell2AsrDataModule: @lru_cache() def valid_cuts(self) -> CutSet: logging.info("About to gen cuts from aishell2_cuts_dev.jsonl.gz") - return load_manifest_lazy( - self.args.manifest_dir / "aishell2_cuts_dev.jsonl.gz" - ) + return load_manifest_lazy(self.args.manifest_dir / "aishell2_cuts_dev.jsonl.gz") @lru_cache() def test_cuts(self) -> CutSet: diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py index 915737f4a..64b64d1b1 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py @@ -168,20 +168,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -269,8 +273,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -348,9 +351,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -409,10 +410,7 @@ def decode_one_batch( ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -538,9 +536,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -573,8 +569,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -625,9 +620,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -661,13 +654,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -690,13 +682,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -724,7 +715,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -749,9 +740,7 @@ def main(): ) decoding_graph.scores *= params.ngram_lm_scale else: - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py index bc7bd71cb..547ce2069 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py @@ -89,20 +89,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -133,8 +137,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) add_model_arguments(parser) @@ -167,13 +170,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -196,13 +198,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -230,7 +231,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -266,9 +267,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py b/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py index 09de1bece..4b16511e8 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py @@ -81,9 +81,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -109,10 +111,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -159,8 +163,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -191,10 +194,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -254,15 +256,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -334,9 +332,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py index 838a0497f..d37e7bdca 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py @@ -92,9 +92,7 @@ from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -220,8 +218,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need " - "to be changed.", + help="The initial learning rate. This value should not need to be changed.", ) parser.add_argument( @@ -244,42 +241,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -603,11 +603,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -636,23 +632,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -771,9 +760,7 @@ def train_one_epoch( scaler.update() optimizer.zero_grad() except: # noqa - display_and_save_batch( - batch, params=params, graph_compiler=graph_compiler - ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) raise if params.print_diagnostics and batch_idx == 5: @@ -829,9 +816,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -939,7 +924,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1104,9 +1089,7 @@ def scan_pessimistic_batches_for_oom( f"Failing criterion: {criterion} " f"(={crit_values[criterion]}) ..." ) - display_and_save_batch( - batch, params=params, graph_compiler=graph_compiler - ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) raise diff --git a/egs/aishell4/ASR/local/compute_fbank_aishell4.py b/egs/aishell4/ASR/local/compute_fbank_aishell4.py index 3f50d9e3e..400c406f0 100755 --- a/egs/aishell4/ASR/local/compute_fbank_aishell4.py +++ b/egs/aishell4/ASR/local/compute_fbank_aishell4.py @@ -85,9 +85,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80): ) if "train" in partition: cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -120,9 +118,7 @@ def get_args(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/aishell4/ASR/local/prepare_char.py b/egs/aishell4/ASR/local/prepare_char.py index d9e47d17a..6b440dfb3 100755 --- a/egs/aishell4/ASR/local/prepare_char.py +++ b/egs/aishell4/ASR/local/prepare_char.py @@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil( cur_state = loop_state word = word2id[word] - pieces = [ - token2id[i] if i in token2id else token2id[""] for i in pieces - ] + pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces] for i in range(len(pieces) - 1): w = word if i == 0 else eps @@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool: return False -def generate_lexicon( - token_sym_table: Dict[str, int], words: List[str] -) -> Lexicon: +def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon: """Generate a lexicon from a word list and token_sym_table. Args: diff --git a/egs/aishell4/ASR/local/prepare_lang.py b/egs/aishell4/ASR/local/prepare_lang.py index e5ae89ec4..c8cf9b881 100755 --- a/egs/aishell4/ASR/local/prepare_lang.py +++ b/egs/aishell4/ASR/local/prepare_lang.py @@ -317,9 +317,7 @@ def lexicon_to_fst( def get_args(): parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", type=str, help="The lang dir, data/lang_phone" - ) + parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone") return parser.parse_args() diff --git a/egs/aishell4/ASR/local/test_prepare_lang.py b/egs/aishell4/ASR/local/test_prepare_lang.py index d4cf62bba..74e025ad7 100755 --- a/egs/aishell4/ASR/local/test_prepare_lang.py +++ b/egs/aishell4/ASR/local/test_prepare_lang.py @@ -88,9 +88,7 @@ def test_read_lexicon(filename: str): fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.draw("L.pdf", title="L") - fsa_disambig = lexicon_to_fst( - lexicon_disambig, phone2id=phone2id, word2id=word2id - ) + fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.draw("L_disambig.pdf", title="L_disambig") diff --git a/egs/aishell4/ASR/local/text2token.py b/egs/aishell4/ASR/local/text2token.py index 71be2a613..2be639b7a 100755 --- a/egs/aishell4/ASR/local/text2token.py +++ b/egs/aishell4/ASR/local/text2token.py @@ -50,15 +50,15 @@ def get_parser(): "-n", default=1, type=int, - help="number of characters to split, i.e., \ - aabb -> a a b b with -n 1 and aa bb with -n 2", + help=( + "number of characters to split, i.e., aabb -> a a b" + " b with -n 1 and aa bb with -n 2" + ), ) parser.add_argument( "--skip-ncols", "-s", default=0, type=int, help="skip first n columns" ) - parser.add_argument( - "--space", default="", type=str, help="space symbol" - ) + parser.add_argument("--space", default="", type=str, help="space symbol") parser.add_argument( "--non-lang-syms", "-l", @@ -66,9 +66,7 @@ def get_parser(): type=str, help="list of non-linguistic symobles, e.g., etc.", ) - parser.add_argument( - "text", type=str, default=False, nargs="?", help="input text" - ) + parser.add_argument("text", type=str, default=False, nargs="?", help="input text") parser.add_argument( "--trans_type", "-t", @@ -108,8 +106,7 @@ def token2id( if token_type == "lazy_pinyin": text = lazy_pinyin(chars_list) sub_ids = [ - token_table[txt] if txt in token_table else oov_id - for txt in text + token_table[txt] if txt in token_table else oov_id for txt in text ] ids.append(sub_ids) else: # token_type = "pinyin" @@ -135,9 +132,7 @@ def main(): if args.text: f = codecs.open(args.text, encoding="utf-8") else: - f = codecs.getreader("utf-8")( - sys.stdin if is_python2 else sys.stdin.buffer - ) + f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer) sys.stdout = codecs.getwriter("utf-8")( sys.stdout if is_python2 else sys.stdout.buffer diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py index 7aa53ddda..84c7f0443 100644 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -74,10 +74,12 @@ class Aishell4AsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", + description=( + "These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc." + ), ) group.add_argument( @@ -91,66 +93,81 @@ class Aishell4AsrDataModule: "--max-duration", type=int, default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", + help=( + "Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM." + ), ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", + help=( + "When enabled, the batches will come from buckets of " + "similar duration (saves padding frames)." + ), ) group.add_argument( "--num-buckets", type=int, default=300, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", + help=( + "The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets)." + ), ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", + help=( + "When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding." + ), ) group.add_argument( "--duration-factor", type=float, default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", + help=( + "Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch." + ), ) group.add_argument( "--gap", type=float, default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", + help=( + "The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used." + ), ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", + help=( + "When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available." + ), ) group.add_argument( "--shuffle", type=str2bool, default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", + help=( + "When enabled (=default), the examples will be shuffled for each epoch." + ), ) group.add_argument( @@ -164,17 +181,18 @@ class Aishell4AsrDataModule: "--return-cuts", type=str2bool, default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", + help=( + "When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it." + ), ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that " - "collect the batches.", + help="The number of training dataloader workers that collect the batches.", ) group.add_argument( @@ -188,18 +206,22 @@ class Aishell4AsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", + help=( + "Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp." + ), ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", + help=( + "When enabled, select noise from MUSAN and mix it" + "with training dataset. " + ), ) group.add_argument( @@ -222,24 +244,20 @@ class Aishell4AsrDataModule: The state dict for the training sampler. """ logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - f"Using cut concatenation with duration factor " + "Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -254,9 +272,7 @@ class Aishell4AsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -300,9 +316,7 @@ class Aishell4AsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -359,9 +373,7 @@ class Aishell4AsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py index 14e44c7d9..616a88937 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py @@ -117,20 +117,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -201,8 +205,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -260,9 +263,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -277,10 +278,7 @@ def decode_one_batch( ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -326,11 +324,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -401,9 +395,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -436,8 +428,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -480,9 +471,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -510,13 +499,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -543,13 +531,12 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -578,7 +565,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py index 993341131..3c580ff7b 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py @@ -89,20 +89,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -136,8 +140,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) add_model_arguments(parser) @@ -169,13 +172,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -202,13 +204,12 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -237,7 +238,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -276,9 +277,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py b/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py index 1fa893637..8151442af 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py @@ -94,9 +94,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -122,10 +124,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -172,8 +176,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -204,10 +207,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -266,15 +268,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -306,10 +304,7 @@ def main(): for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -350,9 +345,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py index 0a48b9059..aacd23ecd 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py @@ -85,9 +85,7 @@ from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -213,8 +211,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need " - "to be changed.", + help="The initial learning rate. This value should not need to be changed.", ) parser.add_argument( @@ -237,42 +234,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -599,11 +599,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -633,22 +629,15 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -827,9 +816,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -937,7 +924,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py b/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py index af926aa53..96115a230 100755 --- a/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py +++ b/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py @@ -84,9 +84,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80): ) if "train" in partition: cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) cur_num_jobs = num_jobs if ex is None else 80 cur_num_jobs = min(cur_num_jobs, len(cut_set)) @@ -121,9 +119,7 @@ def get_args(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/alimeeting/ASR/local/prepare_char.py b/egs/alimeeting/ASR/local/prepare_char.py index d9e47d17a..6b440dfb3 100755 --- a/egs/alimeeting/ASR/local/prepare_char.py +++ b/egs/alimeeting/ASR/local/prepare_char.py @@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil( cur_state = loop_state word = word2id[word] - pieces = [ - token2id[i] if i in token2id else token2id[""] for i in pieces - ] + pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces] for i in range(len(pieces) - 1): w = word if i == 0 else eps @@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool: return False -def generate_lexicon( - token_sym_table: Dict[str, int], words: List[str] -) -> Lexicon: +def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon: """Generate a lexicon from a word list and token_sym_table. Args: diff --git a/egs/alimeeting/ASR/local/prepare_lang.py b/egs/alimeeting/ASR/local/prepare_lang.py index e5ae89ec4..c8cf9b881 100755 --- a/egs/alimeeting/ASR/local/prepare_lang.py +++ b/egs/alimeeting/ASR/local/prepare_lang.py @@ -317,9 +317,7 @@ def lexicon_to_fst( def get_args(): parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", type=str, help="The lang dir, data/lang_phone" - ) + parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone") return parser.parse_args() diff --git a/egs/alimeeting/ASR/local/test_prepare_lang.py b/egs/alimeeting/ASR/local/test_prepare_lang.py index d4cf62bba..74e025ad7 100755 --- a/egs/alimeeting/ASR/local/test_prepare_lang.py +++ b/egs/alimeeting/ASR/local/test_prepare_lang.py @@ -88,9 +88,7 @@ def test_read_lexicon(filename: str): fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.draw("L.pdf", title="L") - fsa_disambig = lexicon_to_fst( - lexicon_disambig, phone2id=phone2id, word2id=word2id - ) + fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.draw("L_disambig.pdf", title="L_disambig") diff --git a/egs/alimeeting/ASR/local/text2segments.py b/egs/alimeeting/ASR/local/text2segments.py index 7c1019aa8..27b904fc8 100644 --- a/egs/alimeeting/ASR/local/text2segments.py +++ b/egs/alimeeting/ASR/local/text2segments.py @@ -30,8 +30,8 @@ with word segmenting: import argparse -import paddle import jieba +import paddle from tqdm import tqdm paddle.enable_static() diff --git a/egs/alimeeting/ASR/local/text2token.py b/egs/alimeeting/ASR/local/text2token.py index 71be2a613..2be639b7a 100755 --- a/egs/alimeeting/ASR/local/text2token.py +++ b/egs/alimeeting/ASR/local/text2token.py @@ -50,15 +50,15 @@ def get_parser(): "-n", default=1, type=int, - help="number of characters to split, i.e., \ - aabb -> a a b b with -n 1 and aa bb with -n 2", + help=( + "number of characters to split, i.e., aabb -> a a b" + " b with -n 1 and aa bb with -n 2" + ), ) parser.add_argument( "--skip-ncols", "-s", default=0, type=int, help="skip first n columns" ) - parser.add_argument( - "--space", default="", type=str, help="space symbol" - ) + parser.add_argument("--space", default="", type=str, help="space symbol") parser.add_argument( "--non-lang-syms", "-l", @@ -66,9 +66,7 @@ def get_parser(): type=str, help="list of non-linguistic symobles, e.g., etc.", ) - parser.add_argument( - "text", type=str, default=False, nargs="?", help="input text" - ) + parser.add_argument("text", type=str, default=False, nargs="?", help="input text") parser.add_argument( "--trans_type", "-t", @@ -108,8 +106,7 @@ def token2id( if token_type == "lazy_pinyin": text = lazy_pinyin(chars_list) sub_ids = [ - token_table[txt] if txt in token_table else oov_id - for txt in text + token_table[txt] if txt in token_table else oov_id for txt in text ] ids.append(sub_ids) else: # token_type = "pinyin" @@ -135,9 +132,7 @@ def main(): if args.text: f = codecs.open(args.text, encoding="utf-8") else: - f = codecs.getreader("utf-8")( - sys.stdin if is_python2 else sys.stdin.buffer - ) + f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer) sys.stdout = codecs.getwriter("utf-8")( sys.stdout if is_python2 else sys.stdout.buffer diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py index bf6faad7a..d0467a29e 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -81,10 +81,12 @@ class AlimeetingAsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", + description=( + "These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc." + ), ) group.add_argument( "--manifest-dir", @@ -96,75 +98,91 @@ class AlimeetingAsrDataModule: "--max-duration", type=int, default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", + help=( + "Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM." + ), ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", + help=( + "When enabled, the batches will come from buckets of " + "similar duration (saves padding frames)." + ), ) group.add_argument( "--num-buckets", type=int, default=300, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", + help=( + "The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets)." + ), ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", + help=( + "When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding." + ), ) group.add_argument( "--duration-factor", type=float, default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", + help=( + "Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch." + ), ) group.add_argument( "--gap", type=float, default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", + help=( + "The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used." + ), ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", + help=( + "When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available." + ), ) group.add_argument( "--shuffle", type=str2bool, default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", + help=( + "When enabled (=default), the examples will be shuffled for each epoch." + ), ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", + help=( + "When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it." + ), ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that " - "collect the batches.", + help="The number of training dataloader workers that collect the batches.", ) group.add_argument( @@ -178,18 +196,22 @@ class AlimeetingAsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", + help=( + "Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp." + ), ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", + help=( + "When enabled, select noise from MUSAN and mix it" + "with training dataset. " + ), ) def train_dataloaders( @@ -205,24 +227,20 @@ class AlimeetingAsrDataModule: The state dict for the training sampler. """ logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - f"Using cut concatenation with duration factor " + "Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -237,9 +255,7 @@ class AlimeetingAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -282,9 +298,7 @@ class AlimeetingAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -341,9 +355,7 @@ class AlimeetingAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py index 6358fe970..ffaca1021 100755 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py @@ -70,11 +70,7 @@ from beam_search import ( from lhotse.cut import Cut from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -93,25 +89,30 @@ def get_parser(): "--epoch", type=int, default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--batch", type=int, default=None, - help="It specifies the batch checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the batch checkpoint to use for decoding." + "Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -193,8 +194,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -249,9 +249,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -266,10 +264,7 @@ def decode_one_batch( ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -315,11 +310,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -390,9 +381,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -425,8 +414,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -563,8 +551,7 @@ def main(): ) dev_shards = [ - str(path) - for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar"))) + str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar"))) ] cuts_dev_webdataset = CutSet.from_webdataset( dev_shards, @@ -574,8 +561,7 @@ def main(): ) test_shards = [ - str(path) - for path in sorted(glob.glob(os.path.join(test, "shared-*.tar"))) + str(path) for path in sorted(glob.glob(os.path.join(test, "shared-*.tar"))) ] cuts_test_webdataset = CutSet.from_webdataset( test_shards, @@ -588,9 +574,7 @@ def main(): return 1.0 <= c.duration cuts_dev_webdataset = cuts_dev_webdataset.filter(remove_short_and_long_utt) - cuts_test_webdataset = cuts_test_webdataset.filter( - remove_short_and_long_utt - ) + cuts_test_webdataset = cuts_test_webdataset.filter(remove_short_and_long_utt) dev_dl = alimeeting.valid_dataloaders(cuts_dev_webdataset) test_dl = alimeeting.test_dataloaders(cuts_test_webdataset) diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py index 8beec1b8a..482e52d83 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py @@ -62,17 +62,20 @@ def get_parser(): "--epoch", type=int, default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -103,8 +106,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) return parser @@ -173,9 +175,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py index 93b1e1f57..afbf0960a 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py @@ -85,9 +85,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -112,10 +114,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -162,8 +166,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -193,10 +196,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -257,9 +259,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -284,10 +284,7 @@ def main(): ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -339,9 +336,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py index 81a0ede7f..158ea9c1b 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py @@ -81,9 +81,7 @@ from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] os.environ["CUDA_LAUNCH_BLOCKING"] = "1" @@ -187,42 +185,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -542,22 +543,15 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -711,9 +705,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -813,7 +805,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/csj/ASR/.gitignore b/egs/csj/ASR/.gitignore index 5d965832e..cd0e20c4c 100644 --- a/egs/csj/ASR/.gitignore +++ b/egs/csj/ASR/.gitignore @@ -5,4 +5,4 @@ notify_tg.py finetune_* misc.ini .vscode/* -offline/* \ No newline at end of file +offline/* diff --git a/egs/csj/ASR/local/compute_fbank_csj.py b/egs/csj/ASR/local/compute_fbank_csj.py index 994dedbdd..036ce925f 100644 --- a/egs/csj/ASR/local/compute_fbank_csj.py +++ b/egs/csj/ASR/local/compute_fbank_csj.py @@ -25,15 +25,10 @@ from random import Random from typing import List, Tuple import torch -from lhotse import ( +from lhotse import ( # fmt: off; See the following for why LilcomChunkyWriter is preferred; https://github.com/k2-fsa/icefall/pull/404; https://github.com/lhotse-speech/lhotse/pull/527; fmt: on CutSet, Fbank, FbankConfig, - # fmt: off - # See the following for why LilcomChunkyWriter is preferred - # https://github.com/k2-fsa/icefall/pull/404 - # https://github.com/lhotse-speech/lhotse/pull/527 - # fmt: on LilcomChunkyWriter, RecordingSet, SupervisionSet, @@ -81,17 +76,13 @@ def make_cutset_blueprints( cut_sets.append((f"eval{i}", cut_set)) # Create train and valid cuts - logging.info( - "Loading, trimming, and shuffling the remaining core+noncore cuts." - ) + logging.info("Loading, trimming, and shuffling the remaining core+noncore cuts.") recording_set = RecordingSet.from_file( manifest_dir / "csj_recordings_core.jsonl.gz" ) + RecordingSet.from_file(manifest_dir / "csj_recordings_noncore.jsonl.gz") supervision_set = SupervisionSet.from_file( manifest_dir / "csj_supervisions_core.jsonl.gz" - ) + SupervisionSet.from_file( - manifest_dir / "csj_supervisions_noncore.jsonl.gz" - ) + ) + SupervisionSet.from_file(manifest_dir / "csj_supervisions_noncore.jsonl.gz") cut_set = CutSet.from_manifests( recordings=recording_set, @@ -101,15 +92,12 @@ def make_cutset_blueprints( cut_set = cut_set.shuffle(Random(RNG_SEED)) logging.info( - "Creating valid and train cuts from core and noncore," - f"split at {split}." + f"Creating valid and train cuts from core and noncore,split at {split}." ) valid_set = CutSet.from_cuts(islice(cut_set, 0, split)) train_set = CutSet.from_cuts(islice(cut_set, split, None)) - train_set = ( - train_set + train_set.perturb_speed(0.9) + train_set.perturb_speed(1.1) - ) + train_set = train_set + train_set.perturb_speed(0.9) + train_set.perturb_speed(1.1) cut_sets.extend([("valid", valid_set), ("train", train_set)]) @@ -122,15 +110,9 @@ def get_args(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument( - "--manifest-dir", type=Path, help="Path to save manifests" - ) - parser.add_argument( - "--fbank-dir", type=Path, help="Path to save fbank features" - ) - parser.add_argument( - "--split", type=int, default=4000, help="Split at this index" - ) + parser.add_argument("--manifest-dir", type=Path, help="Path to save manifests") + parser.add_argument("--fbank-dir", type=Path, help="Path to save fbank features") + parser.add_argument("--split", type=int, default=4000, help="Split at this index") return parser.parse_args() @@ -141,9 +123,7 @@ def main(): extractor = Fbank(FbankConfig(num_mel_bins=80)) num_jobs = min(16, os.cpu_count()) - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/csj/ASR/local/compute_fbank_musan.py b/egs/csj/ASR/local/compute_fbank_musan.py index 44a33c4eb..f60e62c85 100644 --- a/egs/csj/ASR/local/compute_fbank_musan.py +++ b/egs/csj/ASR/local/compute_fbank_musan.py @@ -26,7 +26,6 @@ from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor - ARGPARSE_DESCRIPTION = """ This file computes fbank features of the musan dataset. It looks for manifests in the directory data/manifests. @@ -84,9 +83,7 @@ def compute_fbank_musan(manifest_dir: Path, fbank_dir: Path): # create chunks of Musan with duration 5 - 10 seconds musan_cuts = ( CutSet.from_manifests( - recordings=combine( - part["recordings"] for part in manifests.values() - ) + recordings=combine(part["recordings"] for part in manifests.values()) ) .cut_into_windows(10.0) .filter(lambda c: c.duration > 5) @@ -107,21 +104,15 @@ def get_args(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument( - "--manifest-dir", type=Path, help="Path to save manifests" - ) - parser.add_argument( - "--fbank-dir", type=Path, help="Path to save fbank features" - ) + parser.add_argument("--manifest-dir", type=Path, help="Path to save manifests") + parser.add_argument("--fbank-dir", type=Path, help="Path to save fbank features") return parser.parse_args() if __name__ == "__main__": args = get_args() - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_musan(args.manifest_dir, args.fbank_dir) diff --git a/egs/csj/ASR/local/conf/disfluent.ini b/egs/csj/ASR/local/conf/disfluent.ini index eb70673de..c987e72c5 100644 --- a/egs/csj/ASR/local/conf/disfluent.ini +++ b/egs/csj/ASR/local/conf/disfluent.ini @@ -1,17 +1,17 @@ ; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj +; # lhotse prepare csj [SEGMENTS] ; # Allowed period of nonverbal noise. If exceeded, a new segment is created. gap = 0.5 ; # Maximum length of segment (s). maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. +; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = +; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. +; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. +; # If you intend to use a multicharacter string for gap_sym, remember to register the +; # multicharacter string as part of userdef-string in prepare_lang_char.py. +gap_sym = [CONSTANTS] ; # Name of this mode @@ -115,59 +115,59 @@ B^ = 0 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', +; # Example: 'コク(笑 サイ+(D オン))', 笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' +; # Example: '(泣 ドンナニ)' 泣 = 0 泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' +; # Example: 'シャ(咳 リン) ノ' 咳 = 0 ; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' 咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 ; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' L^ = 0 [REPLACEMENTS] ; # ボーカルフライなどで母音が同定できない場合 - = + = ; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = + = ; # 非語彙的な母音の引き延ばし - = + = ; # 非語彙的な子音の引き延ばし - = + = ; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = +<笑> = ; # 言語音と独立に講演者の咳が生じている場合 -<咳> = +<咳> = ; # 言語音と独立に講演者の息が生じている場合 -<息> = +<息> = ; # 講演者の泣き声 -<泣> = +<泣> = ; # 聴衆(司会者なども含む)の発話 -<フロア発話> = +<フロア発話> = ; # 聴衆の笑い -<フロア笑> = +<フロア笑> = ; # 聴衆の拍手 -<拍手> = +<拍手> = ; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = +<デモ> = ; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = +<ベル> = ; # 転記単位全体が再度読み直された場合 -<朗読間違い> = +<朗読間違い> = ; # 上記以外の音で特に目立った音 -<雑音> = +<雑音> = ; # 0.2秒以上のポーズ -

= +

= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -318,4 +318,3 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo - diff --git a/egs/csj/ASR/local/conf/fluent.ini b/egs/csj/ASR/local/conf/fluent.ini index 5d22f9eb8..f7f27f5bc 100644 --- a/egs/csj/ASR/local/conf/fluent.ini +++ b/egs/csj/ASR/local/conf/fluent.ini @@ -1,17 +1,17 @@ ; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj +; # lhotse prepare csj [SEGMENTS] ; # Allowed period of nonverbal noise. If exceeded, a new segment is created. gap = 0.5 ; # Maximum length of segment (s). maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. +; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = +; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. +; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. +; # If you intend to use a multicharacter string for gap_sym, remember to register the +; # multicharacter string as part of userdef-string in prepare_lang_char.py. +gap_sym = [CONSTANTS] ; # Name of this mode @@ -115,59 +115,59 @@ B^ = 0 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', +; # Example: 'コク(笑 サイ+(D オン))', 笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' +; # Example: '(泣 ドンナニ)' 泣 = 0 泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' +; # Example: 'シャ(咳 リン) ノ' 咳 = 0 ; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' 咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 ; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' L^ = 0 [REPLACEMENTS] ; # ボーカルフライなどで母音が同定できない場合 - = + = ; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = + = ; # 非語彙的な母音の引き延ばし - = + = ; # 非語彙的な子音の引き延ばし - = + = ; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = +<笑> = ; # 言語音と独立に講演者の咳が生じている場合 -<咳> = +<咳> = ; # 言語音と独立に講演者の息が生じている場合 -<息> = +<息> = ; # 講演者の泣き声 -<泣> = +<泣> = ; # 聴衆(司会者なども含む)の発話 -<フロア発話> = +<フロア発話> = ; # 聴衆の笑い -<フロア笑> = +<フロア笑> = ; # 聴衆の拍手 -<拍手> = +<拍手> = ; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = +<デモ> = ; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = +<ベル> = ; # 転記単位全体が再度読み直された場合 -<朗読間違い> = +<朗読間違い> = ; # 上記以外の音で特に目立った音 -<雑音> = +<雑音> = ; # 0.2秒以上のポーズ -

= +

= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -318,4 +318,3 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo - diff --git a/egs/csj/ASR/local/conf/number.ini b/egs/csj/ASR/local/conf/number.ini index 2613c3409..cf9038f62 100644 --- a/egs/csj/ASR/local/conf/number.ini +++ b/egs/csj/ASR/local/conf/number.ini @@ -1,17 +1,17 @@ ; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj +; # lhotse prepare csj [SEGMENTS] ; # Allowed period of nonverbal noise. If exceeded, a new segment is created. gap = 0.5 ; # Maximum length of segment (s). maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. +; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = +; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. +; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. +; # If you intend to use a multicharacter string for gap_sym, remember to register the +; # multicharacter string as part of userdef-string in prepare_lang_char.py. +gap_sym = [CONSTANTS] ; # Name of this mode @@ -115,59 +115,59 @@ B^ = 0 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', +; # Example: 'コク(笑 サイ+(D オン))', 笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' +; # Example: '(泣 ドンナニ)' 泣 = 0 泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' +; # Example: 'シャ(咳 リン) ノ' 咳 = 0 ; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' 咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 ; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' L^ = 0 [REPLACEMENTS] ; # ボーカルフライなどで母音が同定できない場合 - = + = ; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = + = ; # 非語彙的な母音の引き延ばし - = + = ; # 非語彙的な子音の引き延ばし - = + = ; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = +<笑> = ; # 言語音と独立に講演者の咳が生じている場合 -<咳> = +<咳> = ; # 言語音と独立に講演者の息が生じている場合 -<息> = +<息> = ; # 講演者の泣き声 -<泣> = +<泣> = ; # 聴衆(司会者なども含む)の発話 -<フロア発話> = +<フロア発話> = ; # 聴衆の笑い -<フロア笑> = +<フロア笑> = ; # 聴衆の拍手 -<拍手> = +<拍手> = ; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = +<デモ> = ; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = +<ベル> = ; # 転記単位全体が再度読み直された場合 -<朗読間違い> = +<朗読間違い> = ; # 上記以外の音で特に目立った音 -<雑音> = +<雑音> = ; # 0.2秒以上のポーズ -

= +

= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -318,4 +318,3 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo - diff --git a/egs/csj/ASR/local/conf/symbol.ini b/egs/csj/ASR/local/conf/symbol.ini index 8ba451dd5..f9801284b 100644 --- a/egs/csj/ASR/local/conf/symbol.ini +++ b/egs/csj/ASR/local/conf/symbol.ini @@ -1,17 +1,17 @@ ; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj +; # lhotse prepare csj [SEGMENTS] ; # Allowed period of nonverbal noise. If exceeded, a new segment is created. gap = 0.5 ; # Maximum length of segment (s). maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. +; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = +; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. +; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. +; # If you intend to use a multicharacter string for gap_sym, remember to register the +; # multicharacter string as part of userdef-string in prepare_lang_char.py. +gap_sym = [CONSTANTS] ; # Name of this mode @@ -116,59 +116,59 @@ B^ = 0 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', +; # Example: 'コク(笑 サイ+(D オン))', 笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' +; # Example: '(泣 ドンナニ)' 泣 = 0 泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' +; # Example: 'シャ(咳 リン) ノ' 咳 = 0 ; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' 咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 ; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' L^ = 0 [REPLACEMENTS] ; # ボーカルフライなどで母音が同定できない場合 - = + = ; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = + = ; # 非語彙的な母音の引き延ばし - = + = ; # 非語彙的な子音の引き延ばし - = + = ; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = +<笑> = ; # 言語音と独立に講演者の咳が生じている場合 -<咳> = +<咳> = ; # 言語音と独立に講演者の息が生じている場合 -<息> = +<息> = ; # 講演者の泣き声 -<泣> = +<泣> = ; # 聴衆(司会者なども含む)の発話 -<フロア発話> = +<フロア発話> = ; # 聴衆の笑い -<フロア笑> = +<フロア笑> = ; # 聴衆の拍手 -<拍手> = +<拍手> = ; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = +<デモ> = ; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = +<ベル> = ; # 転記単位全体が再度読み直された場合 -<朗読間違い> = +<朗読間違い> = ; # 上記以外の音で特に目立った音 -<雑音> = +<雑音> = ; # 0.2秒以上のポーズ -

= +

= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -319,4 +319,3 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo - diff --git a/egs/csj/ASR/local/display_manifest_statistics.py b/egs/csj/ASR/local/display_manifest_statistics.py index c9de21073..c043cf853 100644 --- a/egs/csj/ASR/local/display_manifest_statistics.py +++ b/egs/csj/ASR/local/display_manifest_statistics.py @@ -37,9 +37,7 @@ def get_parser(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument( - "--manifest-dir", type=Path, help="Path to cutset manifests" - ) + parser.add_argument("--manifest-dir", type=Path, help="Path to cutset manifests") return parser.parse_args() diff --git a/egs/csj/ASR/local/prepare_lang_char.py b/egs/csj/ASR/local/prepare_lang_char.py index e4d996871..f0078421b 100644 --- a/egs/csj/ASR/local/prepare_lang_char.py +++ b/egs/csj/ASR/local/prepare_lang_char.py @@ -68,8 +68,7 @@ def get_args(): type=Path, default=None, help=( - "Name of lang dir. " - "If not set, this will default to lang_char_{trans-mode}" + "Name of lang dir. If not set, this will default to lang_char_{trans-mode}" ), ) @@ -87,9 +86,7 @@ def main(): args = get_args() logging.basicConfig( - format=( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] " "%(message)s" - ), + format="%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s", level=logging.INFO, ) @@ -111,8 +108,7 @@ def main(): words = set() logging.info( - f"Creating vocabulary from {args.train_cut.name}" - f" at {args.trans_mode} mode." + f"Creating vocabulary from {args.train_cut.name} at {args.trans_mode} mode." ) for cut in train_set: try: @@ -123,8 +119,7 @@ def main(): ) except KeyError: raise KeyError( - f"Could not find {args.trans_mode} in " - f"{cut.supervisions[0].custom}" + f"Could not find {args.trans_mode} in {cut.supervisions[0].custom}" ) for t in text.split(): if t in args.userdef_string: @@ -143,9 +138,7 @@ def main(): (args.lang_dir / "words_len").write_text(f"{len(words)}") - (args.lang_dir / "userdef_string").write_text( - "\n".join(args.userdef_string) - ) + (args.lang_dir / "userdef_string").write_text("\n".join(args.userdef_string)) (args.lang_dir / "trans_mode").write_text(args.trans_mode) logging.info("Done.") diff --git a/egs/csj/ASR/local/validate_manifest.py b/egs/csj/ASR/local/validate_manifest.py index 0c4c6c1ea..89448a49c 100644 --- a/egs/csj/ASR/local/validate_manifest.py +++ b/egs/csj/ASR/local/validate_manifest.py @@ -68,8 +68,7 @@ def validate_supervision_and_cut_time_bounds(c: Cut): if s.end > c.end: raise ValueError( - f"{c.id}: Supervision end time {s.end} is larger " - f"than cut end time {c.end}" + f"{c.id}: Supervision end time {s.end} is larger than cut end time {c.end}" ) @@ -89,9 +88,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py index d78e26240..c3e3e84bf 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py @@ -61,10 +61,12 @@ class GigaSpeechAsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", + description=( + "These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc." + ), ) group.add_argument( "--manifest-dir", @@ -76,75 +78,91 @@ class GigaSpeechAsrDataModule: "--max-duration", type=int, default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", + help=( + "Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM." + ), ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", + help=( + "When enabled, the batches will come from buckets of " + "similar duration (saves padding frames)." + ), ) group.add_argument( "--num-buckets", type=int, default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", + help=( + "The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets)." + ), ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", + help=( + "When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding." + ), ) group.add_argument( "--duration-factor", type=float, default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", + help=( + "Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch." + ), ) group.add_argument( "--gap", type=float, default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", + help=( + "The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used." + ), ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", + help=( + "When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available." + ), ) group.add_argument( "--shuffle", type=str2bool, default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", + help=( + "When enabled (=default), the examples will be shuffled for each epoch." + ), ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", + help=( + "When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it." + ), ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that " - "collect the batches.", + help="The number of training dataloader workers that collect the batches.", ) group.add_argument( @@ -158,18 +176,22 @@ class GigaSpeechAsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", + help=( + "Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp." + ), ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help="When enabled, select noise from MUSAN and mix it " - "with training dataset. ", + help=( + "When enabled, select noise from MUSAN and mix it " + "with training dataset. " + ), ) # GigaSpeech specific arguments @@ -183,30 +205,25 @@ class GigaSpeechAsrDataModule: "--small-dev", type=str2bool, default=False, - help="Should we use only 1000 utterances for dev " - "(speeds up training)", + help="Should we use only 1000 utterances for dev (speeds up training)", ) def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - f"Using cut concatenation with duration factor " + "Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -221,9 +238,7 @@ class GigaSpeechAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") input_transforms.append( SpecAugment( time_warp_factor=self.args.spec_aug_time_warp_factor, @@ -256,9 +271,7 @@ class GigaSpeechAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -304,9 +317,7 @@ class GigaSpeechAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: @@ -362,9 +373,7 @@ class GigaSpeechAsrDataModule: @lru_cache() def dev_cuts(self) -> CutSet: logging.info("About to get dev cuts") - cuts_valid = load_manifest_lazy( - self.args.manifest_dir / "cuts_DEV.jsonl.gz" - ) + cuts_valid = load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") if self.args.small_dev: return cuts_valid.subset(first=1000) else: diff --git a/egs/gigaspeech/ASR/conformer_ctc/conformer.py b/egs/gigaspeech/ASR/conformer_ctc/conformer.py index 6fac07f93..1153a814c 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/conformer.py +++ b/egs/gigaspeech/ASR/conformer_ctc/conformer.py @@ -160,9 +160,7 @@ class ConformerEncoderLayer(nn.Module): use_conv_batchnorm: bool = False, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -182,18 +180,14 @@ class ConformerEncoderLayer(nn.Module): d_model, cnn_module_kernel, use_batchnorm=use_conv_batchnorm ) - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -227,9 +221,7 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) + src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -348,9 +340,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -366,9 +356,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -638,9 +626,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -708,33 +696,25 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor" + " instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -771,9 +751,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -785,9 +763,7 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -821,13 +797,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads diff --git a/egs/gigaspeech/ASR/conformer_ctc/decode.py b/egs/gigaspeech/ASR/conformer_ctc/decode.py index 9c1418baa..b38ae9c8c 100755 --- a/egs/gigaspeech/ASR/conformer_ctc/decode.py +++ b/egs/gigaspeech/ASR/conformer_ctc/decode.py @@ -62,16 +62,19 @@ def get_parser(): "--epoch", type=int, default=0, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=1, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -476,9 +479,7 @@ def decode_dataset( results[lm_scale].extend(this_batch) else: - assert ( - len(results) > 0 - ), "It should not decode to empty in the first batch!" + assert len(results) > 0, "It should not decode to empty in the first batch!" this_batch = [] hyp_words = [] for cut_id, ref_text in zip(cut_ids, texts): @@ -493,9 +494,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -528,9 +527,7 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -705,9 +702,7 @@ def main(): eos_id=eos_id, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py b/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py index ef53b77f8..880aa76e2 100755 --- a/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py +++ b/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py @@ -73,8 +73,7 @@ def asr_text_post_processing(text: str) -> str: if __name__ == "__main__": parser = argparse.ArgumentParser( - description="This script evaluates GigaSpeech ASR result via" - "SCTK's tool sclite" + description="This script evaluates GigaSpeech ASR result viaSCTK's tool sclite" ) parser.add_argument( "ref", diff --git a/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py b/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py index cdc85ce9a..3b94f0c4b 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py +++ b/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py @@ -78,13 +78,10 @@ class LabelSmoothingLoss(torch.nn.Module): ignored = target == self.ignore_index target[ignored] = 0 - true_dist = torch.nn.functional.one_hot( - target, num_classes=num_classes - ).to(x) + true_dist = torch.nn.functional.one_hot(target, num_classes=num_classes).to(x) true_dist = ( - true_dist * (1 - self.label_smoothing) - + self.label_smoothing / num_classes + true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes ) # Set the value of ignored indexes to 0 true_dist[ignored] = 0 diff --git a/egs/gigaspeech/ASR/conformer_ctc/subsampling.py b/egs/gigaspeech/ASR/conformer_ctc/subsampling.py index 542fb0364..8e0f73d05 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/subsampling.py +++ b/egs/gigaspeech/ASR/conformer_ctc/subsampling.py @@ -42,13 +42,9 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), - nn.Conv2d( - in_channels=odim, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -132,17 +128,13 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d( - kernel_size=2, stride=2, padding=0, ceil_mode=True - ) + torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear( - block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim - ) + self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/gigaspeech/ASR/conformer_ctc/train.py b/egs/gigaspeech/ASR/conformer_ctc/train.py index 2965cde18..4883d04d8 100755 --- a/egs/gigaspeech/ASR/conformer_ctc/train.py +++ b/egs/gigaspeech/ASR/conformer_ctc/train.py @@ -386,9 +386,7 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids( - supervisions["text"] - ) + unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -521,9 +519,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -641,9 +637,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/gigaspeech/ASR/conformer_ctc/transformer.py b/egs/gigaspeech/ASR/conformer_ctc/transformer.py index 00ca027a7..0566cfc81 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/transformer.py +++ b/egs/gigaspeech/ASR/conformer_ctc/transformer.py @@ -151,9 +151,7 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear( - d_model, self.decoder_num_class - ) + self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) self.decoder_criterion = LabelSmoothingLoss() else: @@ -181,18 +179,13 @@ class Transformer(nn.Module): memory_key_padding_mask for the decoder. Its shape is (N, T). It is None if `supervision` is None. """ - if ( - isinstance(self.use_feat_batchnorm, bool) - and self.use_feat_batchnorm - ): + if isinstance(self.use_feat_batchnorm, bool) and self.use_feat_batchnorm: x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) if isinstance(self.use_feat_batchnorm, float): x *= self.use_feat_batchnorm - encoder_memory, memory_key_padding_mask = self.run_encoder( - x, supervision - ) + encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -273,23 +266,17 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -350,23 +337,17 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -639,9 +620,7 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class PositionalEncoding(nn.Module): @@ -843,9 +822,7 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -866,9 +843,7 @@ def encoder_padding_mask( return mask -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py index 8209ee3ec..07beeb1f0 100755 --- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py +++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py @@ -77,9 +77,7 @@ def compute_fbank_gigaspeech_dev_test(): def main(): - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_gigaspeech_dev_test() diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py index 6410249db..0ee845ec8 100755 --- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py +++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py @@ -47,8 +47,10 @@ def get_parser(): "--batch-duration", type=float, default=600.0, - help="The maximum number of audio seconds in a batch." - "Determines batch size dynamically.", + help=( + "The maximum number of audio seconds in a batch." + "Determines batch size dynamically." + ), ) parser.add_argument( @@ -134,9 +136,7 @@ def main(): date_time = now.strftime("%Y-%m-%d-%H-%M-%S") log_filename = "log-compute_fbank_gigaspeech_splits" - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" log_filename = f"{log_filename}-{date_time}" logging.basicConfig( diff --git a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py index 48d10a157..31abe7fff 100755 --- a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py +++ b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py @@ -98,19 +98,13 @@ def preprocess_giga_speech(): f"Speed perturb for {partition} with factors 0.9 and 1.1 " "(Perturbing may take 8 minutes and saving may take 20 minutes)" ) - cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) - ) + cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) logging.info(f"Saving to {raw_cuts_path}") cut_set.to_file(raw_cuts_path) def main(): - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) preprocess_giga_speech() diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index c87686e1e..9ae3f071e 100644 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -73,10 +73,12 @@ class GigaSpeechAsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", + description=( + "These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc." + ), ) group.add_argument( "--manifest-dir", @@ -88,75 +90,91 @@ class GigaSpeechAsrDataModule: "--max-duration", type=int, default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", + help=( + "Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM." + ), ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", + help=( + "When enabled, the batches will come from buckets of " + "similar duration (saves padding frames)." + ), ) group.add_argument( "--num-buckets", type=int, default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", + help=( + "The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets)." + ), ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", + help=( + "When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding." + ), ) group.add_argument( "--duration-factor", type=float, default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", + help=( + "Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch." + ), ) group.add_argument( "--gap", type=float, default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", + help=( + "The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used." + ), ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", + help=( + "When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available." + ), ) group.add_argument( "--shuffle", type=str2bool, default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", + help=( + "When enabled (=default), the examples will be shuffled for each epoch." + ), ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", + help=( + "When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it." + ), ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that " - "collect the batches.", + help="The number of training dataloader workers that collect the batches.", ) group.add_argument( @@ -170,18 +188,22 @@ class GigaSpeechAsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", + help=( + "Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp." + ), ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help="When enabled, select noise from MUSAN and mix it " - "with training dataset. ", + help=( + "When enabled, select noise from MUSAN and mix it " + "with training dataset. " + ), ) # GigaSpeech specific arguments @@ -195,8 +217,7 @@ class GigaSpeechAsrDataModule: "--small-dev", type=str2bool, default=False, - help="Should we use only 1000 utterances for dev " - "(speeds up training)", + help="Should we use only 1000 utterances for dev (speeds up training)", ) def train_dataloaders( @@ -216,20 +237,16 @@ class GigaSpeechAsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - f"Using cut concatenation with duration factor " + "Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -244,9 +261,7 @@ class GigaSpeechAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -289,9 +304,7 @@ class GigaSpeechAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -347,9 +360,7 @@ class GigaSpeechAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: @@ -405,9 +416,7 @@ class GigaSpeechAsrDataModule: @lru_cache() def dev_cuts(self) -> CutSet: logging.info("About to get dev cuts") - cuts_valid = load_manifest_lazy( - self.args.manifest_dir / "cuts_DEV.jsonl.gz" - ) + cuts_valid = load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") if self.args.small_dev: return cuts_valid.subset(first=1000) else: diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py index 5849a3471..9f5d4711b 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py @@ -77,11 +77,7 @@ from beam_search import ( from gigaspeech_scoring import asr_text_post_processing from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -118,9 +114,11 @@ def get_parser(): "--avg", type=int, default=8, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -188,8 +186,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -258,9 +255,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -275,10 +270,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -324,11 +316,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -398,9 +386,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -434,8 +420,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -511,8 +496,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py index cff9c7377..17f8614dc 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py @@ -51,11 +51,7 @@ import sentencepiece as spm import torch from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import str2bool @@ -87,9 +83,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -120,8 +118,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) return parser @@ -160,8 +157,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -209,9 +205,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py index 83ae25561..4d1a2356d 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py @@ -77,9 +77,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def get_parser(): @@ -178,42 +176,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -553,23 +554,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -732,9 +726,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") diff --git a/egs/librispeech/ASR/conformer_ctc/ali.py b/egs/librispeech/ASR/conformer_ctc/ali.py index 2828e309e..0169d0f82 100755 --- a/egs/librispeech/ASR/conformer_ctc/ali.py +++ b/egs/librispeech/ASR/conformer_ctc/ali.py @@ -61,16 +61,19 @@ def get_parser(): "--epoch", type=int, default=34, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=20, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -231,9 +234,7 @@ def compute_alignments( labels_ali = get_alignments(best_path, kind="labels") aux_labels_ali = get_alignments(best_path, kind="aux_labels") assert len(labels_ali) == len(aux_labels_ali) == len(cut_list) - for cut, labels, aux_labels in zip( - cut_list, labels_ali, aux_labels_ali - ): + for cut, labels, aux_labels in zip(cut_list, labels_ali, aux_labels_ali): cut.labels_alignment = labels_writer.store_array( key=cut.id, value=np.asarray(labels, dtype=np.int32), @@ -258,9 +259,7 @@ def compute_alignments( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return CutSet.from_cuts(cuts) @@ -289,9 +288,7 @@ def main(): out_labels_ali_filename = out_dir / f"labels_{params.dataset}.h5" out_aux_labels_ali_filename = out_dir / f"aux_labels_{params.dataset}.h5" - out_manifest_filename = ( - out_dir / f"librispeech_cuts_{params.dataset}.jsonl.gz" - ) + out_manifest_filename = out_dir / f"librispeech_cuts_{params.dataset}.jsonl.gz" for f in ( out_labels_ali_filename, diff --git a/egs/librispeech/ASR/conformer_ctc/conformer.py b/egs/librispeech/ASR/conformer_ctc/conformer.py index 6fac07f93..1153a814c 100644 --- a/egs/librispeech/ASR/conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc/conformer.py @@ -160,9 +160,7 @@ class ConformerEncoderLayer(nn.Module): use_conv_batchnorm: bool = False, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -182,18 +180,14 @@ class ConformerEncoderLayer(nn.Module): d_model, cnn_module_kernel, use_batchnorm=use_conv_batchnorm ) - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -227,9 +221,7 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) + src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -348,9 +340,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -366,9 +356,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -638,9 +626,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -708,33 +696,25 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor" + " instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -771,9 +751,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -785,9 +763,7 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -821,13 +797,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 3f3b1acda..66fdf82d9 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -64,16 +64,19 @@ def get_parser(): "--epoch", type=int, default=77, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=55, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -551,9 +554,7 @@ def decode_dataset( results[lm_scale].extend(this_batch) else: - assert ( - len(results) > 0 - ), "It should not decode to empty in the first batch!" + assert len(results) > 0, "It should not decode to empty in the first batch!" this_batch = [] hyp_words = [] for ref_text in texts: @@ -568,9 +569,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -602,9 +601,7 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -809,9 +806,7 @@ def main(): eos_id=eos_id, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/librispeech/ASR/conformer_ctc/export.py b/egs/librispeech/ASR/conformer_ctc/export.py index 28c28df01..bdb8a85e5 100755 --- a/egs/librispeech/ASR/conformer_ctc/export.py +++ b/egs/librispeech/ASR/conformer_ctc/export.py @@ -40,17 +40,20 @@ def get_parser(): "--epoch", type=int, default=34, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=20, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -157,9 +160,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conformer_ctc/label_smoothing.py b/egs/librispeech/ASR/conformer_ctc/label_smoothing.py index 1f2f3b137..cb0d6e04d 100644 --- a/egs/librispeech/ASR/conformer_ctc/label_smoothing.py +++ b/egs/librispeech/ASR/conformer_ctc/label_smoothing.py @@ -82,13 +82,10 @@ class LabelSmoothingLoss(torch.nn.Module): # for why we don't use target[ignored] = 0 here target = torch.where(ignored, torch.zeros_like(target), target) - true_dist = torch.nn.functional.one_hot( - target, num_classes=num_classes - ).to(x) + true_dist = torch.nn.functional.one_hot(target, num_classes=num_classes).to(x) true_dist = ( - true_dist * (1 - self.label_smoothing) - + self.label_smoothing / num_classes + true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes ) # Set the value of ignored indexes to 0 diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index a2c0a5486..8cabf1a53 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -48,9 +48,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -189,10 +191,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) return parser @@ -236,10 +240,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -300,9 +303,7 @@ def main(): logging.info("Decoding started") features = fbank(waves) - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) # Note: We don't use key padding mask for attention during decoding with torch.no_grad(): @@ -427,9 +428,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 542fb0364..8e0f73d05 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -42,13 +42,9 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), - nn.Conv2d( - in_channels=odim, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -132,17 +128,13 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d( - kernel_size=2, stride=2, padding=0, ceil_mode=True - ) + torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear( - block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim - ) + self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 6419f6816..1a1c2f4c5 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -393,9 +393,7 @@ def compute_loss( # Works with a phone lexicon decoding_graph = graph_compiler.compile(texts) else: - raise ValueError( - f"Unsupported type of graph compiler: {type(graph_compiler)}" - ) + raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}") dense_fsa_vec = k2.DenseFsaVec( nnet_output, @@ -422,9 +420,7 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids( - supervisions["text"] - ) + unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -453,9 +449,7 @@ def compute_loss( info["utt_duration"] = supervisions["num_frames"].sum().item() # averaged padding proportion over utterances info["utt_pad_proportion"] = ( - ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)) - .sum() - .item() + ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)).sum().item() ) return loss, info @@ -568,9 +562,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -660,7 +652,7 @@ def run(rank, world_size, args): graph_compiler.eos_id = 1 else: raise ValueError( - f"Unsupported type of lang dir (we expected it to have " + "Unsupported type of lang dir (we expected it to have " f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}" ) @@ -733,9 +725,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc/transformer.py index 00ca027a7..0566cfc81 100644 --- a/egs/librispeech/ASR/conformer_ctc/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/transformer.py @@ -151,9 +151,7 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear( - d_model, self.decoder_num_class - ) + self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) self.decoder_criterion = LabelSmoothingLoss() else: @@ -181,18 +179,13 @@ class Transformer(nn.Module): memory_key_padding_mask for the decoder. Its shape is (N, T). It is None if `supervision` is None. """ - if ( - isinstance(self.use_feat_batchnorm, bool) - and self.use_feat_batchnorm - ): + if isinstance(self.use_feat_batchnorm, bool) and self.use_feat_batchnorm: x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) if isinstance(self.use_feat_batchnorm, float): x *= self.use_feat_batchnorm - encoder_memory, memory_key_padding_mask = self.run_encoder( - x, supervision - ) + encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -273,23 +266,17 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -350,23 +337,17 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -639,9 +620,7 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class PositionalEncoding(nn.Module): @@ -843,9 +822,7 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -866,9 +843,7 @@ def encoder_padding_mask( return mask -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/librispeech/ASR/conformer_ctc2/attention.py b/egs/librispeech/ASR/conformer_ctc2/attention.py index 1375d7245..356d3f21b 100644 --- a/egs/librispeech/ASR/conformer_ctc2/attention.py +++ b/egs/librispeech/ASR/conformer_ctc2/attention.py @@ -18,11 +18,10 @@ from typing import Optional, Tuple import torch import torch.nn as nn +from scaling import ScaledLinear from torch import Tensor from torch.nn.init import xavier_normal_ -from scaling import ScaledLinear - class MultiheadAttention(nn.Module): r"""Allows the model to jointly attend to information @@ -76,9 +75,7 @@ class MultiheadAttention(nn.Module): self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim - self._qkv_same_embed_dim = ( - self.kdim == embed_dim and self.vdim == embed_dim - ) + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim self.num_heads = num_heads self.dropout = dropout @@ -94,9 +91,7 @@ class MultiheadAttention(nn.Module): self.v_proj_weight = ScaledLinear(self.vdim, embed_dim, bias=bias) self.register_parameter("in_proj_weight", None) else: - self.in_proj_weight = ScaledLinear( - embed_dim, 3 * embed_dim, bias=bias - ) + self.in_proj_weight = ScaledLinear(embed_dim, 3 * embed_dim, bias=bias) self.register_parameter("q_proj_weight", None) self.register_parameter("k_proj_weight", None) self.register_parameter("v_proj_weight", None) @@ -107,12 +102,8 @@ class MultiheadAttention(nn.Module): self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=bias) if add_bias_kv: - self.bias_k = nn.Parameter( - torch.empty((1, 1, embed_dim), **factory_kwargs) - ) - self.bias_v = nn.Parameter( - torch.empty((1, 1, embed_dim), **factory_kwargs) - ) + self.bias_k = nn.Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) + self.bias_v = nn.Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) else: self.bias_k = self.bias_v = None diff --git a/egs/librispeech/ASR/conformer_ctc2/conformer.py b/egs/librispeech/ASR/conformer_ctc2/conformer.py index b906d2650..a6f1679ef 100644 --- a/egs/librispeech/ASR/conformer_ctc2/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc2/conformer.py @@ -29,9 +29,8 @@ from scaling import ( ScaledConv1d, ScaledLinear, ) -from torch import Tensor, nn from subsampling import Conv2dSubsampling - +from torch import Tensor, nn from transformer import Supervisions, Transformer, encoder_padding_mask @@ -182,9 +181,7 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -356,9 +353,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -373,9 +368,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vecotr and `j` means the @@ -650,9 +643,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -721,33 +714,25 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor" + " instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -784,9 +769,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -794,13 +777,9 @@ class RelPositionMultiheadAttention(nn.Module): ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -834,13 +813,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -863,9 +838,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/conformer_ctc2/decode.py b/egs/librispeech/ASR/conformer_ctc2/decode.py index 97f2f2d39..934177b1f 100755 --- a/egs/librispeech/ASR/conformer_ctc2/decode.py +++ b/egs/librispeech/ASR/conformer_ctc2/decode.py @@ -90,9 +90,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -130,11 +132,13 @@ def get_parser(): "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -658,9 +662,7 @@ def decode_dataset( results[lm_scale].extend(this_batch) else: - assert ( - len(results) > 0 - ), "It should not decode to empty in the first batch!" + assert len(results) > 0, "It should not decode to empty in the first batch!" this_batch = [] hyp_words = [] for ref_text in texts: @@ -675,9 +677,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -709,9 +709,7 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -852,13 +850,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -881,13 +878,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -915,7 +911,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -985,9 +981,7 @@ def main(): eos_id=eos_id, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/librispeech/ASR/conformer_ctc2/export.py b/egs/librispeech/ASR/conformer_ctc2/export.py index 584b3c3fc..0e1841d8d 100755 --- a/egs/librispeech/ASR/conformer_ctc2/export.py +++ b/egs/librispeech/ASR/conformer_ctc2/export.py @@ -47,6 +47,7 @@ import logging from pathlib import Path import torch +from conformer import Conformer from decode import get_params from icefall.checkpoint import ( @@ -55,10 +56,8 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from conformer import Conformer - -from icefall.utils import str2bool from icefall.lexicon import Lexicon +from icefall.utils import str2bool def get_parser(): @@ -89,20 +88,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -177,13 +180,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -206,13 +208,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -240,7 +241,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -273,9 +274,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conformer_ctc2/train.py b/egs/librispeech/ASR/conformer_ctc2/train.py index 9d9c2af1f..63534b76b 100755 --- a/egs/librispeech/ASR/conformer_ctc2/train.py +++ b/egs/librispeech/ASR/conformer_ctc2/train.py @@ -69,8 +69,8 @@ from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter -from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall import diagnostics +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import ( @@ -89,9 +89,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def get_parser(): @@ -505,11 +503,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -546,9 +540,7 @@ def compute_loss( # Works with a phone lexicon decoding_graph = graph_compiler.compile(texts) else: - raise ValueError( - f"Unsupported type of graph compiler: {type(graph_compiler)}" - ) + raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}") dense_fsa_vec = k2.DenseFsaVec( nnet_output, @@ -575,9 +567,7 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids( - supervisions["text"] - ) + unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -595,9 +585,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() info["ctc_loss"] = ctc_loss.detach().cpu().item() if params.att_rate != 0.0: info["att_loss"] = att_loss.detach().cpu().item() @@ -735,8 +723,7 @@ def train_one_epoch( except RuntimeError as e: if "CUDA out of memory" in str(e): logging.error( - f"failing batch size:{batch_size} " - f"failing batch names {batch_name}" + f"failing batch size:{batch_size} failing batch names {batch_name}" ) raise @@ -791,9 +778,9 @@ def train_one_epoch( f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"lr: {cur_lr:.2e}" ) - if loss_info["ctc_loss"] == float("inf") or loss_info[ - "att_loss" - ] == float("inf"): + if loss_info["ctc_loss"] == float("inf") or loss_info["att_loss"] == float( + "inf" + ): logging.error( "Your loss contains inf, something goes wrong" f"failing batch names {batch_name}" @@ -806,9 +793,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -900,7 +885,7 @@ def run(rank, world_size, args): graph_compiler.eos_id = 1 else: raise ValueError( - f"Unsupported type of lang dir (we expected it to have " + "Unsupported type of lang dir (we expected it to have " f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}" ) diff --git a/egs/librispeech/ASR/conformer_ctc2/transformer.py b/egs/librispeech/ASR/conformer_ctc2/transformer.py index fa179acc0..8f0c7dcde 100644 --- a/egs/librispeech/ASR/conformer_ctc2/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc2/transformer.py @@ -21,19 +21,17 @@ from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn -from label_smoothing import LabelSmoothingLoss -from subsampling import Conv2dSubsampling from attention import MultiheadAttention -from torch.nn.utils.rnn import pad_sequence - +from label_smoothing import LabelSmoothingLoss from scaling import ( ActivationBalancer, BasicNorm, DoubleSwish, - ScaledLinear, ScaledEmbedding, + ScaledLinear, ) - +from subsampling import Conv2dSubsampling +from torch.nn.utils.rnn import pad_sequence # Note: TorchScript requires Dict/List/etc. to be fully typed. Supervisions = Dict[str, torch.Tensor] @@ -210,9 +208,7 @@ class Transformer(nn.Module): x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) mask = encoder_padding_mask(x.size(0), supervisions) mask = mask.to(x.device) if mask is not None else None - x = self.encoder( - x, src_key_padding_mask=mask, warmup=warmup - ) # (T, N, C) + x = self.encoder(x, src_key_padding_mask=mask, warmup=warmup) # (T, N, C) return x, mask @@ -261,23 +257,17 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -338,23 +328,17 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -659,9 +643,7 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class TransformerEncoder(nn.Module): @@ -982,9 +964,7 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -1005,9 +985,7 @@ def encoder_padding_mask( return mask -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/librispeech/ASR/conformer_mmi/conformer.py b/egs/librispeech/ASR/conformer_mmi/conformer.py index 97c8d83a2..4d9ddaea9 100644 --- a/egs/librispeech/ASR/conformer_mmi/conformer.py +++ b/egs/librispeech/ASR/conformer_mmi/conformer.py @@ -156,9 +156,7 @@ class ConformerEncoderLayer(nn.Module): normalize_before: bool = True, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -176,18 +174,14 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -221,9 +215,7 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) + src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -342,9 +334,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -360,9 +350,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -632,9 +620,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -702,33 +690,25 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor" + " instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -765,9 +745,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -779,9 +757,7 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -815,13 +791,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -844,9 +816,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/conformer_mmi/decode.py b/egs/librispeech/ASR/conformer_mmi/decode.py index fc9861489..e8390ded9 100755 --- a/egs/librispeech/ASR/conformer_mmi/decode.py +++ b/egs/librispeech/ASR/conformer_mmi/decode.py @@ -60,16 +60,19 @@ def get_parser(): "--epoch", type=int, default=34, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=20, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -478,9 +481,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -512,9 +513,7 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -653,9 +652,7 @@ def main(): if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save( - {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" - ) + torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") return model.to(device) @@ -687,9 +684,7 @@ def main(): eos_id=eos_id, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/librispeech/ASR/conformer_mmi/subsampling.py b/egs/librispeech/ASR/conformer_mmi/subsampling.py index 5c3e1222e..ad9415987 100644 --- a/egs/librispeech/ASR/conformer_mmi/subsampling.py +++ b/egs/librispeech/ASR/conformer_mmi/subsampling.py @@ -25,13 +25,9 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), - nn.Conv2d( - in_channels=odim, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -115,17 +111,13 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d( - kernel_size=2, stride=2, padding=0, ceil_mode=True - ) + torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear( - block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim - ) + self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/librispeech/ASR/conformer_mmi/test_subsampling.py b/egs/librispeech/ASR/conformer_mmi/test_subsampling.py index 937845d77..d0bb017dd 100755 --- a/egs/librispeech/ASR/conformer_mmi/test_subsampling.py +++ b/egs/librispeech/ASR/conformer_mmi/test_subsampling.py @@ -1,8 +1,7 @@ #!/usr/bin/env python3 -from subsampling import Conv2dSubsampling -from subsampling import VggSubsampling import torch +from subsampling import Conv2dSubsampling, VggSubsampling def test_conv2d_subsampling(): diff --git a/egs/librispeech/ASR/conformer_mmi/test_transformer.py b/egs/librispeech/ASR/conformer_mmi/test_transformer.py index 08e680607..25d18076d 100644 --- a/egs/librispeech/ASR/conformer_mmi/test_transformer.py +++ b/egs/librispeech/ASR/conformer_mmi/test_transformer.py @@ -1,17 +1,16 @@ #!/usr/bin/env python3 import torch +from torch.nn.utils.rnn import pad_sequence from transformer import ( Transformer, + add_eos, + add_sos, + decoder_padding_mask, encoder_padding_mask, generate_square_subsequent_mask, - decoder_padding_mask, - add_sos, - add_eos, ) -from torch.nn.utils.rnn import pad_sequence - def test_encoder_padding_mask(): supervisions = { diff --git a/egs/librispeech/ASR/conformer_mmi/train-with-attention.py b/egs/librispeech/ASR/conformer_mmi/train-with-attention.py index 011dadd73..f8c94cff9 100755 --- a/egs/librispeech/ASR/conformer_mmi/train-with-attention.py +++ b/egs/librispeech/ASR/conformer_mmi/train-with-attention.py @@ -36,23 +36,14 @@ from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter from transformer import Noam -from icefall.ali import ( - convert_alignments_to_tensor, - load_alignments, - lookup_alignments, -) +from icefall.ali import convert_alignments_to_tensor, load_alignments, lookup_alignments from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.lexicon import Lexicon from icefall.mmi import LFMMILoss from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler -from icefall.utils import ( - AttributeDict, - encode_supervisions, - setup_logger, - str2bool, -) +from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool def get_parser(): @@ -370,10 +361,7 @@ def compute_loss( nnet_output = nnet_output.clone() nnet_output[:, :min_len, :] += ali_scale * mask[:, :min_len, :] - if ( - params.batch_idx_train > params.use_ali_until - and params.beam_size < 8 - ): + if params.batch_idx_train > params.use_ali_until and params.beam_size < 8: # logging.info("Change beam size to 8") params.beam_size = 8 else: @@ -762,19 +750,14 @@ def run(rank, world_size, args): for epoch in range(params.start_epoch, params.num_epochs): train_dl.sampler.set_epoch(epoch) - if ( - params.batch_idx_train >= params.use_ali_until - and train_ali is not None - ): + if params.batch_idx_train >= params.use_ali_until and train_ali is not None: # Delete the alignments to save memory train_ali = None valid_ali = None cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/conformer_mmi/train.py b/egs/librispeech/ASR/conformer_mmi/train.py index 9a5bdcce2..5cfb2bfc7 100755 --- a/egs/librispeech/ASR/conformer_mmi/train.py +++ b/egs/librispeech/ASR/conformer_mmi/train.py @@ -36,23 +36,14 @@ from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter from transformer import Noam -from icefall.ali import ( - convert_alignments_to_tensor, - load_alignments, - lookup_alignments, -) +from icefall.ali import convert_alignments_to_tensor, load_alignments, lookup_alignments from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.lexicon import Lexicon from icefall.mmi import LFMMILoss from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler -from icefall.utils import ( - AttributeDict, - encode_supervisions, - setup_logger, - str2bool, -) +from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool def get_parser(): @@ -377,10 +368,7 @@ def compute_loss( nnet_output = nnet_output.clone() nnet_output[:, :min_len, :] += ali_scale * mask[:, :min_len, :] - if ( - params.batch_idx_train > params.use_ali_until - and params.beam_size < 8 - ): + if params.batch_idx_train > params.use_ali_until and params.beam_size < 8: logging.info("Change beam size to 8") params.beam_size = 8 else: @@ -770,19 +758,14 @@ def run(rank, world_size, args): for epoch in range(params.start_epoch, params.num_epochs): fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) - if ( - params.batch_idx_train >= params.use_ali_until - and train_ali is not None - ): + if params.batch_idx_train >= params.use_ali_until and train_ali is not None: # Delete the alignments to save memory train_ali = None valid_ali = None cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/conformer_mmi/transformer.py b/egs/librispeech/ASR/conformer_mmi/transformer.py index 68a4ff65c..2542d9abe 100644 --- a/egs/librispeech/ASR/conformer_mmi/transformer.py +++ b/egs/librispeech/ASR/conformer_mmi/transformer.py @@ -148,9 +148,7 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear( - d_model, self.decoder_num_class - ) + self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) self.decoder_criterion = LabelSmoothingLoss(self.decoder_num_class) else: @@ -182,9 +180,7 @@ class Transformer(nn.Module): x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) - encoder_memory, memory_key_padding_mask = self.run_encoder( - x, supervision - ) + encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -274,9 +270,7 @@ class Transformer(nn.Module): ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -341,9 +335,7 @@ class Transformer(nn.Module): ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -616,9 +608,7 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class PositionalEncoding(nn.Module): @@ -887,9 +877,7 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -910,9 +898,7 @@ def encoder_padding_mask( return mask -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py index 620d69a19..a1c43f7f5 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py @@ -135,20 +135,24 @@ def get_parser(): "--avg", type=int, default=10, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -215,8 +219,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -284,9 +287,7 @@ def decode_one_batch( value=LOG_EPS, ) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -301,10 +302,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -350,11 +348,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -427,9 +421,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -462,8 +454,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -506,9 +497,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -540,13 +529,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -569,13 +557,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -603,7 +590,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py index 8ca7d5568..0639ba746 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py @@ -35,7 +35,6 @@ from scaling import ( from icefall.utils import make_pad_mask - LOG_EPSILON = math.log(1e-10) @@ -127,9 +126,7 @@ def stack_states( for si, s in enumerate(layer): attn_caches[li][si].append(s) if b == batch_size - 1: - attn_caches[li][si] = torch.stack( - attn_caches[li][si], dim=1 - ) + attn_caches[li][si] = torch.stack(attn_caches[li][si], dim=1) conv_caches = [] for layer in state_list[0][1]: @@ -268,9 +265,7 @@ class ConvolutionModule(nn.Module): intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length ) - first = torch.arange( - self.chunk_length, self.chunk_length + self.cache_size - ) + first = torch.arange(self.chunk_length, self.chunk_length + self.cache_size) indexes = intervals.unsqueeze(1) + first.unsqueeze(0) indexes = torch.cat( [indexes, torch.arange(U_ - self.cache_size, U_).unsqueeze(0)] @@ -284,9 +279,7 @@ class ConvolutionModule(nn.Module): # (num_chunks * B, cache_size + right_context_length, D) return pad_right_context.permute(0, 2, 1) - def _merge_right_context( - self, right_context: torch.Tensor, B: int - ) -> torch.Tensor: + def _merge_right_context(self, right_context: torch.Tensor, B: int) -> torch.Tensor: """ Args: right_context: @@ -337,12 +330,8 @@ class ConvolutionModule(nn.Module): right_context = x[:, :, :R] # (B, D, R) # make causal convolution - cache = torch.zeros( - B, D, self.cache_size, device=x.device, dtype=x.dtype - ) - pad_utterance = torch.cat( - [cache, utterance], dim=2 - ) # (B, D, cache + U) + cache = torch.zeros(B, D, self.cache_size, device=x.device, dtype=x.dtype) + pad_utterance = torch.cat([cache, utterance], dim=2) # (B, D, cache + U) # depth-wise conv on utterance utterance = self.depthwise_conv(pad_utterance) # (B, D, U) @@ -355,9 +344,7 @@ class ConvolutionModule(nn.Module): right_context = self.depthwise_conv( pad_right_context ) # (num_segs * B, D, right_context_length) - right_context = self._merge_right_context( - right_context, B - ) # (B, D, R) + right_context = self._merge_right_context(right_context, B) # (B, D, R) x = torch.cat([right_context, utterance], dim=2) # (B, D, R + U) x = self.deriv_balancer2(x) @@ -458,8 +445,7 @@ class EmformerAttention(nn.Module): if embed_dim % nhead != 0: raise ValueError( - f"embed_dim ({embed_dim}) is not a multiple of" - f"nhead ({nhead})." + f"embed_dim ({embed_dim}) is not a multiple ofnhead ({nhead})." ) self.embed_dim = embed_dim @@ -469,9 +455,7 @@ class EmformerAttention(nn.Module): self.head_dim = embed_dim // nhead self.dropout = dropout - self.emb_to_key_value = ScaledLinear( - embed_dim, 2 * embed_dim, bias=True - ) + self.emb_to_key_value = ScaledLinear(embed_dim, 2 * embed_dim, bias=True) self.emb_to_query = ScaledLinear(embed_dim, embed_dim, bias=True) self.out_proj = ScaledLinear( embed_dim, embed_dim, bias=True, initial_scale=0.25 @@ -513,9 +497,7 @@ class EmformerAttention(nn.Module): if padding_mask is not None: Q = attention_weights.size(1) B = attention_weights.size(0) // self.nhead - attention_weights_float = attention_weights_float.view( - B, self.nhead, Q, -1 - ) + attention_weights_float = attention_weights_float.view(B, self.nhead, Q, -1) attention_weights_float = attention_weights_float.masked_fill( padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), self.negative_inf, @@ -551,9 +533,7 @@ class EmformerAttention(nn.Module): scaling = float(self.head_dim) ** -0.5 # compute query with [right_context, utterance, summary]. - query = self.emb_to_query( - torch.cat([right_context, utterance, summary]) - ) + query = self.emb_to_query(torch.cat([right_context, utterance, summary])) # compute key and value with [memory, right_context, utterance]. key, value = self.emb_to_key_value( torch.cat([memory, right_context, utterance]) @@ -564,16 +544,12 @@ class EmformerAttention(nn.Module): # [memory, right context, left context, uttrance] # this is used in inference mode key = torch.cat([key[: M + R], left_context_key, key[M + R :]]) - value = torch.cat( - [value[: M + R], left_context_val, value[M + R :]] - ) + value = torch.cat([value[: M + R], left_context_val, value[M + R :]]) Q = query.size(0) # KV = key.size(0) reshaped_query, reshaped_key, reshaped_value = [ - tensor.contiguous() - .view(-1, B * self.nhead, self.head_dim) - .transpose(0, 1) + tensor.contiguous().view(-1, B * self.nhead, self.head_dim).transpose(0, 1) for tensor in [query, key, value] ] # (B * nhead, Q or KV, head_dim) attention_weights = torch.bmm( @@ -588,9 +564,7 @@ class EmformerAttention(nn.Module): # compute attention outputs attention = torch.bmm(attention_probs, reshaped_value) assert attention.shape == (B * self.nhead, Q, self.head_dim) - attention = ( - attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) - ) + attention = attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) # apply output projection outputs = self.out_proj(attention) @@ -672,12 +646,7 @@ class EmformerAttention(nn.Module): - output of right context and utterance, with shape (R + U, B, D). - memory output, with shape (M, B, D), where M = S - 1 or M = 0. """ - ( - output_right_context_utterance, - output_memory, - _, - _, - ) = self._forward_impl( + (output_right_context_utterance, output_memory, _, _,) = self._forward_impl( utterance, right_context, summary, @@ -947,13 +916,9 @@ class EmformerEncoderLayer(nn.Module): right_context = right_context_utterance[:R] if self.use_memory: - summary = self.summary_op(utterance.permute(1, 2, 0)).permute( - 2, 0, 1 - ) + summary = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1) else: - summary = torch.empty(0).to( - dtype=utterance.dtype, device=utterance.device - ) + summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) output_right_context_utterance, output_memory = self.attention( utterance=utterance, right_context=right_context, @@ -992,14 +957,10 @@ class EmformerEncoderLayer(nn.Module): left_context_val = attn_cache[2] if self.use_memory: - summary = self.summary_op(utterance.permute(1, 2, 0)).permute( - 2, 0, 1 - ) + summary = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1) summary = summary[:1] else: - summary = torch.empty(0).to( - dtype=utterance.dtype, device=utterance.device - ) + summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) ( output_right_context_utterance, output_memory, @@ -1014,9 +975,7 @@ class EmformerEncoderLayer(nn.Module): left_context_val=left_context_val, padding_mask=padding_mask, ) - attn_cache = self._update_attn_cache( - next_key, next_val, memory, attn_cache - ) + attn_cache = self._update_attn_cache(next_key, next_val, memory, attn_cache) return output_right_context_utterance, output_memory, attn_cache def forward( @@ -1151,11 +1110,7 @@ class EmformerEncoderLayer(nn.Module): src = src + self.dropout(self.feed_forward_macaron(src)) # emformer attention module - ( - src_att, - output_memory, - attn_cache, - ) = self._apply_attention_module_infer( + (src_att, output_memory, attn_cache,) = self._apply_attention_module_infer( src, R, memory, attn_cache, padding_mask=padding_mask ) src = src + self.dropout(src_att) @@ -1295,9 +1250,7 @@ class EmformerEncoder(nn.Module): def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: """Hard copy each chunk's right context and concat them.""" T = x.shape[0] - num_chunks = math.ceil( - (T - self.right_context_length) / self.chunk_length - ) + num_chunks = math.ceil((T - self.right_context_length) / self.chunk_length) # first (num_chunks - 1) right context block intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length @@ -1316,9 +1269,7 @@ class EmformerEncoder(nn.Module): right_context_blocks = x[indexes.reshape(-1)] return right_context_blocks - def _gen_attention_mask_col_widths( - self, chunk_idx: int, U: int - ) -> List[int]: + def _gen_attention_mask_col_widths(self, chunk_idx: int, U: int) -> List[int]: """Calculate column widths (key, value) in attention mask for the chunk_idx chunk.""" num_chunks = math.ceil(U / self.chunk_length) @@ -1479,9 +1430,7 @@ class EmformerEncoder(nn.Module): output_lengths = torch.clamp(lengths - self.right_context_length, min=0) attention_mask = self._gen_attention_mask(utterance) memory = ( - self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ - :-1 - ] + self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[:-1] if self.use_memory else torch.empty(0).to(dtype=x.dtype, device=x.device) ) @@ -1643,12 +1592,8 @@ class EmformerEncoder(nn.Module): attn_caches = [ [ torch.zeros(self.memory_size, self.d_model, device=device), - torch.zeros( - self.left_context_length, self.d_model, device=device - ), - torch.zeros( - self.left_context_length, self.d_model, device=device - ), + torch.zeros(self.left_context_length, self.d_model, device=device), + torch.zeros(self.left_context_length, self.d_model, device=device), ] for _ in range(self.num_encoder_layers) ] @@ -1693,17 +1638,11 @@ class Emformer(EncoderInterface): raise NotImplementedError( "chunk_length must be a mutiple of subsampling_factor." ) - if ( - left_context_length != 0 - and left_context_length % subsampling_factor != 0 - ): + if left_context_length != 0 and left_context_length % subsampling_factor != 0: raise NotImplementedError( "left_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) - if ( - right_context_length != 0 - and right_context_length % subsampling_factor != 0 - ): + if right_context_length != 0 and right_context_length % subsampling_factor != 0: raise NotImplementedError( "right_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) @@ -1766,9 +1705,7 @@ class Emformer(EncoderInterface): x_lens = (((x_lens - 1) >> 1) - 1) >> 1 assert x.size(0) == x_lens.max().item() - output, output_lengths = self.encoder( - x, x_lens, warmup=warmup - ) # (T, N, C) + output, output_lengths = self.encoder(x, x_lens, warmup=warmup) # (T, N, C) output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py index 4930881ea..59105e286 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py @@ -103,9 +103,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -136,19 +138,20 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) add_model_arguments(parser) @@ -181,13 +184,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -210,13 +212,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -244,7 +245,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -279,9 +280,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py index 9494e1fc1..c211b215e 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py @@ -68,14 +68,12 @@ class Stream(object): elif params.decoding_method == "fast_beam_search": # feature_len is needed to get partial results. # The rnnt_decoding_stream for fast_beam_search. - self.rnnt_decoding_stream: k2.RnntDecodingStream = ( - k2.RnntDecodingStream(decoding_graph) + self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( + decoding_graph ) self.hyp: Optional[List[int]] = None else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") self.ground_truth: str = "" diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py index 61dbe8658..abe83732a 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py @@ -113,8 +113,9 @@ def get_parser(): "--epoch", type=int, default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( @@ -131,20 +132,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -211,8 +216,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -371,9 +375,7 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner( - current_encoder_out, decoder_out, project_input=False - ) + logits = model.joiner(current_encoder_out, decoder_out, project_input=False) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -390,9 +392,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -551,14 +551,10 @@ def decode_one_chunk( feature_list, batch_first=True, padding_value=LOG_EPSILON ).to(device) feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor( - num_processed_frames_list, device=device - ) + num_processed_frames = torch.tensor(num_processed_frames_list, device=device) # Make sure it has at least 1 frame after subsampling, first-and-last-frame cutting, and right context cutting # noqa - tail_length = ( - 3 * params.subsampling_factor + params.right_context_length + 3 - ) + tail_length = 3 * params.subsampling_factor + params.right_context_length + 3 if features.size(1) < tail_length: pad_length = tail_length - features.size(1) feature_lens += pad_length @@ -605,9 +601,7 @@ def decode_one_chunk( max_states=params.max_states, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") # Update cached states of each stream state_list = unstack_states(states) @@ -782,8 +776,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -831,9 +824,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -867,13 +858,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -896,13 +886,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -930,7 +919,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py index c07d8f76b..a76417e5f 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py @@ -95,9 +95,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -265,42 +263,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -636,11 +637,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -668,23 +665,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -871,9 +861,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -981,7 +969,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py index 98b8290b5..9cb4a5afc 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py @@ -135,20 +135,24 @@ def get_parser(): "--avg", type=int, default=10, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -215,8 +219,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -284,9 +287,7 @@ def decode_one_batch( value=LOG_EPS, ) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -301,10 +302,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -350,11 +348,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -427,9 +421,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -462,8 +454,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -506,9 +497,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -540,13 +529,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -569,13 +557,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -603,7 +590,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py index f16f5acc7..09200f2e1 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py @@ -35,7 +35,6 @@ from scaling import ( from icefall.utils import make_pad_mask - LOG_EPSILON = math.log(1e-10) @@ -127,9 +126,7 @@ def stack_states( for si, s in enumerate(layer): attn_caches[li][si].append(s) if b == batch_size - 1: - attn_caches[li][si] = torch.stack( - attn_caches[li][si], dim=1 - ) + attn_caches[li][si] = torch.stack(attn_caches[li][si], dim=1) conv_caches = [] for layer in state_list[0][1]: @@ -268,9 +265,7 @@ class ConvolutionModule(nn.Module): intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length ) - first = torch.arange( - self.chunk_length, self.chunk_length + self.cache_size - ) + first = torch.arange(self.chunk_length, self.chunk_length + self.cache_size) indexes = intervals.unsqueeze(1) + first.unsqueeze(0) indexes = torch.cat( [indexes, torch.arange(U_ - self.cache_size, U_).unsqueeze(0)] @@ -284,9 +279,7 @@ class ConvolutionModule(nn.Module): # (num_chunks * B, cache_size + right_context_length, D) return pad_right_context.permute(0, 2, 1) - def _merge_right_context( - self, right_context: torch.Tensor, B: int - ) -> torch.Tensor: + def _merge_right_context(self, right_context: torch.Tensor, B: int) -> torch.Tensor: """ Args: right_context: @@ -337,12 +330,8 @@ class ConvolutionModule(nn.Module): right_context = x[:, :, :R] # (B, D, R) # make causal convolution - cache = torch.zeros( - B, D, self.cache_size, device=x.device, dtype=x.dtype - ) - pad_utterance = torch.cat( - [cache, utterance], dim=2 - ) # (B, D, cache + U) + cache = torch.zeros(B, D, self.cache_size, device=x.device, dtype=x.dtype) + pad_utterance = torch.cat([cache, utterance], dim=2) # (B, D, cache + U) # depth-wise conv on utterance utterance = self.depthwise_conv(pad_utterance) # (B, D, U) @@ -355,9 +344,7 @@ class ConvolutionModule(nn.Module): right_context = self.depthwise_conv( pad_right_context ) # (num_segs * B, D, right_context_length) - right_context = self._merge_right_context( - right_context, B - ) # (B, D, R) + right_context = self._merge_right_context(right_context, B) # (B, D, R) x = torch.cat([right_context, utterance], dim=2) # (B, D, R + U) x = self.deriv_balancer2(x) @@ -458,8 +445,7 @@ class EmformerAttention(nn.Module): if embed_dim % nhead != 0: raise ValueError( - f"embed_dim ({embed_dim}) is not a multiple of" - f"nhead ({nhead})." + f"embed_dim ({embed_dim}) is not a multiple ofnhead ({nhead})." ) self.embed_dim = embed_dim @@ -469,9 +455,7 @@ class EmformerAttention(nn.Module): self.head_dim = embed_dim // nhead self.dropout = dropout - self.emb_to_key_value = ScaledLinear( - embed_dim, 2 * embed_dim, bias=True - ) + self.emb_to_key_value = ScaledLinear(embed_dim, 2 * embed_dim, bias=True) self.emb_to_query = ScaledLinear(embed_dim, embed_dim, bias=True) self.out_proj = ScaledLinear( embed_dim, embed_dim, bias=True, initial_scale=0.25 @@ -513,9 +497,7 @@ class EmformerAttention(nn.Module): if padding_mask is not None: Q = attention_weights.size(1) B = attention_weights.size(0) // self.nhead - attention_weights_float = attention_weights_float.view( - B, self.nhead, Q, -1 - ) + attention_weights_float = attention_weights_float.view(B, self.nhead, Q, -1) attention_weights_float = attention_weights_float.masked_fill( padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), self.negative_inf, @@ -561,16 +543,12 @@ class EmformerAttention(nn.Module): # [memory, right context, left context, uttrance] # this is used in inference mode key = torch.cat([key[: M + R], left_context_key, key[M + R :]]) - value = torch.cat( - [value[: M + R], left_context_val, value[M + R :]] - ) + value = torch.cat([value[: M + R], left_context_val, value[M + R :]]) Q = query.size(0) # KV = key.size(0) reshaped_query, reshaped_key, reshaped_value = [ - tensor.contiguous() - .view(-1, B * self.nhead, self.head_dim) - .transpose(0, 1) + tensor.contiguous().view(-1, B * self.nhead, self.head_dim).transpose(0, 1) for tensor in [query, key, value] ] # (B * nhead, Q or KV, head_dim) attention_weights = torch.bmm( @@ -585,9 +563,7 @@ class EmformerAttention(nn.Module): # compute attention outputs attention = torch.bmm(attention_probs, reshaped_value) assert attention.shape == (B * self.nhead, Q, self.head_dim) - attention = ( - attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) - ) + attention = attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) # apply output projection output_right_context_utterance = self.out_proj(attention) @@ -905,13 +881,11 @@ class EmformerEncoderLayer(nn.Module): right_context = right_context_utterance[:R] if self.use_memory: - memory = self.summary_op(utterance.permute(1, 2, 0)).permute( - 2, 0, 1 - )[:-1, :, :] + memory = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ + :-1, :, : + ] else: - memory = torch.empty(0).to( - dtype=utterance.dtype, device=utterance.device - ) + memory = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) output_right_context_utterance = self.attention( utterance=utterance, right_context=right_context, @@ -948,18 +922,12 @@ class EmformerEncoderLayer(nn.Module): left_context_val = attn_cache[2] if self.use_memory: - memory = self.summary_op(utterance.permute(1, 2, 0)).permute( - 2, 0, 1 - )[:1, :, :] + memory = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ + :1, :, : + ] else: - memory = torch.empty(0).to( - dtype=utterance.dtype, device=utterance.device - ) - ( - output_right_context_utterance, - next_key, - next_val, - ) = self.attention.infer( + memory = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) + (output_right_context_utterance, next_key, next_val,) = self.attention.infer( utterance=utterance, right_context=right_context, memory=pre_memory, @@ -967,9 +935,7 @@ class EmformerEncoderLayer(nn.Module): left_context_val=left_context_val, padding_mask=padding_mask, ) - attn_cache = self._update_attn_cache( - next_key, next_val, memory, attn_cache - ) + attn_cache = self._update_attn_cache(next_key, next_val, memory, attn_cache) return output_right_context_utterance, attn_cache def forward( @@ -1226,9 +1192,7 @@ class EmformerEncoder(nn.Module): def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: """Hard copy each chunk's right context and concat them.""" T = x.shape[0] - num_chunks = math.ceil( - (T - self.right_context_length) / self.chunk_length - ) + num_chunks = math.ceil((T - self.right_context_length) / self.chunk_length) # first (num_chunks - 1) right context block intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length @@ -1247,9 +1211,7 @@ class EmformerEncoder(nn.Module): right_context_blocks = x[indexes.reshape(-1)] return right_context_blocks - def _gen_attention_mask_col_widths( - self, chunk_idx: int, U: int - ) -> List[int]: + def _gen_attention_mask_col_widths(self, chunk_idx: int, U: int) -> List[int]: """Calculate column widths (key, value) in attention mask for the chunk_idx chunk.""" num_chunks = math.ceil(U / self.chunk_length) @@ -1549,12 +1511,8 @@ class EmformerEncoder(nn.Module): attn_caches = [ [ torch.zeros(self.memory_size, self.d_model, device=device), - torch.zeros( - self.left_context_length, self.d_model, device=device - ), - torch.zeros( - self.left_context_length, self.d_model, device=device - ), + torch.zeros(self.left_context_length, self.d_model, device=device), + torch.zeros(self.left_context_length, self.d_model, device=device), ] for _ in range(self.num_encoder_layers) ] @@ -1599,17 +1557,11 @@ class Emformer(EncoderInterface): raise NotImplementedError( "chunk_length must be a mutiple of subsampling_factor." ) - if ( - left_context_length != 0 - and left_context_length % subsampling_factor != 0 - ): + if left_context_length != 0 and left_context_length % subsampling_factor != 0: raise NotImplementedError( "left_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) - if ( - right_context_length != 0 - and right_context_length % subsampling_factor != 0 - ): + if right_context_length != 0 and right_context_length % subsampling_factor != 0: raise NotImplementedError( "right_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) @@ -1672,9 +1624,7 @@ class Emformer(EncoderInterface): x_lens = (((x_lens - 1) >> 1) - 1) >> 1 assert x.size(0) == x_lens.max().item() - output, output_lengths = self.encoder( - x, x_lens, warmup=warmup - ) # (T, N, C) + output, output_lengths = self.encoder(x, x_lens, warmup=warmup) # (T, N, C) output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py index ab15e0241..4d05b367c 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py @@ -103,9 +103,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -136,19 +138,20 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) add_model_arguments(parser) @@ -181,13 +184,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -210,13 +212,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -244,7 +245,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -279,9 +280,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py index 71150392d..0486ac2eb 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py @@ -113,8 +113,9 @@ def get_parser(): "--epoch", type=int, default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( @@ -131,20 +132,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -211,8 +216,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -371,9 +375,7 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner( - current_encoder_out, decoder_out, project_input=False - ) + logits = model.joiner(current_encoder_out, decoder_out, project_input=False) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -390,9 +392,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -551,14 +551,10 @@ def decode_one_chunk( feature_list, batch_first=True, padding_value=LOG_EPSILON ).to(device) feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor( - num_processed_frames_list, device=device - ) + num_processed_frames = torch.tensor(num_processed_frames_list, device=device) # Make sure it has at least 1 frame after subsampling, first-and-last-frame cutting, and right context cutting # noqa - tail_length = ( - 3 * params.subsampling_factor + params.right_context_length + 3 - ) + tail_length = 3 * params.subsampling_factor + params.right_context_length + 3 if features.size(1) < tail_length: pad_length = tail_length - features.size(1) feature_lens += pad_length @@ -605,9 +601,7 @@ def decode_one_chunk( max_states=params.max_states, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") # Update cached states of each stream state_list = unstack_states(states) @@ -782,8 +776,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -831,9 +824,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -867,13 +858,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -896,13 +886,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -930,7 +919,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py index 2bbc45d78..2c2593b56 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py @@ -95,9 +95,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -265,42 +263,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -636,11 +637,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -668,23 +665,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -871,9 +861,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -981,7 +969,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/local/add_alignment_librispeech.py b/egs/librispeech/ASR/local/add_alignment_librispeech.py index fe6a26c51..cc34a72d8 100755 --- a/egs/librispeech/ASR/local/add_alignment_librispeech.py +++ b/egs/librispeech/ASR/local/add_alignment_librispeech.py @@ -157,9 +157,7 @@ def add_alignment( for ali_path in part_ali_dir.rglob("*.alignment.txt"): ali = parse_alignments(ali_path) alignments.update(ali) - logging.info( - f"{part} has {len(alignments.keys())} cuts with alignments." - ) + logging.info(f"{part} has {len(alignments.keys())} cuts with alignments.") # add alignment attribute and write out cuts_in = load_manifest_lazy(cuts_in_path) @@ -170,18 +168,14 @@ def add_alignment( if origin_id in alignments: ali = alignments[origin_id] else: - logging.info( - f"Warning: {origin_id} does not have alignment." - ) + logging.info(f"Warning: {origin_id} does not have alignment.") ali = [] subcut.alignment = {"word": ali} writer.write(cut, flush=True) def main(): - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) parser = get_parser() diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index 9a35750e0..295156ed5 100755 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -150,9 +150,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/compile_lg.py b/egs/librispeech/ASR/local/compile_lg.py index 45c4b7f5f..19bf3bff4 100755 --- a/egs/librispeech/ASR/local/compile_lg.py +++ b/egs/librispeech/ASR/local/compile_lg.py @@ -132,9 +132,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py index c0c7ef8c5..97750f3ea 100644 --- a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py +++ b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py @@ -80,9 +80,7 @@ def compute_fbank_gigaspeech_dev_test(): def main(): - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_gigaspeech_dev_test() diff --git a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py index 5587106e5..37fce11f4 100644 --- a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py +++ b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py @@ -48,8 +48,10 @@ def get_parser(): "--batch-duration", type=float, default=600.0, - help="The maximum number of audio seconds in a batch." - "Determines batch size dynamically.", + help=( + "The maximum number of audio seconds in a batch." + "Determines batch size dynamically." + ), ) parser.add_argument( @@ -144,9 +146,7 @@ def main(): date_time = now.strftime("%Y-%m-%d-%H-%M-%S") log_filename = "log-compute_fbank_gigaspeech_splits" - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" log_filename = f"{log_filename}-{date_time}" logging.basicConfig( diff --git a/egs/librispeech/ASR/local/compute_fbank_librispeech.py b/egs/librispeech/ASR/local/compute_fbank_librispeech.py index ce7d087f0..9f8503814 100755 --- a/egs/librispeech/ASR/local/compute_fbank_librispeech.py +++ b/egs/librispeech/ASR/local/compute_fbank_librispeech.py @@ -112,9 +112,7 @@ def compute_fbank_librispeech(bpe_model: Optional[str] = None): if "train" in partition: cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -128,9 +126,7 @@ def compute_fbank_librispeech(bpe_model: Optional[str] = None): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() diff --git a/egs/librispeech/ASR/local/compute_fbank_musan.py b/egs/librispeech/ASR/local/compute_fbank_musan.py index 056da29e5..4a4093ae4 100755 --- a/egs/librispeech/ASR/local/compute_fbank_musan.py +++ b/egs/librispeech/ASR/local/compute_fbank_musan.py @@ -83,9 +83,7 @@ def compute_fbank_musan(): # create chunks of Musan with duration 5 - 10 seconds musan_cuts = ( CutSet.from_manifests( - recordings=combine( - part["recordings"] for part in manifests.values() - ) + recordings=combine(part["recordings"] for part in manifests.values()) ) .cut_into_windows(10.0) .filter(lambda c: c.duration > 5) @@ -101,9 +99,7 @@ def compute_fbank_musan(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_musan() diff --git a/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py b/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py index 133499c8b..f149b7871 100755 --- a/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py +++ b/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py @@ -46,21 +46,19 @@ def get_args(): parser.add_argument( "--transcript", type=str, - help="The input transcript file." - "We assume that the transcript file consists of " - "lines. Each line consists of space separated words.", + help=( + "The input transcript file." + "We assume that the transcript file consists of " + "lines. Each line consists of space separated words." + ), ) parser.add_argument("--lexicon", type=str, help="The input lexicon file.") - parser.add_argument( - "--oov", type=str, default="", help="The OOV word." - ) + parser.add_argument("--oov", type=str, default="", help="The OOV word.") return parser.parse_args() -def process_line( - lexicon: Dict[str, List[str]], line: str, oov_token: str -) -> None: +def process_line(lexicon: Dict[str, List[str]], line: str, oov_token: str) -> None: """ Args: lexicon: diff --git a/egs/librispeech/ASR/local/download_lm.py b/egs/librispeech/ASR/local/download_lm.py index 030122aa7..3518db524 100755 --- a/egs/librispeech/ASR/local/download_lm.py +++ b/egs/librispeech/ASR/local/download_lm.py @@ -87,9 +87,7 @@ def main(out_dir: str): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/filter_cuts.py b/egs/librispeech/ASR/local/filter_cuts.py index dff98a954..fbcc9e24a 100644 --- a/egs/librispeech/ASR/local/filter_cuts.py +++ b/egs/librispeech/ASR/local/filter_cuts.py @@ -79,8 +79,7 @@ def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor): total += 1 if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) removed += 1 return False @@ -125,8 +124,7 @@ def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor): ans = cut_set.filter(remove_short_and_long_utterances).to_eager() ratio = removed / total * 100 logging.info( - f"Removed {removed} cuts from {total} cuts. " - f"{ratio:.3f}% data is removed." + f"Removed {removed} cuts from {total} cuts. {ratio:.3f}% data is removed." ) return ans @@ -155,9 +153,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/generate_unique_lexicon.py b/egs/librispeech/ASR/local/generate_unique_lexicon.py index 566c0743d..3459c2f5a 100755 --- a/egs/librispeech/ASR/local/generate_unique_lexicon.py +++ b/egs/librispeech/ASR/local/generate_unique_lexicon.py @@ -91,9 +91,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/prepare_lang_bpe.py b/egs/librispeech/ASR/local/prepare_lang_bpe.py index dec8a7442..e121aefa9 100755 --- a/egs/librispeech/ASR/local/prepare_lang_bpe.py +++ b/egs/librispeech/ASR/local/prepare_lang_bpe.py @@ -150,9 +150,7 @@ def generate_lexicon( words_pieces_ids: List[List[int]] = sp.encode(words, out_type=int) # Now convert word piece IDs back to word piece strings. - words_pieces: List[List[str]] = [ - sp.id_to_piece(ids) for ids in words_pieces_ids - ] + words_pieces: List[List[str]] = [sp.id_to_piece(ids) for ids in words_pieces_ids] lexicon = [] for word, pieces in zip(words, words_pieces): diff --git a/egs/librispeech/ASR/local/prepare_lm_training_data.py b/egs/librispeech/ASR/local/prepare_lm_training_data.py index 5070341f1..70343fef7 100755 --- a/egs/librispeech/ASR/local/prepare_lm_training_data.py +++ b/egs/librispeech/ASR/local/prepare_lm_training_data.py @@ -137,8 +137,7 @@ def main(): for i in range(num_sentences): if step and i % step == 0: logging.info( - f"Processed number of lines: {i} " - f"({i/num_sentences*100: .3f}%)" + f"Processed number of lines: {i} ({i/num_sentences*100: .3f}%)" ) word_ids = sentences[i] @@ -154,18 +153,14 @@ def main(): sentence_lengths[i] = token_ids.numel() - output["sentence_lengths"] = torch.tensor( - sentence_lengths, dtype=torch.int32 - ) + output["sentence_lengths"] = torch.tensor(sentence_lengths, dtype=torch.int32) torch.save(output, args.lm_archive) logging.info(f"Saved to {args.lm_archive}") if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/preprocess_gigaspeech.py b/egs/librispeech/ASR/local/preprocess_gigaspeech.py index 077f23039..8aa5e461d 100644 --- a/egs/librispeech/ASR/local/preprocess_gigaspeech.py +++ b/egs/librispeech/ASR/local/preprocess_gigaspeech.py @@ -119,9 +119,7 @@ def preprocess_giga_speech(): def main(): - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) preprocess_giga_speech() diff --git a/egs/librispeech/ASR/local/test_prepare_lang.py b/egs/librispeech/ASR/local/test_prepare_lang.py index d4cf62bba..74e025ad7 100755 --- a/egs/librispeech/ASR/local/test_prepare_lang.py +++ b/egs/librispeech/ASR/local/test_prepare_lang.py @@ -88,9 +88,7 @@ def test_read_lexicon(filename: str): fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.draw("L.pdf", title="L") - fsa_disambig = lexicon_to_fst( - lexicon_disambig, phone2id=phone2id, word2id=word2id - ) + fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.draw("L_disambig.pdf", title="L_disambig") diff --git a/egs/librispeech/ASR/local/validate_manifest.py b/egs/librispeech/ASR/local/validate_manifest.py index 7c57d629a..807aaf891 100755 --- a/egs/librispeech/ASR/local/validate_manifest.py +++ b/egs/librispeech/ASR/local/validate_manifest.py @@ -64,8 +64,7 @@ def validate_supervision_and_cut_time_bounds(c: Cut): if s.end > c.end: raise ValueError( - f"{c.id}: Supervision end time {s.end} is larger " - f"than cut end time {c.end}" + f"{c.id}: Supervision end time {s.end} is larger than cut end time {c.end}" ) @@ -85,9 +84,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py old mode 100755 new mode 100644 index 27414d717..e69de29bb --- a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py @@ -1,818 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./lstm_transducer_stateless/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./lstm_transducer_stateless/exp \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./lstm_transducer_stateless/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./lstm_transducer_stateless/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./lstm_transducer_stateless/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./lstm_transducer_stateless/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search (one best) -./lstm_transducer_stateless/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./lstm_transducer_stateless/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(5) fast beam search (nbest) -./lstm_transducer_stateless/decode.py \ - --epoch 30 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless3/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(6) fast beam search (nbest oracle WER) -./lstm_transducer_stateless/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./lstm_transducer_stateless/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(7) fast beam search (with LG) -./lstm_transducer_stateless/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./lstm_transducer_stateless/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 -""" - - -import argparse -import logging -import math -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest, - fast_beam_search_nbest_LG, - fast_beam_search_nbest_oracle, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="lstm_transducer_stateless/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bpe_500", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - - fast_beam_search_nbest - - fast_beam_search_nbest_oracle - - fast_beam_search_nbest_LG - If you use fast_beam_search_nbest_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, - fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", - ) - - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or LG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - # tail padding here to alleviate the tail deletion problem - num_tail_padded_frames = 35 - feature = torch.nn.functional.pad( - feature, - (0, 0, 0, num_tail_padded_frames), - mode="constant", - value=LOG_EPS, - ) - feature_lens += num_tail_padded_frames - - encoder_out, encoder_out_lens, _ = model.encoder( - x=feature, x_lens=feature_lens - ) - - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_LG": - hyp_tokens = fast_beam_search_nbest_LG( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in hyp_tokens: - hyps.append([word_table[i] for i in hyp]) - elif params.decoding_method == "fast_beam_search_nbest": - hyp_tokens = fast_beam_search_nbest( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - ref_texts=sp.encode(supervisions["text"]), - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif "fast_beam_search" in params.decoding_method: - key = f"beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" - if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - - return {key: hyps} - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or LG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - word_table=word_table, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "fast_beam_search_nbest", - "fast_beam_search_nbest_LG", - "fast_beam_search_nbest_oracle", - "modified_beam_search", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": - lexicon = Lexicon(params.lang_dir) - word_table = lexicon.word_table - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) - else: - decoding_graph = None - word_table = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - librispeech = LibriSpeechAsrDataModule(args) - - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() - - test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) - test_other_dl = librispeech.test_dataloaders(test_other_cuts) - - test_sets = ["test-clean", "test-other"] - test_dl = [test_clean_dl, test_other_dl] - - for test_set, test_dl in zip(test_sets, test_dl): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - word_table=word_table, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/export.py b/egs/librispeech/ASR/lstm_transducer_stateless/export.py old mode 100755 new mode 100644 index 13dac6009..e69de29bb --- a/egs/librispeech/ASR/lstm_transducer_stateless/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/export.py @@ -1,388 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" - -Usage: - -(1) Export to torchscript model using torch.jit.trace() - -./lstm_transducer_stateless/export.py \ - --exp-dir ./lstm_transducer_stateless/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 35 \ - --avg 10 \ - --jit-trace 1 - -It will generate 3 files: `encoder_jit_trace.pt`, -`decoder_jit_trace.pt`, and `joiner_jit_trace.pt`. - -(2) Export `model.state_dict()` - -./lstm_transducer_stateless/export.py \ - --exp-dir ./lstm_transducer_stateless/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 35 \ - --avg 10 - -It will generate a file `pretrained.pt` in the given `exp_dir`. You can later -load it by `icefall.checkpoint.load_checkpoint()`. - -To use the generated file with `lstm_transducer_stateless/decode.py`, -you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/librispeech/ASR - ./lstm_transducer_stateless/decode.py \ - --exp-dir ./lstm_transducer_stateless/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --decoding-method greedy_search \ - --bpe-model data/lang_bpe_500/bpe.model - -Check ./pretrained.py for its usage. - -Note: If you don't want to train a model from scratch, we have -provided one for you. You can get it at - -https://huggingface.co/Zengwei/icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18 - -with the following commands: - - sudo apt-get install git-lfs - git lfs install - git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18 - # You will find the pre-trained model in icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18/exp -""" - -import argparse -import logging -from pathlib import Path - -import sentencepiece as spm -import torch -import torch.nn as nn -from scaling_converter import convert_scaled_to_non_scaled -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for averaging. - Note: Epoch counts from 0. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless3/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--jit-trace", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.trace. - It will generate 3 files: - - encoder_jit_trace.pt - - decoder_jit_trace.pt - - joiner_jit_trace.pt - - Check ./jit_pretrained.py for how to use them. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -def export_encoder_model_jit_trace( - encoder_model: nn.Module, - encoder_filename: str, -) -> None: - """Export the given encoder model with torch.jit.trace() - - Note: The warmup argument is fixed to 1. - - Args: - encoder_model: - The input encoder model - encoder_filename: - The filename to save the exported model. - """ - x = torch.zeros(1, 100, 80, dtype=torch.float32) - x_lens = torch.tensor([100], dtype=torch.int64) - states = encoder_model.get_init_states() - - traced_model = torch.jit.trace(encoder_model, (x, x_lens, states)) - traced_model.save(encoder_filename) - logging.info(f"Saved to {encoder_filename}") - - -def export_decoder_model_jit_trace( - decoder_model: nn.Module, - decoder_filename: str, -) -> None: - """Export the given decoder model with torch.jit.trace() - - Note: The argument need_pad is fixed to False. - - Args: - decoder_model: - The input decoder model - decoder_filename: - The filename to save the exported model. - """ - y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) - need_pad = torch.tensor([False]) - - traced_model = torch.jit.trace(decoder_model, (y, need_pad)) - traced_model.save(decoder_filename) - logging.info(f"Saved to {decoder_filename}") - - -def export_joiner_model_jit_trace( - joiner_model: nn.Module, - joiner_filename: str, -) -> None: - """Export the given joiner model with torch.jit.trace() - - Note: The argument project_input is fixed to True. A user should not - project the encoder_out/decoder_out by himself/herself. The exported joiner - will do that for the user. - - Args: - joiner_model: - The input joiner model - joiner_filename: - The filename to save the exported model. - - """ - encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] - decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] - encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) - decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) - - traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) - traced_model.save(joiner_filename) - logging.info(f"Saved to {joiner_filename}") - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - if params.jit_trace is True: - convert_scaled_to_non_scaled(model, inplace=True) - logging.info("Using torch.jit.trace()") - encoder_filename = params.exp_dir / "encoder_jit_trace.pt" - export_encoder_model_jit_trace(model.encoder, encoder_filename) - - decoder_filename = params.exp_dir / "decoder_jit_trace.pt" - export_decoder_model_jit_trace(model.decoder, decoder_filename) - - joiner_filename = params.exp_dir / "joiner_jit_trace.pt" - export_joiner_model_jit_trace(model.joiner, joiner_filename) - else: - logging.info("Not using torchscript") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py old mode 100755 new mode 100644 index 594c33e4f..e69de29bb --- a/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py @@ -1,322 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads torchscript models, either exported by `torch.jit.trace()` -or by `torch.jit.script()`, and uses them to decode waves. -You can use the following command to get the exported models: - -./lstm_transducer_stateless/export.py \ - --exp-dir ./lstm_transducer_stateless/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 20 \ - --avg 10 \ - --jit-trace 1 - -Usage of this script: - -./lstm_transducer_stateless/jit_pretrained.py \ - --encoder-model-filename ./lstm_transducer_stateless/exp/encoder_jit_trace.pt \ - --decoder-model-filename ./lstm_transducer_stateless/exp/decoder_jit_trace.pt \ - --joiner-model-filename ./lstm_transducer_stateless/exp/joiner_jit_trace.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - /path/to/foo.wav \ - /path/to/bar.wav -""" - -import argparse -import logging -import math -from typing import List - -import kaldifeat -import sentencepiece as spm -import torch -import torchaudio -from torch.nn.utils.rnn import pad_sequence - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--encoder-model-filename", - type=str, - required=True, - help="Path to the encoder torchscript model. ", - ) - - parser.add_argument( - "--decoder-model-filename", - type=str, - required=True, - help="Path to the decoder torchscript model. ", - ) - - parser.add_argument( - "--joiner-model-filename", - type=str, - required=True, - help="Path to the joiner torchscript model. ", - ) - - parser.add_argument( - "--bpe-model", - type=str, - help="""Path to bpe.model.""", - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="Context size of the decoder model", - ) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) - # We use only the first channel - ans.append(wave[0]) - return ans - - -def greedy_search( - decoder: torch.jit.ScriptModule, - joiner: torch.jit.ScriptModule, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - context_size: int, -) -> List[List[int]]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - decoder: - The decoder model. - joiner: - The joiner model. - encoder_out: - A 3-D tensor of shape (N, T, C) - encoder_out_lens: - A 1-D tensor of shape (N,). - context_size: - The context size of the decoder model. - Returns: - Return the decoded results for each utterance. - """ - assert encoder_out.ndim == 3 - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - device = encoder_out.device - blank_id = 0 # hard-code to 0 - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - hyps = [[blank_id] * context_size for _ in range(N)] - - decoder_input = torch.tensor( - hyps, - device=device, - dtype=torch.int64, - ) # (N, context_size) - - decoder_out = decoder( - decoder_input, - need_pad=torch.tensor([False]), - ).squeeze(1) - - offset = 0 - for batch_size in batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = packed_encoder_out.data[start:end] - current_encoder_out = current_encoder_out - # current_encoder_out's shape: (batch_size, encoder_out_dim) - offset = end - - decoder_out = decoder_out[:batch_size] - - logits = joiner( - current_encoder_out, - decoder_out, - ) - # logits'shape (batch_size, vocab_size) - - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v != blank_id: - hyps[i].append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = [h[-context_size:] for h in hyps[:batch_size]] - decoder_input = torch.tensor( - decoder_input, - device=device, - dtype=torch.int64, - ) - decoder_out = decoder( - decoder_input, - need_pad=torch.tensor([False]), - ) - decoder_out = decoder_out.squeeze(1) - - sorted_ans = [h[context_size:] for h in hyps] - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - encoder = torch.jit.load(args.encoder_model_filename) - decoder = torch.jit.load(args.decoder_model_filename) - joiner = torch.jit.load(args.joiner_model_filename) - - encoder.eval() - decoder.eval() - joiner.eval() - - encoder.to(device) - decoder.to(device) - joiner.to(device) - - sp = spm.SentencePieceProcessor() - sp.load(args.bpe_model) - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = args.sample_rate - opts.mel_opts.num_bins = 80 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {args.sound_files}") - waves = read_sound_files( - filenames=args.sound_files, - expected_sample_rate=args.sample_rate, - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence( - features, - batch_first=True, - padding_value=math.log(1e-10), - ) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - states = encoder.get_init_states(batch_size=features.size(0), device=device) - - encoder_out, encoder_out_lens, _ = encoder( - x=features, - x_lens=feature_lengths, - states=states, - ) - - hyps = greedy_search( - decoder=decoder, - joiner=joiner, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - context_size=args.context_size, - ) - s = "\n" - for filename, hyp in zip(args.sound_files, hyps): - words = sp.decode(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py index c54a4c478..e69de29bb 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py @@ -1,871 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import math -from typing import List, Optional, Tuple - -import torch -from encoder_interface import EncoderInterface -from scaling import ( - ActivationBalancer, - BasicNorm, - DoubleSwish, - ScaledConv2d, - ScaledLinear, - ScaledLSTM, -) -from torch import nn - -LOG_EPSILON = math.log(1e-10) - - -def unstack_states( - states: Tuple[torch.Tensor, torch.Tensor] -) -> List[Tuple[torch.Tensor, torch.Tensor]]: - """ - Unstack the lstm states corresponding to a batch of utterances into a list - of states, where the i-th entry is the state from the i-th utterance. - - Args: - states: - A tuple of 2 elements. - ``states[0]`` is the lstm hidden states, of a batch of utterance. - ``states[1]`` is the lstm cell states, of a batch of utterances. - - Returns: - A list of states. - ``states[i]`` is a tuple of 2 elememts of i-th utterance. - ``states[i][0]`` is the lstm hidden states of i-th utterance. - ``states[i][1]`` is the lstm cell states of i-th utterance. - """ - hidden_states, cell_states = states - - list_hidden_states = hidden_states.unbind(dim=1) - list_cell_states = cell_states.unbind(dim=1) - - ans = [ - (h.unsqueeze(1), c.unsqueeze(1)) - for (h, c) in zip(list_hidden_states, list_cell_states) - ] - return ans - - -def stack_states( - states_list: List[Tuple[torch.Tensor, torch.Tensor]] -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Stack list of lstm states corresponding to separate utterances into a single - lstm state so that it can be used as an input for lstm when those utterances - are formed into a batch. - - Args: - state_list: - Each element in state_list corresponds to the lstm state for a single - utterance. - ``states[i]`` is a tuple of 2 elememts of i-th utterance. - ``states[i][0]`` is the lstm hidden states of i-th utterance. - ``states[i][1]`` is the lstm cell states of i-th utterance. - - - Returns: - A new state corresponding to a batch of utterances. - It is a tuple of 2 elements. - ``states[0]`` is the lstm hidden states, of a batch of utterance. - ``states[1]`` is the lstm cell states, of a batch of utterances. - """ - hidden_states = torch.cat([s[0] for s in states_list], dim=1) - cell_states = torch.cat([s[1] for s in states_list], dim=1) - ans = (hidden_states, cell_states) - return ans - - -class RNN(EncoderInterface): - """ - Args: - num_features (int): - Number of input features. - subsampling_factor (int): - Subsampling factor of encoder (convolution layers before lstm layers) (default=4). # noqa - d_model (int): - Output dimension (default=512). - dim_feedforward (int): - Feedforward dimension (default=2048). - rnn_hidden_size (int): - Hidden dimension for lstm layers (default=1024). - num_encoder_layers (int): - Number of encoder layers (default=12). - dropout (float): - Dropout rate (default=0.1). - layer_dropout (float): - Dropout value for model-level warmup (default=0.075). - aux_layer_period (int): - Period of auxiliary layers used for random combiner during training. - If set to 0, will not use the random combiner (Default). - You can set a positive integer to use the random combiner, e.g., 3. - is_pnnx: - True to make this class exportable via PNNX. - """ - - def __init__( - self, - num_features: int, - subsampling_factor: int = 4, - d_model: int = 512, - dim_feedforward: int = 2048, - rnn_hidden_size: int = 1024, - num_encoder_layers: int = 12, - dropout: float = 0.1, - layer_dropout: float = 0.075, - aux_layer_period: int = 0, - is_pnnx: bool = False, - ) -> None: - super(RNN, self).__init__() - - self.num_features = num_features - self.subsampling_factor = subsampling_factor - if subsampling_factor != 4: - raise NotImplementedError("Support only 'subsampling_factor=4'.") - - # self.encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, T//subsampling_factor, d_model). - # That is, it does two things simultaneously: - # (1) subsampling: T -> T//subsampling_factor - # (2) embedding: num_features -> d_model - self.encoder_embed = Conv2dSubsampling( - num_features, - d_model, - is_pnnx=is_pnnx, - ) - - self.is_pnnx = is_pnnx - - self.num_encoder_layers = num_encoder_layers - self.d_model = d_model - self.rnn_hidden_size = rnn_hidden_size - - encoder_layer = RNNEncoderLayer( - d_model=d_model, - dim_feedforward=dim_feedforward, - rnn_hidden_size=rnn_hidden_size, - dropout=dropout, - layer_dropout=layer_dropout, - ) - self.encoder = RNNEncoder( - encoder_layer, - num_encoder_layers, - aux_layers=list( - range( - num_encoder_layers // 3, - num_encoder_layers - 1, - aux_layer_period, - ) - ) - if aux_layer_period > 0 - else None, - ) - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - warmup: float = 1.0, - ) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """ - Args: - x: - The input tensor. Its shape is (N, T, C), where N is the batch size, - T is the sequence length, C is the feature dimension. - x_lens: - A tensor of shape (N,), containing the number of frames in `x` - before padding. - states: - A tuple of 2 tensors (optional). It is for streaming inference. - states[0] is the hidden states of all layers, - with shape of (num_layers, N, d_model); - states[1] is the cell states of all layers, - with shape of (num_layers, N, rnn_hidden_size). - warmup: - A floating point value that gradually increases from 0 throughout - training; when it is >= 1.0 we are "fully warmed up". It is used - to turn modules on sequentially. - - Returns: - A tuple of 3 tensors: - - embeddings: its shape is (N, T', d_model), where T' is the output - sequence lengths. - - lengths: a tensor of shape (batch_size,) containing the number of - frames in `embeddings` before padding. - - updated states, whose shape is the same as the input states. - """ - x = self.encoder_embed(x) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - # lengths = ((x_lens - 3) // 2 - 1) // 2 # issue an warning - # - # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 - if not self.is_pnnx: - lengths = (((x_lens - 3) >> 1) - 1) >> 1 - else: - lengths1 = torch.floor((x_lens - 3) / 2) - lengths = torch.floor((lengths1 - 1) / 2) - lengths = lengths.to(x_lens) - - if not torch.jit.is_tracing(): - assert x.size(0) == lengths.max().item() - - if states is None: - x = self.encoder(x, warmup=warmup)[0] - # torch.jit.trace requires returned types to be the same as annotated # noqa - new_states = (torch.empty(0), torch.empty(0)) - else: - assert not self.training - assert len(states) == 2 - if not torch.jit.is_tracing(): - # for hidden state - assert states[0].shape == ( - self.num_encoder_layers, - x.size(1), - self.d_model, - ) - # for cell state - assert states[1].shape == ( - self.num_encoder_layers, - x.size(1), - self.rnn_hidden_size, - ) - x, new_states = self.encoder(x, states) - - x = x.permute(1, 0, 2) # (T, N, C) -> (N, T, C) - return x, lengths, new_states - - @torch.jit.export - def get_init_states( - self, batch_size: int = 1, device: torch.device = torch.device("cpu") - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Get model initial states.""" - # for rnn hidden states - hidden_states = torch.zeros( - (self.num_encoder_layers, batch_size, self.d_model), device=device - ) - cell_states = torch.zeros( - (self.num_encoder_layers, batch_size, self.rnn_hidden_size), - device=device, - ) - return (hidden_states, cell_states) - - -class RNNEncoderLayer(nn.Module): - """ - RNNEncoderLayer is made up of lstm and feedforward networks. - - Args: - d_model: - The number of expected features in the input (required). - dim_feedforward: - The dimension of feedforward network model (default=2048). - rnn_hidden_size: - The hidden dimension of rnn layer. - dropout: - The dropout value (default=0.1). - layer_dropout: - The dropout value for model-level warmup (default=0.075). - """ - - def __init__( - self, - d_model: int, - dim_feedforward: int, - rnn_hidden_size: int, - dropout: float = 0.1, - layer_dropout: float = 0.075, - ) -> None: - super(RNNEncoderLayer, self).__init__() - self.layer_dropout = layer_dropout - self.d_model = d_model - self.rnn_hidden_size = rnn_hidden_size - - assert rnn_hidden_size >= d_model, (rnn_hidden_size, d_model) - self.lstm = ScaledLSTM( - input_size=d_model, - hidden_size=rnn_hidden_size, - proj_size=d_model if rnn_hidden_size > d_model else 0, - num_layers=1, - dropout=0.0, - ) - self.feed_forward = nn.Sequential( - ScaledLinear(d_model, dim_feedforward), - ActivationBalancer(channel_dim=-1), - DoubleSwish(), - nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), - ) - self.norm_final = BasicNorm(d_model) - - # try to ensure the output is close to zero-mean (or at least, zero-median). # noqa - self.balancer = ActivationBalancer( - channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 - ) - self.dropout = nn.Dropout(dropout) - - def forward( - self, - src: torch.Tensor, - states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - warmup: float = 1.0, - ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """ - Pass the input through the encoder layer. - - Args: - src: - The sequence to the encoder layer (required). - Its shape is (S, N, E), where S is the sequence length, - N is the batch size, and E is the feature number. - states: - A tuple of 2 tensors (optional). It is for streaming inference. - states[0] is the hidden states of all layers, - with shape of (1, N, d_model); - states[1] is the cell states of all layers, - with shape of (1, N, rnn_hidden_size). - warmup: - It controls selective bypass of of layers; if < 1.0, we will - bypass layers more frequently. - """ - src_orig = src - - warmup_scale = min(0.1 + warmup, 1.0) - # alpha = 1.0 means fully use this encoder layer, 0.0 would mean - # completely bypass it. - if self.training: - alpha = ( - warmup_scale - if torch.rand(()).item() <= (1.0 - self.layer_dropout) - else 0.1 - ) - else: - alpha = 1.0 - - # lstm module - if states is None: - src_lstm = self.lstm(src)[0] - # torch.jit.trace requires returned types be the same as annotated - new_states = (torch.empty(0), torch.empty(0)) - else: - assert not self.training - assert len(states) == 2 - if not torch.jit.is_tracing(): - # for hidden state - assert states[0].shape == (1, src.size(1), self.d_model) - # for cell state - assert states[1].shape == (1, src.size(1), self.rnn_hidden_size) - src_lstm, new_states = self.lstm(src, states) - src = self.dropout(src_lstm) + src - - # feed forward module - src = src + self.dropout(self.feed_forward(src)) - - src = self.norm_final(self.balancer(src)) - - if alpha != 1.0: - src = alpha * src + (1 - alpha) * src_orig - - return src, new_states - - -class RNNEncoder(nn.Module): - """ - RNNEncoder is a stack of N encoder layers. - - Args: - encoder_layer: - An instance of the RNNEncoderLayer() class (required). - num_layers: - The number of sub-encoder-layers in the encoder (required). - """ - - def __init__( - self, - encoder_layer: nn.Module, - num_layers: int, - aux_layers: Optional[List[int]] = None, - ) -> None: - super(RNNEncoder, self).__init__() - self.layers = nn.ModuleList( - [copy.deepcopy(encoder_layer) for i in range(num_layers)] - ) - self.num_layers = num_layers - self.d_model = encoder_layer.d_model - self.rnn_hidden_size = encoder_layer.rnn_hidden_size - - self.aux_layers: List[int] = [] - self.combiner: Optional[nn.Module] = None - if aux_layers is not None: - assert len(set(aux_layers)) == len(aux_layers) - assert num_layers - 1 not in aux_layers - self.aux_layers = aux_layers + [num_layers - 1] - self.combiner = RandomCombine( - num_inputs=len(self.aux_layers), - final_weight=0.5, - pure_prob=0.333, - stddev=2.0, - ) - - def forward( - self, - src: torch.Tensor, - states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - warmup: float = 1.0, - ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """ - Pass the input through the encoder layer in turn. - - Args: - src: - The sequence to the encoder layer (required). - Its shape is (S, N, E), where S is the sequence length, - N is the batch size, and E is the feature number. - states: - A tuple of 2 tensors (optional). It is for streaming inference. - states[0] is the hidden states of all layers, - with shape of (num_layers, N, d_model); - states[1] is the cell states of all layers, - with shape of (num_layers, N, rnn_hidden_size). - warmup: - It controls selective bypass of of layers; if < 1.0, we will - bypass layers more frequently. - """ - if states is not None: - assert not self.training - assert len(states) == 2 - if not torch.jit.is_tracing(): - # for hidden state - assert states[0].shape == ( - self.num_layers, - src.size(1), - self.d_model, - ) - # for cell state - assert states[1].shape == ( - self.num_layers, - src.size(1), - self.rnn_hidden_size, - ) - - output = src - - outputs = [] - - new_hidden_states = [] - new_cell_states = [] - - for i, mod in enumerate(self.layers): - if states is None: - output = mod(output, warmup=warmup)[0] - else: - layer_state = ( - states[0][i : i + 1, :, :], # h: (1, N, d_model) - states[1][i : i + 1, :, :], # c: (1, N, rnn_hidden_size) - ) - output, (h, c) = mod(output, layer_state) - new_hidden_states.append(h) - new_cell_states.append(c) - - if self.combiner is not None and i in self.aux_layers: - outputs.append(output) - - if self.combiner is not None: - output = self.combiner(outputs) - - if states is None: - new_states = (torch.empty(0), torch.empty(0)) - else: - new_states = ( - torch.cat(new_hidden_states, dim=0), - torch.cat(new_cell_states, dim=0), - ) - - return output, new_states - - -class Conv2dSubsampling(nn.Module): - """Convolutional 2D subsampling (to 1/4 length). - - Convert an input of shape (N, T, idim) to an output - with shape (N, T', odim), where - T' = ((T-3)//2-1)//2, which approximates T' == T//4 - - It is based on - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - layer1_channels: int = 8, - layer2_channels: int = 32, - layer3_channels: int = 128, - is_pnnx: bool = False, - ) -> None: - """ - Args: - in_channels: - Number of channels in. The input shape is (N, T, in_channels). - Caution: It requires: T >= 9, in_channels >= 9. - out_channels - Output dim. The output shape is (N, ((T-3)//2-1)//2, out_channels) - layer1_channels: - Number of channels in layer1 - layer1_channels: - Number of channels in layer2 - is_pnnx: - True if we are converting the model to PNNX format. - False otherwise. - """ - assert in_channels >= 9 - super().__init__() - - self.conv = nn.Sequential( - ScaledConv2d( - in_channels=1, - out_channels=layer1_channels, - kernel_size=3, - padding=0, - ), - ActivationBalancer(channel_dim=1), - DoubleSwish(), - ScaledConv2d( - in_channels=layer1_channels, - out_channels=layer2_channels, - kernel_size=3, - stride=2, - ), - ActivationBalancer(channel_dim=1), - DoubleSwish(), - ScaledConv2d( - in_channels=layer2_channels, - out_channels=layer3_channels, - kernel_size=3, - stride=2, - ), - ActivationBalancer(channel_dim=1), - DoubleSwish(), - ) - self.out = ScaledLinear( - layer3_channels * (((in_channels - 3) // 2 - 1) // 2), out_channels - ) - # set learn_eps=False because out_norm is preceded by `out`, and `out` - # itself has learned scale, so the extra degree of freedom is not - # needed. - self.out_norm = BasicNorm(out_channels, learn_eps=False) - # constrain median of output to be close to zero. - self.out_balancer = ActivationBalancer( - channel_dim=-1, min_positive=0.45, max_positive=0.55 - ) - - # ncnn supports only batch size == 1 - self.is_pnnx = is_pnnx - self.conv_out_dim = self.out.weight.shape[1] - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Subsample x. - - Args: - x: - Its shape is (N, T, idim). - - Returns: - Return a tensor of shape (N, ((T-3)//2-1)//2, odim) - """ - # On entry, x is (N, T, idim) - x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) - x = self.conv(x) - - if torch.jit.is_tracing() and self.is_pnnx: - x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim) - x = self.out(x) - else: - # Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2) - b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) - - # Now x is of shape (N, ((T-3)//2-1))//2, odim) - x = self.out_norm(x) - x = self.out_balancer(x) - return x - - -class RandomCombine(nn.Module): - """ - This module combines a list of Tensors, all with the same shape, to - produce a single output of that same shape which, in training time, - is a random combination of all the inputs; but which in test time - will be just the last input. - - The idea is that the list of Tensors will be a list of outputs of multiple - conformer layers. This has a similar effect as iterated loss. (See: - DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER - NETWORKS). - """ - - def __init__( - self, - num_inputs: int, - final_weight: float = 0.5, - pure_prob: float = 0.5, - stddev: float = 2.0, - ) -> None: - """ - Args: - num_inputs: - The number of tensor inputs, which equals the number of layers' - outputs that are fed into this module. E.g. in an 18-layer neural - net if we output layers 16, 12, 18, num_inputs would be 3. - final_weight: - The amount of weight or probability we assign to the - final layer when randomly choosing layers or when choosing - continuous layer weights. - pure_prob: - The probability, on each frame, with which we choose - only a single layer to output (rather than an interpolation) - stddev: - A standard deviation that we add to log-probs for computing - randomized weights. - - The method of choosing which layers, or combinations of layers, to use, - is conceptually as follows:: - - With probability `pure_prob`:: - With probability `final_weight`: choose final layer, - Else: choose random non-final layer. - Else:: - Choose initial log-weights that correspond to assigning - weight `final_weight` to the final layer and equal - weights to other layers; then add Gaussian noise - with variance `stddev` to these log-weights, and normalize - to weights (note: the average weight assigned to the - final layer here will not be `final_weight` if stddev>0). - """ - super().__init__() - assert 0 <= pure_prob <= 1, pure_prob - assert 0 < final_weight < 1, final_weight - assert num_inputs >= 1 - - self.num_inputs = num_inputs - self.final_weight = final_weight - self.pure_prob = pure_prob - self.stddev = stddev - - self.final_log_weight = ( - torch.tensor( - (final_weight / (1 - final_weight)) * (self.num_inputs - 1) - ) - .log() - .item() - ) - - def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: - """Forward function. - Args: - inputs: - A list of Tensor, e.g. from various layers of a transformer. - All must be the same shape, of (*, num_channels) - Returns: - A Tensor of shape (*, num_channels). In test mode - this is just the final input. - """ - num_inputs = self.num_inputs - assert len(inputs) == num_inputs - if not self.training or torch.jit.is_scripting(): - return inputs[-1] - - # Shape of weights: (*, num_inputs) - num_channels = inputs[0].shape[-1] - num_frames = inputs[0].numel() // num_channels - - ndim = inputs[0].ndim - # stacked_inputs: (num_frames, num_channels, num_inputs) - stacked_inputs = torch.stack(inputs, dim=ndim).reshape( - (num_frames, num_channels, num_inputs) - ) - - # weights: (num_frames, num_inputs) - weights = self._get_random_weights( - inputs[0].dtype, inputs[0].device, num_frames - ) - - weights = weights.reshape(num_frames, num_inputs, 1) - # ans: (num_frames, num_channels, 1) - ans = torch.matmul(stacked_inputs, weights) - # ans: (*, num_channels) - - ans = ans.reshape(inputs[0].shape[:-1] + (num_channels,)) - - # The following if causes errors for torch script in torch 1.6.0 - # if __name__ == "__main__": - # # for testing only... - # print("Weights = ", weights.reshape(num_frames, num_inputs)) - return ans - - def _get_random_weights( - self, dtype: torch.dtype, device: torch.device, num_frames: int - ) -> torch.Tensor: - """Return a tensor of random weights, of shape - `(num_frames, self.num_inputs)`, - Args: - dtype: - The data-type desired for the answer, e.g. float, double. - device: - The device needed for the answer. - num_frames: - The number of sets of weights desired - Returns: - A tensor of shape (num_frames, self.num_inputs), such that - `ans.sum(dim=1)` is all ones. - """ - pure_prob = self.pure_prob - if pure_prob == 0.0: - return self._get_random_mixed_weights(dtype, device, num_frames) - elif pure_prob == 1.0: - return self._get_random_pure_weights(dtype, device, num_frames) - else: - p = self._get_random_pure_weights(dtype, device, num_frames) - m = self._get_random_mixed_weights(dtype, device, num_frames) - return torch.where( - torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m - ) - - def _get_random_pure_weights( - self, dtype: torch.dtype, device: torch.device, num_frames: int - ): - """Return a tensor of random one-hot weights, of shape - `(num_frames, self.num_inputs)`, - Args: - dtype: - The data-type desired for the answer, e.g. float, double. - device: - The device needed for the answer. - num_frames: - The number of sets of weights desired. - Returns: - A one-hot tensor of shape `(num_frames, self.num_inputs)`, with - exactly one weight equal to 1.0 on each frame. - """ - final_prob = self.final_weight - - # final contains self.num_inputs - 1 in all elements - final = torch.full((num_frames,), self.num_inputs - 1, device=device) - # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. # noqa - nonfinal = torch.randint( - self.num_inputs - 1, (num_frames,), device=device - ) - - indexes = torch.where( - torch.rand(num_frames, device=device) < final_prob, final, nonfinal - ) - ans = torch.nn.functional.one_hot( - indexes, num_classes=self.num_inputs - ).to(dtype=dtype) - return ans - - def _get_random_mixed_weights( - self, dtype: torch.dtype, device: torch.device, num_frames: int - ): - """Return a tensor of random one-hot weights, of shape - `(num_frames, self.num_inputs)`, - Args: - dtype: - The data-type desired for the answer, e.g. float, double. - device: - The device needed for the answer. - num_frames: - The number of sets of weights desired. - Returns: - A tensor of shape (num_frames, self.num_inputs), which elements - in [0..1] that sum to one over the second axis, i.e. - `ans.sum(dim=1)` is all ones. - """ - logprobs = ( - torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device) - * self.stddev # noqa - ) - logprobs[:, -1] += self.final_log_weight - return logprobs.softmax(dim=1) - - -def _test_random_combine(final_weight: float, pure_prob: float, stddev: float): - print( - f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}" # noqa - ) - num_inputs = 3 - num_channels = 50 - m = RandomCombine( - num_inputs=num_inputs, - final_weight=final_weight, - pure_prob=pure_prob, - stddev=stddev, - ) - - x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)] - - y = m(x) - assert y.shape == x[0].shape - assert torch.allclose(y, x[0]) # .. since actually all ones. - - -def _test_random_combine_main(): - _test_random_combine(0.999, 0, 0.0) - _test_random_combine(0.5, 0, 0.0) - _test_random_combine(0.999, 0, 0.0) - _test_random_combine(0.5, 0, 0.3) - _test_random_combine(0.5, 1, 0.3) - _test_random_combine(0.5, 0.5, 0.3) - - feature_dim = 50 - c = RNN(num_features=feature_dim, d_model=128) - batch_size = 5 - seq_len = 20 - # Just make sure the forward pass runs. - f = c( - torch.randn(batch_size, seq_len, feature_dim), - torch.full((batch_size,), seq_len, dtype=torch.int64), - ) - f # to remove flake8 warnings - - -if __name__ == "__main__": - feature_dim = 80 - m = RNN( - num_features=feature_dim, - d_model=512, - rnn_hidden_size=1024, - dim_feedforward=2048, - num_encoder_layers=12, - ) - batch_size = 5 - seq_len = 20 - # Just make sure the forward pass runs. - f = m( - torch.randn(batch_size, seq_len, feature_dim), - torch.full((batch_size,), seq_len, dtype=torch.int64), - warmup=0.5, - ) - num_param = sum([p.numel() for p in m.parameters()]) - print(f"Number of model parameters: {num_param}") - - _test_random_combine_main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/model.py b/egs/librispeech/ASR/lstm_transducer_stateless/model.py index d71132b4a..e69de29bb 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/model.py @@ -1,210 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import Tuple - -import k2 -import torch -import torch.nn as nn -from encoder_interface import EncoderInterface -from scaling import ScaledLinear - -from icefall.utils import add_sos - - -class Transducer(nn.Module): - """It implements https://arxiv.org/pdf/1211.3711.pdf - "Sequence Transduction with Recurrent Neural Networks" - """ - - def __init__( - self, - encoder: EncoderInterface, - decoder: nn.Module, - joiner: nn.Module, - encoder_dim: int, - decoder_dim: int, - joiner_dim: int, - vocab_size: int, - ): - """ - Args: - encoder: - It is the transcription network in the paper. Its accepts - two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). - It returns two tensors: `logits` of shape (N, T, encoder_dm) and - `logit_lens` of shape (N,). - decoder: - It is the prediction network in the paper. Its input shape - is (N, U) and its output shape is (N, U, decoder_dim). - It should contain one attribute: `blank_id`. - joiner: - It has two inputs with shapes: (N, T, encoder_dim) and - (N, U, decoder_dim). - Its output shape is (N, T, U, vocab_size). Note that its output - contains unnormalized probs, i.e., not processed by log-softmax. - """ - super().__init__() - assert isinstance(encoder, EncoderInterface), type(encoder) - assert hasattr(decoder, "blank_id") - - self.encoder = encoder - self.decoder = decoder - self.joiner = joiner - - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) - self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - y: k2.RaggedTensor, - prune_range: int = 5, - am_scale: float = 0.0, - lm_scale: float = 0.0, - warmup: float = 1.0, - reduction: str = "sum", - delay_penalty: float = 0.0, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: - A 3-D tensor of shape (N, T, C). - x_lens: - A 1-D tensor of shape (N,). It contains the number of frames in `x` - before padding. - y: - A ragged tensor with 2 axes [utt][label]. It contains labels of each - utterance. - prune_range: - The prune range for rnnt loss, it means how many symbols(context) - we are considering for each frame to compute the loss. - am_scale: - The scale to smooth the loss with am (output of encoder network) - part - lm_scale: - The scale to smooth the loss with lm (output of predictor network) - part - warmup: - A value warmup >= 0 that determines which modules are active, values - warmup > 1 "are fully warmed up" and all modules will be active. - reduction: - "sum" to sum the losses over all utterances in the batch. - "none" to return the loss in a 1-D tensor for each utterance - in the batch. - delay_penalty: - A constant value used to penalize symbol delay, to encourage - streaming models to emit symbols earlier. - See https://github.com/k2-fsa/k2/issues/955 and - https://arxiv.org/pdf/2211.00490.pdf for more details. - Returns: - Return the transducer loss. - - Note: - Regarding am_scale & lm_scale, it will make the loss-function one of - the form: - lm_scale * lm_probs + am_scale * am_probs + - (1-lm_scale-am_scale) * combined_probs - """ - assert reduction in ("sum", "none"), reduction - assert x.ndim == 3, x.shape - assert x_lens.ndim == 1, x_lens.shape - assert y.num_axes == 2, y.num_axes - - assert x.size(0) == x_lens.size(0) == y.dim0 - - encoder_out, x_lens, _ = self.encoder(x, x_lens, warmup=warmup) - assert torch.all(x_lens > 0) - - # Now for the decoder, i.e., the prediction network - row_splits = y.shape.row_splits(1) - y_lens = row_splits[1:] - row_splits[:-1] - - blank_id = self.decoder.blank_id - sos_y = add_sos(y, sos_id=blank_id) - - # sos_y_padded: [B, S + 1], start with SOS. - sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) - - # decoder_out: [B, S + 1, decoder_dim] - decoder_out = self.decoder(sos_y_padded) - - # Note: y does not start with SOS - # y_padded : [B, S] - y_padded = y.pad(mode="constant", padding_value=0) - - y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) - boundary[:, 2] = y_lens - boundary[:, 3] = x_lens - - lm = self.simple_lm_proj(decoder_out) - am = self.simple_am_proj(encoder_out) - - with torch.cuda.amp.autocast(enabled=False): - simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( - lm=lm.float(), - am=am.float(), - symbols=y_padded, - termination_symbol=blank_id, - lm_only_scale=lm_scale, - am_only_scale=am_scale, - boundary=boundary, - reduction=reduction, - delay_penalty=delay_penalty, - return_grad=True, - ) - - # ranges : [B, T, prune_range] - ranges = k2.get_rnnt_prune_ranges( - px_grad=px_grad, - py_grad=py_grad, - boundary=boundary, - s_range=prune_range, - ) - - # am_pruned : [B, T, prune_range, encoder_dim] - # lm_pruned : [B, T, prune_range, decoder_dim] - am_pruned, lm_pruned = k2.do_rnnt_pruning( - am=self.joiner.encoder_proj(encoder_out), - lm=self.joiner.decoder_proj(decoder_out), - ranges=ranges, - ) - - # logits : [B, T, prune_range, vocab_size] - - # project_input=False since we applied the decoder's input projections - # prior to do_rnnt_pruning (this is an optimization for speed). - logits = self.joiner(am_pruned, lm_pruned, project_input=False) - - with torch.cuda.amp.autocast(enabled=False): - pruned_loss = k2.rnnt_loss_pruned( - logits=logits.float(), - symbols=y_padded, - ranges=ranges, - termination_symbol=blank_id, - boundary=boundary, - delay_penalty=delay_penalty, - reduction=reduction, - ) - - return (simple_loss, pruned_loss) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py old mode 100755 new mode 100644 index 2a6e2adc6..e69de29bb --- a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py @@ -1,352 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -(1) greedy search -./lstm_transducer_stateless/pretrained.py \ - --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) beam search -./lstm_transducer_stateless/pretrained.py \ - --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) modified beam search -./lstm_transducer_stateless/pretrained.py \ - --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method modified_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(4) fast beam search -./lstm_transducer_stateless/pretrained.py \ - --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method fast_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -You can also use `./lstm_transducer_stateless/exp/epoch-xx.pt`. - -Note: ./lstm_transducer_stateless/exp/pretrained.pt is generated by -./lstm_transducer_stateless/export.py -""" - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import sentencepiece as spm -import torch -import torchaudio -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--bpe-model", - type=str, - help="""Path to bpe.model.""", - ) - - parser.add_argument( - "--method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. - """, - ) - - add_model_arguments(parser) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) - # We use only the first channel - ans.append(wave[0]) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - model.device = device - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens, _ = model.encoder( - x=features, x_lens=feature_lengths - ) - - num_waves = encoder_out.size(0) - hyps = [] - msg = f"Using {params.method}" - if params.method == "beam_search": - msg += f" with beam size {params.beam_size}" - logging.info(msg) - - if params.method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - else: - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError(f"Unsupported method: {params.method}") - - hyps.append(sp.decode(hyp).split()) - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/stream.py b/egs/librispeech/ASR/lstm_transducer_stateless/stream.py index 97d890c82..e69de29bb 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/stream.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/stream.py @@ -1,148 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from typing import List, Optional, Tuple - -import k2 -import torch -from beam_search import Hypothesis, HypothesisList - -from icefall.utils import AttributeDict - - -class Stream(object): - def __init__( - self, - params: AttributeDict, - cut_id: str, - decoding_graph: Optional[k2.Fsa] = None, - device: torch.device = torch.device("cpu"), - LOG_EPS: float = math.log(1e-10), - ) -> None: - """ - Args: - params: - It's the return value of :func:`get_params`. - cut_id: - The cut id of the current stream. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - device: - The device to run this stream. - LOG_EPS: - A float value used for padding. - """ - self.LOG_EPS = LOG_EPS - self.cut_id = cut_id - - # Containing attention caches and convolution caches - self.states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None - - # It uses different attributes for different decoding methods. - self.context_size = params.context_size - self.decoding_method = params.decoding_method - if params.decoding_method == "greedy_search": - self.hyp = [params.blank_id] * params.context_size - elif params.decoding_method == "modified_beam_search": - self.hyps = HypothesisList() - self.hyps.add( - Hypothesis( - ys=[params.blank_id] * params.context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - ) - ) - elif params.decoding_method == "fast_beam_search": - # feature_len is needed to get partial results. - # The rnnt_decoding_stream for fast_beam_search. - self.rnnt_decoding_stream: k2.RnntDecodingStream = ( - k2.RnntDecodingStream(decoding_graph) - ) - self.hyp: Optional[List[int]] = None - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - - self.ground_truth: str = "" - - self.feature: Optional[torch.Tensor] = None - # Make sure all feature frames can be used. - # We aim to obtain 1 frame after subsampling. - self.chunk_length = params.subsampling_factor - self.pad_length = 5 - self.num_frames = 0 - self.num_processed_frames = 0 - - # After all feature frames are processed, we set this flag to True - self._done = False - - def set_feature(self, feature: torch.Tensor) -> None: - assert feature.dim() == 2, feature.dim() - # tail padding here to alleviate the tail deletion problem - num_tail_padded_frames = 35 - self.num_frames = feature.size(0) + num_tail_padded_frames - self.feature = torch.nn.functional.pad( - feature, - (0, 0, 0, self.pad_length + num_tail_padded_frames), - mode="constant", - value=self.LOG_EPS, - ) - - def get_feature_chunk(self) -> torch.Tensor: - """Get a chunk of feature frames. - - Returns: - A tensor of shape (ret_length, feature_dim). - """ - update_length = min( - self.num_frames - self.num_processed_frames, self.chunk_length - ) - ret_length = update_length + self.pad_length - - ret_feature = self.feature[ - self.num_processed_frames : self.num_processed_frames + ret_length - ] - # Cut off used frames. - # self.feature = self.feature[update_length:] - - self.num_processed_frames += update_length - if self.num_processed_frames >= self.num_frames: - self._done = True - - return ret_feature - - @property - def id(self) -> str: - return self.cut_id - - @property - def done(self) -> bool: - """Return True if all feature frames are processed.""" - return self._done - - def decoding_result(self) -> List[int]: - """Obtain current decoding result.""" - if self.decoding_method == "greedy_search": - return self.hyp[self.context_size :] - elif self.decoding_method == "modified_beam_search": - best_hyp = self.hyps.get_most_probable(length_norm=True) - return best_hyp.ys[self.context_size :] - else: - assert self.decoding_method == "fast_beam_search" - return self.hyp diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py old mode 100755 new mode 100644 index d6376bdc0..e69de29bb --- a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py @@ -1,968 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./lstm_transducer_stateless/streaming_decode.py \ - --epoch 35 \ - --avg 10 \ - --exp-dir lstm_transducer_stateless/exp \ - --num-decode-streams 2000 \ - --num-encoder-layers 12 \ - --rnn-hidden-size 1024 \ - --decoding-method greedy_search \ - --use-averaged-model True - -(2) modified beam search -./lstm_transducer_stateless/streaming_decode.py \ - --epoch 35 \ - --avg 10 \ - --exp-dir lstm_transducer_stateless/exp \ - --num-decode-streams 2000 \ - --num-encoder-layers 12 \ - --rnn-hidden-size 1024 \ - --decoding-method modified_beam_search \ - --use-averaged-model True \ - --beam-size 4 - -(3) fast beam search -./lstm_transducer_stateless/streaming_decode.py \ - --epoch 35 \ - --avg 10 \ - --exp-dir lstm_transducer_stateless/exp \ - --num-decode-streams 2000 \ - --num-encoder-layers 12 \ - --rnn-hidden-size 1024 \ - --decoding-method fast_beam_search \ - --use-averaged-model True \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 -""" -import argparse -import logging -import warnings -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import numpy as np -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule -from beam_search import Hypothesis, HypothesisList, get_hyps_shape -from kaldifeat import Fbank, FbankOptions -from lhotse import CutSet -from lstm import LOG_EPSILON, stack_states, unstack_states -from stream import Stream -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.decode import one_best_decoding -from icefall.utils import ( - AttributeDict, - get_texts, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="transducer_emformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An interger indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--sampling-rate", - type=float, - default=16000, - help="Sample rate of the audio", - ) - - parser.add_argument( - "--num-decode-streams", - type=int, - default=2000, - help="The number of streams that can be decoded in parallel", - ) - - add_model_arguments(parser) - - return parser - - -def greedy_search( - model: nn.Module, - encoder_out: torch.Tensor, - streams: List[Stream], -) -> None: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C), where N >= 1. - streams: - A list of Stream objects. - """ - assert len(streams) == encoder_out.size(0) - assert encoder_out.ndim == 3 - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - device = next(model.parameters()).device - T = encoder_out.size(1) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - decoder_input = torch.tensor( - [stream.hyp[-context_size:] for stream in streams], - device=device, - dtype=torch.int64, - ) - # decoder_out is of shape (batch_size, 1, decoder_out_dim) - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - for t in range(T): - # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) - current_encoder_out = encoder_out[:, t : t + 1, :] # noqa - - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - # logits'shape (batch_size, vocab_size) - logits = logits.squeeze(1).squeeze(1) - - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v != blank_id: - streams[i].hyp.append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = torch.tensor( - [stream.hyp[-context_size:] for stream in streams], - device=device, - dtype=torch.int64, - ) - decoder_out = model.decoder( - decoder_input, - need_pad=False, - ) - decoder_out = model.joiner.decoder_proj(decoder_out) - - -def modified_beam_search( - model: nn.Module, - encoder_out: torch.Tensor, - streams: List[Stream], - beam: int = 4, -): - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. - - Args: - model: - The RNN-T model. - encoder_out: - A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of - the encoder model. - streams: - A list of stream objects. - beam: - Number of active paths during the beam search. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert len(streams) == encoder_out.size(0) - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - device = next(model.parameters()).device - batch_size = len(streams) - T = encoder_out.size(1) - - B = [stream.hyps for stream in streams] - - encoder_out = model.joiner.encoder_proj(encoder_out) - - for t in range(T): - current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.stack( - [hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0 - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, decoder_out, project_input=False - ) - # logits is of shape (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) - - log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) - - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_ys = hyp.ys[:] - new_token = topk_token_indexes[k] - if new_token != blank_id: - new_ys.append(new_token) - - new_log_prob = topk_log_probs[k] - new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) - B[i].add(new_hyp) - - for i in range(batch_size): - streams[i].hyps = B[i] - - -def fast_beam_search_one_best( - model: nn.Module, - streams: List[Stream], - encoder_out: torch.Tensor, - processed_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, -) -> None: - """It limits the maximum number of symbols per frame to 1. - - A lattice is first obtained using modified beam search, and then - the shortest path within the lattice is used as the final output. - - Args: - model: - An instance of `Transducer`. - streams: - A list of stream objects. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - processed_lens: - A tensor of shape (N,) containing the number of processed frames - in `encoder_out` before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - """ - assert encoder_out.ndim == 3 - - context_size = model.decoder.context_size - vocab_size = model.decoder.vocab_size - - B, T, C = encoder_out.shape - assert B == len(streams) - - config = k2.RnntDecodingConfig( - vocab_size=vocab_size, - decoder_history_len=context_size, - beam=beam, - max_contexts=max_contexts, - max_states=max_states, - ) - individual_streams = [] - for i in range(B): - individual_streams.append(streams[i].rnnt_decoding_stream) - decoding_streams = k2.RnntDecodingStreams(individual_streams, config) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - for t in range(T): - # shape is a RaggedShape of shape (B, context) - # contexts is a Tensor of shape (shape.NumElements(), context_size) - shape, contexts = decoding_streams.get_contexts() - # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 - contexts = contexts.to(torch.int64) - # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) - decoder_out = model.decoder(contexts, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # current_encoder_out is of shape - # (shape.NumElements(), 1, joiner_dim) - # fmt: off - current_encoder_out = torch.index_select( - encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) - ) - # fmt: on - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - logits = logits.squeeze(1).squeeze(1) - log_probs = logits.log_softmax(dim=-1) - decoding_streams.advance(log_probs) - - decoding_streams.terminate_and_flush_to_streams() - - lattice = decoding_streams.format_output(processed_lens.tolist()) - - best_path = one_best_decoding(lattice) - hyps = get_texts(best_path) - - for i in range(B): - streams[i].hyp = hyps[i] - - -def decode_one_chunk( - model: nn.Module, - streams: List[Stream], - params: AttributeDict, - decoding_graph: Optional[k2.Fsa] = None, -) -> List[int]: - """ - Args: - model: - The Transducer model. - streams: - A list of Stream objects. - params: - It is returned by :func:`get_params`. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or LG, Used - only when --decoding_method is fast_beam_search. - - Returns: - A list of indexes indicating the finished streams. - """ - device = next(model.parameters()).device - - feature_list = [] - feature_len_list = [] - state_list = [] - num_processed_frames_list = [] - - for stream in streams: - # We should first get `stream.num_processed_frames` - # before calling `stream.get_feature_chunk()` - # since `stream.num_processed_frames` would be updated - num_processed_frames_list.append(stream.num_processed_frames) - feature = stream.get_feature_chunk() - feature_len = feature.size(0) - feature_list.append(feature) - feature_len_list.append(feature_len) - state_list.append(stream.states) - - features = pad_sequence( - feature_list, batch_first=True, padding_value=LOG_EPSILON - ).to(device) - feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor( - num_processed_frames_list, device=device - ) - - # Make sure it has at least 1 frame after subsampling - tail_length = params.subsampling_factor + 5 - if features.size(1) < tail_length: - pad_length = tail_length - features.size(1) - feature_lens += pad_length - features = torch.nn.functional.pad( - features, - (0, 0, 0, pad_length), - mode="constant", - value=LOG_EPSILON, - ) - - # Stack states of all streams - states = stack_states(state_list) - - encoder_out, encoder_out_lens, states = model.encoder( - x=features, - x_lens=feature_lens, - states=states, - ) - - if params.decoding_method == "greedy_search": - greedy_search( - model=model, - streams=streams, - encoder_out=encoder_out, - ) - elif params.decoding_method == "modified_beam_search": - modified_beam_search( - model=model, - streams=streams, - encoder_out=encoder_out, - beam=params.beam_size, - ) - elif params.decoding_method == "fast_beam_search": - # feature_len is needed to get partial results. - # The rnnt_decoding_stream for fast_beam_search. - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - processed_lens = ( - num_processed_frames // params.subsampling_factor - + encoder_out_lens - ) - fast_beam_search_one_best( - model=model, - streams=streams, - encoder_out=encoder_out, - processed_lens=processed_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - - # Update cached states of each stream - state_list = unstack_states(states) - for i, s in enumerate(state_list): - streams[i].states = s - - finished_streams = [i for i, stream in enumerate(streams) if stream.done] - return finished_streams - - -def create_streaming_feature_extractor() -> Fbank: - """Create a CPU streaming feature extractor. - - At present, we assume it returns a fbank feature extractor with - fixed options. In the future, we will support passing in the options - from outside. - - Returns: - Return a CPU streaming feature extractor. - """ - opts = FbankOptions() - opts.device = "cpu" - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - return Fbank(opts) - - -def decode_dataset( - cuts: CutSet, - model: nn.Module, - params: AttributeDict, - sp: spm.SentencePieceProcessor, - decoding_graph: Optional[k2.Fsa] = None, -): - """Decode dataset. - - Args: - cuts: - Lhotse Cutset containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The Transducer model. - sp: - The BPE model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or LG, Used - only when --decoding_method is fast_beam_search. - - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - device = next(model.parameters()).device - - log_interval = 300 - - fbank = create_streaming_feature_extractor() - - decode_results = [] - streams = [] - for num, cut in enumerate(cuts): - # Each utterance has a Stream. - stream = Stream( - params=params, - cut_id=cut.id, - decoding_graph=decoding_graph, - device=device, - LOG_EPS=LOG_EPSILON, - ) - - stream.states = model.encoder.get_init_states(device=device) - - audio: np.ndarray = cut.load_audio() - # audio.shape: (1, num_samples) - assert len(audio.shape) == 2 - assert audio.shape[0] == 1, "Should be single channel" - assert audio.dtype == np.float32, audio.dtype - # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" - - samples = torch.from_numpy(audio).squeeze(0) - feature = fbank(samples) - stream.set_feature(feature) - stream.ground_truth = cut.supervisions[0].text - - streams.append(stream) - - while len(streams) >= params.num_decode_streams: - finished_streams = decode_one_chunk( - model=model, - streams=streams, - params=params, - decoding_graph=decoding_graph, - ) - - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - streams[i].id, - streams[i].ground_truth.split(), - sp.decode(streams[i].decoding_result()).split(), - ) - ) - del streams[i] - - if num % log_interval == 0: - logging.info(f"Cuts processed until now is {num}.") - - while len(streams) > 0: - finished_streams = decode_one_chunk( - model=model, - streams=streams, - params=params, - decoding_graph=decoding_graph, - ) - - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - streams[i].id, - streams[i].ground_truth.split(), - sp.decode(streams[i].decoding_result()).split(), - ) - ) - del streams[i] - - if params.decoding_method == "greedy_search": - key = "greedy_search" - elif params.decoding_method == "fast_beam_search": - key = ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ) - else: - key = f"beam_size_{params.beam_size}" - - return {key: decode_results} - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - store_transcripts(filename=recog_path, texts=sorted(results)) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "fast_beam_search", - "modified_beam_search", - ) - params.res_dir = params.exp_dir / "streaming" / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-streaming-decode") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - params.device = device - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.eval() - - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - librispeech = LibriSpeechAsrDataModule(args) - - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() - - test_sets = ["test-clean", "test-other"] - test_cuts = [test_clean_cuts, test_other_cuts] - - for test_set, test_cut in zip(test_sets, test_cuts): - results_dict = decode_dataset( - cuts=test_cut, - model=model, - params=params, - sp=sp, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - torch.manual_seed(20220810) - main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/train.py b/egs/librispeech/ASR/lstm_transducer_stateless/train.py old mode 100755 new mode 100644 index d30fc260a..e69de29bb --- a/egs/librispeech/ASR/lstm_transducer_stateless/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/train.py @@ -1,1157 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo,) -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./lstm_transducer_stateless/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --exp-dir lstm_transducer_stateless/exp \ - --full-libri 1 \ - --max-duration 300 - -# For mix precision training: - -./lstm_transducer_stateless/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir lstm_transducer_stateless/exp \ - --full-libri 1 \ - --max-duration 550 -""" - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from lstm import RNN -from model import Transducer -from optim import Eden, Eve -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.utils import ( - AttributeDict, - MetricsTracker, - display_and_save_batch, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=int, - default=12, - help="Number of RNN encoder layers..", - ) - - parser.add_argument( - "--encoder-dim", - type=int, - default=512, - help="Encoder output dimesion.", - ) - - parser.add_argument( - "--rnn-hidden-size", - type=int, - default=1024, - help="Hidden dim for LSTM layers.", - ) - - parser.add_argument( - "--aux-layer-period", - type=int, - default=0, - help="""Peroid of auxiliary layers used for randomly combined during training. - If set to 0, will not use the random combiner (Default). - You can set a positive integer to use the random combiner, e.g., 3. - """, - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=35, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="lstm_transducer_stateless/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--initial-lr", - type=float, - default=0.003, - help="""The initial learning rate. This value should not need to be - changed.""", - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate decreases. - We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=10, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=4000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=20, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=100, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - parser.add_argument( - "--delay-penalty", - type=float, - default=0.0, - help="""A constant value used to penalize symbol delay, - to encourage streaming models to emit symbols earlier. - See https://github.com/k2-fsa/k2/issues/955 and - https://arxiv.org/pdf/2211.00490.pdf for more details.""", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warm_step for Noam optimizer. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for conformer - "feature_dim": 80, - "subsampling_factor": 4, - "dim_feedforward": 2048, - # parameters for decoder - "decoder_dim": 512, - # parameters for joiner - "joiner_dim": 512, - # parameters for Noam - "model_warm_step": 3000, # arg given to model, not for lrate - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = RNN( - num_features=params.feature_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.encoder_dim, - rnn_hidden_size=params.rnn_hidden_size, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - aux_layer_period=params.aux_layer_period, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, - warmup: float = 1.0, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute RNN-T loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Conformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - warmup=warmup, - reduction="none", - delay_penalty=params.delay_penalty if warmup >= 2.0 else 0, - ) - simple_loss_is_finite = torch.isfinite(simple_loss) - pruned_loss_is_finite = torch.isfinite(pruned_loss) - is_finite = simple_loss_is_finite & pruned_loss_is_finite - if not torch.all(is_finite): - logging.info( - "Not all losses are finite!\n" - f"simple_loss: {simple_loss}\n" - f"pruned_loss: {pruned_loss}" - ) - display_and_save_batch(batch, params=params, sp=sp) - simple_loss = simple_loss[simple_loss_is_finite] - pruned_loss = pruned_loss[pruned_loss_is_finite] - - # If either all simple_loss or pruned_loss is inf or nan, - # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): - raise ValueError( - "There are too many utterances in this batch " - "leading to inf or nan losses." - ) - - simple_loss = simple_loss.sum() - pruned_loss = pruned_loss.sum() - # after the main warmup step, we keep pruned_loss_scale small - # for the same amount of time (model_warm_step), to avoid - # overwhelming the simple_loss and causing it to diverge, - # in case it had not fully learned the alignment yet. - pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss - ) - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - # info["frames"] is an approximate number for two reasons: - # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 - # (2) If some utterances in the batch lead to inf/nan loss, they - # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) - - # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa - info["utterances"] = feature.size(0) - # averaged input duration in frames over utterances - info["utt_duration"] = feature_lens.sum().item() - # averaged padding proportion over utterances - info["utt_pad_proportion"] = ( - ((feature.size(1) - feature_lens) / feature.size(1)).sum().item() - ) - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - warmup=(params.batch_idx_train / params.model_warm_step), - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 30: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}" - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - if params.full_libri is False: - params.valid_interval = 800 - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank]) - - optimizer = Eve(model.parameters(), lr=params.initial_lr) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - # # overwrite it - # scheduler.base_lrs = [params.initial_lr for _ in scheduler.base_lrs] - # print(scheduler.base_lrs) - - if params.print_diagnostics: - diagnostic = diagnostics.attach_diagnostics(model) - - librispeech = LibriSpeechAsrDataModule(args) - - train_cuts = librispeech.train_clean_100_cuts() - if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 20.0: - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" - ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./lstm.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 3) // 2 - 1) // 2 - tokens = sp.encode(c.supervisions[0].text, out_type=str) - - if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - - return True - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = librispeech.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = librispeech.dev_clean_cuts() - valid_cuts += librispeech.dev_other_cuts() - valid_dl = librispeech.valid_dataloaders(valid_cuts) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - warmup=0.0 if params.start_epoch == 1 else 1.0, - ) - - scaler = GradScaler(enabled=params.use_fp16) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - params: AttributeDict, - warmup: float, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - warmup=warmup, - ) - loss.backward() - optimizer.step() - optimizer.zero_grad() - except RuntimeError as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - raise - - -def main(): - parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index bad4e243e..f7e1b5a54 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -185,20 +185,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -295,8 +299,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -474,9 +477,7 @@ def decode_one_batch( ) feature_lens += num_tail_padded_frames - encoder_out, encoder_out_lens, _ = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens, _ = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -535,10 +536,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -700,9 +698,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -735,8 +731,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -789,9 +784,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -826,13 +819,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -860,13 +852,12 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -895,7 +886,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -961,9 +952,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py index 190673638..0ad00cda3 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py @@ -146,20 +146,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -225,8 +229,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) add_model_arguments(parser) @@ -342,9 +345,7 @@ def export_encoder_model_onnx( x = torch.zeros(N, 9, 80, dtype=torch.float32) x_lens = torch.tensor([9], dtype=torch.int64) h = torch.rand(encoder_model.num_encoder_layers, N, encoder_model.d_model) - c = torch.rand( - encoder_model.num_encoder_layers, N, encoder_model.rnn_hidden_size - ) + c = torch.rand(encoder_model.num_encoder_layers, N, encoder_model.rnn_hidden_size) warmup = 1.0 torch.onnx.export( @@ -445,13 +446,9 @@ def export_joiner_model_onnx( - projected_decoder_out: a tensor of shape (N, joiner_dim) """ - encoder_proj_filename = str(joiner_filename).replace( - ".onnx", "_encoder_proj.onnx" - ) + encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") - decoder_proj_filename = str(joiner_filename).replace( - ".onnx", "_decoder_proj.onnx" - ) + decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] @@ -550,13 +547,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -585,13 +581,12 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -620,7 +615,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -694,9 +689,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py index da184b76f..5a8efd718 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py @@ -86,10 +86,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -124,10 +126,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -315,9 +316,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py index fadeb4ac2..4957d14b1 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py @@ -84,9 +84,7 @@ class Transducer(nn.Module): self.decoder_giga = decoder_giga self.joiner_giga = joiner_giga - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) if decoder_giga is not None: @@ -190,9 +188,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py index 410de8d3d..3b471fa85 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py @@ -156,9 +156,7 @@ class Model: assert ret == 0, ret encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() - encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to( - torch.int32 - ) + encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to(torch.int32) hx = torch.from_numpy(ncnn_out2.numpy()).clone() cx = torch.from_numpy(ncnn_out3.numpy()).clone() return encoder_out, encoder_out_lens, hx, cx @@ -200,10 +198,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -286,9 +283,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py index bef0ad760..7d931a286 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py @@ -92,9 +92,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -119,10 +121,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -169,8 +173,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -201,10 +204,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -267,15 +269,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens, _ = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens, _ = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -347,9 +345,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py index e47a05a9e..baff15ea6 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py @@ -144,9 +144,7 @@ class Model: assert ret == 0, ret encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() - encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to( - torch.int32 - ) + encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to(torch.int32) hx = torch.from_numpy(ncnn_out2.numpy()).clone() cx = torch.from_numpy(ncnn_out3.numpy()).clone() return encoder_out, encoder_out_lens, hx, cx @@ -188,10 +186,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -229,9 +226,7 @@ def greedy_search( if decoder_out is None: assert hyp is None, hyp hyp = [blank_id] * context_size - decoder_input = torch.tensor( - hyp, dtype=torch.int32 - ) # (1, context_size) + decoder_input = torch.tensor(hyp, dtype=torch.int32) # (1, context_size) decoder_out = model.run_decoder(decoder_input).squeeze(0) else: assert decoder_out.ndim == 1 @@ -310,9 +305,7 @@ def main(): frames.append(online_fbank.get_frame(num_processed_frames + i)) num_processed_frames += offset frames = torch.cat(frames, dim=0) - encoder_out, encoder_out_lens, hx, cx = model.run_encoder( - frames, states - ) + encoder_out, encoder_out_lens, hx, cx = model.run_encoder(frames, states) states = (hx, cx) hyp, decoder_out = greedy_search( model, encoder_out.squeeze(0), decoder_out, hyp @@ -328,9 +321,7 @@ def main(): frames.append(online_fbank.get_frame(num_processed_frames + i)) num_processed_frames += offset frames = torch.cat(frames, dim=0) - encoder_out, encoder_out_lens, hx, cx = model.run_encoder( - frames, states - ) + encoder_out, encoder_out_lens, hx, cx = model.run_encoder(frames, states) states = (hx, cx) hyp, decoder_out = greedy_search( model, encoder_out.squeeze(0), decoder_out, hyp @@ -343,9 +334,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py index 232d3dd18..b31fefa0a 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py @@ -109,10 +109,12 @@ def get_args(): parser.add_argument( "sound_filename", type=str, - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -147,10 +149,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -199,9 +200,7 @@ class Model: sess_options=self.session_opts, ) - def run_encoder( - self, x, h0, c0 - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def run_encoder(self, x, h0, c0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: x: @@ -258,9 +257,7 @@ class Model: }, )[0] - return self.run_joiner_decoder_proj( - torch.from_numpy(decoder_out).squeeze(1) - ) + return self.run_joiner_decoder_proj(torch.from_numpy(decoder_out).squeeze(1)) def run_joiner( self, @@ -303,11 +300,7 @@ class Model: projected_encoder_out = self.joiner_encoder_proj.run( [self.joiner_encoder_proj.get_outputs()[0].name], - { - self.joiner_encoder_proj.get_inputs()[ - 0 - ].name: encoder_out.numpy() - }, + {self.joiner_encoder_proj.get_inputs()[0].name: encoder_out.numpy()}, )[0] return torch.from_numpy(projected_encoder_out) @@ -326,11 +319,7 @@ class Model: projected_decoder_out = self.joiner_decoder_proj.run( [self.joiner_decoder_proj.get_outputs()[0].name], - { - self.joiner_decoder_proj.get_inputs()[ - 0 - ].name: decoder_out.numpy() - }, + {self.joiner_decoder_proj.get_inputs()[0].name: decoder_out.numpy()}, )[0] return torch.from_numpy(projected_decoder_out) @@ -369,9 +358,7 @@ def greedy_search( if decoder_out is None: assert hyp is None, hyp hyp = [blank_id] * context_size - decoder_input = torch.tensor( - [hyp], dtype=torch.int64 - ) # (1, context_size) + decoder_input = torch.tensor([hyp], dtype=torch.int64) # (1, context_size) decoder_out = model.run_decoder(decoder_input) else: assert decoder_out.shape[0] == 1 @@ -474,9 +461,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py index 5eaaf321f..08a895a75 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py @@ -95,9 +95,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -163,8 +161,7 @@ def get_parser(): "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. " - "Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", ) parser.add_argument( @@ -238,42 +235,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -645,11 +645,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -692,9 +688,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -707,14 +701,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -725,9 +714,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -958,9 +945,7 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) libri_tot_loss.write_summary( tb_writer, "train/libri_tot_", params.batch_idx_train ) @@ -1006,8 +991,7 @@ def filter_short_and_long_utterances( # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False @@ -1155,9 +1139,7 @@ def run(rank, world_size, args): train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: - cuts_musan = load_manifest( - Path(args.manifest_dir) / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") else: cuts_musan = None diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py index 9eee19379..a8d5605fb 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py @@ -182,20 +182,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -290,8 +294,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -386,9 +389,7 @@ def decode_one_batch( ) feature_lens += num_tail_padded_frames - encoder_out, encoder_out_lens, _ = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens, _ = model.encoder(x=feature, x_lens=feature_lens) if params.decoding_method == "fast_beam_search": res = fast_beam_search_one_best( @@ -441,10 +442,7 @@ def decode_one_batch( nbest_scale=params.nbest_scale, return_timestamps=True, ) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: res = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -522,9 +520,7 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[ - str, List[Tuple[str, List[str], List[str], List[float], List[float]]] -]: +) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]: """Decode dataset. Args: @@ -599,9 +595,7 @@ def decode_dataset( cut_ids, hyps, texts, timestamps_hyp, timestamps_ref ): ref_words = ref_text.split() - this_batch.append( - (cut_id, ref_words, hyp_words, time_ref, time_hyp) - ) + this_batch.append((cut_id, ref_words, hyp_words, time_ref, time_hyp)) results[name].extend(this_batch) @@ -610,9 +604,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -650,8 +642,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -678,9 +669,7 @@ def save_results( note = "" logging.info(s) - s = "\nFor {}, symbol-delay of different settings are:\n".format( - test_set_name - ) + s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name) note = "\tbest for {}".format(test_set_name) for key, val in test_set_delays: s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note) @@ -724,9 +713,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -758,13 +745,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -787,13 +773,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -821,7 +806,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -848,9 +833,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/export.py b/egs/librispeech/ASR/lstm_transducer_stateless3/export.py index 212c7bad6..51238f768 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/export.py @@ -122,20 +122,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -172,8 +176,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) add_model_arguments(parser) @@ -281,13 +284,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -310,13 +312,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -344,7 +345,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -380,9 +381,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py index a3443cf0a..180ba8c72 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py @@ -85,10 +85,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -123,10 +125,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -314,9 +315,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py index 90bc351f4..6e51b85e4 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py @@ -661,9 +661,7 @@ class RandomCombine(nn.Module): self.stddev = stddev self.final_log_weight = ( - torch.tensor( - (final_weight / (1 - final_weight)) * (self.num_inputs - 1) - ) + torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)) .log() .item() ) @@ -760,16 +758,14 @@ class RandomCombine(nn.Module): # final contains self.num_inputs - 1 in all elements final = torch.full((num_frames,), self.num_inputs - 1, device=device) # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. # noqa - nonfinal = torch.randint( - self.num_inputs - 1, (num_frames,), device=device - ) + nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device) indexes = torch.where( torch.rand(num_frames, device=device) < final_prob, final, nonfinal ) - ans = torch.nn.functional.one_hot( - indexes, num_classes=self.num_inputs - ).to(dtype=dtype) + ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to( + dtype=dtype + ) return ans def _get_random_mixed_weights( diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py index 0e48fef04..4f8049245 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py @@ -89,9 +89,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -116,10 +118,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -166,8 +170,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -198,10 +201,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -264,15 +266,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens, _ = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens, _ = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -344,9 +342,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py index cfa918ed5..4e9063a40 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py @@ -101,8 +101,9 @@ def get_parser(): "--epoch", type=int, default=40, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( @@ -119,20 +120,24 @@ def get_parser(): "--avg", type=int, default=20, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -199,8 +204,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -359,9 +363,7 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner( - current_encoder_out, decoder_out, project_input=False - ) + logits = model.joiner(current_encoder_out, decoder_out, project_input=False) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -378,9 +380,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -539,9 +539,7 @@ def decode_one_chunk( feature_list, batch_first=True, padding_value=LOG_EPSILON ).to(device) feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor( - num_processed_frames_list, device=device - ) + num_processed_frames = torch.tensor(num_processed_frames_list, device=device) # Make sure it has at least 1 frame after subsampling tail_length = params.subsampling_factor + 5 @@ -583,8 +581,7 @@ def decode_one_chunk( with warnings.catch_warnings(): warnings.simplefilter("ignore") processed_lens = ( - num_processed_frames // params.subsampling_factor - + encoder_out_lens + num_processed_frames // params.subsampling_factor + encoder_out_lens ) fast_beam_search_one_best( model=model, @@ -596,9 +593,7 @@ def decode_one_chunk( max_states=params.max_states, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") # Update cached states of each stream state_list = unstack_states(states) @@ -773,8 +768,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -816,9 +810,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -852,13 +844,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -881,13 +872,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -915,7 +905,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py index 60a5a2be7..a1d19fb73 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py @@ -87,9 +87,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -232,42 +230,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -606,11 +607,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -650,9 +647,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -665,14 +660,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -683,9 +673,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -852,10 +840,7 @@ def train_one_epoch( rank=rank, ) - if ( - batch_idx % params.log_interval == 0 - and not params.print_diagnostics - ): + if batch_idx % params.log_interval == 0 and not params.print_diagnostics: cur_lr = scheduler.get_last_lr()[0] logging.info( f"Epoch {params.cur_epoch}, " @@ -872,9 +857,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if ( batch_idx > 0 @@ -1009,8 +992,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py index 8dd1459ca..fd2a5354a 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py +++ b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py @@ -74,17 +74,18 @@ class LibriSpeechAsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", + description=( + "These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc." + ), ) group.add_argument( "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. " - "Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", ) group.add_argument( "--manifest-dir", @@ -96,75 +97,91 @@ class LibriSpeechAsrDataModule: "--max-duration", type=int, default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", + help=( + "Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM." + ), ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", + help=( + "When enabled, the batches will come from buckets of " + "similar duration (saves padding frames)." + ), ) group.add_argument( "--num-buckets", type=int, default=30, - help="The number of buckets for the BucketingSampler" - "(you might want to increase it for larger datasets).", + help=( + "The number of buckets for the BucketingSampler" + "(you might want to increase it for larger datasets)." + ), ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", + help=( + "When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding." + ), ) group.add_argument( "--duration-factor", type=float, default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", + help=( + "Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch." + ), ) group.add_argument( "--gap", type=float, default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", + help=( + "The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used." + ), ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", + help=( + "When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available." + ), ) group.add_argument( "--shuffle", type=str2bool, default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", + help=( + "When enabled (=default), the examples will be shuffled for each epoch." + ), ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", + help=( + "When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it." + ), ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that " - "collect the batches.", + help="The number of training dataloader workers that collect the batches.", ) group.add_argument( @@ -178,18 +195,22 @@ class LibriSpeechAsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", + help=( + "Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp." + ), ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", + help=( + "When enabled, select noise from MUSAN and mix it" + "with training dataset. " + ), ) def train_dataloaders( @@ -208,20 +229,16 @@ class LibriSpeechAsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "cuts_musan.json.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.json.gz") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - f"Using cut concatenation with duration factor " + "Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -236,9 +253,7 @@ class LibriSpeechAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -281,9 +296,7 @@ class LibriSpeechAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -340,9 +353,7 @@ class LibriSpeechAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: @@ -389,23 +400,17 @@ class LibriSpeechAsrDataModule: @lru_cache() def train_clean_100_cuts(self) -> CutSet: logging.info("About to get train-clean-100 cuts") - return load_manifest( - self.args.manifest_dir / "cuts_train-clean-100.json.gz" - ) + return load_manifest(self.args.manifest_dir / "cuts_train-clean-100.json.gz") @lru_cache() def train_clean_360_cuts(self) -> CutSet: logging.info("About to get train-clean-360 cuts") - return load_manifest( - self.args.manifest_dir / "cuts_train-clean-360.json.gz" - ) + return load_manifest(self.args.manifest_dir / "cuts_train-clean-360.json.gz") @lru_cache() def train_other_500_cuts(self) -> CutSet: logging.info("About to get train-other-500 cuts") - return load_manifest( - self.args.manifest_dir / "cuts_train-other-500.json.gz" - ) + return load_manifest(self.args.manifest_dir / "cuts_train-other-500.json.gz") @lru_cache() def dev_clean_cuts(self) -> CutSet: diff --git a/egs/librispeech/ASR/pruned2_knowledge/beam_search.py b/egs/librispeech/ASR/pruned2_knowledge/beam_search.py index 2e9bf3e0b..785a8f097 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/beam_search.py +++ b/egs/librispeech/ASR/pruned2_knowledge/beam_search.py @@ -172,9 +172,9 @@ def greedy_search( y = logits.argmax().item() if y != blank_id: hyp.append(y) - decoder_input = torch.tensor( - [hyp[-context_size:]], device=device - ).reshape(1, context_size) + decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( + 1, context_size + ) decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.joiner.decoder_proj(decoder_out) @@ -302,9 +302,7 @@ class HypothesisList(object): key = hyp.key if key in self: old_hyp = self._data[key] # shallow copy - torch.logaddexp( - old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob - ) + torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) else: self._data[key] = hyp @@ -320,9 +318,7 @@ class HypothesisList(object): Return the hypothesis that has the largest `log_prob`. """ if length_norm: - return max( - self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) - ) + return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -496,9 +492,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) diff --git a/egs/librispeech/ASR/pruned2_knowledge/conformer.py b/egs/librispeech/ASR/pruned2_knowledge/conformer.py index 295a35204..3b6d0549d 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/conformer.py +++ b/egs/librispeech/ASR/pruned2_knowledge/conformer.py @@ -18,10 +18,10 @@ import math import warnings from typing import Optional, Tuple -from sampling import create_knowledge_base, KnowledgeBaseLookup import torch from encoder_interface import EncoderInterface +from sampling import KnowledgeBaseLookup, create_knowledge_base from scaling import ( ActivationBalancer, BasicNorm, @@ -73,9 +73,9 @@ class Conformer(EncoderInterface): if subsampling_factor != 4: raise NotImplementedError("Support only 'subsampling_factor=4'.") - - self.knowledge_base = create_knowledge_base(knowledge_M, knowledge_N, - knowledge_D) + self.knowledge_base = create_knowledge_base( + knowledge_M, knowledge_N, knowledge_D + ) # self.encoder_embed converts the input of shape (N, T, num_features) # to the shape (N, T//subsampling_factor, d_model). @@ -89,7 +89,7 @@ class Conformer(EncoderInterface): # Pass in a lambda that creates a new ConformerEncoderLayer with these # args. Don't use deepcopy because we need the knowledge_base # to be shared. - encoder_layer_fn = lambda: ConformerEncoderLayer( + encoder_layer_fn = lambda: ConformerEncoderLayer( # noqa: E731 self.knowledge_base, d_model, nhead, @@ -100,7 +100,7 @@ class Conformer(EncoderInterface): knowledge_M, knowledge_N, knowledge_D, - knowledge_K + knowledge_K, ) self.encoder = ConformerEncoder(encoder_layer_fn, num_encoder_layers) @@ -187,9 +187,7 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -209,10 +207,14 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.lookup = KnowledgeBaseLookup(knowledge_M, knowledge_N, - knowledge_D, knowledge_K, - d_model, - knowledge_base) + self.lookup = KnowledgeBaseLookup( + knowledge_M, + knowledge_N, + knowledge_D, + knowledge_K, + d_model, + knowledge_base, + ) self.norm_final = BasicNorm(d_model) @@ -311,9 +313,7 @@ class ConformerEncoder(nn.Module): def __init__(self, encoder_layer_fn, num_layers: int) -> None: super().__init__() - self.layers = nn.ModuleList( - [encoder_layer_fn() for i in range(num_layers)] - ) + self.layers = nn.ModuleList([encoder_layer_fn() for i in range(num_layers)]) self.num_layers = num_layers def forward( @@ -367,9 +367,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -384,9 +382,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vecotr and `j` means the @@ -661,9 +657,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -732,33 +728,25 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor" + " instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -795,9 +783,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -805,13 +791,9 @@ class RelPositionMultiheadAttention(nn.Module): ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -845,13 +827,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -874,9 +852,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/pruned2_knowledge/decode.py b/egs/librispeech/ASR/pruned2_knowledge/decode.py index b4a9af55a..65da19f27 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/decode.py +++ b/egs/librispeech/ASR/pruned2_knowledge/decode.py @@ -76,11 +76,7 @@ from beam_search import ( ) from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -98,16 +94,19 @@ def get_parser(): "--epoch", type=int, default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -186,8 +185,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -245,9 +243,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -262,10 +258,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -309,11 +302,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -385,9 +374,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -419,8 +406,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned2_knowledge/decoder.py b/egs/librispeech/ASR/pruned2_knowledge/decoder.py index b6d94aaf1..0b9c886c7 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/decoder.py +++ b/egs/librispeech/ASR/pruned2_knowledge/decoder.py @@ -90,9 +90,7 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned2_knowledge/decoder2.py b/egs/librispeech/ASR/pruned2_knowledge/decoder2.py index db51fb1cd..2ca76a30c 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/decoder2.py +++ b/egs/librispeech/ASR/pruned2_knowledge/decoder2.py @@ -14,12 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + import torch import torch.nn as nn import torch.nn.functional as F -from torch import Tensor -from typing import Optional from subsampling import ScaledConv1d +from torch import Tensor class Decoder(nn.Module): @@ -90,9 +91,7 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output @@ -102,7 +101,6 @@ class Decoder(nn.Module): return embedding_out - class ScaledEmbedding(nn.Module): r"""A simple lookup table that stores embeddings of a fixed dictionary and size. @@ -171,8 +169,13 @@ class ScaledEmbedding(nn.Module): [ 0.0000, 0.0000, 0.0000], [-0.1655, 0.9897, 0.0635]]]) """ - __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', - 'scale_grad_by_freq', 'sparse'] + __constants__ = [ + "num_embeddings", + "embedding_dim", + "padding_idx", + "scale_grad_by_freq", + "sparse", + ] num_embeddings: int embedding_dim: int @@ -181,34 +184,41 @@ class ScaledEmbedding(nn.Module): weight: Tensor sparse: bool - def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, - scale_grad_by_freq: bool = False, - sparse: bool = False, - scale_speed: float = 5.0) -> None: + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + scale_grad_by_freq: bool = False, + sparse: bool = False, + scale_speed: float = 5.0, + ) -> None: super(ScaledEmbedding, self).__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim if padding_idx is not None: if padding_idx > 0: - assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert ( + padding_idx < self.num_embeddings + ), "Padding_idx must be within num_embeddings" elif padding_idx < 0: - assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert ( + padding_idx >= -self.num_embeddings + ), "Padding_idx must be within num_embeddings" padding_idx = self.num_embeddings + padding_idx self.padding_idx = padding_idx self.scale_grad_by_freq = scale_grad_by_freq self.scale_speed = scale_speed - self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() + self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() self.sparse = sparse self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) self.reset_parameters() - - def reset_parameters(self) -> None: nn.init.normal_(self.weight, std=0.05) - nn.init.constant_(self.scale, torch.tensor(1.0/0.05).log() / self.scale_speed) + nn.init.constant_(self.scale, torch.tensor(1.0 / 0.05).log() / self.scale_speed) if self.padding_idx is not None: with torch.no_grad(): @@ -217,22 +227,38 @@ class ScaledEmbedding(nn.Module): def forward(self, input: Tensor) -> Tensor: scale = (self.scale * self.scale_speed).exp() if input.numel() < self.num_embeddings: - return F.embedding( - input, self.weight, self.padding_idx, - None, 2.0, # None, 2.0 relate to normalization - self.scale_grad_by_freq, self.sparse) * scale + return ( + F.embedding( + input, + self.weight, + self.padding_idx, + None, + 2.0, # None, 2.0 relate to normalization + self.scale_grad_by_freq, + self.sparse, + ) + * scale + ) else: return F.embedding( - input, self.weight * scale, self.padding_idx, - None, 2.0, # None, 2.0 relates to normalization - self.scale_grad_by_freq, self.sparse) + input, + self.weight * scale, + self.padding_idx, + None, + 2.0, # None, 2.0 relates to normalization + self.scale_grad_by_freq, + self.sparse, + ) def extra_repr(self) -> str: - s = '{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}, scale={scale}' + s = ( + "{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}," + " scale={scale}" + ) if self.padding_idx is not None: - s += ', padding_idx={padding_idx}' + s += ", padding_idx={padding_idx}" if self.scale_grad_by_freq is not False: - s += ', scale_grad_by_freq={scale_grad_by_freq}' + s += ", scale_grad_by_freq={scale_grad_by_freq}" if self.sparse is not False: - s += ', sparse=True' + s += ", sparse=True" return s.format(**self.__dict__) diff --git a/egs/librispeech/ASR/pruned2_knowledge/export.py b/egs/librispeech/ASR/pruned2_knowledge/export.py index 96d1a30fb..1af05d9c8 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/export.py +++ b/egs/librispeech/ASR/pruned2_knowledge/export.py @@ -64,17 +64,20 @@ def get_parser(): "--epoch", type=int, default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -105,8 +108,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) return parser @@ -174,9 +176,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned2_knowledge/joiner.py b/egs/librispeech/ASR/pruned2_knowledge/joiner.py index 35f75ed2a..68c663b66 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/joiner.py +++ b/egs/librispeech/ASR/pruned2_knowledge/joiner.py @@ -56,9 +56,7 @@ class Joiner(nn.Module): assert encoder_out.shape[:-1] == decoder_out.shape[:-1] if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj( - decoder_out - ) + logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/pruned2_knowledge/model.py b/egs/librispeech/ASR/pruned2_knowledge/model.py index 599bf2506..ca8c28af1 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/model.py +++ b/egs/librispeech/ASR/pruned2_knowledge/model.py @@ -63,9 +63,7 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) def forward( @@ -136,9 +134,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned2_knowledge/optim.py b/egs/librispeech/ASR/pruned2_knowledge/optim.py index 432bf8220..76cd4e11e 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/optim.py +++ b/egs/librispeech/ASR/pruned2_knowledge/optim.py @@ -72,17 +72,11 @@ class Eve(Optimizer): if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: - raise ValueError( - "Invalid beta parameter at index 0: {}".format(betas[0]) - ) + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: - raise ValueError( - "Invalid beta parameter at index 1: {}".format(betas[1]) - ) + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) if not 0 <= weight_decay <= 0.1: - raise ValueError( - "Invalid weight_decay value: {}".format(weight_decay) - ) + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict( @@ -118,9 +112,7 @@ class Eve(Optimizer): # Perform optimization step grad = p.grad if grad.is_sparse: - raise RuntimeError( - "AdamW does not support sparse gradients" - ) + raise RuntimeError("AdamW does not support sparse gradients") state = self.state[p] @@ -147,7 +139,7 @@ class Eve(Optimizer): # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( + denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( group["eps"] ) @@ -158,9 +150,7 @@ class Eve(Optimizer): if p.numel() > 1: # avoid applying this weight-decay on "scaling factors" # (which are scalar). - is_above_target_rms = p.norm() > ( - target_rms * (p.numel() ** 0.5) - ) + is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) @@ -176,18 +166,14 @@ class LRScheduler(object): def __init__(self, optimizer: Optimizer, verbose: bool = False): # Attach optimizer if not isinstance(optimizer, Optimizer): - raise TypeError( - "{} is not an Optimizer".format(type(optimizer).__name__) - ) + raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) self.optimizer = optimizer self.verbose = verbose for group in optimizer.param_groups: group.setdefault("initial_lr", group["lr"]) - self.base_lrs = [ - group["initial_lr"] for group in optimizer.param_groups - ] + self.base_lrs = [group["initial_lr"] for group in optimizer.param_groups] self.epoch = 0 self.batch = 0 @@ -295,10 +281,9 @@ class Eden(LRScheduler): def get_lr(self): factor = ( - (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 + (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 ) ** -0.25 * ( - ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) - ** -0.25 + ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 ) return [x * factor for x in self.base_lrs] diff --git a/egs/librispeech/ASR/pruned2_knowledge/sampling.py b/egs/librispeech/ASR/pruned2_knowledge/sampling.py index 7b05e2f00..8cc930927 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/sampling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/sampling.py @@ -3,32 +3,29 @@ # This was copied from /ceph-dan/torch-sampling/torch_sampling/sampling_ref.py, # its git history is there. -import timeit -import torch -from torch import Tensor -from torch import nn -from torch.cuda.amp import GradScaler, custom_fwd, custom_bwd -from typing import Tuple, Optional -from scaling import ScaledLinear import random +import timeit +from typing import Optional, Tuple + +import torch +from scaling import ScaledLinear +from torch import Tensor, nn +from torch.cuda.amp import GradScaler, custom_bwd, custom_fwd from torch_scheduled_sampling import sample_combined # The main exports of this file are the module KnowledgeBaseLookup and the # function create_knowledge_base. - - - - def create_knowledge_base(M: int, N: int, D: int) -> nn.Parameter: std = 0.1 - a = (3 ** 0.5) * std # this sqrt(3) thing is intended to get variance of - # 0.1 from uniform distribution - ans = nn.Parameter(torch.ones(M ** N, D)) + a = (3**0.5) * std # this sqrt(3) thing is intended to get variance of + # 0.1 from uniform distribution + ans = nn.Parameter(torch.ones(M**N, D)) nn.init.uniform_(ans, -a, a) return ans + def join_indexes(indexes: Tensor, M: int) -> Tensor: """ Combines N-tuples of indexes into single indexes that can be used for @@ -47,9 +44,9 @@ def join_indexes(indexes: Tensor, M: int) -> Tensor: # Note, we don't use this, we -def weighted_matrix_lookup(weights: Tensor, - indexes: Tensor, - knowledge_base: Tensor) -> Tensor: +def weighted_matrix_lookup( + weights: Tensor, indexes: Tensor, knowledge_base: Tensor +) -> Tensor: """ Weighted combination of specified rows of a matrix. weights: Tensor of shape (*, K), can contain any value but probably in [0..1]. @@ -65,9 +62,9 @@ def weighted_matrix_lookup(weights: Tensor, # simpler but less memory-efficient implementation lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) D = knowledge_base.shape[-1] - weights = weights.unsqueeze(-2) # (*, 1, K) - lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) - ans = torch.matmul(weights, lookup) # ans: (*, 1, D) + weights = weights.unsqueeze(-2) # (*, 1, K) + lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) + ans = torch.matmul(weights, lookup) # ans: (*, 1, D) ans = ans.squeeze(-2) assert list(ans.shape) == list(weights.shape[:-2]) + [D] return ans @@ -76,7 +73,9 @@ def weighted_matrix_lookup(weights: Tensor, class WeightedMatrixLookupFunction(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, weights: Tensor, indexes: Tensor, knowledge_base: Tensor) -> Tensor: + def forward( + ctx, weights: Tensor, indexes: Tensor, knowledge_base: Tensor + ) -> Tensor: """ Weighted combination of specified rows of a matrix. weights: Tensor of shape (*, K), can contain any value but probably in [0..1]. @@ -88,15 +87,16 @@ class WeightedMatrixLookupFunction(torch.autograd.Function): """ if random.random() < 0.001: print("dtype[1] = ", weights.dtype) - ctx.save_for_backward(weights.detach(), indexes.detach(), - knowledge_base.detach()) + ctx.save_for_backward( + weights.detach(), indexes.detach(), knowledge_base.detach() + ) with torch.no_grad(): lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) D = knowledge_base.shape[-1] - weights = weights.unsqueeze(-2) # (*, 1, K) - lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) - ans = torch.matmul(weights, lookup) # ans: (*, 1, D) - ans = ans.squeeze(-2) #(*, D) + weights = weights.unsqueeze(-2) # (*, 1, K) + lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) + ans = torch.matmul(weights, lookup) # ans: (*, 1, D) + ans = ans.squeeze(-2) # (*, D) return ans @staticmethod @@ -107,7 +107,7 @@ class WeightedMatrixLookupFunction(torch.autograd.Function): knowledge_base.requires_grad = True dtype = ans_grad.dtype ans_grad = ans_grad.to(weights.dtype) - assert weights.requires_grad == False + assert weights.requires_grad is False D = knowledge_base.shape[-1] with torch.enable_grad(): # we'll use torch's autograd to differentiate this operation, which @@ -115,16 +115,19 @@ class WeightedMatrixLookupFunction(torch.autograd.Function): # We don't save `lookup` because it's large, that is the reason # we override Torch autograd. lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) - lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) - weights = weights.unsqueeze(-1) # (*, K, 1) + lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) + weights = weights.unsqueeze(-1) # (*, K, 1) # forward pass: was: ## ans = torch.matmul(weights, lookup) ## ans: (*, 1, D) ## ans = ans.squeeze(-2) # ans, ans_grad: (*, D) - weights_grad = torch.matmul(lookup, # (*, K, D) - ans_grad.unsqueeze(-1)) # (*, D, 1) - weights_grad = weights_grad.squeeze(-1) # (*, K, 1) -> (*, K) - lookup_grad = weights * ans_grad.unsqueeze(-2) # (*, K, 1) * (*, 1, D) = (*, K, D) + weights_grad = torch.matmul( + lookup, ans_grad.unsqueeze(-1) # (*, K, D) + ) # (*, D, 1) + weights_grad = weights_grad.squeeze(-1) # (*, K, 1) -> (*, K) + lookup_grad = weights * ans_grad.unsqueeze( + -2 + ) # (*, K, 1) * (*, 1, D) = (*, K, D) lookup.backward(gradient=lookup_grad) return weights_grad.to(dtype), None, knowledge_base.grad.to(dtype) @@ -146,6 +149,7 @@ class PenalizeNegentropyFunction(torch.autograd.Function): Returns: logprobs """ + @staticmethod def forward(ctx, logprobs: Tensor, alpha: float): ctx.save_for_backward(logprobs.detach()) @@ -154,18 +158,23 @@ class PenalizeNegentropyFunction(torch.autograd.Function): @staticmethod def backward(ctx, logprobs_grad: Tensor) -> Tuple[Tensor, None]: - logprobs, = ctx.saved_tensors + (logprobs,) = ctx.saved_tensors with torch.enable_grad(): logprobs.requires_grad = True # `negentropy` is the negative entropy of the average distribution. # distributions. It will be <= 0. - l = logprobs.reshape(-1, logprobs.shape[-1]) + l = logprobs.reshape(-1, logprobs.shape[-1]) # noqa: E741 scale = ctx.alpha * l.shape[0] avg_dist = l.exp().mean(dim=0) negentropy = (avg_dist * (avg_dist + 1.0e-20).log()).sum() if random.random() < 0.0005: negentropy_individual = (l * l.exp()).sum(dim=-1).mean() - print("Negentropy[individual,combined] = ", negentropy_individual.item(), ", ", negentropy.item()) + print( + "Negentropy[individual,combined] = ", + negentropy_individual.item(), + ", ", + negentropy.item(), + ) loss = negentropy * scale loss.backward() return logprobs_grad + logprobs.grad, None @@ -183,18 +192,23 @@ class KnowledgeBaseLookup(nn.Module): embedding_dim: the dimension to project from and to, e.g. the d_model of the conformer. """ - def __init__(self, M: int, N: int, D: int, - K: int, embedding_dim: int, - knowledge_base: nn.Parameter, - negentropy_penalty: float = 0.001): + + def __init__( + self, + M: int, + N: int, + D: int, + K: int, + embedding_dim: int, + knowledge_base: nn.Parameter, + negentropy_penalty: float = 0.001, + ): super(KnowledgeBaseLookup, self).__init__() self.knowledge_base = knowledge_base # shared! - self.in_proj = ScaledLinear(embedding_dim, M * N, - initial_scale=1.0) + self.in_proj = ScaledLinear(embedding_dim, M * N, initial_scale=1.0) # initial_scale = 4.0 because the knowlege_base activations are # quite small -- if we use our optimizer they'll have stddev <= 0.1. - self.out_proj = ScaledLinear(D, embedding_dim, - initial_scale = 4.0) + self.out_proj = ScaledLinear(D, embedding_dim, initial_scale=4.0) self.M = M self.N = N self.K = K @@ -210,14 +224,14 @@ class KnowledgeBaseLookup(nn.Module): # TODO: later we can try multiplying by a projection of x or something like that. """ - x = self.in_proj(x) # now (*, M*N) - x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M) - x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M) + x = self.in_proj(x) # now (*, M*N) + x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M) + x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M) x = PenalizeNegentropyFunction.apply(x, self.negentropy_penalty) _, indexes, weights = sample_combined(x, self.K, input_is_log=True) - x = weighted_matrix_lookup(weights, indexes, self.knowledge_base) # now (*, D) - x = self.out_proj(x) # now (*, self.embedding_dim) + x = weighted_matrix_lookup(weights, indexes, self.knowledge_base) # now (*, D) + x = self.out_proj(x) # now (*, self.embedding_dim) return x @@ -237,38 +251,44 @@ def _test_knowledge_base_lookup(): x.requires_grad = True y = m(x) assert y.shape == x.shape - y.sum().backward() # make sure backward doesn't crash.. + y.sum().backward() # make sure backward doesn't crash.. print("y = ", y) print("x.grad = ", x.grad) print("knowlege_base.grad norm = ", knowledge_base.grad.norm()) dtype = torch.float32 - device = torch.device('cuda') - train_pairs = [ (torch.randn(B, T, E, device=device, dtype=dtype), torch.randn(B, T, E, device=device, dtype=dtype)) for _ in range(10) ] + device = torch.device("cuda") + train_pairs = [ + ( + torch.randn(B, T, E, device=device, dtype=dtype), + torch.randn(B, T, E, device=device, dtype=dtype), + ) + for _ in range(10) + ] from optim import Eve + optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04) m = m.to(device).to(dtype) - start = timeit.default_timer() -# Epoch 0, batch 0, loss 1.0109944343566895 -# Epoch 10, batch 0, loss 1.0146660804748535 -# Epoch 20, batch 0, loss 1.0119813680648804 -# Epoch 30, batch 0, loss 1.0105408430099487 -# Epoch 40, batch 0, loss 1.0077732801437378 -# Epoch 50, batch 0, loss 1.0050103664398193 -# Epoch 60, batch 0, loss 1.0033129453659058 -# Epoch 70, batch 0, loss 1.0014232397079468 -# Epoch 80, batch 0, loss 0.9977912306785583 -# Epoch 90, batch 0, loss 0.8274348974227905 -# Epoch 100, batch 0, loss 0.3368612825870514 -# Epoch 110, batch 0, loss 0.11323091387748718 -# Time taken: 17.591704960912466 + # Epoch 0, batch 0, loss 1.0109944343566895 + # Epoch 10, batch 0, loss 1.0146660804748535 + # Epoch 20, batch 0, loss 1.0119813680648804 + # Epoch 30, batch 0, loss 1.0105408430099487 + # Epoch 40, batch 0, loss 1.0077732801437378 + # Epoch 50, batch 0, loss 1.0050103664398193 + # Epoch 60, batch 0, loss 1.0033129453659058 + # Epoch 70, batch 0, loss 1.0014232397079468 + # Epoch 80, batch 0, loss 0.9977912306785583 + # Epoch 90, batch 0, loss 0.8274348974227905 + # Epoch 100, batch 0, loss 0.3368612825870514 + # Epoch 110, batch 0, loss 0.11323091387748718 + # Time taken: 17.591704960912466 for epoch in range(150): - for n, (x,y) in enumerate(train_pairs): + for n, (x, y) in enumerate(train_pairs): y_out = m(x) - loss = ((y_out - y)**2).mean() * 100.0 + loss = ((y_out - y) ** 2).mean() * 100.0 if n % 10 == 0 and epoch % 10 == 0: print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") loss.backward() @@ -276,7 +296,8 @@ def _test_knowledge_base_lookup(): optimizer.zero_grad() stop = timeit.default_timer() - print('Time taken: ', stop - start) + print("Time taken: ", stop - start) + def _test_knowledge_base_lookup_autocast(): K = 16 @@ -294,14 +315,21 @@ def _test_knowledge_base_lookup_autocast(): x.requires_grad = True y = m(x) assert y.shape == x.shape - y.sum().backward() # make sure backward doesn't crash.. + y.sum().backward() # make sure backward doesn't crash.. print("y = ", y) print("x.grad = ", x.grad) print("knowlege_base.grad norm = ", knowledge_base.grad.norm()) - device = torch.device('cuda') - train_pairs = [ (torch.randn(B, T, E, device=device), torch.randn(B, T, E, device=device)) for _ in range(10) ] + device = torch.device("cuda") + train_pairs = [ + ( + torch.randn(B, T, E, device=device), + torch.randn(B, T, E, device=device), + ) + for _ in range(10) + ] from optim import Eve + optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04) m = m.to(device) @@ -309,12 +337,11 @@ def _test_knowledge_base_lookup_autocast(): start = timeit.default_timer() - for epoch in range(150): - for n, (x,y) in enumerate(train_pairs): + for n, (x, y) in enumerate(train_pairs): y_out = m(x) with torch.cuda.amp.autocast(enabled=True): - loss = ((y_out - y)**2).mean() * 100.0 + loss = ((y_out - y) ** 2).mean() * 100.0 if n % 10 == 0 and epoch % 10 == 0: print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") scaler.scale(loss).backward() @@ -323,10 +350,9 @@ def _test_knowledge_base_lookup_autocast(): optimizer.zero_grad() stop = timeit.default_timer() - print('Time taken: ', stop - start) + print("Time taken: ", stop - start) - -if __name__ == '__main__': +if __name__ == "__main__": _test_knowledge_base_lookup() _test_knowledge_base_lookup_autocast() diff --git a/egs/librispeech/ASR/pruned2_knowledge/scaling.py b/egs/librispeech/ASR/pruned2_knowledge/scaling.py index f726c2583..527c735eb 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/scaling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/scaling.py @@ -18,11 +18,11 @@ import collections from itertools import repeat from typing import Optional, Tuple -from torch.cuda.amp import custom_fwd, custom_bwd import torch import torch.nn as nn from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd def _ntuple(n): @@ -79,9 +79,7 @@ class ActivationBalancerFunction(torch.autograd.Function): below_threshold = mean_abs < min_abs above_threshold = mean_abs > max_abs - ctx.save_for_backward( - factor, xgt0, below_threshold, above_threshold - ) + ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) ctx.max_factor = max_factor ctx.sum_dims = sum_dims return x @@ -149,8 +147,7 @@ class BasicNorm(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels scales = ( - torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) - + self.eps.exp() + torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps.exp() ) ** -0.5 return x * scales @@ -182,11 +179,7 @@ class ScaledLinear(nn.Linear): """ def __init__( - self, - *args, - initial_scale: float = 1.0, - initial_speed: float = 1.0, - **kwargs + self, *args, initial_scale: float = 1.0, initial_speed: float = 1.0, **kwargs ): super(ScaledLinear, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() @@ -202,12 +195,12 @@ class ScaledLinear(nn.Linear): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -218,19 +211,13 @@ class ScaledLinear(nn.Linear): return None if self.bias is None else self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear( - input, self.get_weight(), self.get_bias() - ) + return torch.nn.functional.linear(input, self.get_weight(), self.get_bias()) class ScaledConv1d(nn.Conv1d): # See docs for ScaledLinear def __init__( - self, - *args, - initial_scale: float = 1.0, - initial_speed: float = 1.0, - **kwargs + self, *args, initial_scale: float = 1.0, initial_speed: float = 1.0, **kwargs ): super(ScaledConv1d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() @@ -245,12 +232,12 @@ class ScaledConv1d(nn.Conv1d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -290,11 +277,7 @@ class ScaledConv1d(nn.Conv1d): class ScaledConv2d(nn.Conv2d): # See docs for ScaledLinear def __init__( - self, - *args, - initial_scale: float = 1.0, - initial_speed: float = 1.0, - **kwargs + self, *args, initial_scale: float = 1.0, initial_speed: float = 1.0, **kwargs ): super(ScaledConv2d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() @@ -309,12 +292,12 @@ class ScaledConv2d(nn.Conv2d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -653,9 +636,7 @@ def _test_activation_balancer_sign(): def _test_activation_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( - -1 - ) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True m = ActivationBalancer( @@ -685,8 +666,8 @@ def _test_basic_norm(): y = m(x) assert y.shape == x.shape - x_rms = (x ** 2).mean().sqrt() - y_rms = (y ** 2).mean().sqrt() + x_rms = (x**2).mean().sqrt() + y_rms = (y**2).mean().sqrt() print("x rms = ", x_rms) print("y rms = ", y_rms) assert y_rms < x_rms diff --git a/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py b/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py index 6293e081a..3f21133a0 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py +++ b/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py @@ -15,21 +15,23 @@ # limitations under the License. +from typing import Optional, Tuple + import torch import torch.nn as nn from torch import Tensor -from typing import Tuple, Optional - -def _activation_balancer_loss(mean_pos: Tensor, - mean_neg: Tensor, - min_positive: float, # e.g. 0.05 - max_positive: float, # e.g. 0.95 - max_factor: float, # e.g. 0.01 - min_abs: float, # e.g. 0.2 - max_abs: float, # e.g. 100.0 - eps: float = 1.0e-10): +def _activation_balancer_loss( + mean_pos: Tensor, + mean_neg: Tensor, + min_positive: float, # e.g. 0.05 + max_positive: float, # e.g. 0.95 + max_factor: float, # e.g. 0.01 + min_abs: float, # e.g. 0.2 + max_abs: float, # e.g. 100.0 + eps: float = 1.0e-10, +): """ Returns a loss-function for the ActivationBalancer module. This loss function is not exposed to the user but is used internally, and eventually @@ -50,28 +52,32 @@ def _activation_balancer_loss(mean_pos: Tensor, """ loss_parts = [] - x_mean = mean_positive - mean_negative - x_mean_abs = (mean_positive + mean_negative + eps).detach() - x_rel_mean= x_mean / x_mean_abs + x_mean = mean_pos - mean_neg + x_mean_abs = (mean_pos + mean_neg + eps).detach() + x_rel_mean = x_mean / x_mean_abs if min_positive != 0.0: # e.g. x_mean_floor = -0.95 + 0.05 = -0.9 - x_rel_mean_floor = (-(1-min_positive) + min_positive) - min_positive_loss = (x_rel_mean_floor - x_rel_mean).relu().sum() * (1.0 / (2*min_positive)) + x_rel_mean_floor = -(1 - min_positive) + min_positive + min_positive_loss = (x_rel_mean_floor - x_rel_mean).relu().sum() * ( + 1.0 / (2 * min_positive) + ) # this part of the loss would be 1.0 * num_channels if all these constraints were # 100% violated. loss_parts.append(min_positive_loss) if max_positive != 1.0: # e.g. x_mean_floor = -0.05 + 0.95 = 0.8 - x_rel_mean_ceil = - (1.0-max_positive) + max_positive - max_positive_loss = (x_rel_mean - x_rel_mean_ceil).relu().sum() * (1.0 / (1 - x_rel_mean_ceil)) + x_rel_mean_ceil = -(1.0 - max_positive) + max_positive + max_positive_loss = (x_rel_mean - x_rel_mean_ceil).relu().sum() * ( + 1.0 / (1 - x_rel_mean_ceil) + ) # this part of the loss would be 1.0 * num_channels if all these constraints were # 100% violated. loss_parts.append(max_positive_loss) if min_abs != 0.0: - min_abs_loss = min_abs - x_mean_abs).relu().sum() / min_abs + min_abs_loss = (min_abs - x_mean_abs).relu().sum() / min_abs # this part of the loss would be 1.0 * num_channels if all these constraints were # 100% violated. loss_parts.append(min_abs_loss) @@ -82,43 +88,53 @@ def _activation_balancer_loss(mean_pos: Tensor, # 100% violated. loss_parts.append(max_abs_loss) - # the min_positive and 1 - max_positive are "ballast" added to the denom = mean_pos + mean_neg + (min_positive + (1 - max_positive)) - num + # num if min_positive != 0.0: - - + pass class ActivationBalancerFunction(torch.autograd.Function): @staticmethod - def forward(ctx, x: Tensor, - channel_dim: int, - min_positive: float, # e.g. 0.05 - max_positive: float, # e.g. 0.95 - max_factor: float, # e.g. 0.01 - min_abs: float, # e.g. 0.2 - max_abs: float, # e.g. 100.0 + def forward( + ctx, + x: Tensor, + channel_dim: int, + min_positive: float, # e.g. 0.05 + max_positive: float, # e.g. 0.95 + max_factor: float, # e.g. 0.01 + min_abs: float, # e.g. 0.2 + max_abs: float, # e.g. 100.0 ) -> Tensor: if x.requires_grad: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] xgt0 = x > 0 - proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True) - factor1 = ((min_positive - proportion_positive).relu() * (max_factor / min_positive) - if min_positive != 0.0 else 0.0) - factor2 = ((proportion_positive - max_positive).relu() * (max_factor / (max_positive - 1.0)) - if max_positive != 1.0 else 0.0) + proportion_positive = torch.mean( + xgt0.to(x.dtype), dim=sum_dims, keepdim=True + ) + factor1 = ( + (min_positive - proportion_positive).relu() + * (max_factor / min_positive) + if min_positive != 0.0 + else 0.0 + ) + factor2 = ( + (proportion_positive - max_positive).relu() + * (max_factor / (max_positive - 1.0)) + if max_positive != 1.0 + else 0.0 + ) factor = factor1 + factor2 if isinstance(factor, float): factor = torch.zeros_like(proportion_positive) mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True) - below_threshold = (mean_abs < min_abs) - above_threshold = (mean_abs > max_abs) + below_threshold = mean_abs < min_abs + above_threshold = mean_abs > max_abs ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) ctx.max_factor = max_factor @@ -126,11 +142,16 @@ class ActivationBalancerFunction(torch.autograd.Function): return x @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None, None]: + def backward( + ctx, x_grad: Tensor + ) -> Tuple[Tensor, None, None, None, None, None, None]: factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors dtype = x_grad.dtype - scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) * - (xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0)) + scale_factor = ( + (below_threshold.to(dtype) - above_threshold.to(dtype)) + * (xgt0.to(dtype) - 0.5) + * (ctx.max_factor * 2.0) + ) neg_delta_grad = x_grad.abs() * (factor + scale_factor) return x_grad - neg_delta_grad, None, None, None, None, None, None @@ -163,29 +184,30 @@ class BasicNorm(torch.nn.Module): learn_eps: if true, we learn epsilon; if false, we keep it at the initial value. """ - def __init__(self, - num_channels: int, - channel_dim: int = -1, # CAUTION: see documentation. - eps: float = 0.25, - learn_eps: bool = True) -> None: + + def __init__( + self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + eps: float = 0.25, + learn_eps: bool = True, + ) -> None: super(BasicNorm, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim if learn_eps: self.eps = nn.Parameter(torch.tensor(eps).log().detach()) else: - self.register_buffer('eps', torch.tensor(eps).log().detach()) - + self.register_buffer("eps", torch.tensor(eps).log().detach()) def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels - scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) + - self.eps.exp()) ** -0.5 + scales = ( + torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps.exp() + ) ** -0.5 return x * scales - - class ScaledLinear(nn.Linear): """ A modified version of nn.Linear where the parameters are scaled before @@ -207,27 +229,26 @@ class ScaledLinear(nn.Linear): inherited from nn.Linear. For modules with small fan-in, this may be larger than optimal. """ - def __init__(self, *args, - initial_scale: float = 1.0, - **kwargs): + + def __init__(self, *args, initial_scale: float = 1.0, **kwargs): super(ScaledLinear, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: - self.register_parameter('bias_scale', None) + self.register_parameter("bias_scale", None) self._reset_parameters() # Overrides the reset_parameters in nn.Linear def _reset_parameters(self): std = 0.01 - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() if self.bias is not None: @@ -237,56 +258,67 @@ class ScaledLinear(nn.Linear): return self.weight * self.weight_scale.exp() def get_bias(self): - return (None if self.bias is None else - self.bias * self.bias_scale.exp()) + return None if self.bias is None else self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear(input, self.get_weight(), - self.get_bias()) + return torch.nn.functional.linear(input, self.get_weight(), self.get_bias()) class ScaledConv1d(nn.Conv1d): - def __init__(self, *args, - initial_scale=1.0, **kwargs): + def __init__(self, *args, initial_scale=1.0, **kwargs): super(ScaledConv1d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: - self.register_parameter('bias_scale', None) + self.register_parameter("bias_scale", None) self._reset_parameters() # Overrides the reset_parameters in base class def _reset_parameters(self): std = 0.01 - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() if self.bias is not None: self.bias_scale += torch.tensor(scale / std).log() - def get_weight(self): return self.weight * self.weight_scale.exp() def get_bias(self): - return (None if self.bias is None else - self.bias * self.bias_scale.exp()) + return None if self.bias is None else self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: F = torch.nn.functional - if self.padding_mode != 'zeros': - return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), - self.get_weight(), self.get_bias(), self.stride, - _single(0), self.dilation, self.groups) - return F.conv1d(input, self.get_weight(), self.get_bias(), self.stride, - self.padding, self.dilation, self.groups) - + if self.padding_mode != "zeros": + return F.conv1d( + F.pad( + input, + self._reversed_padding_repeated_twice, + mode=self.padding_mode, + ), + self.get_weight(), + self.get_bias(), + self.stride, + _single(0), # noqa: F821 + self.dilation, + self.groups, + ) + return F.conv1d( + input, + self.get_weight(), + self.get_bias(), + self.stride, + self.padding, + self.dilation, + self.groups, + ) class ScaledConv2d(nn.Conv2d): @@ -297,45 +329,58 @@ class ScaledConv2d(nn.Conv2d): if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: - self.register_parameter('bias_scale', None) + self.register_parameter("bias_scale", None) self._reset_parameters() # Overrides the reset_parameters in base class def _reset_parameters(self): std = 0.01 - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() if self.bias is not None: self.bias_scale += torch.tensor(scale / std).log() - def get_weight(self): return self.weight * self.weight_scale.exp() def get_bias(self): - return (None if self.bias is None else - self.bias * self.bias_scale.exp()) + return None if self.bias is None else self.bias * self.bias_scale.exp() def _conv_forward(self, input, weight): F = torch.nn.functional - if self.padding_mode != 'zeros': - return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), - weight, self.get_bias(), self.stride, - _pair(0), self.dilation, self.groups) - return F.conv2d(input, weight, self.get_bias(), self.stride, - self.padding, self.dilation, self.groups) + if self.padding_mode != "zeros": + return F.conv2d( + F.pad( + input, + self._reversed_padding_repeated_twice, + mode=self.padding_mode, + ), + weight, + self.get_bias(), + self.stride, + _pair(0), # noqa: F821 + self.dilation, + self.groups, + ) + return F.conv2d( + input, + weight, + self.get_bias(), + self.stride, + self.padding, + self.dilation, + self.groups, + ) def forward(self, input: Tensor) -> Tensor: return self._conv_forward(input, self.get_weight()) - - class ActivationBalancer(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to encourage, for @@ -364,12 +409,16 @@ class ActivationBalancer(torch.nn.Module): we allow, before we start to modify the derivatives to prevent this. """ - def __init__(self, channel_dim: int, - min_positive: float = 0.05, - max_positive: float = 0.95, - max_factor: float = 0.01, - min_abs: float = 0.2, - max_abs: float = 100.0): + + def __init__( + self, + channel_dim: int, + min_positive: float = 0.05, + max_positive: float = 0.95, + max_factor: float = 0.01, + min_abs: float = 0.2, + max_abs: float = 100.0, + ): super(ActivationBalancer, self).__init__() self.channel_dim = channel_dim self.min_positive = min_positive @@ -379,10 +428,15 @@ class ActivationBalancer(torch.nn.Module): self.max_abs = max_abs def forward(self, x: Tensor) -> Tensor: - return ActivationBalancerFunction.apply(x, self.channel_dim, - self.min_positive, self.max_positive, - self.max_factor, self.min_abs, - self.max_abs) + return ActivationBalancerFunction.apply( + x, + self.channel_dim, + self.min_positive, + self.max_positive, + self.max_factor, + self.min_abs, + self.max_abs, + ) class DoubleSwishFunction(torch.autograd.Function): @@ -400,6 +454,7 @@ class DoubleSwishFunction(torch.autograd.Function): = double_swish(x) * (1-s(x)) + s(x) ... so we just need to remember s(x) but not x itself. """ + @staticmethod def forward(ctx, x: Tensor) -> Tensor: x = x.detach() @@ -411,18 +466,17 @@ class DoubleSwishFunction(torch.autograd.Function): @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: s, y = ctx.saved_tensors - return (y * (1-s) + s) * y_grad + return (y * (1 - s) + s) * y_grad + class DoubleSwish(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: """Return double-swish activation function which is an approximation to Swish(Swish(x)), - that we approximate closely with x * sigmoid(x-1). + that we approximate closely with x * sigmoid(x-1). """ return DoubleSwishFunction.apply(x) - - class ScaledEmbedding(nn.Module): r"""A simple lookup table that stores embeddings of a fixed dictionary and size. @@ -491,8 +545,13 @@ class ScaledEmbedding(nn.Module): [ 0.0000, 0.0000, 0.0000], [-0.1655, 0.9897, 0.0635]]]) """ - __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', - 'scale_grad_by_freq', 'sparse'] + __constants__ = [ + "num_embeddings", + "embedding_dim", + "padding_idx", + "scale_grad_by_freq", + "sparse", + ] num_embeddings: int embedding_dim: int @@ -501,33 +560,40 @@ class ScaledEmbedding(nn.Module): weight: Tensor sparse: bool - def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, - scale_grad_by_freq: bool = False, - sparse: bool = False) -> None: + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + scale_grad_by_freq: bool = False, + sparse: bool = False, + ) -> None: super(ScaledEmbedding, self).__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim if padding_idx is not None: if padding_idx > 0: - assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert ( + padding_idx < self.num_embeddings + ), "Padding_idx must be within num_embeddings" elif padding_idx < 0: - assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert ( + padding_idx >= -self.num_embeddings + ), "Padding_idx must be within num_embeddings" padding_idx = self.num_embeddings + padding_idx self.padding_idx = padding_idx self.scale_grad_by_freq = scale_grad_by_freq - self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() + self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() self.sparse = sparse self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) self.reset_parameters() - - def reset_parameters(self) -> None: std = 0.01 nn.init.normal_(self.weight, std=std) - nn.init.constant_(self.scale, torch.tensor(1.0/std).log()) + nn.init.constant_(self.scale, torch.tensor(1.0 / std).log()) if self.padding_idx is not None: with torch.no_grad(): @@ -537,24 +603,37 @@ class ScaledEmbedding(nn.Module): F = torch.nn.functional scale = self.scale.exp() if input.numel() < self.num_embeddings: - return F.embedding( - input, self.weight, self.padding_idx, - None, 2.0, # None, 2.0 relate to normalization - self.scale_grad_by_freq, self.sparse) * scale + return ( + F.embedding( + input, + self.weight, + self.padding_idx, + None, + 2.0, # None, 2.0 relate to normalization + self.scale_grad_by_freq, + self.sparse, + ) + * scale + ) else: return F.embedding( - input, self.weight * scale, self.padding_idx, - None, 2.0, # None, 2.0 relates to normalization - self.scale_grad_by_freq, self.sparse) + input, + self.weight * scale, + self.padding_idx, + None, + 2.0, # None, 2.0 relates to normalization + self.scale_grad_by_freq, + self.sparse, + ) def extra_repr(self) -> str: - s = '{num_embeddings}, {embedding_dim}, scale={scale}' + s = "{num_embeddings}, {embedding_dim}, scale={scale}" if self.padding_idx is not None: - s += ', padding_idx={padding_idx}' + s += ", padding_idx={padding_idx}" if self.scale_grad_by_freq is not False: - s += ', scale_grad_by_freq={scale_grad_by_freq}' + s += ", scale_grad_by_freq={scale_grad_by_freq}" if self.sparse is not False: - s += ', sparse=True' + s += ", sparse=True" return s.format(**self.__dict__) @@ -565,8 +644,13 @@ def _test_activation_balancer_sign(): x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) x = x.detach() x.requires_grad = True - m = ActivationBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95, - max_factor=0.2, min_abs=0.0) + m = ActivationBalancer( + channel_dim=0, + min_positive=0.05, + max_positive=0.95, + max_factor=0.2, + min_abs=0.0, + ) y_grad = torch.sign(torch.randn(probs.numel(), N)) @@ -576,17 +660,22 @@ def _test_activation_balancer_sign(): print("_test_activation_balancer_sign: y grad = ", y_grad) print("_test_activation_balancer_sign: x grad = ", x.grad) + def _test_activation_balancer_magnitude(): channel_dim = 0 magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True - m = ActivationBalancer(channel_dim=0, - min_positive=0.0, max_positive=1.0, - max_factor=0.2, - min_abs=0.2, max_abs=0.8) + m = ActivationBalancer( + channel_dim=0, + min_positive=0.0, + max_positive=1.0, + max_factor=0.2, + min_abs=0.2, + max_abs=0.8, + ) y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) @@ -621,7 +710,7 @@ def _test_double_swish_deriv(): torch.autograd.gradcheck(m, x) -if __name__ == '__main__': +if __name__ == "__main__": _test_activation_balancer_sign() _test_activation_balancer_magnitude() _test_basic_norm() diff --git a/egs/librispeech/ASR/pruned2_knowledge/train.py b/egs/librispeech/ASR/pruned2_knowledge/train.py index 2f6840166..a60d15c3b 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/train.py +++ b/egs/librispeech/ASR/pruned2_knowledge/train.py @@ -78,9 +78,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def get_parser(): @@ -179,42 +177,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -554,23 +555,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -733,9 +727,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -835,7 +827,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py index 2d5724d30..1df1650f3 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py @@ -123,20 +123,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -204,8 +208,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -272,9 +275,7 @@ def decode_one_batch( value=LOG_EPS, ) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -289,10 +290,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -338,11 +336,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -415,9 +409,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -450,8 +442,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -494,9 +485,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -528,13 +517,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -557,13 +545,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -591,7 +578,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py index 318cd5094..008f40fb1 100644 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py @@ -272,13 +272,9 @@ class Emformer(EncoderInterface): # Caution: We assume the subsampling factor is 4! x_lens = (((x_lens - 1) >> 1) - 1) >> 1 - emformer_out, emformer_out_lens, states = self.model.infer( - x, x_lens, states - ) + emformer_out, emformer_out_lens, states = self.model.infer(x, x_lens, states) - if x.size(1) != ( - self.model.segment_length + self.model.right_context_length - ): + if x.size(1) != (self.model.segment_length + self.model.right_context_length): raise ValueError( "Incorrect input shape." f"{x.size(1)} vs {self.model.segment_length} + " diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py index 2375f5001..81afb523d 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py @@ -89,20 +89,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -133,8 +137,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) add_model_arguments(parser) @@ -170,13 +173,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -199,13 +201,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -233,7 +234,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -273,9 +274,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py index 2f019bcdb..ed6848879 100644 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py @@ -122,9 +122,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py index fed814f19..6b30d3be8 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py @@ -209,42 +209,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -566,11 +569,7 @@ def compute_loss( function enables autograd during computation; when it is False, it disables autograd. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -599,9 +598,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -782,9 +779,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -908,8 +903,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py index 7af9cc3d7..830b37cfb 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py @@ -509,9 +509,9 @@ def greedy_search( y = logits.argmax().item() if y not in (blank_id, unk_id): hyp.append(y) - decoder_input = torch.tensor( - [hyp[-context_size:]], device=device - ).reshape(1, context_size) + decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( + 1, context_size + ) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -670,9 +670,7 @@ class HypothesisList(object): if use_max: old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob) else: - torch.logaddexp( - old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob - ) + torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) else: self._data[key] = hyp @@ -688,9 +686,7 @@ class HypothesisList(object): Return the hypothesis that has the largest `log_prob`. """ if length_norm: - return max( - self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) - ) + return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -892,9 +888,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1088,9 +1082,7 @@ def beam_search( t = 0 B = HypothesisList() - B.add( - Hypothesis(ys=[blank_id] * context_size, log_prob=0.0), use_max=use_max - ) + B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0), use_max=use_max) max_sym_per_utt = 20000 @@ -1130,9 +1122,7 @@ def beam_search( cached_key += f"-t-{t}" if cached_key not in joint_cache: - logits = model.joiner( - current_encoder_out, decoder_out.unsqueeze(1) - ) + logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1)) # TODO(fangjun): Scale the blank posterior diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 7b6338948..03ad45f49 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -128,11 +128,7 @@ from beam_search import ( ) from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -171,9 +167,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -269,8 +267,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -383,9 +380,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if ( @@ -450,10 +445,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -584,9 +576,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -619,8 +609,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -678,9 +667,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -718,8 +705,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -757,9 +743,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py index 386248554..e522943c0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py @@ -75,9 +75,7 @@ class DecodeStream(object): # encoder.streaming_forward self.done_frames: int = 0 - self.pad_length = ( - params.right_context + 2 - ) * params.subsampling_factor + 3 + self.pad_length = (params.right_context + 2) * params.subsampling_factor + 3 if params.decoding_method == "greedy_search": self.hyp = [params.blank_id] * params.context_size @@ -91,13 +89,11 @@ class DecodeStream(object): ) elif params.decoding_method == "fast_beam_search": # The rnnt_decoding_stream for fast_beam_search. - self.rnnt_decoding_stream: k2.RnntDecodingStream = ( - k2.RnntDecodingStream(decoding_graph) + self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( + decoding_graph ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") @property def done(self) -> bool: @@ -126,13 +122,10 @@ class DecodeStream(object): """Consume chunk_size frames of features""" chunk_length = chunk_size + self.pad_length - ret_length = min( - self.num_frames - self.num_processed_frames, chunk_length - ) + ret_length = min(self.num_frames - self.num_processed_frames, chunk_length) ret_features = self.features[ - self.num_processed_frames : self.num_processed_frames # noqa - + ret_length + self.num_processed_frames : self.num_processed_frames + ret_length # noqa ] self.num_processed_frames += chunk_size diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py index f4355e8a0..72593173c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py @@ -92,9 +92,7 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/export.py b/egs/librispeech/ASR/pruned_transducer_stateless/export.py index b5a151878..64708e524 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/export.py @@ -64,17 +64,20 @@ def get_parser(): "--epoch", type=int, default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -105,8 +108,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -192,9 +194,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/model.py b/egs/librispeech/ASR/pruned_transducer_stateless/model.py index 73b651b3f..2cca7fa27 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/model.py @@ -130,9 +130,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py index eb95827af..a42b63b9c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py @@ -91,9 +91,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -118,10 +120,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -168,8 +172,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -221,10 +224,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -292,9 +294,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -381,9 +381,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py index dcf6dc42f..9e09200a1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py @@ -166,14 +166,10 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk( - num_active_paths - ) + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths) with warnings.catch_warnings(): warnings.simplefilter("ignore") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py index d2cae4f9f..a50b4d4f0 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py @@ -51,11 +51,7 @@ from streaming_beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -94,9 +90,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -162,8 +160,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -269,9 +266,7 @@ def decode_one_chunk( ) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -291,9 +286,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -349,9 +342,7 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device - ) + initial_states = model.encoder.get_init_state(params.left_context, device=device) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -422,9 +413,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -460,8 +449,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -533,8 +521,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index 399b11a29..dd0331a60 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -203,42 +203,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -562,9 +565,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -584,9 +585,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -777,9 +776,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -897,8 +894,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False @@ -956,9 +952,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index b7c2010f7..5e9428b60 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -580,9 +580,9 @@ def greedy_search( if y not in (blank_id, unk_id): hyp.append(y) timestamp.append(t) - decoder_input = torch.tensor( - [hyp[-context_size:]], device=device - ).reshape(1, context_size) + decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( + 1, context_size + ) decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.joiner.decoder_proj(decoder_out) @@ -775,9 +775,7 @@ class HypothesisList(object): key = hyp.key if key in self: old_hyp = self._data[key] # shallow copy - torch.logaddexp( - old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob - ) + torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) else: self._data[key] = hyp @@ -793,9 +791,7 @@ class HypothesisList(object): Return the hypothesis that has the largest `log_prob`. """ if length_norm: - return max( - self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) - ) + return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -990,9 +986,7 @@ def modified_beam_search( logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - log_probs = (logits / temperature).log_softmax( - dim=-1 - ) # (num_hyps, vocab_size) + log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) log_probs.add_(ys_log_probs) @@ -1004,9 +998,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1676,9 +1668,7 @@ def fast_beam_search_with_nbest_rnn_rescoring( for rnn_scale in rnn_lm_scale_list: key = f"ngram_lm_scale_{n_scale}_rnn_lm_scale_{rnn_scale}" tot_scores = ( - am_scores.values - + n_scale * ngram_lm_scores - + rnn_scale * rnn_lm_scores + am_scores.values + n_scale * ngram_lm_scores + rnn_scale * rnn_lm_scores ) ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) max_indexes = ragged_tot_scores.argmax() @@ -1804,9 +1794,7 @@ def modified_beam_search_ngram_rescoring( logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - log_probs = (logits / temperature).log_softmax( - dim=-1 - ) # (num_hyps, vocab_size) + log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) log_probs.add_(ys_log_probs) vocab_size = log_probs.size(-1) @@ -1816,9 +1804,7 @@ def modified_beam_search_ngram_rescoring( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1841,9 +1827,7 @@ def modified_beam_search_ngram_rescoring( state_cost = hyp.state_cost # We only keep AM scores in new_hyp.log_prob - new_log_prob = ( - topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale - ) + new_log_prob = topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale new_hyp = Hypothesis( ys=new_ys, log_prob=new_log_prob, state_cost=state_cost @@ -1995,9 +1979,7 @@ def modified_beam_search_rnnlm_shallow_fusion( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) """ for all hyps with a non-blank new token, score this token. It is a little confusing here because this for-loop @@ -2032,10 +2014,7 @@ def modified_beam_search_rnnlm_shallow_fusion( # forward RNNLM to get new states and scores if len(token_list) != 0: tokens_to_score = ( - torch.tensor(token_list) - .to(torch.int64) - .to(device) - .reshape(-1, 1) + torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) ) hs = torch.cat(hs, dim=1).to(device) @@ -2067,9 +2046,7 @@ def modified_beam_search_rnnlm_shallow_fusion( ys.append(new_token) new_timestamp.append(t) - hyp_log_prob += ( - lm_score[new_token] * lm_scale - ) # add the lm score + hyp_log_prob += lm_score[new_token] * lm_scale # add the lm score lm_score = scores[count] state = ( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index bc273d33b..34ff0d7e2 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -214,10 +214,7 @@ class Conformer(EncoderInterface): NOTE: the returned tensors are on the given device. """ - if ( - len(self._init_state) == 2 - and self._init_state[0].size(1) == left_context - ): + if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context: # Note: It is OK to share the init state as it is # not going to be modified by the model return self._init_state @@ -439,9 +436,7 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -459,9 +454,7 @@ class ConformerEncoderLayer(nn.Module): ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) - self.conv_module = ConvolutionModule( - d_model, cnn_module_kernel, causal=causal - ) + self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal) self.norm_final = BasicNorm(d_model) @@ -527,9 +520,7 @@ class ConformerEncoderLayer(nn.Module): src = src + self.dropout(src_att) # convolution module - conv, _ = self.conv_module( - src, src_key_padding_mask=src_key_padding_mask - ) + conv, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask) src = src + self.dropout(conv) # feed forward module @@ -785,9 +776,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() if is_jit_tracing(): @@ -811,9 +800,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x_size_1 * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -1127,9 +1114,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -1198,33 +1185,25 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor" + " instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -1264,23 +1243,15 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d - matrix_bd = torch.matmul( - q_with_bias_v, p - ) # (batch, head, time1, 2*time1-1) + matrix_bd = torch.matmul(q_with_bias_v, p) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd, left_context) - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) if not is_jit_tracing(): assert list(attn_output_weights.size()) == [ @@ -1322,21 +1293,17 @@ class RelPositionMultiheadAttention(nn.Module): ): if attn_mask.size(0) != 1: attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len) - combined_mask = attn_mask | key_padding_mask.unsqueeze( - 1 - ).unsqueeze(2) + combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) else: # attn_mask.shape == (1, tgt_len, src_len) - combined_mask = attn_mask.unsqueeze( - 0 - ) | key_padding_mask.unsqueeze(1).unsqueeze(2) + combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( + 1 + ).unsqueeze(2) attn_output_weights = attn_output_weights.view( bsz, num_heads, tgt_len, src_len ) - attn_output_weights = attn_output_weights.masked_fill( - combined_mask, 0.0 - ) + attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, src_len ) @@ -1355,13 +1322,9 @@ class RelPositionMultiheadAttention(nn.Module): ] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -1498,16 +1461,12 @@ class ConvolutionModule(nn.Module): # manualy padding self.lorder zeros to the left x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) else: - assert ( - not self.training - ), "Cache should be None in training time" + assert not self.training, "Cache should be None in training time" assert cache.size(0) == self.lorder x = torch.cat([cache.permute(1, 2, 0), x], dim=2) if right_context > 0: cache = x.permute(2, 0, 1)[ - -(self.lorder + right_context) : ( # noqa - -right_context - ), + -(self.lorder + right_context) : (-right_context), # noqa ..., ] else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 979a0e02e..32cd53be3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -132,11 +132,7 @@ from beam_search import ( ) from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -177,9 +173,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -275,8 +273,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -397,9 +394,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -465,10 +460,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -514,11 +506,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } elif "fast_beam_search" in params.decoding_method: key = f"beam_{params.beam}_" @@ -608,9 +596,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -643,8 +629,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -700,9 +685,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -740,8 +723,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -779,9 +761,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index ba91302ce..b59928103 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -107,15 +107,11 @@ class Decoder(nn.Module): # This is for exporting to PNNX via ONNX embedding_out = self.embedding(y) else: - embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze( - -1 - ) + embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py index f1a8ea589..90367bd03 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py @@ -51,11 +51,7 @@ import sentencepiece as spm import torch from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import str2bool @@ -87,9 +83,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -120,8 +118,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -173,8 +170,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -222,9 +218,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py index 6a9d08033..1954f4724 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -60,9 +60,7 @@ class Joiner(nn.Module): assert encoder_out.shape == decoder_out.shape if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj( - decoder_out - ) + logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 417c391d9..272d06c37 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -66,9 +66,7 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) def forward( @@ -152,9 +150,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index 041a81f45..2d7f557ad 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -72,17 +72,11 @@ class Eve(Optimizer): if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: - raise ValueError( - "Invalid beta parameter at index 0: {}".format(betas[0]) - ) + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: - raise ValueError( - "Invalid beta parameter at index 1: {}".format(betas[1]) - ) + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) if not 0 <= weight_decay <= 0.1: - raise ValueError( - "Invalid weight_decay value: {}".format(weight_decay) - ) + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict( @@ -118,9 +112,7 @@ class Eve(Optimizer): # Perform optimization step grad = p.grad if grad.is_sparse: - raise RuntimeError( - "AdamW does not support sparse gradients" - ) + raise RuntimeError("AdamW does not support sparse gradients") state = self.state[p] @@ -147,7 +139,7 @@ class Eve(Optimizer): # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( + denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( group["eps"] ) @@ -158,9 +150,7 @@ class Eve(Optimizer): if p.numel() > 1: # avoid applying this weight-decay on "scaling factors" # (which are scalar). - is_above_target_rms = p.norm() > ( - target_rms * (p.numel() ** 0.5) - ) + is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) @@ -180,18 +170,14 @@ class LRScheduler(object): def __init__(self, optimizer: Optimizer, verbose: bool = False): # Attach optimizer if not isinstance(optimizer, Optimizer): - raise TypeError( - "{} is not an Optimizer".format(type(optimizer).__name__) - ) + raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) self.optimizer = optimizer self.verbose = verbose for group in optimizer.param_groups: group.setdefault("initial_lr", group["lr"]) - self.base_lrs = [ - group["initial_lr"] for group in optimizer.param_groups - ] + self.base_lrs = [group["initial_lr"] for group in optimizer.param_groups] self.epoch = 0 self.batch = 0 @@ -299,10 +285,9 @@ class Eden(LRScheduler): def get_lr(self): factor = ( - (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 + (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 ) ** -0.25 * ( - ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) - ** -0.25 + ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 ) return [x * factor for x in self.base_lrs] diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py index f52cb22ab..58de6875f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py @@ -91,9 +91,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -118,10 +120,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -168,8 +172,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -222,10 +225,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -293,9 +295,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -382,9 +382,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 8c572a9ef..f671e97b1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -89,9 +89,7 @@ class ActivationBalancerFunction(torch.autograd.Function): below_threshold = mean_abs < min_abs above_threshold = mean_abs > max_abs - ctx.save_for_backward( - factor, xgt0, below_threshold, above_threshold - ) + ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) ctx.max_factor = max_factor ctx.sum_dims = sum_dims return x @@ -137,7 +135,7 @@ class GradientFilterFunction(torch.autograd.Function): eps = 1.0e-20 dim = ctx.batch_dim norm_dims = [d for d in range(x_grad.ndim) if d != dim] - norm_of_batch = (x_grad ** 2).mean(dim=norm_dims, keepdim=True).sqrt() + norm_of_batch = (x_grad**2).mean(dim=norm_dims, keepdim=True).sqrt() median_norm = norm_of_batch.median() cutoff = median_norm * ctx.threshold @@ -229,8 +227,7 @@ class BasicNorm(torch.nn.Module): if not is_jit_tracing(): assert x.shape[self.channel_dim] == self.num_channels scales = ( - torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) - + self.eps.exp() + torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps.exp() ) ** -0.5 return x * scales @@ -282,12 +279,12 @@ class ScaledLinear(nn.Linear): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -301,9 +298,7 @@ class ScaledLinear(nn.Linear): return self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear( - input, self.get_weight(), self.get_bias() - ) + return torch.nn.functional.linear(input, self.get_weight(), self.get_bias()) class ScaledConv1d(nn.Conv1d): @@ -331,12 +326,12 @@ class ScaledConv1d(nn.Conv1d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -400,12 +395,12 @@ class ScaledConv2d(nn.Conv2d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -476,9 +471,7 @@ class ScaledLSTM(nn.LSTM): setattr(self, scale_name, param) self._scales.append(param) - self.grad_filter = GradientFilter( - batch_dim=1, threshold=grad_norm_threshold - ) + self.grad_filter = GradientFilter(batch_dim=1, threshold=grad_norm_threshold) self._reset_parameters( initial_speed @@ -486,8 +479,8 @@ class ScaledLSTM(nn.LSTM): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std - scale = self.hidden_size ** -0.5 + a = (3**0.5) * std + scale = self.hidden_size**-0.5 v = scale / std for idx, name in enumerate(self._flat_weights_names): if "weight" in name: @@ -559,15 +552,11 @@ class ScaledLSTM(nn.LSTM): """Get scaled weights, and resets their data pointer.""" flat_weights = [] for idx in range(len(self._flat_weights_names)): - flat_weights.append( - self._flat_weights[idx] * self._scales[idx].exp() - ) + flat_weights.append(self._flat_weights[idx] * self._scales[idx].exp()) self._flatten_parameters(flat_weights) return flat_weights - def forward( - self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None - ): + def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None): # This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa # The change for calling `_VF.lstm()` is: # self._flat_weights -> self._get_flat_weights() @@ -915,9 +904,7 @@ def _test_activation_balancer_sign(): def _test_activation_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( - -1 - ) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True m = ActivationBalancer( @@ -947,8 +934,8 @@ def _test_basic_norm(): y = m(x) assert y.shape == x.shape - x_rms = (x ** 2).mean().sqrt() - y_rms = (y ** 2).mean().sqrt() + x_rms = (x**2).mean().sqrt() + y_rms = (y**2).mean().sqrt() print("x rms = ", x_rms) print("y rms = ", y_rms) assert y_rms < x_rms @@ -1001,17 +988,18 @@ def _test_grad_filter(): ) print( - "_test_grad_filter: for gradient norms, the first element > median * threshold ", # noqa + "_test_grad_filter: for gradient norms, the first element > median *" + " threshold ", # noqa i % 2 == 1, ) print( "_test_grad_filter: x_out_grad norm = ", - (x_out_grad ** 2).mean(dim=(0, 2)).sqrt(), + (x_out_grad**2).mean(dim=(0, 2)).sqrt(), ) print( "_test_grad_filter: x.grad norm = ", - (x.grad ** 2).mean(dim=(0, 2)).sqrt(), + (x.grad**2).mean(dim=(0, 2)).sqrt(), ) print("_test_grad_filter: w_out_grad = ", w_out_grad) print("_test_grad_filter: w.grad = ", w.grad) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py index 9bcd2f9f9..e6e0fb1c8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py @@ -153,9 +153,7 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner( - current_encoder_out, decoder_out, project_input=False - ) + logits = model.joiner(current_encoder_out, decoder_out, project_input=False) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -172,14 +170,10 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk( - num_active_paths - ) + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths) with warnings.catch_warnings(): warnings.simplefilter("ignore") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py index d76a03946..0139863a1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py @@ -51,11 +51,7 @@ from streaming_beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -94,9 +90,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -162,8 +160,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -271,9 +268,7 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -293,9 +288,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -351,9 +344,7 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device - ) + initial_states = model.encoder.get_init_state(params.left_context, device=device) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -425,9 +416,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -462,8 +451,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -536,8 +524,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 1947834bf..623bdd51a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -96,9 +96,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -210,8 +208,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need to " - "be changed.", + help="The initial learning rate. This value should not need to be changed.", ) parser.add_argument( @@ -234,42 +231,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -634,9 +634,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -649,14 +647,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -667,9 +660,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -837,9 +828,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -963,8 +952,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py index 1df7f9ee5..5e81aef07 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py @@ -27,10 +27,7 @@ from lhotse.dataset import ( K2SpeechRecognitionDataset, SpecAugment, ) -from lhotse.dataset.input_strategies import ( - OnTheFlyFeatures, - PrecomputedFeatures, -) +from lhotse.dataset.input_strategies import OnTheFlyFeatures, PrecomputedFeatures from torch.utils.data import DataLoader from icefall.utils import str2bool @@ -44,59 +41,69 @@ class AsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", + description=( + "These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc." + ), ) group.add_argument( "--max-duration", type=int, default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", + help=( + "Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM." + ), ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", + help=( + "When enabled, the batches will come from buckets of " + "similar duration (saves padding frames)." + ), ) group.add_argument( "--num-buckets", type=int, default=30, - help="The number of buckets for the DynamicBucketingSampler. " - "(you might want to increase it for larger datasets).", + help=( + "The number of buckets for the DynamicBucketingSampler. " + "(you might want to increase it for larger datasets)." + ), ) group.add_argument( "--shuffle", type=str2bool, default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", + help=( + "When enabled (=default), the examples will be shuffled for each epoch." + ), ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", + help=( + "When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it." + ), ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that " - "collect the batches.", + help="The number of training dataloader workers that collect the batches.", ) group.add_argument( @@ -117,18 +124,22 @@ class AsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", + help=( + "Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp." + ), ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", + help=( + "When enabled, select noise from MUSAN and mix it" + "with training dataset. " + ), ) group.add_argument( @@ -142,9 +153,11 @@ class AsrDataModule: "--on-the-fly-feats", type=str2bool, default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available. Used only in dev/test CutSet", + help=( + "When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available. Used only in dev/test CutSet" + ), ) def train_dataloaders( @@ -167,9 +180,7 @@ class AsrDataModule: if cuts_musan is not None: logging.info("Enable MUSAN") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -178,9 +189,7 @@ class AsrDataModule: if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") input_transforms.append( SpecAugment( time_warp_factor=self.args.spec_aug_time_warp_factor, @@ -250,9 +259,7 @@ class AsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py index 5784a78ba..66c8e30ba 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py @@ -79,11 +79,7 @@ from gigaspeech import GigaSpeech from gigaspeech_scoring import asr_text_post_processing from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -120,9 +116,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -192,8 +190,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -280,9 +277,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -312,10 +307,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -359,21 +351,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } elif params.decoding_method == "fast_beam_search_nbest_oracle": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}_" - f"num_paths_{params.num_paths}_" - f"nbest_scale_{params.nbest_scale}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}_num_paths_{params.num_paths}_nbest_scale_{params.nbest_scale}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -446,9 +428,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -481,8 +461,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -532,9 +511,7 @@ def main(): params.suffix += f"-num-paths-{params.num_paths}" params.suffix += f"-nbest-scale-{params.nbest_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -567,8 +544,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index 8025d6be1..d90497e26 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -120,11 +120,7 @@ from beam_search import ( from librispeech import LibriSpeech from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.rnn_lm.model import RnnLmModel from icefall.utils import ( @@ -167,9 +163,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -265,8 +263,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -478,9 +475,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -550,10 +545,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -646,21 +638,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - f"temperature_{params.temperature}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}temperature_{params.temperature}": hyps } elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - f"temperature_{params.temperature}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}temperature_{params.temperature}": hyps } elif params.decoding_method in [ "fast_beam_search_with_nbest_rescoring", @@ -690,12 +672,7 @@ def decode_one_batch( key += f"_ngram_lm_scale_{params.ngram_lm_scale}" return {key: hyps} else: - return { - ( - f"beam_size_{params.beam_size}_" - f"temperature_{params.temperature}" - ): hyps - } + return {f"beam_size_{params.beam_size}_temperature_{params.temperature}": hyps} def decode_dataset( @@ -779,9 +756,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -814,8 +789,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -939,9 +913,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" params.suffix += f"-temperature-{params.temperature}" else: params.suffix += f"-context-{params.context_size}" @@ -981,8 +953,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -1032,15 +1003,10 @@ def main(): word_table=word_table, device=device, ) - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) logging.info(f"G properties_str: {G.properties_str}") rnn_lm_model = None - if ( - params.decoding_method - == "fast_beam_search_with_nbest_rnn_rescoring" - ): + if params.decoding_method == "fast_beam_search_with_nbest_rnn_rescoring": rnn_lm_model = RnnLmModel( vocab_size=params.vocab_size, embedding_dim=params.rnn_lm_embedding_dim, @@ -1065,9 +1031,7 @@ def main(): rnn_lm_model.eval() else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) rnn_lm_model = None else: decoding_graph = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py index 47217ba05..dcf65e937 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py @@ -128,11 +128,7 @@ import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import str2bool @@ -164,9 +160,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -235,8 +233,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -509,13 +506,9 @@ def export_joiner_model_onnx( - projected_decoder_out: a tensor of shape (N, joiner_dim) """ - encoder_proj_filename = str(joiner_filename).replace( - ".onnx", "_encoder_proj.onnx" - ) + encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") - decoder_proj_filename = str(joiner_filename).replace( - ".onnx", "_decoder_proj.onnx" - ) + decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] @@ -616,8 +609,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -715,9 +707,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py index 36f32c6b3..598434f54 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py @@ -52,18 +52,14 @@ class GigaSpeech: ) pattern = re.compile(r"gigaspeech_cuts_XL.([0-9]+).jsonl.gz") - idx_filenames = [ - (int(pattern.search(f).group(1)), f) for f in filenames - ] + idx_filenames = [(int(pattern.search(f).group(1)), f) for f in filenames] idx_filenames = sorted(idx_filenames, key=lambda x: x[0]) sorted_filenames = [f[1] for f in idx_filenames] logging.info(f"Loading {len(sorted_filenames)} splits") - return lhotse.combine( - lhotse.load_manifest_lazy(p) for p in sorted_filenames - ) + return lhotse.combine(lhotse.load_manifest_lazy(p) for p in sorted_filenames) def train_L_cuts(self) -> CutSet: f = self.manifest_dir / "gigaspeech_cuts_L_raw.jsonl.gz" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py index 162f8c7db..108915389 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py @@ -104,10 +104,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -142,10 +144,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -330,9 +331,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py index 7852f84e9..d45f6dadc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py @@ -84,9 +84,7 @@ class Transducer(nn.Module): self.decoder_giga = decoder_giga self.joiner_giga = joiner_giga - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) if decoder_giga is not None: @@ -190,9 +188,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = encoder_out_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py index d03d1d7ef..163d737e3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py @@ -203,9 +203,7 @@ def test_joiner( ) # Now test encoder_proj - joiner_encoder_proj_inputs = { - encoder_proj_input_name: encoder_out.numpy() - } + joiner_encoder_proj_inputs = {encoder_proj_input_name: encoder_out.numpy()} joiner_encoder_proj_out = joiner_encoder_proj_session.run( [encoder_proj_output_name], joiner_encoder_proj_inputs )[0] @@ -214,16 +212,10 @@ def test_joiner( torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out) assert torch.allclose( joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5 - ), ( - (joiner_encoder_proj_out - torch_joiner_encoder_proj_out) - .abs() - .max() - ) + ), ((joiner_encoder_proj_out - torch_joiner_encoder_proj_out).abs().max()) # Now test decoder_proj - joiner_decoder_proj_inputs = { - decoder_proj_input_name: decoder_out.numpy() - } + joiner_decoder_proj_inputs = {decoder_proj_input_name: decoder_out.numpy()} joiner_decoder_proj_out = joiner_decoder_proj_session.run( [decoder_proj_output_name], joiner_decoder_proj_inputs )[0] @@ -232,11 +224,7 @@ def test_joiner( torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out) assert torch.allclose( joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5 - ), ( - (joiner_decoder_proj_out - torch_joiner_decoder_proj_out) - .abs() - .max() - ) + ), ((joiner_decoder_proj_out - torch_joiner_decoder_proj_out).abs().max()) @torch.no_grad() @@ -288,9 +276,7 @@ def main(): if __name__ == "__main__": torch.manual_seed(20220727) - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py index ea5d4e674..11597aa49 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py @@ -102,10 +102,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -140,10 +142,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -191,11 +192,7 @@ def greedy_search( projected_encoder_out = joiner_encoder_proj.run( [joiner_encoder_proj.get_outputs()[0].name], - { - joiner_encoder_proj.get_inputs()[ - 0 - ].name: packed_encoder_out.data.numpy() - }, + {joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()}, )[0] blank_id = 0 # hard-code to 0 @@ -382,9 +379,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py index 19b636a23..849d6cf4e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py @@ -100,9 +100,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -127,10 +129,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -177,8 +181,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -231,10 +234,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -302,9 +304,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -391,9 +391,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py index 1e6022b57..85d87f8f2 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py @@ -234,9 +234,7 @@ def scaled_lstm_to_lstm(scaled_lstm: ScaledLSTM) -> nn.LSTM: assert lstm._flat_weights_names == scaled_lstm._flat_weights_names for idx in range(len(scaled_lstm._flat_weights_names)): - scaled_weight = ( - scaled_lstm._flat_weights[idx] * scaled_lstm._scales[idx].exp() - ) + scaled_weight = scaled_lstm._flat_weights[idx] * scaled_lstm._scales[idx].exp() lstm._flat_weights[idx].data.copy_(scaled_weight) return lstm @@ -251,12 +249,10 @@ def get_submodule(model, target): mod: torch.nn.Module = model for item in atoms: if not hasattr(mod, item): - raise AttributeError( - mod._get_name() + " has no " "attribute `" + item + "`" - ) + raise AttributeError(mod._get_name() + " has no attribute `" + item + "`") mod = getattr(mod, item) if not isinstance(mod, torch.nn.Module): - raise AttributeError("`" + item + "` is not " "an nn.Module") + raise AttributeError("`" + item + "` is not an nn.Module") return mod diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py index 10bb44e00..41a712498 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py @@ -52,11 +52,7 @@ from streaming_beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -95,9 +91,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -163,8 +161,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -272,9 +269,7 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -294,9 +289,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -352,9 +345,7 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device - ) + initial_states = model.encoder.get_init_state(params.left_context, device=device) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -426,9 +417,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -461,8 +450,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -535,8 +523,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py index 66ffbd3ec..598fcf344 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py @@ -90,9 +90,7 @@ def test_conv2d_subsampling(): onnx_y = torch.from_numpy(onnx_y) torch_y = jit_model(x) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( - (onnx_y - torch_y).abs().max() - ) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() os.remove(filename) @@ -147,9 +145,7 @@ def test_rel_pos(): onnx_pos_emb = torch.from_numpy(onnx_pos_emb) torch_y, torch_pos_emb = jit_model(x) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( - (onnx_y - torch_y).abs().max() - ) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() assert torch.allclose(onnx_pos_emb, torch_pos_emb, atol=1e-05), ( (onnx_pos_emb - torch_pos_emb).abs().max() @@ -197,9 +193,7 @@ def test_conformer_encoder_layer(): encoder_layer.eval() encoder_layer = convert_scaled_to_non_scaled(encoder_layer, inplace=True) - jit_model = torch.jit.trace( - encoder_layer, (x, pos_emb, src_key_padding_mask) - ) + jit_model = torch.jit.trace(encoder_layer, (x, pos_emb, src_key_padding_mask)) torch.onnx.export( encoder_layer, @@ -236,9 +230,7 @@ def test_conformer_encoder_layer(): onnx_y = torch.from_numpy(onnx_y) torch_y = jit_model(x, pos_emb, src_key_padding_mask) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( - (onnx_y - torch_y).abs().max() - ) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) @@ -322,9 +314,7 @@ def test_conformer_encoder(): onnx_y = torch.from_numpy(onnx_y) torch_y = jit_model(x, pos_emb, src_key_padding_mask) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( - (onnx_y - torch_y).abs().max() - ) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) @@ -379,9 +369,7 @@ def test_conformer(): onnx_y_lens = torch.from_numpy(onnx_y_lens) torch_y, torch_y_lens = jit_model(x, x_lens) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( - (onnx_y - torch_y).abs().max() - ) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() assert torch.allclose(onnx_y_lens, torch_y_lens, atol=1e-05), ( (onnx_y_lens - torch_y_lens).abs().max() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index 44e96644a..6724343dd 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -92,9 +92,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -163,8 +161,7 @@ def get_parser(): "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. " - "Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", ) parser.add_argument( @@ -214,8 +211,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need " - "to be changed.", + help="The initial learning rate. This value should not need to be changed.", ) parser.add_argument( @@ -238,42 +234,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -672,9 +671,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -687,14 +684,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -705,9 +697,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -919,9 +909,7 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) libri_tot_loss.write_summary( tb_writer, "train/libri_tot_", params.batch_idx_train ) @@ -967,8 +955,7 @@ def filter_short_and_long_utterances( # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False @@ -1109,9 +1096,7 @@ def run(rank, world_size, args): train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: - cuts_musan = load_manifest( - Path(args.manifest_dir) / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") else: cuts_musan = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 4f043e5a6..69cfcd298 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -197,20 +197,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -306,8 +310,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -427,9 +430,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) if ( params.decoding_method == "fast_beam_search" @@ -485,10 +486,7 @@ def decode_one_batch( nbest_scale=params.nbest_scale, return_timestamps=True, ) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: res = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -566,9 +564,7 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[ - str, List[Tuple[str, List[str], List[str], List[float], List[float]]] -]: +) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]: """Decode dataset. Args: @@ -643,9 +639,7 @@ def decode_dataset( cut_ids, hyps, texts, timestamps_hyp, timestamps_ref ): ref_words = ref_text.split() - this_batch.append( - (cut_id, ref_words, hyp_words, time_ref, time_hyp) - ) + this_batch.append((cut_id, ref_words, hyp_words, time_ref, time_hyp)) results[name].extend(this_batch) @@ -654,9 +648,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -694,8 +686,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -722,9 +713,7 @@ def save_results( note = "" logging.info(s) - s = "\nFor {}, symbol-delay of different settings are:\n".format( - test_set_name - ) + s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name) note = "\tbest for {}".format(test_set_name) for key, val in test_set_delays: s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note) @@ -773,9 +762,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -812,13 +799,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -841,13 +827,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -875,7 +860,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -902,9 +887,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py index ce7518ceb..bd5801a78 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py @@ -89,20 +89,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -133,8 +137,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -183,13 +186,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -212,13 +214,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -246,7 +247,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -282,9 +283,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py index 7af9ea9b8..a28e52c78 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py @@ -96,20 +96,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -175,8 +179,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -284,9 +287,7 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -306,9 +307,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -364,9 +363,7 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device - ) + initial_states = model.encoder.get_init_state(params.left_context, device=device) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -438,9 +435,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -473,8 +468,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -547,13 +541,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -576,13 +569,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -610,7 +602,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index cf32e565b..76785a845 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -101,9 +101,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -239,42 +237,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -621,11 +622,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -665,9 +662,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -680,14 +675,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -698,9 +688,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -879,9 +867,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -1013,8 +999,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index 427b06294..8499651d7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -214,10 +214,7 @@ class Conformer(EncoderInterface): (num_encoder_layers, cnn_module_kernel - 1, encoder_dim). NOTE: the returned tensors are on the given device. """ - if ( - len(self._init_state) == 2 - and self._init_state[0].size(1) == left_context - ): + if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context: # Note: It is OK to share the init state as it is # not going to be modified by the model return self._init_state @@ -439,9 +436,7 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -459,9 +454,7 @@ class ConformerEncoderLayer(nn.Module): ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) - self.conv_module = ConvolutionModule( - d_model, cnn_module_kernel, causal=causal - ) + self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal) self.norm_final = BasicNorm(d_model) @@ -527,9 +520,7 @@ class ConformerEncoderLayer(nn.Module): src = src + self.dropout(src_att) # convolution module - conv, _ = self.conv_module( - src, src_key_padding_mask=src_key_padding_mask - ) + conv, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask) src = src + self.dropout(conv) # feed forward module @@ -802,9 +793,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -820,9 +809,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x_size_1 * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -848,9 +835,7 @@ class RelPositionalEncoding(torch.nn.Module): pe = torch.cat([pe_positive, pe_negative], dim=1) self.pe = pe.to(device=x.device, dtype=x.dtype) - def forward( - self, x: torch.Tensor, left_context: int = 0 - ) -> Tuple[Tensor, Tensor]: + def forward(self, x: torch.Tensor, left_context: int = 0) -> Tuple[Tensor, Tensor]: """Add positional encoding. Args: @@ -1118,9 +1103,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -1189,33 +1174,25 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor" + " instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -1253,23 +1230,15 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d - matrix_bd = torch.matmul( - q_with_bias_v, p - ) # (batch, head, time1, 2*time1-1) + matrix_bd = torch.matmul(q_with_bias_v, p) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd, left_context) - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -1310,21 +1279,17 @@ class RelPositionMultiheadAttention(nn.Module): ): if attn_mask.size(0) != 1: attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len) - combined_mask = attn_mask | key_padding_mask.unsqueeze( - 1 - ).unsqueeze(2) + combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) else: # attn_mask.shape == (1, tgt_len, src_len) - combined_mask = attn_mask.unsqueeze( - 0 - ) | key_padding_mask.unsqueeze(1).unsqueeze(2) + combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( + 1 + ).unsqueeze(2) attn_output_weights = attn_output_weights.view( bsz, num_heads, tgt_len, src_len ) - attn_output_weights = attn_output_weights.masked_fill( - combined_mask, 0.0 - ) + attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, src_len ) @@ -1336,13 +1301,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -1481,16 +1442,12 @@ class ConvolutionModule(nn.Module): # manualy padding self.lorder zeros to the left x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) else: - assert ( - not self.training - ), "Cache should be None in training time" + assert not self.training, "Cache should be None in training time" assert cache.size(0) == self.lorder x = torch.cat([cache.permute(1, 2, 0), x], dim=2) if right_context > 0: cache = x.permute(2, 0, 1)[ - -(self.lorder + right_context) : ( # noqa - -right_context - ), + -(self.lorder + right_context) : (-right_context), # noqa ..., ] else: @@ -1666,9 +1623,7 @@ class RandomCombine(nn.Module): self.stddev = stddev self.final_log_weight = ( - torch.tensor( - (final_weight / (1 - final_weight)) * (self.num_inputs - 1) - ) + torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)) .log() .item() ) @@ -1765,16 +1720,14 @@ class RandomCombine(nn.Module): # final contains self.num_inputs - 1 in all elements final = torch.full((num_frames,), self.num_inputs - 1, device=device) # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. - nonfinal = torch.randint( - self.num_inputs - 1, (num_frames,), device=device - ) + nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device) indexes = torch.where( torch.rand(num_frames, device=device) < final_prob, final, nonfinal ) - ans = torch.nn.functional.one_hot( - indexes, num_classes=self.num_inputs - ).to(dtype=dtype) + ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to( + dtype=dtype + ) return ans def _get_random_mixed_weights( @@ -1804,7 +1757,8 @@ class RandomCombine(nn.Module): def _test_random_combine(final_weight: float, pure_prob: float, stddev: float): print( - f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}" + f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}," + f" stddev={stddev}" ) num_inputs = 3 num_channels = 50 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 22bcdd88e..f462cc42f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -179,20 +179,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -303,8 +307,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -477,9 +480,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -545,10 +546,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -696,9 +694,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -731,8 +727,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -787,9 +782,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -828,13 +821,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -857,13 +849,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -891,7 +882,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -937,9 +928,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py index b2e5b430e..a739c17bc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py @@ -89,20 +89,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -133,8 +137,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -181,13 +184,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -210,13 +212,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -244,7 +245,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -280,9 +281,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py index 1e100fcbd..e2da0da4c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py @@ -89,9 +89,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -116,10 +118,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -166,8 +170,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -198,10 +201,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -264,15 +266,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -344,9 +342,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py index 6fee9483e..59a0e8fa2 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py @@ -96,20 +96,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -175,8 +179,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -284,9 +287,7 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -306,9 +307,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -364,9 +363,7 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device - ) + initial_states = model.encoder.get_init_state(params.left_context, device=device) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -438,9 +435,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -473,8 +468,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -547,13 +541,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -576,13 +569,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -610,7 +602,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index 179d9372e..75696d61b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -89,9 +89,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -248,8 +246,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need " - "to be changed.", + help="The initial learning rate. This value should not need to be changed.", ) parser.add_argument( @@ -272,42 +269,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -645,11 +645,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -690,9 +686,7 @@ def compute_loss( # If the batch contains more than 10 utterances AND # if either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -705,14 +699,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -723,9 +712,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -908,9 +895,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -1023,7 +1008,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1045,8 +1030,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py index 53788b3f7..40ad61fd4 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py @@ -90,10 +90,7 @@ class Conformer(EncoderInterface): output_layers = [] if middle_output_layer is not None: - assert ( - middle_output_layer >= 0 - and middle_output_layer < num_encoder_layers - ) + assert middle_output_layer >= 0 and middle_output_layer < num_encoder_layers output_layers.append(middle_output_layer) # The last layer is always needed. @@ -178,9 +175,7 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -362,9 +357,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -379,9 +372,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -656,9 +647,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -727,33 +718,25 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor" + " instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -790,9 +773,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -800,13 +781,9 @@ class RelPositionMultiheadAttention(nn.Module): ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -840,13 +817,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -869,9 +842,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py index 74df04006..600aa9b39 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py @@ -120,20 +120,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -208,8 +212,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -267,9 +270,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - layer_results, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + layer_results, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) encoder_out = layer_results[-1] hyps = [] @@ -285,10 +286,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -334,11 +332,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -411,9 +405,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -446,8 +438,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -490,9 +481,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -524,13 +513,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -553,13 +541,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -587,7 +574,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/export.py b/egs/librispeech/ASR/pruned_transducer_stateless6/export.py index cff9c7377..17f8614dc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/export.py @@ -51,11 +51,7 @@ import sentencepiece as spm import torch from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import str2bool @@ -87,9 +83,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -120,8 +118,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) return parser @@ -160,8 +157,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -209,9 +205,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py b/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py index 21409287c..86cf34877 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py @@ -21,9 +21,10 @@ import os from pathlib import Path import torch -from vq_utils import CodebookIndexExtractor from asr_datamodule import LibriSpeechAsrDataModule from hubert_xlarge import HubertXlargeFineTuned +from vq_utils import CodebookIndexExtractor + from icefall.utils import AttributeDict, str2bool diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py index 49b557814..b8440f90a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py @@ -23,7 +23,6 @@ from pathlib import Path from typing import Dict, List, Tuple import torch - from asr_datamodule import LibriSpeechAsrDataModule from hubert_xlarge import HubertXlargeFineTuned @@ -99,9 +98,7 @@ def decode_dataset( if batch_idx % 20 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -124,9 +121,7 @@ def save_results( ) test_set_wers[key] = wer - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.res_dir / f"wer-summary-{test_set_name}.txt" @@ -155,9 +150,7 @@ def main(): # reset some parameters needed by hubert. params.update(HubertXlargeFineTuned.get_params()) - params.res_dir = ( - params.exp_dir / f"ctc_greedy_search-{params.teacher_model_id}" - ) + params.res_dir = params.exp_dir / f"ctc_greedy_search-{params.teacher_model_id}" setup_logger(f"{params.res_dir}/log/log-ctc_greedy_search") logging.info("Decoding started") @@ -190,9 +183,7 @@ def main(): params=params, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py index 55ce7b00d..4f9417c9f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py @@ -22,11 +22,7 @@ from pathlib import Path from typing import Dict, List, Tuple import torch -from fairseq import ( - checkpoint_utils, - tasks, - utils, -) +from fairseq import checkpoint_utils, tasks, utils from fairseq.data.data_utils import post_process from omegaconf import OmegaConf @@ -51,9 +47,7 @@ def _load_hubert_model(params: AttributeDict): "data": str(params.hubert_model_dir), } ) - model_path = Path(params.hubert_model_dir) / ( - params.teacher_model_id + ".pt" - ) + model_path = Path(params.hubert_model_dir) / (params.teacher_model_id + ".pt") task = tasks.setup_task(cfg_task) processor = task.target_dictionary models, saved_cfg = checkpoint_utils.load_model_ensemble( @@ -151,9 +145,7 @@ class HubertXlargeFineTuned: supervisions = batch["supervisions"] num_samples = supervisions["num_samples"] B, T = features.shape - padding_mask = torch.arange(0, T).expand(B, T) > num_samples.reshape( - [-1, 1] - ) + padding_mask = torch.arange(0, T).expand(B, T) > num_samples.reshape([-1, 1]) padding_mask = padding_mask.to(self.params.device) features = features.to(self.params.device) @@ -163,9 +155,7 @@ class HubertXlargeFineTuned: features = features.transpose(1, 2) features = self.w2v_model.layer_norm(features) - padding_mask = self.w2v_model.forward_padding_mask( - features, padding_mask - ) + padding_mask = self.w2v_model.forward_padding_mask(features, padding_mask) if self.w2v_model.post_extract_proj is not None: features = self.w2v_model.post_extract_proj(features) @@ -212,9 +202,7 @@ class HubertXlargeFineTuned: toks = encoder_out.argmax(dim=-1) blank = 0 toks = [tok.unique_consecutive() for tok in toks] - hyps = [ - self.processor.string(tok[tok != blank].int().cpu()) for tok in toks - ] + hyps = [self.processor.string(tok[tok != blank].int().cpu()) for tok in toks] hyps = [post_process(hyp, "letter") for hyp in hyps] return hyps diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py index 7716d19cf..daadb70c9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py @@ -69,9 +69,7 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) from icefall import is_module_available @@ -180,9 +178,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens @@ -237,9 +233,7 @@ class Transducer(nn.Module): return (simple_loss, pruned_loss, codebook_loss) @staticmethod - def concat_successive_codebook_indexes( - middle_layer_output, codebook_indexes - ): + def concat_successive_codebook_indexes(middle_layer_output, codebook_indexes): # Output rate of hubert is 50 frames per second, # while that of current encoder is 25. # Following code handling two issues: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py index f717d85fb..be54ff0ce 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py @@ -101,9 +101,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def get_parser(): @@ -203,42 +201,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -569,9 +570,7 @@ def save_checkpoint( def extract_codebook_indexes(batch): cuts = batch["supervisions"]["cut"] # -100 is identical to ignore_value in CE loss computation. - cuts_pre_mixed = [ - c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts - ] + cuts_pre_mixed = [c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts] codebook_indexes, codebook_indexes_lens = collate_custom_field( cuts_pre_mixed, "codebook_indexes", pad_value=-100 ) @@ -604,11 +603,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -655,9 +650,7 @@ def compute_loss( # If the batch contains more than 10 utterances AND # if either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -670,14 +663,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss if is_training and params.enable_distillation: assert codebook_loss is not None loss += params.codebook_loss_scale * codebook_loss @@ -690,9 +678,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -873,9 +859,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -1007,8 +991,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py index 47cf2b14b..40f97f662 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py @@ -68,9 +68,7 @@ class CodebookIndexExtractor: def init_dirs(self): # vq_dir is the root dir for quantization, containing: # training data, trained quantizer, and extracted codebook indexes - self.vq_dir = ( - self.params.exp_dir / f"vq/{self.params.teacher_model_id}/" - ) + self.vq_dir = self.params.exp_dir / f"vq/{self.params.teacher_model_id}/" self.vq_dir.mkdir(parents=True, exist_ok=True) # manifest_dir contains: @@ -208,9 +206,7 @@ class CodebookIndexExtractor: start = cur_offset % (data.shape[0] + 1 - B) end = start + B cur_offset += B - yield data[start:end, :].to(self.params.device).to( - dtype=torch.float - ) + yield data[start:end, :].to(self.params.device).to(dtype=torch.float) for x in minibatch_generator(train, repeat=True): trainer.step(x) @@ -227,10 +223,11 @@ class CodebookIndexExtractor: """ for subset in self.params.subsets: logging.info(f"About to split {subset}.") - ori_manifest = ( - f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" + ori_manifest = f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" + split_cmd = ( + "lhotse split" + f" {self.params.world_size} {ori_manifest} {self.manifest_dir}" ) - split_cmd = f"lhotse split {self.params.world_size} {ori_manifest} {self.manifest_dir}" os.system(f"{split_cmd}") def join_manifests(self): @@ -240,16 +237,13 @@ class CodebookIndexExtractor: logging.info("Start to join manifest files.") for subset in self.params.subsets: vq_manifest_path = ( - self.dst_manifest_dir - / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" + self.dst_manifest_dir / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" ) ori_manifest_path = ( - self.ori_manifest_dir - / f"librispeech_cuts_train-{subset}.jsonl.gz" + self.ori_manifest_dir / f"librispeech_cuts_train-{subset}.jsonl.gz" ) dst_vq_manifest_path = ( - self.dst_manifest_dir - / f"librispeech_cuts_train-{subset}.jsonl.gz" + self.dst_manifest_dir / f"librispeech_cuts_train-{subset}.jsonl.gz" ) cuts_vq = load_manifest(vq_manifest_path) cuts_ori = load_manifest(ori_manifest_path) @@ -269,8 +263,7 @@ class CodebookIndexExtractor: for subset in self.params.subsets: vq_manifests = f"{self.manifest_dir}/with_codebook_indexes-librispeech-cuts_train-{subset}*.jsonl.gz" dst_vq_manifest = ( - self.dst_manifest_dir - / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" + self.dst_manifest_dir / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" ) if 1 == self.params.world_size: merge_cmd = f"cp {vq_manifests} {dst_vq_manifest}" @@ -330,9 +323,7 @@ class CodebookIndexExtractor: def load_ori_dl(self, subset): if self.params.world_size == 1: - ori_manifest_path = ( - f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" - ) + ori_manifest_path = f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" else: ori_manifest_path = ( self.manifest_dir diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py index 06c5863f1..fa8144935 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py @@ -164,20 +164,24 @@ def get_parser(): "--avg", type=int, default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -272,8 +276,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -393,9 +396,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -454,10 +455,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -588,9 +586,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -623,8 +619,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -679,9 +674,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -718,13 +711,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -747,13 +739,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -781,7 +772,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -808,9 +799,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py index 712dc8ce1..5f90e6375 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py @@ -69,7 +69,7 @@ class Decoder(nn.Module): out_channels=decoder_dim, kernel_size=context_size, padding=0, - groups=decoder_dim//4, # group size == 4 + groups=decoder_dim // 4, # group size == 4 bias=False, ) @@ -91,9 +91,7 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py index 5744ea3ea..43ac658e5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py @@ -129,20 +129,24 @@ def get_parser(): "--avg", type=int, default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -176,8 +180,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) add_model_arguments(parser) @@ -215,13 +218,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -244,13 +246,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -278,7 +279,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -316,9 +317,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py index e2405d5ef..c94a34d58 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py @@ -69,10 +69,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) return parser @@ -93,10 +95,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -267,9 +268,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py index 7d8de5afe..3ddac2cf2 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py @@ -56,9 +56,7 @@ class Joiner(nn.Module): assert encoder_out.shape[:-1] == decoder_out.shape[:-1] if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj( - decoder_out - ) + logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index 53cde6c6f..0e59b0f2f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -15,14 +15,15 @@ # limitations under the License. +import random + import k2 import torch import torch.nn as nn -import random from encoder_interface import EncoderInterface +from scaling import penalize_abs_values_gt from icefall.utils import add_sos -from scaling import penalize_abs_values_gt class Transducer(nn.Module): @@ -65,7 +66,8 @@ class Transducer(nn.Module): self.joiner = joiner self.simple_am_proj = nn.Linear( - encoder_dim, vocab_size, + encoder_dim, + vocab_size, ) self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size) @@ -133,18 +135,16 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - #if self.training and random.random() < 0.25: + # if self.training and random.random() < 0.25: # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) - #if self.training and random.random() < 0.25: + # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) with torch.cuda.amp.autocast(enabled=False): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index bb8b0a0e3..460ac2c3e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -14,17 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import defaultdict -from typing import List, Optional, Union, Tuple, List -from lhotse.utils import fix_random_seed -import torch -from scaling import ActivationBalancer +import contextlib +import logging import random +from collections import defaultdict +from typing import List, Optional, Tuple, Union + +import torch +from lhotse.utils import fix_random_seed +from scaling import ActivationBalancer from torch import Tensor from torch.optim import Optimizer -import logging -import contextlib - class BatchedOptimizer(Optimizer): @@ -37,11 +37,10 @@ class BatchedOptimizer(Optimizer): Args: params: """ + def __init__(self, params, defaults): super(BatchedOptimizer, self).__init__(params, defaults) - - @contextlib.contextmanager def batched_params(self, param_group): """ @@ -73,7 +72,9 @@ class BatchedOptimizer(Optimizer): group: a parameter group, which is a list of parameters; should be one of self.groups. """ - batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter + batches = defaultdict( + list + ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter for p in param_group: key = (str(p.dtype), *p.shape) @@ -82,7 +83,7 @@ class BatchedOptimizer(Optimizer): stacked_params_dict = dict() # turn batches into a list, in deterministic order. - batches = [ batches[key] for key in sorted(batches.keys()) ] + batches = [batches[key] for key in sorted(batches.keys())] # pairs will contain pairs of (stacked_param, state), one for each batch # in `batches`. pairs = [] @@ -94,76 +95,77 @@ class BatchedOptimizer(Optimizer): # group. class Optimizer will take care of saving/loading state. state = self.state[p] p_stacked = torch.stack(batch) - grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch ]) + grad = torch.stack( + [torch.zeros_like(p) if p.grad is None else p.grad for p in batch] + ) p_stacked.grad = grad stacked_params_dict[key] = p_stacked pairs.append((p_stacked, state)) - yield pairs # <-- calling code will do the actual optimization here! + yield pairs # <-- calling code will do the actual optimization here! for ((stacked_params, _state), batch) in zip(pairs, batches): for i, p in enumerate(batch): # batch is list of Parameter p.copy_(stacked_params[i]) - class ScaledAdam(BatchedOptimizer): """ - Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update - proportional to the norm of that parameter; and also learn the scale of the parameter, - in log space, subject to upper and lower limits (as if we had factored each parameter as - param = underlying_param * log_scale.exp()) + Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update + proportional to the norm of that parameter; and also learn the scale of the parameter, + in log space, subject to upper and lower limits (as if we had factored each parameter as + param = underlying_param * log_scale.exp()) - Args: - params: The parameters or param_groups to optimize (like other Optimizer subclasses) - lr: The learning rate. We will typically use a learning rate schedule that starts - at 0.03 and decreases over time, i.e. much higher than other common - optimizers. - clipping_scale: (e.g. 2.0) - A scale for gradient-clipping: if specified, the normalized gradients - over the whole model will be clipped to have 2-norm equal to - `clipping_scale` times the median 2-norm over the most recent period - of `clipping_update_period` minibatches. By "normalized gradients", - we mean after multiplying by the rms parameter value for this tensor - [for non-scalars]; this is appropriate because our update is scaled - by this quantity. - betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. - Must satisfy 0 < beta <= beta2 < 1. - scalar_lr_scale: A scaling factor on the learning rate, that we use to update the - scale of each parameter tensor and scalar parameters of the mode.. - If each parameter were decomposed - as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale - would be a the scaling factor on the learning rate of p_scale. - eps: A general-purpose epsilon to prevent division by zero - param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be >= this value) - param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be <= this value) - scalar_max: Maximum absolute value for scalar parameters (applicable if your - model has any parameters with numel() == 1). - size_update_period: The periodicity, in steps, with which we update the size (scale) - of the parameter tensor. This is provided to save a little time - in the update. - clipping_update_period: if clipping_scale is specified, this is the period + Args: + params: The parameters or param_groups to optimize (like other Optimizer subclasses) + lr: The learning rate. We will typically use a learning rate schedule that starts + at 0.03 and decreases over time, i.e. much higher than other common + optimizers. + clipping_scale: (e.g. 2.0) + A scale for gradient-clipping: if specified, the normalized gradients + over the whole model will be clipped to have 2-norm equal to + `clipping_scale` times the median 2-norm over the most recent period + of `clipping_update_period` minibatches. By "normalized gradients", + we mean after multiplying by the rms parameter value for this tensor + [for non-scalars]; this is appropriate because our update is scaled + by this quantity. + betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. + Must satisfy 0 < beta <= beta2 < 1. + scalar_lr_scale: A scaling factor on the learning rate, that we use to update the + scale of each parameter tensor and scalar parameters of the mode.. + If each parameter were decomposed + as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale + would be a the scaling factor on the learning rate of p_scale. + eps: A general-purpose epsilon to prevent division by zero + param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be >= this value) + param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be <= this value) + scalar_max: Maximum absolute value for scalar parameters (applicable if your + model has any parameters with numel() == 1). + size_update_period: The periodicity, in steps, with which we update the size (scale) + of the parameter tensor. This is provided to save a little time + in the update. + clipping_update_period: if clipping_scale is specified, this is the period """ - def __init__( - self, - params, - lr=3e-02, - clipping_scale=None, - betas=(0.9, 0.98), - scalar_lr_scale=0.1, - eps=1.0e-08, - param_min_rms=1.0e-05, - param_max_rms=3.0, - scalar_max=10.0, - size_update_period=4, - clipping_update_period=100, - ): + def __init__( + self, + params, + lr=3e-02, + clipping_scale=None, + betas=(0.9, 0.98), + scalar_lr_scale=0.1, + eps=1.0e-08, + param_min_rms=1.0e-05, + param_max_rms=3.0, + scalar_max=10.0, + size_update_period=4, + clipping_update_period=100, + ): defaults = dict( lr=lr, @@ -183,7 +185,6 @@ class ScaledAdam(BatchedOptimizer): def __setstate__(self, state): super(ScaledAdam, self).__setstate__(state) - @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. @@ -206,7 +207,9 @@ class ScaledAdam(BatchedOptimizer): # a regular parameter, and will have a .grad, but the 1st dim corresponds to # a stacking dim, it is not a real dim. - if len(batches[0][1]) == 0: # if len(first state) == 0: not yet initialized + if ( + len(batches[0][1]) == 0 + ): # if len(first state) == 0: not yet initialized clipping_scale = 1 else: clipping_scale = self._get_clipping_scale(group, batches) @@ -225,13 +228,9 @@ class ScaledAdam(BatchedOptimizer): self._step_one_batch(group, p, state, clipping_scale) - return loss - def _init_state(self, - group: dict, - p: Tensor, - state: dict): + def _init_state(self, group: dict, p: Tensor, state: dict): """ Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p is actually the batch dimension, corresponding to batched-together @@ -247,7 +246,7 @@ class ScaledAdam(BatchedOptimizer): state["step"] = 0 - kwargs = {'device':p.device, 'dtype':p.dtype} + kwargs = {"device": p.device, "dtype": p.dtype} # 'delta' implements conventional momentum. There are # several different kinds of update going on, so rather than @@ -255,36 +254,30 @@ class ScaledAdam(BatchedOptimizer): # parameter-change "delta", which combines all forms of # update. this is equivalent to how it's done in Adam, # except for the first few steps. - state["delta"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) + state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format) batch_size = p.shape[0] numel = p.numel() // batch_size numel = p.numel() - if numel > 1: # "param_rms" just periodically records the scalar root-mean-square value of # the parameter tensor. # it has a shape like (batch_size, 1, 1, 1, 1) - param_rms = (p**2).mean(dim=list(range(1, p.ndim)), - keepdim=True).sqrt() + param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() state["param_rms"] = param_rms state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) - state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape, - **kwargs) - + state["scale_grads"] = torch.zeros( + size_update_period, *param_rms.shape, **kwargs + ) # exp_avg_sq is the weighted sum of scaled gradients. as in Adam. - state["exp_avg_sq"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) + state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) - def _get_clipping_scale(self, - group: dict, - pairs: List[Tuple[Tensor, dict]]) -> float: + def _get_clipping_scale( + self, group: dict, pairs: List[Tuple[Tensor, dict]] + ) -> float: """ Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients by this amount before applying the rest of the update. @@ -314,57 +307,67 @@ class ScaledAdam(BatchedOptimizer): if p.numel() == p.shape[0]: # a batch of scalars tot_sumsq += (grad**2).sum() # sum() to change shape [1] to [] else: - tot_sumsq += ((grad * state["param_rms"])**2).sum() + tot_sumsq += ((grad * state["param_rms"]) ** 2).sum() tot_norm = tot_sumsq.sqrt() - if not "model_norms" in first_state: - first_state["model_norms"] = torch.zeros(clipping_update_period, - device=p.device) + if "model_norms" not in first_state: + first_state["model_norms"] = torch.zeros( + clipping_update_period, device=p.device + ) first_state["model_norms"][step % clipping_update_period] = tot_norm if step % clipping_update_period == 0: # Print some stats. # We don't reach here if step == 0 because we would have returned # above. - sorted_norms = first_state["model_norms"].sort()[0].to('cpu') + sorted_norms = first_state["model_norms"].sort()[0].to("cpu") quartiles = [] for n in range(0, 5): - index = min(clipping_update_period - 1, - (clipping_update_period // 4) * n) + index = min( + clipping_update_period - 1, + (clipping_update_period // 4) * n, + ) quartiles.append(sorted_norms[index].item()) median = quartiles[2] threshold = clipping_scale * median first_state["model_norm_threshold"] = threshold - percent_clipped = (first_state["num_clipped"] * 100.0 / clipping_update_period - if "num_clipped" in first_state else 0.0) + percent_clipped = ( + first_state["num_clipped"] * 100.0 / clipping_update_period + if "num_clipped" in first_state + else 0.0 + ) first_state["num_clipped"] = 0 - quartiles = ' '.join([ '%.3e' % x for x in quartiles ]) - logging.info(f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " - f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}") + quartiles = " ".join(["%.3e" % x for x in quartiles]) + logging.info( + f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " + f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}" + ) if step < clipping_update_period: return 1.0 # We have not yet estimated a norm to clip to. else: try: model_norm_threshold = first_state["model_norm_threshold"] - except: - logging.info("Warning: model_norm_threshold not in state: possibly " - "you changed config when restarting, adding clipping_scale option?") + except KeyError: + logging.info( + "Warning: model_norm_threshold not in state: possibly " + "you changed config when restarting, adding clipping_scale option?" + ) return 1.0 - ans = min(1.0,(model_norm_threshold / (tot_norm + 1.0e-20)).item()) + ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item()) if ans < 1.0: first_state["num_clipped"] += 1 if ans < 0.1: - logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}") + logging.warn( + f"Scaling gradients by {ans}," + f" model_norm_threshold={model_norm_threshold}" + ) return ans - - def _step_one_batch(self, - group: dict, - p: Tensor, - state: dict, - clipping_scale: float): + def _step_one_batch( + self, group: dict, p: Tensor, state: dict, clipping_scale: float + ): """ Do the step for one parameter, which is actually going to be a batch of `real` parameters, with dim 0 as the batch dim. @@ -391,17 +394,18 @@ class ScaledAdam(BatchedOptimizer): # Update the size/scale of p, and set param_rms scale_grads = state["scale_grads"] scale_grads[step % size_update_period] = (p * grad).sum( - dim=list(range(1, p.ndim)), keepdim=True) + dim=list(range(1, p.ndim)), keepdim=True + ) if step % size_update_period == size_update_period - 1: param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..) - param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), - keepdim=True).sqrt()) + param_rms.copy_( + (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() + ) if step > 0: # self._size_update() learns the overall scale on the # parameter, by shrinking or expanding it. self._size_update(group, scale_grads, p, state) - if numel == 1: # For parameters with 1 element we just use regular Adam. # Updates delta. @@ -411,24 +415,21 @@ class ScaledAdam(BatchedOptimizer): state["step"] = step + 1 - - def _size_update(self, - group: dict, - scale_grads: Tensor, - p: Tensor, - state: dict) -> None: + def _size_update( + self, group: dict, scale_grads: Tensor, p: Tensor, state: dict + ) -> None: """ - Called only where p.numel() > 1, this updates the scale of the parameter. - If we imagine: p = underlying_param * scale.exp(), and we are doing - gradient descent on underlying param and on scale, this function does the update - on `scale`. + Called only where p.numel() > 1, this updates the scale of the parameter. + If we imagine: p = underlying_param * scale.exp(), and we are doing + gradient descent on underlying param and on scale, this function does the update + on `scale`. - Args: - group: dict to look up configuration values - scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing - grads w.r.t. the scales. - p: The parameter to update - state: The state-dict of p + Args: + group: dict to look up configuration values + scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing + grads w.r.t. the scales. + p: The parameter to update + state: The state-dict of p """ param_rms = state["param_rms"] @@ -443,25 +444,28 @@ class ScaledAdam(BatchedOptimizer): size_update_period = scale_grads.shape[0] # correct beta2 for the size update period: we will have # faster decay at this level. - beta2_corr = beta2 ** size_update_period + beta2_corr = beta2**size_update_period scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..) scale_exp_avg_sq.mul_(beta2_corr).add_( - (scale_grads ** 2).mean(dim=0), # mean over dim `size_update_period` - alpha=1-beta2_corr) # shape is (batch_size, 1, 1, ...) + (scale_grads**2).mean(dim=0), # mean over dim `size_update_period` + alpha=1 - beta2_corr, + ) # shape is (batch_size, 1, 1, ...) # The 1st time we reach here is when size_step == 1. size_step = (step + 1) // size_update_period - bias_correction2 = 1 - beta2_corr ** size_step + bias_correction2 = 1 - beta2_corr**size_step # we don't bother with bias_correction1; this will help prevent divergence # at the start of training. denom = scale_exp_avg_sq.sqrt() + eps - scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum(dim=0) / denom + scale_step = ( + -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom + ) - is_too_small = (param_rms < param_min_rms) - is_too_large = (param_rms > param_max_rms) + is_too_small = param_rms < param_min_rms + is_too_large = param_rms > param_max_rms # when the param gets too small, just don't shrink it any further. scale_step.masked_fill_(is_too_small, 0.0) @@ -469,13 +473,9 @@ class ScaledAdam(BatchedOptimizer): scale_step.masked_fill_(is_too_large, -size_lr * size_update_period) delta = state["delta"] # the factor of (1-beta1) relates to momentum. - delta.add_(p * scale_step, alpha=(1-beta1)) + delta.add_(p * scale_step, alpha=(1 - beta1)) - - def _step(self, - group: dict, - p: Tensor, - state: dict): + def _step(self, group: dict, p: Tensor, state: dict): """ This function does the core update of self.step(), in the case where the members of the batch have more than 1 element. @@ -496,8 +496,7 @@ class ScaledAdam(BatchedOptimizer): step = state["step"] exp_avg_sq = state["exp_avg_sq"] - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, - value=(1-beta2)) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0) bias_correction2 = 1 - beta2 ** (this_step + 1) @@ -509,17 +508,13 @@ class ScaledAdam(BatchedOptimizer): denom += eps grad = grad / denom - alpha = -lr * (1-beta1) * state["param_rms"].clamp(min=param_min_rms) + alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms) delta = state["delta"] delta.add_(grad * alpha) p.add_(delta) - - def _step_scalar(self, - group: dict, - p: Tensor, - state: dict): + def _step_scalar(self, group: dict, p: Tensor, state: dict): """ A simplified form of the core update for scalar tensors, where we cannot get a good estimate of the parameter rms. @@ -531,8 +526,7 @@ class ScaledAdam(BatchedOptimizer): grad = p.grad exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, - value=1-beta2) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # bias_correction2 is like in Adam. Don't bother with bias_correction1; # slower update at the start will help stability anyway. @@ -540,12 +534,11 @@ class ScaledAdam(BatchedOptimizer): denom = (exp_avg_sq / bias_correction2).sqrt() + eps delta = state["delta"] - delta.add_(grad / denom, alpha=-lr*(1-beta1)) + delta.add_(grad / denom, alpha=-lr * (1 - beta1)) p.clamp_(min=-scalar_max, max=scalar_max) p.add_(delta) - class LRScheduler(object): """ Base-class for learning rate schedulers where the learning-rate depends on both the @@ -555,18 +548,14 @@ class LRScheduler(object): def __init__(self, optimizer: Optimizer, verbose: bool = False): # Attach optimizer if not isinstance(optimizer, Optimizer): - raise TypeError( - "{} is not an Optimizer".format(type(optimizer).__name__) - ) + raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) self.optimizer = optimizer self.verbose = verbose for group in optimizer.param_groups: group.setdefault("base_lr", group["lr"]) - self.base_lrs = [ - group["base_lr"] for group in optimizer.param_groups - ] + self.base_lrs = [group["base_lr"] for group in optimizer.param_groups] self.epoch = 0 self.batch = 0 @@ -680,13 +669,15 @@ class Eden(LRScheduler): def get_lr(self): factor = ( - (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 + (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 ) ** -0.25 * ( - ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) - ** -0.25 + ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 + ) + warmup_factor = ( + 1.0 + if self.batch >= self.warmup_batches + else 0.5 + 0.5 * (self.batch / self.warmup_batches) ) - warmup_factor = (1.0 if self.batch >= self.warmup_batches - else 0.5 + 0.5 * (self.batch / self.warmup_batches)) return [x * factor * warmup_factor for x in self.base_lrs] @@ -745,13 +736,14 @@ class Eve(Optimizer): parameters, if they fall below this we will stop applying weight decay. - .. _Adam\: A Method for Stochastic Optimization: + .. _Adam: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 .. _Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101 .. _On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ """ + def __init__( self, params, @@ -766,17 +758,11 @@ class Eve(Optimizer): if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: - raise ValueError( - "Invalid beta parameter at index 0: {}".format(betas[0]) - ) + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: - raise ValueError( - "Invalid beta parameter at index 1: {}".format(betas[1]) - ) + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) if not 0 <= weight_decay <= 0.1: - raise ValueError( - "Invalid weight_decay value: {}".format(weight_decay) - ) + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict( @@ -812,9 +798,7 @@ class Eve(Optimizer): # Perform optimization step grad = p.grad if grad.is_sparse: - raise RuntimeError( - "AdamW does not support sparse gradients" - ) + raise RuntimeError("AdamW does not support sparse gradients") state = self.state[p] @@ -841,7 +825,7 @@ class Eve(Optimizer): # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( + denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( group["eps"] ) @@ -852,30 +836,31 @@ class Eve(Optimizer): if p.numel() > 1: # avoid applying this weight-decay on "scaling factors" # (which are scalar). - is_above_target_rms = p.norm() > ( - target_rms * (p.numel() ** 0.5) - ) + is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) if random.random() < 0.0005: - step = (exp_avg/denom) * step_size - logging.info(f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}") - + step = (exp_avg / denom) * step_size + logging.info( + f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}" + ) return loss def _test_scaled_adam(hidden_dim: int): import timeit + from scaling import ScaledLinear + E = 100 B = 4 T = 2 logging.info("in test_eve_cain") - #device = torch.device('cuda') - device = torch.device('cpu') + # device = torch.device('cuda') + device = torch.device("cpu") dtype = torch.float32 fix_random_seed(42) @@ -889,79 +874,93 @@ def _test_scaled_adam(hidden_dim: int): fix_random_seed(42) Linear = torch.nn.Linear if iter == 0 else ScaledLinear - m = torch.nn.Sequential(Linear(E, hidden_dim), - torch.nn.PReLU(), - Linear(hidden_dim, hidden_dim), - torch.nn.PReLU(), - Linear(hidden_dim, E), - ).to(device) + m = torch.nn.Sequential( + Linear(E, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, E), + ).to(device) - train_pairs = [ (100.0 * torch.randn(B, T, E, device=device, dtype=dtype) * input_magnitudes, - torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes) for _ in range(20) ] + train_pairs = [ + ( + 100.0 + * torch.randn(B, T, E, device=device, dtype=dtype) + * input_magnitudes, + torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes, + ) + for _ in range(20) + ] - if iter == 0: optim = Eve(m.parameters(), lr=0.003) - elif iter == 1: optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) + if iter == 0: + optim = Eve(m.parameters(), lr=0.003) + elif iter == 1: + optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) - start = timeit.default_timer() avg_loss = 0.0 for epoch in range(180): scheduler.step_epoch() - #if epoch == 100 and iter in [2,3]: + # if epoch == 100 and iter in [2,3]: # optim.reset_speedup() # check it doesn't crash. - #if epoch == 130: + # if epoch == 130: # opts = diagnostics.TensorDiagnosticOptions( # 2 ** 22 # ) # allow 4 megabytes per sub-module # diagnostic = diagnostics.attach_diagnostics(m, opts) - - for n, (x,y) in enumerate(train_pairs): + for n, (x, y) in enumerate(train_pairs): y_out = m(x) - loss = ((y_out - y)**2).mean() * 100.0 + loss = ((y_out - y) ** 2).mean() * 100.0 if epoch == 0 and n == 0: avg_loss = loss.item() else: avg_loss = 0.98 * avg_loss + 0.02 * loss.item() if n == 0 and epoch % 5 == 0: - #norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() - #norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() - #norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() - #norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() - #scale1 = '%.2e' % (m[0].weight_scale.exp().item()) - #scale1b = '%.2e' % (m[0].bias_scale.exp().item()) - #scale2 = '%.2e' % (m[2].weight_scale.exp().item()) - #scale2b = '%.2e' % (m[2].bias_scale.exp().item()) + # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() + # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() + # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() + # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() + # scale1 = '%.2e' % (m[0].weight_scale.exp().item()) + # scale1b = '%.2e' % (m[0].bias_scale.exp().item()) + # scale2 = '%.2e' % (m[2].weight_scale.exp().item()) + # scale2b = '%.2e' % (m[2].bias_scale.exp().item()) lr = scheduler.get_last_lr()[0] - logging.info(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}") #, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} + logging.info( + f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss" + f" {avg_loss:.4g}, lr={lr:.4e}" + ) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} loss.log().backward() optim.step() optim.zero_grad() scheduler.step_batch() - #diagnostic.print_diagnostics() + # diagnostic.print_diagnostics() stop = timeit.default_timer() logging.info(f"Iter={iter}, Time taken: {stop - start}") logging.info(f"last lr = {scheduler.get_last_lr()}") - #logging.info("state dict = ", scheduler.state_dict()) - #logging.info("optim state_dict = ", optim.state_dict()) + # logging.info("state dict = ", scheduler.state_dict()) + # logging.info("optim state_dict = ", optim.state_dict()) logging.info(f"input_magnitudes = {input_magnitudes}") logging.info(f"output_magnitudes = {output_magnitudes}") - if __name__ == "__main__": torch.set_num_threads(1) torch.set_num_interop_threads(1) logging.getLogger().setLevel(logging.INFO) import subprocess - s = subprocess.check_output("git status -uno .; git log -1; git diff HEAD .", shell=True) + + s = subprocess.check_output( + "git status -uno .; git log -1; git diff HEAD .", shell=True + ) logging.info(s) import sys + if len(sys.argv) > 1: hidden_dim = int(sys.argv[1]) else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py index 7fe1e681a..8b4d88871 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py @@ -100,9 +100,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -127,10 +129,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -177,8 +181,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -209,10 +212,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -275,15 +277,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -355,9 +353,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 50cedba56..4040065e1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -16,12 +16,12 @@ import collections +import logging +import random +from functools import reduce from itertools import repeat from typing import Optional, Tuple, Union -from functools import reduce -import logging -import random import torch import torch.nn as nn import torch.nn.functional as F @@ -32,27 +32,24 @@ from torch.nn import Embedding as ScaledEmbedding class ActivationBalancerFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - scale_factor: Tensor, - sign_factor: Optional[Tensor], - channel_dim: int, + ctx, + x: Tensor, + scale_factor: Tensor, + sign_factor: Optional[Tensor], + channel_dim: int, ) -> Tensor: if channel_dim < 0: channel_dim += x.ndim ctx.channel_dim = channel_dim - xgt0 = (x > 0) + xgt0 = x > 0 if sign_factor is None: ctx.save_for_backward(xgt0, scale_factor) else: ctx.save_for_backward(xgt0, scale_factor, sign_factor) return x - @staticmethod - def backward( - ctx, x_grad: Tensor - ) -> Tuple[Tensor, None, None, None]: + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: if len(ctx.saved_tensors) == 3: xgt0, scale_factor, sign_factor = ctx.saved_tensors for _ in range(ctx.channel_dim, x_grad.ndim - 1): @@ -65,14 +62,22 @@ class ActivationBalancerFunction(torch.autograd.Function): scale_factor = scale_factor.unsqueeze(-1) factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5) neg_delta_grad = x_grad.abs() * factor - return x_grad - neg_delta_grad, None, None, None, + return ( + x_grad - neg_delta_grad, + None, + None, + None, + ) -def _compute_scale_factor(x: Tensor, - channel_dim: int, - min_abs: float, - max_abs: float, - gain_factor: float, - max_factor: float) -> Tensor: + +def _compute_scale_factor( + x: Tensor, + channel_dim: int, + min_abs: float, + max_abs: float, + gain_factor: float, + max_factor: float, +) -> Tensor: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] @@ -83,71 +88,76 @@ def _compute_scale_factor(x: Tensor, else: # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if # x_abs)_mean , min_abs. - below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(min=0, max=max_factor) + below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp( + min=0, max=max_factor + ) - above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(min=0, max=max_factor) + above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp( + min=0, max=max_factor + ) return below_threshold - above_threshold -def _compute_sign_factor(x: Tensor, - channel_dim: int, - min_positive: float, - max_positive: float, - gain_factor: float, - max_factor: float) -> Tensor: + +def _compute_sign_factor( + x: Tensor, + channel_dim: int, + min_positive: float, + max_positive: float, + gain_factor: float, + max_factor: float, +) -> Tensor: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] - proportion_positive = torch.mean((x > 0).to(torch.float32), - dim=sum_dims) + proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims) if min_positive == 0.0: factor1 = 0.0 else: # 0 if proportion_positive >= min_positive, else can be # as large as max_factor. - factor1 = ((min_positive - proportion_positive) * - (gain_factor / min_positive)).clamp_(min=0, max=max_factor) + factor1 = ( + (min_positive - proportion_positive) * (gain_factor / min_positive) + ).clamp_(min=0, max=max_factor) if max_positive == 1.0: factor2 = 0.0 else: # 0 if self.proportion_positive <= max_positive, else can be # as large as -max_factor. - factor2 = ((proportion_positive - max_positive) * - (gain_factor / (1.0 - max_positive))).clamp_(min=0, max=max_factor) + factor2 = ( + (proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive)) + ).clamp_(min=0, max=max_factor) sign_factor = factor1 - factor2 # require min_positive != 0 or max_positive != 1: assert not isinstance(sign_factor, float) return sign_factor - class ActivationScaleBalancerFunction(torch.autograd.Function): """ This object is used in class ActivationBalancer when the user specified min_positive=0, max_positive=1, so there are no constraints on the signs of the activations and only the absolute value has a constraint. """ + @staticmethod def forward( - ctx, - x: Tensor, - sign_factor: Tensor, - scale_factor: Tensor, - channel_dim: int, + ctx, + x: Tensor, + sign_factor: Tensor, + scale_factor: Tensor, + channel_dim: int, ) -> Tensor: if channel_dim < 0: channel_dim += x.ndim ctx.channel_dim = channel_dim - xgt0 = (x > 0) + xgt0 = x > 0 ctx.save_for_backward(xgt0, sign_factor, scale_factor) return x - @staticmethod - def backward( - ctx, x_grad: Tensor - ) -> Tuple[Tensor, None, None, None]: + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: xgt0, sign_factor, scale_factor = ctx.saved_tensors for _ in range(ctx.channel_dim, x_grad.ndim - 1): sign_factor = sign_factor.unsqueeze(-1) @@ -155,18 +165,24 @@ class ActivationScaleBalancerFunction(torch.autograd.Function): factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) neg_delta_grad = x_grad.abs() * factor - return x_grad - neg_delta_grad, None, None, None, + return ( + x_grad - neg_delta_grad, + None, + None, + None, + ) class RandomClampFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - min: Optional[float], - max: Optional[float], - prob: float, - reflect: float) -> Tensor: + ctx, + x: Tensor, + min: Optional[float], + max: Optional[float], + prob: float, + reflect: float, + ) -> Tensor: x_clamped = torch.clamp(x, min=min, max=max) mask = torch.rand_like(x) < prob ans = torch.where(mask, x_clamped, x) @@ -179,30 +195,32 @@ class RandomClampFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None, None]: - is_same, = ctx.saved_tensors + (is_same,) = ctx.saved_tensors x_grad = ans_grad * is_same.to(ans_grad.dtype) reflect = ctx.reflect - if reflect != 0.0: + if reflect != 0.0: x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect) return x_grad, None, None, None, None -def random_clamp(x: Tensor, - min: Optional[float] = None, - max: Optional[float] = None, - prob: float = 0.5, - reflect: float = 0.0): + +def random_clamp( + x: Tensor, + min: Optional[float] = None, + max: Optional[float] = None, + prob: float = 0.5, + reflect: float = 0.0, +): return RandomClampFunction.apply(x, min, max, prob, reflect) -def random_cast_to_half(x: Tensor, - min_abs: float = 5.0e-06) -> Tensor: +def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor: """ A randomized way of casting a floating point value to half precision. """ if x.dtype == torch.float16: return x x_abs = x.abs() - is_too_small = (x_abs < min_abs) + is_too_small = x_abs < min_abs # for elements where is_too_small is true, random_val will contain +-min_abs with # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, # for those elements]. @@ -215,6 +233,7 @@ class RandomGradFunction(torch.autograd.Function): Does nothing in forward pass; in backward pass, gets rid of very small grads using randomized approach that preserves expectations (intended to reduce roundoff). """ + @staticmethod def forward(ctx, x: Tensor, min_abs: float) -> Tensor: ctx.min_abs = min_abs @@ -223,35 +242,37 @@ class RandomGradFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]: if ans_grad.dtype == torch.float16: - return random_cast_to_half(ans_grad.to(torch.float32), - min_abs=ctx.min_abs), None + return ( + random_cast_to_half(ans_grad.to(torch.float32), min_abs=ctx.min_abs), + None, + ) else: return ans_grad, None + class RandomGrad(torch.nn.Module): """ Gets rid of very small gradients using an expectation-preserving method, intended to increase accuracy of training when using amp (automatic mixed precision) """ - def __init__(self, - min_abs: float = 5.0e-06): + + def __init__(self, min_abs: float = 5.0e-06): super(RandomGrad, self).__init__() self.min_abs = min_abs - def forward(self, - x: Tensor): + def forward(self, x: Tensor): if torch.jit.is_scripting() or not self.training: return x else: return RandomGradFunction.apply(x, self.min_abs) - class SoftmaxFunction(torch.autograd.Function): """ Tries to handle half-precision derivatives in a randomized way that should be more accurate for training than the default behavior. """ + @staticmethod def forward(ctx, x: Tensor, dim: int): ans = x.softmax(dim=dim) @@ -267,7 +288,7 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): - ans, = ctx.saved_tensors + (ans,) = ctx.saved_tensors with torch.cuda.amp.autocast(enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) @@ -276,9 +297,7 @@ class SoftmaxFunction(torch.autograd.Function): return x_grad, None - -def softmax(x: Tensor, - dim: int): +def softmax(x: Tensor, dim: int): if torch.jit.is_scripting(): return x.softmax(dim) @@ -288,20 +307,18 @@ def softmax(x: Tensor, class MaxEigLimiterFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - coeffs: Tensor, - direction: Tensor, - channel_dim: int, - grad_scale: float) -> Tensor: + ctx, + x: Tensor, + coeffs: Tensor, + direction: Tensor, + channel_dim: int, + grad_scale: float, + ) -> Tensor: ctx.channel_dim = channel_dim ctx.grad_scale = grad_scale - ctx.save_for_backward(x.detach(), - coeffs.detach(), - direction.detach()) + ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach()) return x - @staticmethod def backward(ctx, x_grad, *args): with torch.enable_grad(): @@ -311,15 +328,20 @@ class MaxEigLimiterFunction(torch.autograd.Function): x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) new_direction.requires_grad = False x = x - x.mean(dim=0) - x_var = (x ** 2).mean() + x_var = (x**2).mean() x_residual = x - coeffs * new_direction - x_residual_var = (x_residual ** 2).mean() + x_residual_var = (x_residual**2).mean() # `variance_proportion` is the proportion of the variance accounted for # by the top eigen-direction. This is to be minimized. variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) variance_proportion.backward() x_orig_grad = x_orig.grad - x_extra_grad = x_orig.grad * ctx.grad_scale * x_grad.norm() / (x_orig_grad.norm() + 1.0e-20) + x_extra_grad = ( + x_orig.grad + * ctx.grad_scale + * x_grad.norm() + / (x_orig_grad.norm() + 1.0e-20) + ) return x_grad + x_extra_grad.detach(), None, None, None, None @@ -385,15 +407,12 @@ class BasicNorm(torch.nn.Module): # region if it happens to exit it. eps = eps.clamp(min=self.eps_min, max=self.eps_max) scales = ( - torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp() + torch.mean(x**2, dim=self.channel_dim, keepdim=True) + eps.exp() ) ** -0.5 return x * scales - -def ScaledLinear(*args, - initial_scale: float = 1.0, - **kwargs ) -> nn.Linear: +def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: """ Behaves like a constructor of a modified version of nn.Linear that gives an easy way to set the default initial parameter scale. @@ -412,16 +431,11 @@ def ScaledLinear(*args, with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) return ans - -def ScaledConv1d(*args, - initial_scale: float = 1.0, - **kwargs ) -> nn.Conv1d: +def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d: """ Behaves like a constructor of a modified version of nn.Conv1d that gives an easy way to set the default initial parameter scale. @@ -440,13 +454,10 @@ def ScaledConv1d(*args, with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) return ans - class ActivationBalancer(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to encourage, for @@ -486,18 +497,19 @@ class ActivationBalancer(torch.nn.Module): from doing it at the same time. Early in training we may use higher probabilities than this; it will decay to this value. """ + def __init__( - self, - num_channels: int, - channel_dim: int, - min_positive: float = 0.05, - max_positive: float = 0.95, - max_factor: float = 0.04, - sign_gain_factor: float = 0.01, - scale_gain_factor: float = 0.02, - min_abs: float = 0.2, - max_abs: float = 100.0, - min_prob: float = 0.1, + self, + num_channels: int, + channel_dim: int, + min_positive: float = 0.05, + max_positive: float = 0.95, + max_factor: float = 0.04, + sign_gain_factor: float = 0.01, + scale_gain_factor: float = 0.02, + min_abs: float = 0.2, + max_abs: float = 100.0, + min_prob: float = 0.1, ): super(ActivationBalancer, self).__init__() self.num_channels = num_channels @@ -515,9 +527,7 @@ class ActivationBalancer(torch.nn.Module): # We occasionally sync this to a tensor called `count`, that exists to # make sure it is synced to disk when we load and save the model. self.cpu_count = 0 - self.register_buffer('count', torch.tensor(0, dtype=torch.int64)) - - + self.register_buffer("count", torch.tensor(0, dtype=torch.int64)) def forward(self, x: Tensor) -> Tensor: if torch.jit.is_scripting() or not x.requires_grad: @@ -535,26 +545,35 @@ class ActivationBalancer(torch.nn.Module): # the prob of doing some work exponentially decreases from 0.5 till it hits # a floor at min_prob (==0.1, by default) - prob = max(self.min_prob, 0.5 ** (1 + (count/4000.0))) + prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0))) if random.random() < prob: sign_gain_factor = 0.5 if self.min_positive != 0.0 or self.max_positive != 1.0: - sign_factor = _compute_sign_factor(x, self.channel_dim, - self.min_positive, self.max_positive, - gain_factor=self.sign_gain_factor / prob, - max_factor=self.max_factor) + sign_factor = _compute_sign_factor( + x, + self.channel_dim, + self.min_positive, + self.max_positive, + gain_factor=self.sign_gain_factor / prob, + max_factor=self.max_factor, + ) else: sign_factor = None - - scale_factor = _compute_scale_factor(x, self.channel_dim, - min_abs=self.min_abs, - max_abs=self.max_abs, - gain_factor=self.scale_gain_factor / prob, - max_factor=self.max_factor) + scale_factor = _compute_scale_factor( + x, + self.channel_dim, + min_abs=self.min_abs, + max_abs=self.max_abs, + gain_factor=self.scale_gain_factor / prob, + max_factor=self.max_factor, + ) return ActivationBalancerFunction.apply( - x, scale_factor, sign_factor, self.channel_dim, + x, + scale_factor, + sign_factor, + self.channel_dim, ) else: return _no_op(x) @@ -594,13 +613,12 @@ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. else: (batch, dim, dim) = x.shape x = x.reshape(batch, dim * dim) - x = x[:, ::dim+1] + x = x[:, :: dim + 1] assert x.shape == (batch, dim) return x -def _whitening_metric(x: Tensor, - num_groups: int): +def _whitening_metric(x: Tensor, num_groups: int): """ Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of of the centered feature covariance are the same within each group's covariance matrix @@ -630,19 +648,21 @@ def _whitening_metric(x: Tensor, # the following expression is what we'd get if we took the matrix product # of each covariance and measured the mean of its trace, i.e. # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). - x_covarsq_mean_diag = (x_covar ** 2).sum() / (num_groups * channels_per_group) + x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group) # this metric will be >= 1.0; the larger it is, the less 'white' the data was. - metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20) + metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20) return metric class WhiteningPenaltyFunction(torch.autograd.Function): @staticmethod - def forward(ctx, - x: Tensor, - num_groups: int, - whitening_limit: float, - grad_scale: float) -> Tensor: + def forward( + ctx, + x: Tensor, + num_groups: int, + whitening_limit: float, + grad_scale: float, + ) -> Tensor: ctx.save_for_backward(x) ctx.num_groups = num_groups ctx.whitening_limit = whitening_limit @@ -650,9 +670,8 @@ class WhiteningPenaltyFunction(torch.autograd.Function): return x @staticmethod - def backward(ctx, - x_grad: Tensor): - x_orig, = ctx.saved_tensors + def backward(ctx, x_grad: Tensor): + (x_orig,) = ctx.saved_tensors with torch.enable_grad(): with torch.cuda.amp.autocast(enabled=False): x_detached = x_orig.to(torch.float32).detach() @@ -661,25 +680,29 @@ class WhiteningPenaltyFunction(torch.autograd.Function): metric = _whitening_metric(x_detached, ctx.num_groups) if random.random() < 0.005 or __name__ == "__main__": - logging.info(f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, " - f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}") + logging.info( + f"Whitening: num_groups={ctx.num_groups}," + f" num_channels={x_orig.shape[-1]}," + f" metric={metric.item():.2f} vs. limit={ctx.whitening_limit}" + ) (metric - ctx.whitening_limit).relu().backward() penalty_grad = x_detached.grad - scale = ctx.grad_scale * (x_grad.to(torch.float32).norm() / - (penalty_grad.norm() + 1.0e-20)) + scale = ctx.grad_scale * ( + x_grad.to(torch.float32).norm() / (penalty_grad.norm() + 1.0e-20) + ) penalty_grad = penalty_grad * scale return x_grad + penalty_grad.to(x_grad.dtype), None, None, None - class Whiten(nn.Module): def __init__( - self, - num_groups: int, - whitening_limit: float, - prob: Union[float, Tuple[float,float]], - grad_scale: float): + self, + num_groups: int, + whitening_limit: float, + prob: Union[float, Tuple[float, float]], + grad_scale: float, + ): """ Args: num_groups: the number of groups to divide the channel dim into before @@ -714,8 +737,7 @@ class Whiten(nn.Module): self.grad_scale = grad_scale - def forward(self, - x: Tensor) -> Tensor: + def forward(self, x: Tensor) -> Tensor: """ In the forward pass, this function just returns the input unmodified. In the backward pass, it will modify the gradients to ensure that the @@ -735,19 +757,21 @@ class Whiten(nn.Module): if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0: return _no_op(x) else: - if hasattr(self, 'min_prob') and random.random() < 0.25: + if hasattr(self, "min_prob") and random.random() < 0.25: # occasionally switch between min_prob and max_prob, based on whether # we are above or below the threshold. - if _whitening_metric(x.to(torch.float32), self.num_groups) > self.whitening_limit: + if ( + _whitening_metric(x.to(torch.float32), self.num_groups) + > self.whitening_limit + ): # there would be a change to the grad. self.prob = self.max_prob else: self.prob = self.min_prob - return WhiteningPenaltyFunction.apply(x, - self.num_groups, - self.whitening_limit, - self.grad_scale) + return WhiteningPenaltyFunction.apply( + x, self.num_groups, self.whitening_limit, self.grad_scale + ) class WithLoss(torch.autograd.Function): @@ -755,11 +779,14 @@ class WithLoss(torch.autograd.Function): def forward(ctx, x: Tensor, y: Tensor): ctx.y_shape = y.shape return x + @staticmethod def backward(ctx, ans_grad: Tensor): - return ans_grad, torch.ones(ctx.y_shape, - dtype=ans_grad.dtype, - device=ans_grad.device) + return ans_grad, torch.ones( + ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device + ) + + def with_loss(x, y): if torch.jit.is_scripting(): return x @@ -768,7 +795,7 @@ def with_loss(x, y): def _no_op(x: Tensor) -> Tensor: - if (torch.jit.is_scripting()): + if torch.jit.is_scripting(): return x else: # a no-op function that will have a node in the autograd graph, @@ -783,6 +810,7 @@ class Identity(torch.nn.Module): def forward(self, x): return _no_op(x) + class MaxEig(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to discourage @@ -803,13 +831,14 @@ class MaxEig(torch.nn.Module): scale: determines the scale with which we modify the gradients, relative to the existing / unmodified gradients """ + def __init__( - self, - num_channels: int, - channel_dim: int, - max_var_per_eig: float = 0.2, - min_prob: float = 0.01, - scale: float = 0.01, + self, + num_channels: int, + channel_dim: int, + max_var_per_eig: float = 0.2, + min_prob: float = 0.01, + scale: float = 0.01, ): super(MaxEig, self).__init__() self.num_channels = num_channels @@ -825,7 +854,7 @@ class MaxEig(torch.nn.Module): # random parameters unchanged for comparison direction = torch.arange(num_channels).to(torch.float) direction = direction / direction.norm() - self.register_buffer('max_eig_direction', direction) + self.register_buffer("max_eig_direction", direction) self.min_prob = min_prob # cur_prob is the current probability we'll use to apply the ActivationBalancer. @@ -833,12 +862,12 @@ class MaxEig(torch.nn.Module): # active. self.cur_prob = 1.0 - - def forward(self, x: Tensor) -> Tensor: - if (torch.jit.is_scripting() or - self.max_var_per_eig <= 0 or - random.random() > self.cur_prob): + if ( + torch.jit.is_scripting() + or self.max_var_per_eig <= 0 + or random.random() > self.cur_prob + ): return _no_op(x) with torch.cuda.amp.autocast(enabled=False): @@ -848,7 +877,9 @@ class MaxEig(torch.nn.Module): with torch.no_grad(): x = x.transpose(self.channel_dim, -1).reshape(-1, self.num_channels) x = x - x.mean(dim=0) - new_direction, coeffs = self._find_direction_coeffs(x, self.max_eig_direction) + new_direction, coeffs = self._find_direction_coeffs( + x, self.max_eig_direction + ) x_var = (x**2).mean() x_residual = x - coeffs * new_direction x_residual_var = (x_residual**2).mean() @@ -861,7 +892,10 @@ class MaxEig(torch.nn.Module): self._set_direction(0.1 * self.max_eig_direction + new_direction) if random.random() < 0.01 or __name__ == "__main__": - logging.info(f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}") + logging.info( + f"variance_proportion = {variance_proportion.item()}," + f" shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}" + ) if variance_proportion >= self.max_var_per_eig: # The constraint is active. Note, we should quite rarely @@ -869,17 +903,16 @@ class MaxEig(torch.nn.Module): # starting to diverge, should this constraint be active. cur_prob = self.cur_prob self.cur_prob = 1.0 # next time, do the update with probability 1.0. - return MaxEigLimiterFunction.apply(orig_x, coeffs, new_direction, - self.channel_dim, self.scale) + return MaxEigLimiterFunction.apply( + orig_x, coeffs, new_direction, self.channel_dim, self.scale + ) else: # let self.cur_prob exponentially approach self.min_prob, as # long as the constraint is inactive. self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob return orig_x - - def _set_direction(self, - direction: Tensor): + def _set_direction(self, direction: Tensor): """ Sets self.max_eig_direction to a normalized version of `direction` """ @@ -889,40 +922,39 @@ class MaxEig(torch.nn.Module): if direction_sum - direction_sum == 0: # no inf/nan self.max_eig_direction[:] = direction else: - logging.info(f"Warning: sum of direction in MaxEig is {direction_sum}, " - "num_channels={self.num_channels}, channel_dim={self.channel_dim}") + logging.info( + f"Warning: sum of direction in MaxEig is {direction_sum}, " + "num_channels={self.num_channels}, channel_dim={self.channel_dim}" + ) - - def _find_direction_coeffs(self, - x: Tensor, - prev_direction: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + def _find_direction_coeffs( + self, x: Tensor, prev_direction: Tensor + ) -> Tuple[Tensor, Tensor, Tensor]: """ - Figure out (an approximation to) the proportion of the variance of a set of - feature vectors that can be attributed to the top eigen-direction. - Args: - x: a Tensor of shape (num_frames, num_channels), with num_frames > 1. - prev_direction: a Tensor of shape (num_channels,), that is our previous estimate - of the top eigen-direction, or a random direction if this is the first - iteration. Does not have to be normalized, but should be nonzero. + Figure out (an approximation to) the proportion of the variance of a set of + feature vectors that can be attributed to the top eigen-direction. + Args: + x: a Tensor of shape (num_frames, num_channels), with num_frames > 1. + prev_direction: a Tensor of shape (num_channels,), that is our previous estimate + of the top eigen-direction, or a random direction if this is the first + iteration. Does not have to be normalized, but should be nonzero. - Returns: (cur_direction, coeffs), where: - cur_direction: a Tensor of shape (num_channels,) that is the current - estimate of the top eigen-direction. - coeffs: a Tensor of shape (num_frames, 1) that minimizes, or - approximately minimizes, (x - coeffs * cur_direction).norm() - """ + Returns: (cur_direction, coeffs), where: + cur_direction: a Tensor of shape (num_channels,) that is the current + estimate of the top eigen-direction. + coeffs: a Tensor of shape (num_frames, 1) that minimizes, or + approximately minimizes, (x - coeffs * cur_direction).norm() + """ (num_frames, num_channels) = x.shape assert num_channels > 1 and num_frames > 1 assert prev_direction.shape == (num_channels,) # `coeffs` are the coefficients of `prev_direction` in x. # actually represent the coeffs up to a constant positive factor. coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10 - cur_direction = (x * coeffs).sum(dim=0) / ((coeffs ** 2).sum() + 1.0e-20) + cur_direction = (x * coeffs).sum(dim=0) / ((coeffs**2).sum() + 1.0e-20) return cur_direction, coeffs - - class DoubleSwishFunction(torch.autograd.Function): """ double_swish(x) = x * torch.sigmoid(x-1) @@ -950,7 +982,7 @@ class DoubleSwishFunction(torch.autograd.Function): y = x * s if requires_grad: - deriv = (y * (1 - s) + s) + deriv = y * (1 - s) + s # notes on derivative of x * sigmoid(x - 1): # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund @@ -959,7 +991,9 @@ class DoubleSwishFunction(torch.autograd.Function): # floors), should be expectation-preserving. floor = -0.043637 ceil = 1.2 - d_scaled = ((deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv)) + d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + deriv + ) if __name__ == "__main__": # for self-testing only. assert d_scaled.min() >= 0.0 @@ -972,12 +1006,12 @@ class DoubleSwishFunction(torch.autograd.Function): @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: - d, = ctx.saved_tensors + (d,) = ctx.saved_tensors # the same constants as used in forward pass. floor = -0.043637 ceil = 1.2 - d = (d * ((ceil - floor) / 255.0) + floor) - return (y_grad * d) + d = d * ((ceil - floor) / 255.0) + floor + return y_grad * d class DoubleSwish(torch.nn.Module): @@ -990,7 +1024,6 @@ class DoubleSwish(torch.nn.Module): return DoubleSwishFunction.apply(x) - def _test_max_eig(): for proportion in [0.1, 0.5, 10.0]: logging.info(f"proportion = {proportion}") @@ -1002,11 +1035,9 @@ def _test_max_eig(): x.requires_grad = True num_channels = 128 - m = MaxEig(num_channels, - 1, # channel_dim - 0.5, # max_var_per_eig - scale=0.1) # grad_scale - + m = MaxEig( + num_channels, 1, 0.5, scale=0.1 # channel_dim # max_var_per_eig + ) # grad_scale for _ in range(4): y = m(x) @@ -1031,11 +1062,9 @@ def _test_whiten(): x.requires_grad = True num_channels = 128 - m = Whiten(1, # num_groups - 5.0, # whitening_limit, - prob=1.0, - grad_scale=0.1) # grad_scale - + m = Whiten( + 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit, + ) # grad_scale for _ in range(4): y = m(x) @@ -1049,7 +1078,6 @@ def _test_whiten(): assert not torch.allclose(x.grad, y_grad) - def _test_activation_balancer_sign(): probs = torch.arange(0, 1, 0.01) N = 1000 @@ -1077,9 +1105,7 @@ def _test_activation_balancer_sign(): def _test_activation_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( - -1 - ) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True m = ActivationBalancer( @@ -1111,8 +1137,8 @@ def _test_basic_norm(): y = m(x) assert y.shape == x.shape - x_rms = (x ** 2).mean().sqrt() - y_rms = (y ** 2).mean().sqrt() + x_rms = (x**2).mean().sqrt() + y_rms = (y**2).mean().sqrt() print("x rms = ", x_rms) print("y rms = ", y_rms) assert y_rms < x_rms @@ -1124,30 +1150,27 @@ def _test_double_swish_deriv(): x.requires_grad = True m = DoubleSwish() - tol = ((1.2-(-0.043637))/255.0) + tol = (1.2 - (-0.043637)) / 255.0 torch.autograd.gradcheck(m, x, atol=tol) - # for self-test. x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 x.requires_grad = True y = m(x) - def _test_softmax(): a = torch.randn(2, 10, dtype=torch.float64) b = a.clone() a.requires_grad = True b.requires_grad = True - a.softmax(dim=1)[:,0].sum().backward() + a.softmax(dim=1)[:, 0].sum().backward() print("a grad = ", a.grad) - softmax(b, dim=1)[:,0].sum().backward() + softmax(b, dim=1)[:, 0].sum().backward() print("b grad = ", b.grad) assert torch.allclose(a.grad, b.grad) - if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py index 8d357b15f..46e775285 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py @@ -26,11 +26,7 @@ from typing import List import torch import torch.nn as nn -from scaling import ( - ActivationBalancer, - BasicNorm, - Whiten, -) +from scaling import ActivationBalancer, BasicNorm, Whiten class NonScaledNorm(nn.Module): @@ -75,12 +71,10 @@ def get_submodule(model, target): mod: torch.nn.Module = model for item in atoms: if not hasattr(mod, item): - raise AttributeError( - mod._get_name() + " has no " "attribute `" + item + "`" - ) + raise AttributeError(mod._get_name() + " has no attribute `" + item + "`") mod = getattr(mod, item) if not isinstance(mod, torch.nn.Module): - raise AttributeError("`" + item + "` is not " "an nn.Module") + raise AttributeError("`" + item + "` is not an nn.Module") return mod diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 3f27736b3..7f9526104 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -84,9 +84,7 @@ from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: @@ -124,7 +122,10 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-dims", type=str, default="384,384,384,384,384", - help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + help=( + "Embedding dimension in the 2 blocks of zipformer encoder layers, comma" + " separated" + ), ) parser.add_argument( @@ -139,9 +140,11 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-unmasked-dims", type=str, default="256,256,256,256,256", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " - " worse.", + help=( + "Unmasked dimensions in the encoders, relates to augmentation during" + " training. Must be <= each of encoder_dims. Empirically, less than 256" + " seems to make performance worse." + ), ) parser.add_argument( @@ -269,42 +272,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -646,11 +652,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -697,9 +699,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -870,9 +870,7 @@ def train_one_epoch( # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or ( - cur_grad_scale < 8.0 and batch_idx % 400 == 0 - ): + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") @@ -890,11 +888,7 @@ def train_one_epoch( f"batch {batch_idx}, loss[{loss_info}], " f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"lr: {cur_lr:.2e}, " - + ( - f"grad_scale: {scaler._scale.item()}" - if params.use_fp16 - else "" - ) + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") ) if tb_writer is not None: @@ -905,9 +899,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if params.use_fp16: tb_writer.add_scalar( "train/grad_scale", @@ -915,10 +907,7 @@ def train_one_epoch( params.batch_idx_train, ) - if ( - batch_idx % params.valid_interval == 0 - and not params.print_diagnostics - ): + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, @@ -930,7 +919,8 @@ def train_one_epoch( model.train() logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + "Maximum memory allocated so far is" + f" {torch.cuda.max_memory_allocated()//1000000}MB" ) if tb_writer is not None: valid_info.write_summary( @@ -1009,9 +999,7 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = ScaledAdam( - model.parameters(), lr=params.base_lr, clipping_scale=2.0 - ) + optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) @@ -1029,7 +1017,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1054,8 +1042,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False @@ -1229,7 +1216,8 @@ def scan_pessimistic_batches_for_oom( display_and_save_batch(batch, params=params, sp=sp) raise logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + "Maximum memory allocated so far is" + f" {torch.cuda.max_memory_allocated()//1000000}MB" ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 023dec97d..fcd9858cd 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -16,32 +16,35 @@ # limitations under the License. import copy -import math -import warnings import itertools -from typing import List, Optional, Tuple, Union import logging -import torch +import math import random +import warnings +from typing import List, Optional, Tuple, Union + +import torch from encoder_interface import EncoderInterface +from scaling import ( + ScaledLinear, # not as in other dirs.. just scales down initial parameter values. +) from scaling import ( ActivationBalancer, BasicNorm, - MaxEig, DoubleSwish, - ScaledConv1d, - ScaledLinear, # not as in other dirs.. just scales down initial parameter values. - Whiten, Identity, + MaxEig, + ScaledConv1d, + Whiten, _diag, - random_clamp, penalize_abs_values_gt, + random_clamp, softmax, ) from torch import Tensor, nn -from icefall.utils import make_pad_mask from icefall.dist import get_rank +from icefall.utils import make_pad_mask class Zipformer(EncoderInterface): @@ -89,7 +92,7 @@ class Zipformer(EncoderInterface): self.batch_count = 0 self.warmup_end = warmup_batches - for u,d in zip(encoder_unmasked_dims, encoder_dims): + for u, d in zip(encoder_unmasked_dims, encoder_dims): assert u <= d, (u, d) # self.encoder_embed converts the input of shape (N, T, num_features) @@ -97,9 +100,9 @@ class Zipformer(EncoderInterface): # That is, it does two things simultaneously: # (1) subsampling: T -> (T - 7)//2 # (2) embedding: num_features -> encoder_dims - self.encoder_embed = Conv2dSubsampling(num_features, encoder_dims[0], - dropout=dropout) - + self.encoder_embed = Conv2dSubsampling( + num_features, encoder_dims[0], dropout=dropout + ) # each one will be ZipformerEncoder or DownsampledZipformerEncoder encoders = [] @@ -123,13 +126,13 @@ class Zipformer(EncoderInterface): num_encoder_layers[i], dropout, warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), - warmup_end=warmup_batches * (i + 2) / (num_encoders + 1) + warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), ) if zipformer_downsampling_factors[i] != 1: encoder = DownsampledZipformerEncoder( encoder, - input_dim=encoder_dims[i-1] if i > 0 else encoder_dims[0], + input_dim=encoder_dims[i - 1] if i > 0 else encoder_dims[0], output_dim=encoder_dims[i], downsample=zipformer_downsampling_factors[i], ) @@ -139,10 +142,11 @@ class Zipformer(EncoderInterface): # initializes self.skip_layers and self.skip_modules self._init_skip_modules() - self.downsample_output = AttentionDownsample(encoder_dims[-1], - encoder_dims[-1], - downsample=output_downsampling_factor) - + self.downsample_output = AttentionDownsample( + encoder_dims[-1], + encoder_dims[-1], + downsample=output_downsampling_factor, + ) def _get_layer_skip_dropout_prob(self): if not self.training: @@ -166,27 +170,33 @@ class Zipformer(EncoderInterface): skip_modules = [] z = self.zipformer_downsampling_factors for i in range(len(z)): - if i <= 1 or z[i-1] <= z[i]: + if i <= 1 or z[i - 1] <= z[i]: skip_layers.append(None) skip_modules.append(SimpleCombinerIdentity()) else: # TEMP - for j in range(i-2, -1, -1): + for j in range(i - 2, -1, -1): if z[j] <= z[i] or j == 0: # TEMP logging statement. - logging.info(f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will " - f"combine the outputs of layers {j} and {i-1}, with downsampling_factors={z[j]} and {z[i-1]}.") + logging.info( + f"At encoder stack {i}, which has" + f" downsampling_factor={z[i]}, we will combine the outputs" + f" of layers {j} and {i-1}, with" + f" downsampling_factors={z[j]} and {z[i-1]}." + ) skip_layers.append(j) - skip_modules.append(SimpleCombiner(self.encoder_dims[j], - self.encoder_dims[i-1], - min_weight=(0.0,0.25))) + skip_modules.append( + SimpleCombiner( + self.encoder_dims[j], + self.encoder_dims[i - 1], + min_weight=(0.0, 0.25), + ) + ) break self.skip_layers = skip_layers self.skip_modules = nn.ModuleList(skip_modules) - def get_feature_masks( - self, - x: torch.Tensor) -> List[float]: + def get_feature_masks(self, x: torch.Tensor) -> List[float]: # Note: The actual return type is Union[List[float], List[Tensor]], # but to make torch.jit.script() work, we use List[float] """ @@ -206,46 +216,56 @@ class Zipformer(EncoderInterface): """ num_encoders = len(self.encoder_dims) if torch.jit.is_scripting() or not self.training: - return [ 1.0 ] * num_encoders + return [1.0] * num_encoders (num_frames0, batch_size, _encoder_dims0) = x.shape - - assert self.encoder_dims[0] == _encoder_dims0, (self.encoder_dims, _encoder_dims0) + assert self.encoder_dims[0] == _encoder_dims0, ( + self.encoder_dims, + _encoder_dims0, + ) max_downsampling_factor = max(self.zipformer_downsampling_factors) - num_frames_max = (num_frames0 + max_downsampling_factor - 1) - + num_frames_max = num_frames0 + max_downsampling_factor - 1 feature_mask_dropout_prob = 0.15 # frame_mask_max shape: (num_frames_max, batch_size, 1) - frame_mask_max = (torch.rand(num_frames_max, batch_size, 1, - device=x.device) > - feature_mask_dropout_prob).to(x.dtype) + frame_mask_max = ( + torch.rand(num_frames_max, batch_size, 1, device=x.device) + > feature_mask_dropout_prob + ).to(x.dtype) feature_masks = [] for i in range(num_encoders): ds = self.zipformer_downsampling_factors[i] - upsample_factor = (max_downsampling_factor // ds) + upsample_factor = max_downsampling_factor // ds - frame_mask = (frame_mask_max.unsqueeze(1).expand(num_frames_max, upsample_factor, - batch_size, 1) - .reshape(num_frames_max * upsample_factor, batch_size, 1)) + frame_mask = ( + frame_mask_max.unsqueeze(1) + .expand(num_frames_max, upsample_factor, batch_size, 1) + .reshape(num_frames_max * upsample_factor, batch_size, 1) + ) num_frames = (num_frames0 + ds - 1) // ds frame_mask = frame_mask[:num_frames] - feature_mask = torch.ones(num_frames, batch_size, self.encoder_dims[i], - dtype=x.dtype, device=x.device) + feature_mask = torch.ones( + num_frames, + batch_size, + self.encoder_dims[i], + dtype=x.dtype, + device=x.device, + ) u = self.encoder_unmasked_dims[i] feature_mask[:, :, u:] *= frame_mask feature_masks.append(feature_mask) return feature_masks - def forward( - self, x: torch.Tensor, x_lens: torch.Tensor, + self, + x: torch.Tensor, + x_lens: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -265,13 +285,19 @@ class Zipformer(EncoderInterface): x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) lengths = (x_lens - 7) >> 1 - assert x.size(0) == lengths.max().item(), (x.shape, lengths, lengths.max()) + assert x.size(0) == lengths.max().item(), ( + x.shape, + lengths, + lengths.max(), + ) mask = make_pad_mask(lengths) outputs = [] feature_masks = self.get_feature_masks(x) - for i, (module, skip_module) in enumerate(zip(self.encoders, self.skip_modules)): + for i, (module, skip_module) in enumerate( + zip(self.encoders, self.skip_modules) + ): ds = self.zipformer_downsampling_factors[i] k = self.skip_layers[i] if isinstance(k, int): @@ -280,9 +306,11 @@ class Zipformer(EncoderInterface): x = skip_module(outputs[k], x) elif (not self.training) or random.random() > layer_skip_dropout_prob: x = skip_module(outputs[k], x) - x = module(x, - feature_mask=feature_masks[i], - src_key_padding_mask=None if mask is None else mask[...,::ds]) + x = module( + x, + feature_mask=feature_masks[i], + src_key_padding_mask=None if mask is None else mask[..., ::ds], + ) outputs.append(x) x = self.downsample_output(x) @@ -312,15 +340,16 @@ class ZipformerEncoderLayer(nn.Module): >>> pos_emb = torch.rand(32, 19, 512) >>> out = encoder_layer(src, pos_emb) """ + def __init__( - self, - d_model: int, - attention_dim: int, - nhead: int, - feedforward_dim: int = 2048, - dropout: float = 0.1, - cnn_module_kernel: int = 31, - pos_dim: int = 4, + self, + d_model: int, + attention_dim: int, + nhead: int, + feedforward_dim: int = 2048, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + pos_dim: int = 4, ) -> None: super(ZipformerEncoderLayer, self).__init__() @@ -330,29 +359,24 @@ class ZipformerEncoderLayer(nn.Module): self.batch_count = 0 self.self_attn = RelPositionMultiheadAttention( - d_model, attention_dim, nhead, pos_dim, dropout=0.0, + d_model, + attention_dim, + nhead, + pos_dim, + dropout=0.0, ) self.pooling = PoolingModule(d_model) - self.feed_forward1 = FeedforwardModule(d_model, - feedforward_dim, - dropout) + self.feed_forward1 = FeedforwardModule(d_model, feedforward_dim, dropout) - self.feed_forward2 = FeedforwardModule(d_model, - feedforward_dim, - dropout) + self.feed_forward2 = FeedforwardModule(d_model, feedforward_dim, dropout) - self.feed_forward3 = FeedforwardModule(d_model, - feedforward_dim, - dropout) + self.feed_forward3 = FeedforwardModule(d_model, feedforward_dim, dropout) + self.conv_module1 = ConvolutionModule(d_model, cnn_module_kernel) - self.conv_module1 = ConvolutionModule(d_model, - cnn_module_kernel) - - self.conv_module2 = ConvolutionModule(d_model, - cnn_module_kernel) + self.conv_module2 = ConvolutionModule(d_model, cnn_module_kernel) self.norm_final = BasicNorm(d_model) @@ -360,14 +384,18 @@ class ZipformerEncoderLayer(nn.Module): # try to ensure the output is close to zero-mean (or at least, zero-median). self.balancer = ActivationBalancer( - d_model, channel_dim=-1, - min_positive=0.45, max_positive=0.55, + d_model, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, max_abs=6.0, ) - self.whiten = Whiten(num_groups=1, - whitening_limit=5.0, - prob=(0.025, 0.25), - grad_scale=0.01) + self.whiten = Whiten( + num_groups=1, + whitening_limit=5.0, + prob=(0.025, 0.25), + grad_scale=0.01, + ) def get_bypass_scale(self): if torch.jit.is_scripting() or not self.training: @@ -382,8 +410,9 @@ class ZipformerEncoderLayer(nn.Module): if self.batch_count > warmup_period: clamp_min = final_clamp_min else: - clamp_min = (initial_clamp_min - - (self.batch_count / warmup_period) * (initial_clamp_min - final_clamp_min)) + clamp_min = initial_clamp_min - (self.batch_count / warmup_period) * ( + initial_clamp_min - final_clamp_min + ) return self.bypass_scale.clamp(min=clamp_min, max=1.0) def get_dynamic_dropout_rate(self): @@ -398,8 +427,9 @@ class ZipformerEncoderLayer(nn.Module): if self.batch_count > warmup_period: return final_dropout_rate else: - return (initial_dropout_rate - - (initial_dropout_rate * final_dropout_rate) * (self.batch_count / warmup_period)) + return initial_dropout_rate - ( + initial_dropout_rate * final_dropout_rate + ) * (self.batch_count / warmup_period) def forward( self, @@ -508,13 +538,14 @@ class ZipformerEncoder(nn.Module): >>> src = torch.rand(10, 32, 512) >>> out = zipformer_encoder(src) """ + def __init__( - self, - encoder_layer: nn.Module, - num_layers: int, - dropout: float, - warmup_begin: float, - warmup_end: float + self, + encoder_layer: nn.Module, + num_layers: int, + dropout: float, + warmup_begin: float, + warmup_end: float, ) -> None: super().__init__() # will be written to, see set_batch_count() Note: in inference time this @@ -528,8 +559,7 @@ class ZipformerEncoder(nn.Module): # so that we can keep this consistent across worker tasks (for efficiency). self.module_seed = torch.randint(0, 1000, ()).item() - self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model, - dropout) + self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model, dropout) self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] @@ -538,15 +568,13 @@ class ZipformerEncoder(nn.Module): assert 0 <= warmup_begin <= warmup_end, (warmup_begin, warmup_end) - - delta = (1. / num_layers) * (warmup_end - warmup_begin) + delta = (1.0 / num_layers) * (warmup_end - warmup_begin) cur_begin = warmup_begin for i in range(num_layers): self.layers[i].warmup_begin = cur_begin cur_begin += delta self.layers[i].warmup_end = cur_begin - def get_layers_to_drop(self, rnd_seed: int): ans = set() if not self.training: @@ -579,12 +607,14 @@ class ZipformerEncoder(nn.Module): # linearly interpolate t = (batch_count - layer_warmup_begin) / layer_warmup_end assert 0.0 <= t < 1.001, t - return initial_layerdrop_prob + t * (final_layerdrop_prob - initial_layerdrop_prob) + return initial_layerdrop_prob + t * ( + final_layerdrop_prob - initial_layerdrop_prob + ) shared_rng = random.Random(batch_count + self.module_seed) independent_rng = random.Random(rnd_seed) - layerdrop_probs = [ get_layerdrop_prob(i) for i in range(num_layers) ] + layerdrop_probs = [get_layerdrop_prob(i) for i in range(num_layers)] tot = sum(layerdrop_probs) # Instead of drawing the samples independently, we first randomly decide # how many layers to drop out, using the same random number generator between @@ -604,11 +634,13 @@ class ZipformerEncoder(nn.Module): if len(ans) == num_to_drop: break if shared_rng.random() < 0.005 or __name__ == "__main__": - logging.info(f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, " - f"batch_count={batch_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}") + logging.info( + f"warmup_begin={self.warmup_begin:.1f}," + f" warmup_end={self.warmup_end:.1f}, batch_count={batch_count:.1f}," + f" num_to_drop={num_to_drop}, layers_to_drop={ans}" + ) return ans - def forward( self, src: Tensor, @@ -639,7 +671,6 @@ class ZipformerEncoder(nn.Module): pos_emb = self.encoder_pos(src) output = src - if torch.jit.is_scripting(): layers_to_drop = [] else: @@ -670,28 +701,31 @@ class DownsampledZipformerEncoder(nn.Module): after convolutional downsampling, and then upsampled again at the output, and combined with the origin input, so that the output has the same shape as the input. """ - def __init__(self, - encoder: nn.Module, - input_dim: int, - output_dim: int, - downsample: int): + + def __init__( + self, + encoder: nn.Module, + input_dim: int, + output_dim: int, + downsample: int, + ): super(DownsampledZipformerEncoder, self).__init__() self.downsample_factor = downsample self.downsample = AttentionDownsample(input_dim, output_dim, downsample) self.encoder = encoder self.upsample = SimpleUpsample(output_dim, downsample) - self.out_combiner = SimpleCombiner(input_dim, - output_dim, - min_weight=(0.0, 0.25)) + self.out_combiner = SimpleCombiner( + input_dim, output_dim, min_weight=(0.0, 0.25) + ) - - def forward(self, - src: Tensor, - # Note: the type of feature_mask should be Unino[float, Tensor], - # but to make torch.jit.script() happ, we use float here - feature_mask: float = 1.0, - mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, + def forward( + self, + src: Tensor, + # Note: the type of feature_mask should be Unino[float, Tensor], + # but to make torch.jit.script() happ, we use float here + feature_mask: float = 1.0, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: r"""Downsample, go through encoder, upsample. @@ -718,42 +752,43 @@ class DownsampledZipformerEncoder(nn.Module): src = self.downsample(src) ds = self.downsample_factor if mask is not None: - mask = mask[::ds,::ds] + mask = mask[::ds, ::ds] src = self.encoder( - src, feature_mask=feature_mask, mask=mask, src_key_padding_mask=mask, + src, + feature_mask=feature_mask, + mask=mask, + src_key_padding_mask=mask, ) src = self.upsample(src) # remove any extra frames that are not a multiple of downsample_factor - src = src[:src_orig.shape[0]] + src = src[: src_orig.shape[0]] return self.out_combiner(src_orig, src) + class AttentionDownsample(torch.nn.Module): """ Does downsampling with attention, by weighted sum, and a projection.. """ - def __init__(self, - in_channels: int, - out_channels: int, - downsample: int): + + def __init__(self, in_channels: int, out_channels: int, downsample: int): """ Require out_channels > in_channels. """ super(AttentionDownsample, self).__init__() - self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5)) + self.query = nn.Parameter(torch.randn(in_channels) * (in_channels**-0.5)) # fill in the extra dimensions with a projection of the input if out_channels > in_channels: - self.extra_proj = nn.Linear(in_channels * downsample, - out_channels - in_channels, - bias=False) + self.extra_proj = nn.Linear( + in_channels * downsample, out_channels - in_channels, bias=False + ) else: self.extra_proj = None self.downsample = downsample - def forward(self, - src: Tensor) -> Tensor: + def forward(self, src: Tensor) -> Tensor: """ x: (seq_len, batch_size, in_channels) Returns a tensor of shape @@ -767,16 +802,14 @@ class AttentionDownsample(torch.nn.Module): if seq_len != d_seq_len * ds: # right-pad src, repeating the last element. pad = d_seq_len * ds - seq_len - src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2]) + src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) src = torch.cat((src, src_extra), dim=0) assert src.shape[0] == d_seq_len * ds, (src.shape[0], d_seq_len, ds) src = src.reshape(d_seq_len, ds, batch_size, in_channels) scores = (src * self.query).sum(dim=-1, keepdim=True) - scores = penalize_abs_values_gt(scores, - limit=10.0, - penalty=1.0e-04) + scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) weights = scores.softmax(dim=1) @@ -795,14 +828,12 @@ class SimpleUpsample(torch.nn.Module): A very simple form of upsampling that mostly just repeats the input, but also adds a position-specific bias. """ - def __init__(self, - num_channels: int, - upsample: int): + + def __init__(self, num_channels: int, upsample: int): super(SimpleUpsample, self).__init__() self.bias = nn.Parameter(torch.randn(upsample, num_channels) * 0.01) - def forward(self, - src: Tensor) -> Tensor: + def forward(self, src: Tensor) -> Tensor: """ x: (seq_len, batch_size, num_channels) Returns a tensor of shape @@ -815,6 +846,7 @@ class SimpleUpsample(torch.nn.Module): src = src.reshape(seq_len * upsample, batch_size, num_channels) return src + class SimpleCombinerIdentity(nn.Module): def __init__(self, *args, **kwargs): super().__init__() @@ -822,6 +854,7 @@ class SimpleCombinerIdentity(nn.Module): def forward(self, src1: Tensor, src2: Tensor) -> Tensor: return src1 + class SimpleCombiner(torch.nn.Module): """ A very simple way of combining 2 vectors of 2 different dims, via a @@ -831,18 +864,14 @@ class SimpleCombiner(torch.nn.Module): dim2: the dimension of the second input, e.g. 384. The output will have the same dimension as dim2. """ - def __init__(self, - dim1: int, - dim2: int, - min_weight: Tuple[float] = (0., 0.)): + + def __init__(self, dim1: int, dim2: int, min_weight: Tuple[float] = (0.0, 0.0)): super(SimpleCombiner, self).__init__() assert dim2 >= dim1, (dim2, dim1) self.weight1 = nn.Parameter(torch.zeros(())) self.min_weight = min_weight - def forward(self, - src1: Tensor, - src2: Tensor) -> Tensor: + def forward(self, src1: Tensor, src2: Tensor) -> Tensor: """ src1: (*, dim1) src2: (*, dim2) @@ -853,10 +882,14 @@ class SimpleCombiner(torch.nn.Module): weight1 = self.weight1 if not torch.jit.is_scripting(): - if self.training and random.random() < 0.25 and self.min_weight != (0., 0.): - weight1 = weight1.clamp(min=self.min_weight[0], - max=1.0-self.min_weight[1]) - + if ( + self.training + and random.random() < 0.25 + and self.min_weight != (0.0, 0.0) + ): + weight1 = weight1.clamp( + min=self.min_weight[0], max=1.0 - self.min_weight[1] + ) src1 = src1 * weight1 src2 = src2 * (1.0 - weight1) @@ -869,12 +902,9 @@ class SimpleCombiner(torch.nn.Module): else: src1 = src1[:src2_dim] - return src1 + src2 - - class RelPositionalEncoding(torch.nn.Module): """Relative positional encoding module. @@ -888,9 +918,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct a PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -905,9 +933,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(0) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vecotr and `j` means the @@ -955,7 +981,6 @@ class RelPositionalEncoding(torch.nn.Module): return self.dropout(pos_emb) - class RelPositionMultiheadAttention(nn.Module): r"""Multi-Head Attention layer with relative position encoding @@ -992,34 +1017,46 @@ class RelPositionMultiheadAttention(nn.Module): self.head_dim = attention_dim // num_heads self.pos_dim = pos_dim assert self.head_dim % 2 == 0, self.head_dim - assert ( - self.head_dim * num_heads == attention_dim - ), (self.head_dim, num_heads, attention_dim) + assert self.head_dim * num_heads == attention_dim, ( + self.head_dim, + num_heads, + attention_dim, + ) # the initial_scale is supposed to take over the "scaling" factor of # head_dim ** -0.5, dividing it between the query and key. - in_proj_dim = (2 * attention_dim + # query, key - attention_dim // 2 + # value - pos_dim * num_heads) # positional encoding query + in_proj_dim = ( + 2 * attention_dim + + attention_dim // 2 # query, key + + pos_dim * num_heads # value + ) # positional encoding query - self.in_proj = ScaledLinear(embed_dim, in_proj_dim, bias=True, - initial_scale=self.head_dim**-0.25) + self.in_proj = ScaledLinear( + embed_dim, + in_proj_dim, + bias=True, + initial_scale=self.head_dim**-0.25, + ) # self.whiten_values is applied on the values in forward(); # it just copies the keys but prevents low-rank distribution by modifying grads. - self.whiten_values = Whiten(num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025) - self.whiten_keys = Whiten(num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025) - + self.whiten_values = Whiten( + num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025, + ) + self.whiten_keys = Whiten( + num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025, + ) # linear transformation for positional encoding. - self.linear_pos = ScaledLinear(embed_dim, num_heads * pos_dim, bias=False, - initial_scale=0.05) + self.linear_pos = ScaledLinear( + embed_dim, num_heads * pos_dim, bias=False, initial_scale=0.05 + ) # the following are for diagnosics only, see --print-diagnostics option. # they only copy their inputs. @@ -1031,14 +1068,16 @@ class RelPositionMultiheadAttention(nn.Module): ) self.in_proj2 = nn.Linear(embed_dim, attention_dim // 2, bias=False) - self.out_proj2 = ScaledLinear(attention_dim // 2, embed_dim, bias=True, - initial_scale=0.05) + self.out_proj2 = ScaledLinear( + attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 + ) # self.whiten_values2 is applied on the values in forward2() - self.whiten_values2 = Whiten(num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025) - + self.whiten_values2 = Whiten( + num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025, + ) def forward( self, @@ -1098,7 +1137,6 @@ class RelPositionMultiheadAttention(nn.Module): ) return x, weights - def multi_head_attention_forward( self, x_proj: Tensor, @@ -1156,26 +1194,24 @@ class RelPositionMultiheadAttention(nn.Module): head_dim = attention_dim // num_heads pos_dim = self.pos_dim # positional-encoding dim per head - assert ( - head_dim * num_heads == attention_dim - ), f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}, {attention_dim}" - + assert head_dim * num_heads == attention_dim, ( + f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}," + f" {attention_dim}" + ) # self-attention - q = x_proj[...,0:attention_dim] - k = x_proj[...,attention_dim:2*attention_dim] + q = x_proj[..., 0:attention_dim] + k = x_proj[..., attention_dim : 2 * attention_dim] value_dim = attention_dim // 2 - v = x_proj[...,2*attention_dim:2*attention_dim+value_dim] + v = x_proj[..., 2 * attention_dim : 2 * attention_dim + value_dim] # p is the position-encoding query, its dimension is num_heads*pos_dim.. - p = x_proj[...,2*attention_dim+value_dim:] - + p = x_proj[..., 2 * attention_dim + value_dim :] k = self.whiten_keys(k) # does nothing in the forward pass. v = self.whiten_values(v) # does nothing in the forward pass. q = self.copy_query(q) # for diagnostics only, does nothing. p = self.copy_pos_query(p) # for diagnostics only, does nothing. - if attn_mask is not None: assert ( attn_mask.dtype == torch.float32 @@ -1195,33 +1231,25 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, seq_len, seq_len]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, seq_len, seq_len, ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor" + " instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -1230,7 +1258,6 @@ class RelPositionMultiheadAttention(nn.Module): k = k.reshape(seq_len, bsz, num_heads, head_dim) v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) - if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz, "{} == {}".format( key_padding_mask.size(0), bsz @@ -1239,13 +1266,10 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask.size(1), seq_len ) - - q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim) p = p.permute(1, 2, 0, 3) # (batch, head, time1, pos_dim) k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - seq_len2 = 2 * seq_len - 1 pos = pos.reshape(1, seq_len2, num_heads, pos_dim).permute(0, 2, 3, 1) # pos shape now: (batch, head, pos_dim, seq_len2) @@ -1256,13 +1280,16 @@ class RelPositionMultiheadAttention(nn.Module): # the following .as_strided() expression converts the last axis of pos_weights from relative # to absolute position. I don't know whether I might have got the time-offsets backwards or # not, but let this code define which way round it is supposed to be. - pos_weights = pos_weights.as_strided((bsz, num_heads, seq_len, seq_len), - (pos_weights.stride(0), - pos_weights.stride(1), - pos_weights.stride(2)-pos_weights.stride(3), - pos_weights.stride(3)), - storage_offset=pos_weights.stride(3) * (seq_len - 1)) - + pos_weights = pos_weights.as_strided( + (bsz, num_heads, seq_len, seq_len), + ( + pos_weights.stride(0), + pos_weights.stride(1), + pos_weights.stride(2) - pos_weights.stride(3), + pos_weights.stride(3), + ), + storage_offset=pos_weights.stride(3) * (seq_len - 1), + ) # caution: they are really scores at this point. attn_output_weights = torch.matmul(q, k) + pos_weights @@ -1275,10 +1302,9 @@ class RelPositionMultiheadAttention(nn.Module): # this mechanism instead of, say, a limit on entropy, because once the entropy # gets very small gradients through the softmax can become very small, and # some mechanisms like that become ineffective. - attn_output_weights = penalize_abs_values_gt(attn_output_weights, - limit=25.0, - penalty=1.0e-04) - + attn_output_weights = penalize_abs_values_gt( + attn_output_weights, limit=25.0, penalty=1.0e-04 + ) # attn_output_weights: (batch, head, time1, time2) attn_output_weights = attn_output_weights.view( @@ -1320,20 +1346,20 @@ class RelPositionMultiheadAttention(nn.Module): ) attn_output = torch.bmm(attn_output_weights, v) - assert list(attn_output.size()) == [bsz * num_heads, seq_len, - head_dim // 2] + assert list(attn_output.size()) == [ + bsz * num_heads, + seq_len, + head_dim // 2, + ] attn_output = ( attn_output.transpose(0, 1) .contiguous() .view(seq_len, bsz, attention_dim // 2) ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias - ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) return attn_output, attn_output_weights - def forward2( self, x: Tensor, @@ -1372,11 +1398,7 @@ class RelPositionMultiheadAttention(nn.Module): # returned value is of shape (seq_len, bsz, embed_dim), like x. return self.out_proj2(attn_output) - - def _print_attn_stats( - self, - attn_weights: Tensor, - attn_output: Tensor): + def _print_attn_stats(self, attn_weights: Tensor, attn_output: Tensor): # attn_weights: (batch_size * num_heads, seq_len, seq_len) # attn_output: (bsz * num_heads, seq_len, head_dim) (n, seq_len, head_dim) = attn_output.shape @@ -1387,39 +1409,50 @@ class RelPositionMultiheadAttention(nn.Module): with torch.cuda.amp.autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_output = attn_output.to(torch.float32) - attn_weights_entropy = -((attn_weights + 1.0e-20).log() * attn_weights).sum( - dim=-1).reshape(bsz, num_heads, seq_len).mean(dim=(0,2)) + attn_weights_entropy = ( + -((attn_weights + 1.0e-20).log() * attn_weights) + .sum(dim=-1) + .reshape(bsz, num_heads, seq_len) + .mean(dim=(0, 2)) + ) attn_output = attn_output.reshape(bsz, num_heads, seq_len, head_dim) - attn_output = attn_output.permute(1, 0, 2, 3).reshape(num_heads, bsz * seq_len, head_dim) + attn_output = attn_output.permute(1, 0, 2, 3).reshape( + num_heads, bsz * seq_len, head_dim + ) attn_output_mean = attn_output.mean(dim=1, keepdim=True) attn_output = attn_output - attn_output_mean - attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / (bsz * seq_len) + attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / ( + bsz * seq_len + ) # attn_covar: (num_heads, head_dim, head_dim) - #eigs, _ = torch.symeig(attn_covar) - #logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") + # eigs, _ = torch.symeig(attn_covar) + # logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") attn_covar = _diag(attn_covar).mean(dim=1) # (num_heads,) embed_dim = self.in_proj2.weight.shape[1] - in_proj_covar = (self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2).mean(dim=(1,2)) - out_proj_covar = (self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2).mean(dim=(0,2)) - logging.info(f"attn_weights_entropy = {attn_weights_entropy}, covar={attn_covar}, in_proj_covar={in_proj_covar}, out_proj_covar={out_proj_covar}") - - + in_proj_covar = ( + self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2 + ).mean(dim=(1, 2)) + out_proj_covar = ( + self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2 + ).mean(dim=(0, 2)) + logging.info( + f"attn_weights_entropy = {attn_weights_entropy}," + f" covar={attn_covar}, in_proj_covar={in_proj_covar}," + f" out_proj_covar={out_proj_covar}" + ) class PoolingModule(nn.Module): """ Averages the input over the time dimension and project with a square matrix. """ - def __init__(self, - d_model: int): - super().__init__() - self.proj = ScaledLinear(d_model, d_model, - initial_scale=0.1, bias=False) - def forward(self, - x: Tensor, - key_padding_mask: Optional[Tensor] = None): + def __init__(self, d_model: int): + super().__init__() + self.proj = ScaledLinear(d_model, d_model, initial_scale=0.1, bias=False) + + def forward(self, x: Tensor, key_padding_mask: Optional[Tensor] = None): """ Args: x: a Tensor of shape (T, N, C) @@ -1430,7 +1463,7 @@ class PoolingModule(nn.Module): """ if key_padding_mask is not None: pooling_mask = key_padding_mask.logical_not().to(x.dtype) # (N, T) - pooling_mask = (pooling_mask / pooling_mask.sum(dim=1, keepdim=True)) + pooling_mask = pooling_mask / pooling_mask.sum(dim=1, keepdim=True) pooling_mask = pooling_mask.transpose(0, 1).contiguous().unsqueeze(-1) # now pooling_mask: (T, N, 1) x = (x * pooling_mask).sum(dim=0, keepdim=True) @@ -1444,24 +1477,19 @@ class PoolingModule(nn.Module): class FeedforwardModule(nn.Module): - """Feedforward module in Zipformer model. - """ - def __init__(self, - d_model: int, - feedforward_dim: int, - dropout: float): + """Feedforward module in Zipformer model.""" + + def __init__(self, d_model: int, feedforward_dim: int, dropout: float): super(FeedforwardModule, self).__init__() self.in_proj = nn.Linear(d_model, feedforward_dim) - self.balancer = ActivationBalancer(feedforward_dim, - channel_dim=-1, max_abs=10.0, - min_prob=0.25) + self.balancer = ActivationBalancer( + feedforward_dim, channel_dim=-1, max_abs=10.0, min_prob=0.25 + ) self.activation = DoubleSwish() self.dropout = nn.Dropout(dropout) - self.out_proj = ScaledLinear(feedforward_dim, d_model, - initial_scale=0.01) + self.out_proj = ScaledLinear(feedforward_dim, d_model, initial_scale=0.01) - def forward(self, - x: Tensor): + def forward(self, x: Tensor): x = self.in_proj(x) x = self.balancer(x) x = self.activation(x) @@ -1481,9 +1509,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding @@ -1513,7 +1539,10 @@ class ConvolutionModule(nn.Module): # the correct range. self.deriv_balancer1 = ActivationBalancer( 2 * channels, - channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0 + channel_dim=1, + max_abs=10.0, + min_positive=0.05, + max_positive=1.0, ) self.depthwise_conv = nn.Conv1d( @@ -1527,8 +1556,10 @@ class ConvolutionModule(nn.Module): ) self.deriv_balancer2 = ActivationBalancer( - channels, channel_dim=1, - min_positive=0.05, max_positive=1.0, + channels, + channel_dim=1, + min_positive=0.05, + max_positive=1.0, max_abs=20.0, ) @@ -1544,9 +1575,10 @@ class ConvolutionModule(nn.Module): initial_scale=0.05, ) - def forward(self, - x: Tensor, - src_key_padding_mask: Optional[Tensor] = None, + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: """Compute convolution module. @@ -1626,8 +1658,7 @@ class Conv2dSubsampling(nn.Module): kernel_size=3, padding=(0, 1), # (time, freq) ), - ActivationBalancer(layer1_channels, - channel_dim=1), + ActivationBalancer(layer1_channels, channel_dim=1), DoubleSwish(), nn.Conv2d( in_channels=layer1_channels, @@ -1636,24 +1667,21 @@ class Conv2dSubsampling(nn.Module): stride=2, padding=0, ), - ActivationBalancer(layer2_channels, - channel_dim=1), + ActivationBalancer(layer2_channels, channel_dim=1), DoubleSwish(), nn.Conv2d( in_channels=layer2_channels, out_channels=layer3_channels, kernel_size=3, - stride=(1, 2), # (time, freq) + stride=(1, 2), # (time, freq) ), - ActivationBalancer(layer3_channels, - channel_dim=1), + ActivationBalancer(layer3_channels, channel_dim=1), DoubleSwish(), ) out_height = (((in_channels - 1) // 2) - 1) // 2 self.out = ScaledLinear(out_height * layer3_channels, out_channels) self.dropout = nn.Dropout(dropout) - def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. @@ -1674,6 +1702,7 @@ class Conv2dSubsampling(nn.Module): x = self.dropout(x) return x + class AttentionCombine(nn.Module): """ This module combines a list of Tensors, all with the same shape, to @@ -1717,15 +1746,12 @@ class AttentionCombine(nn.Module): self.random_prob = random_prob self.single_prob = single_prob - self.weight = torch.nn.Parameter(torch.zeros(num_channels, - num_inputs)) + self.weight = torch.nn.Parameter(torch.zeros(num_channels, num_inputs)) self.bias = torch.nn.Parameter(torch.zeros(num_inputs)) assert 0 <= random_prob <= 1, random_prob assert 0 <= single_prob <= 1, single_prob - - def forward(self, inputs: List[Tensor]) -> Tensor: """Forward function. Args: @@ -1756,28 +1782,35 @@ class AttentionCombine(nn.Module): if self.training: # random masking.. - mask_start = torch.randint(low=1, high=int(num_inputs / self.random_prob), - size=(num_frames,), device=scores.device).unsqueeze(1) + mask_start = torch.randint( + low=1, + high=int(num_inputs / self.random_prob), + size=(num_frames,), + device=scores.device, + ).unsqueeze(1) # mask will have rows like: [ False, False, False, True, True, .. ] - arange = torch.arange(num_inputs, device=scores.device).unsqueeze(0).expand( - num_frames, num_inputs) + arange = ( + torch.arange(num_inputs, device=scores.device) + .unsqueeze(0) + .expand(num_frames, num_inputs) + ) mask = arange >= mask_start - apply_single_prob = torch.logical_and(torch.rand(size=(num_frames, 1), - device=scores.device) < self.single_prob, - mask_start < num_inputs) - single_prob_mask = torch.logical_and(apply_single_prob, - arange < mask_start - 1) + apply_single_prob = torch.logical_and( + torch.rand(size=(num_frames, 1), device=scores.device) + < self.single_prob, + mask_start < num_inputs, + ) + single_prob_mask = torch.logical_and( + apply_single_prob, arange < mask_start - 1 + ) - mask = torch.logical_or(mask, - single_prob_mask) + mask = torch.logical_or(mask, single_prob_mask) - scores = scores.masked_fill(mask, float('-inf')) + scores = scores.masked_fill(mask, float("-inf")) if self.training and random.random() < 0.1: - scores = penalize_abs_values_gt(scores, - limit=10.0, - penalty=1.0e-04) + scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) weights = scores.softmax(dim=1) @@ -1792,7 +1825,6 @@ class AttentionCombine(nn.Module): return ans - def _test_random_combine(): print("_test_random_combine()") num_inputs = 3 @@ -1801,8 +1833,8 @@ def _test_random_combine(): num_channels=num_channels, num_inputs=num_inputs, random_prob=0.5, - single_prob=0.0) - + single_prob=0.0, + ) x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)] @@ -1819,7 +1851,10 @@ def _test_zipformer_main(): # Just make sure the forward pass runs. c = Zipformer( - num_features=feature_dim, encoder_dims=(64,96), encoder_unmasked_dims=(48,64), nhead=(4,4) + num_features=feature_dim, + encoder_dims=(64, 96), + encoder_unmasked_dims=(48, 64), + nhead=(4, 4), ) batch_size = 5 seq_len = 20 @@ -1837,19 +1872,18 @@ def _test_zipformer_main(): ) f # to remove flake8 warnings + def _test_conv2d_subsampling(): num_features = 80 encoder_dims = 384 dropout = 0.1 - encoder_embed = Conv2dSubsampling(num_features, encoder_dims, - dropout=dropout) + encoder_embed = Conv2dSubsampling(num_features, encoder_dims, dropout=dropout) for i in range(20, 40): x = torch.rand(2, i, num_features) y = encoder_embed(x) assert (x.shape[1] - 7) // 2 == y.shape[1], (x.shape[1], y.shape[1]) - if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py index 9d7335e77..822f8e44b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py @@ -165,20 +165,24 @@ def get_parser(): "--avg", type=int, default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -273,8 +277,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -394,9 +397,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -455,10 +456,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -589,9 +587,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -624,8 +620,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -680,9 +675,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -719,13 +712,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -753,13 +745,12 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -788,7 +779,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -816,9 +807,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/export.py b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py index 49f469e29..43eb0c1bc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py @@ -129,20 +129,24 @@ def get_parser(): "--avg", type=int, default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -176,8 +180,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) add_model_arguments(parser) @@ -217,13 +220,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -252,13 +254,12 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -287,7 +288,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -326,9 +327,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py index e79a3a3aa..ed920dc03 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py @@ -69,10 +69,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) return parser @@ -93,10 +95,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -267,9 +268,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/model.py b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py index 497b89136..39a360796 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py @@ -160,9 +160,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py index 373a48fc1..716136812 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py @@ -100,9 +100,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -127,10 +129,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -177,8 +181,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -209,10 +212,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -275,15 +277,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -355,9 +353,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py index 2603bb854..381a86a67 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py @@ -92,9 +92,7 @@ from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: @@ -132,7 +130,10 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-dims", type=str, default="384,384,384,384,384", - help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + help=( + "Embedding dimension in the 2 blocks of zipformer encoder layers, comma" + " separated" + ), ) parser.add_argument( @@ -147,9 +148,11 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-unmasked-dims", type=str, default="256,256,256,256,256", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " - " worse.", + help=( + "Unmasked dimensions in the encoders, relates to augmentation during" + " training. Must be <= each of encoder_dims. Empirically, less than 256" + " seems to make performance worse." + ), ) parser.add_argument( @@ -214,8 +217,7 @@ def get_parser(): "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. " - "Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", ) parser.add_argument( @@ -285,42 +287,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -691,11 +696,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -744,9 +745,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -952,9 +951,7 @@ def train_one_epoch( # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or ( - cur_grad_scale < 8.0 and batch_idx % 400 == 0 - ): + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") @@ -975,11 +972,7 @@ def train_one_epoch( f"giga_tot_loss[{giga_tot_loss}], " f"batch size: {batch_size}, " f"lr: {cur_lr:.2e}, " - + ( - f"grad_scale: {scaler._scale.item()}" - if params.use_fp16 - else "" - ) + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") ) if tb_writer is not None: @@ -992,12 +985,8 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) libri_tot_loss.write_summary( tb_writer, "train/libri_tot_", params.batch_idx_train ) @@ -1011,10 +1000,7 @@ def train_one_epoch( params.batch_idx_train, ) - if ( - batch_idx % params.valid_interval == 0 - and not params.print_diagnostics - ): + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, @@ -1026,7 +1012,8 @@ def train_one_epoch( model.train() logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + "Maximum memory allocated so far is" + f" {torch.cuda.max_memory_allocated()//1000000}MB" ) if tb_writer is not None: valid_info.write_summary( @@ -1054,8 +1041,7 @@ def filter_short_and_long_utterances( # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False @@ -1152,9 +1138,7 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = ScaledAdam( - model.parameters(), lr=params.base_lr, clipping_scale=2.0 - ) + optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) @@ -1172,7 +1156,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1207,9 +1191,7 @@ def run(rank, world_size, args): train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: - cuts_musan = load_manifest( - Path(args.manifest_dir) / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") else: cuts_musan = None @@ -1364,7 +1346,8 @@ def scan_pessimistic_batches_for_oom( display_and_save_batch(batch, params=params, sp=sp) raise logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + "Maximum memory allocated so far is" + f" {torch.cuda.max_memory_allocated()//1000000}MB" ) diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/README.md b/egs/librispeech/ASR/streaming_conformer_ctc/README.md index 01be7090b..53f383c99 100644 --- a/egs/librispeech/ASR/streaming_conformer_ctc/README.md +++ b/egs/librispeech/ASR/streaming_conformer_ctc/README.md @@ -1,20 +1,20 @@ ## Train and Decode -Commands of data preparation/train/decode steps are almost the same with +Commands of data preparation/train/decode steps are almost the same with ../conformer_ctc experiment except some options. Please read the code and understand following new added options before running this experiment: For data preparation: - + Nothing new. For streaming_conformer_ctc/train.py: - + --dynamic-chunk-training --short-chunk-proportion For streaming_conformer_ctc/streaming_decode.py: - + --chunk-size --tailing-num-frames --simulate-streaming @@ -57,10 +57,10 @@ And check md5sum values again. Finally, following files will be downloaded:

-streaming_models/  
-|-- lang_bpe  
-|   |-- L.pt  
-|   |-- Linv.pt  
+streaming_models/
+|-- lang_bpe
+|   |-- L.pt
+|   |-- Linv.pt
 |   |-- bpe.model
 |   |-- tokens.txt
 |   `-- words.txt
diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py
index ff4c91446..4f7427c1f 100644
--- a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py
+++ b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py
@@ -309,36 +309,26 @@ class Conformer(Transformer):
 
                 # start chunk_by_chunk decoding
                 offset = 0
-                for cur in range(
-                    0, num_frames - embed_left_context + 1, stride
-                ):
+                for cur in range(0, num_frames - embed_left_context + 1, stride):
                     end = min(cur + decoding_window, num_frames)
                     cur_feature = feature[:, cur:end, :]
                     cur_feature = self.encoder_embed(cur_feature)
-                    cur_embed, cur_pos_emb = self.encoder_pos(
-                        cur_feature, offset
-                    )
-                    cur_embed = cur_embed.permute(
-                        1, 0, 2
-                    )  # (B, T, F) -> (T, B, F)
+                    cur_embed, cur_pos_emb = self.encoder_pos(cur_feature, offset)
+                    cur_embed = cur_embed.permute(1, 0, 2)  # (B, T, F) -> (T, B, F)
 
                     cur_T = cur_feature.size(1)
                     if cur == 0:
                         # for first chunk extract the central pos embedding
-                        pos_emb_central = cur_pos_emb[
-                            0, (chunk_size - 1), :
-                        ].view(1, 1, -1)
+                        pos_emb_central = cur_pos_emb[0, (chunk_size - 1), :].view(
+                            1, 1, -1
+                        )
                         cur_T -= 1
                     pos_emb_positive.append(cur_pos_emb[0, :cur_T].flip(0))
                     pos_emb_negative.append(cur_pos_emb[0, -cur_T:])
                     assert pos_emb_positive[-1].size(0) == cur_T
 
-                    pos_emb_pos = torch.cat(pos_emb_positive, dim=0).unsqueeze(
-                        0
-                    )
-                    pos_emb_neg = torch.cat(pos_emb_negative, dim=0).unsqueeze(
-                        0
-                    )
+                    pos_emb_pos = torch.cat(pos_emb_positive, dim=0).unsqueeze(0)
+                    pos_emb_neg = torch.cat(pos_emb_negative, dim=0).unsqueeze(0)
                     cur_pos_emb = torch.cat(
                         [pos_emb_pos.flip(1), pos_emb_central, pos_emb_neg],
                         dim=1,
@@ -413,9 +403,7 @@ class ConformerEncoderLayer(nn.Module):
         causal: bool = False,
     ) -> None:
         super(ConformerEncoderLayer, self).__init__()
-        self.self_attn = RelPositionMultiheadAttention(
-            d_model, nhead, dropout=0.0
-        )
+        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
 
         self.feed_forward = nn.Sequential(
             nn.Linear(d_model, dim_feedforward),
@@ -431,22 +419,16 @@ class ConformerEncoderLayer(nn.Module):
             nn.Linear(dim_feedforward, d_model),
         )
 
-        self.conv_module = ConvolutionModule(
-            d_model, cnn_module_kernel, causal=causal
-        )
+        self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal)
 
-        self.norm_ff_macaron = nn.LayerNorm(
-            d_model
-        )  # for the macaron style FNN module
+        self.norm_ff_macaron = nn.LayerNorm(d_model)  # for the macaron style FNN module
         self.norm_ff = nn.LayerNorm(d_model)  # for the FNN module
         self.norm_mha = nn.LayerNorm(d_model)  # for the MHA module
 
         self.ff_scale = 0.5
 
         self.norm_conv = nn.LayerNorm(d_model)  # for the CNN module
-        self.norm_final = nn.LayerNorm(
-            d_model
-        )  # for the final output of the block
+        self.norm_final = nn.LayerNorm(d_model)  # for the final output of the block
 
         self.dropout = nn.Dropout(dropout)
 
@@ -480,9 +462,7 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(
-            self.feed_forward_macaron(src)
-        )
+        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -554,9 +534,7 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(
-            self.feed_forward_macaron(src)
-        )
+        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -736,9 +714,7 @@ class RelPositionalEncoding(torch.nn.Module):
 
     """
 
-    def __init__(
-        self, d_model: int, dropout_rate: float, max_len: int = 5000
-    ) -> None:
+    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
         """Construct an PositionalEncoding object."""
         super(RelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -755,9 +731,7 @@ class RelPositionalEncoding(torch.nn.Module):
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(1) >= x_size_1 * 2 - 1:
                 # Note: TorchScript doesn't implement operator== for torch.Device
-                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
-                    x.device
-                ):
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
                     self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                 return
         # Suppose `i` means to the position of query vector and `j` means the
@@ -783,9 +757,7 @@ class RelPositionalEncoding(torch.nn.Module):
         pe = torch.cat([pe_positive, pe_negative], dim=1)
         self.pe = pe.to(device=x.device, dtype=x.dtype)
 
-    def forward(
-        self, x: torch.Tensor, offset: int = 0
-    ) -> Tuple[Tensor, Tensor]:
+    def forward(self, x: torch.Tensor, offset: int = 0) -> Tuple[Tensor, Tensor]:
         """Add positional encoding.
 
         Args:
@@ -813,9 +785,7 @@ class RelPositionalEncoding(torch.nn.Module):
             pos_emb = torch.cat(
                 [
                     pos_emb[:, : (x_T - 1)],
-                    self.pe[0, self.pe.size(1) // 2].view(
-                        1, 1, self.pe.size(-1)
-                    ),
+                    self.pe[0, self.pe.size(1) // 2].view(1, 1, self.pe.size(-1)),
                     pos_emb[:, -(x_T - 1) :],  # noqa: E203
                 ],
                 dim=1,
@@ -1050,9 +1020,9 @@ class RelPositionMultiheadAttention(nn.Module):
 
         if torch.equal(query, key) and torch.equal(key, value):
             # self-attention
-            q, k, v = nn.functional.linear(
-                query, in_proj_weight, in_proj_bias
-            ).chunk(3, dim=-1)
+            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
+                3, dim=-1
+            )
 
         elif torch.equal(key, value):
             # encoder-decoder attention
@@ -1120,33 +1090,25 @@ class RelPositionMultiheadAttention(nn.Module):
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0)
                 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
-                    raise RuntimeError(
-                        "The size of the 2D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
             elif attn_mask.dim() == 3:
                 if list(attn_mask.size()) != [
                     bsz * num_heads,
                     query.size(0),
                     key.size(0),
                 ]:
-                    raise RuntimeError(
-                        "The size of the 3D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
             else:
                 raise RuntimeError(
-                    "attn_mask's dimension {} is not supported".format(
-                        attn_mask.dim()
-                    )
+                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
                 )
             # attn_mask's dim is 3 now.
 
         # convert ByteTensor key_padding_mask to bool
-        if (
-            key_padding_mask is not None
-            and key_padding_mask.dtype == torch.uint8
-        ):
+        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
             warnings.warn(
-                "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
+                "Byte tensor for key_padding_mask is deprecated. Use bool tensor"
+                " instead."
             )
             key_padding_mask = key_padding_mask.to(torch.bool)
 
@@ -1185,24 +1147,16 @@ class RelPositionMultiheadAttention(nn.Module):
         # first compute matrix a and matrix c
         # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
         k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
-        matrix_ac = torch.matmul(
-            q_with_bias_u, k
-        )  # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
 
         # compute matrix b and matrix d
-        matrix_bd = torch.matmul(
-            q_with_bias_v, p
-        )  # (batch, head, time1, 2*time1-1)
-        matrix_bd = self.rel_shift(
-            matrix_bd, offset=offset
-        )  # [B, head, time1, time2]
+        matrix_bd = torch.matmul(q_with_bias_v, p)  # (batch, head, time1, 2*time1-1)
+        matrix_bd = self.rel_shift(matrix_bd, offset=offset)  # [B, head, time1, time2]
         attn_output_weights = (
             matrix_ac + matrix_bd
         ) * scaling  # (batch, head, time1, time2)
 
-        attn_output_weights = attn_output_weights.view(
-            bsz * num_heads, tgt_len, -1
-        )
+        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
 
         assert list(attn_output_weights.size()) == [
             bsz * num_heads,
@@ -1236,13 +1190,9 @@ class RelPositionMultiheadAttention(nn.Module):
         attn_output = torch.bmm(attn_output_weights, v)
         assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
         attn_output = (
-            attn_output.transpose(0, 1)
-            .contiguous()
-            .view(tgt_len, bsz, embed_dim)
-        )
-        attn_output = nn.functional.linear(
-            attn_output, out_proj_weight, out_proj_bias
+            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
         )
+        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
 
         if need_weights:
             # average attention weights over heads
diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py b/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py
index a74c51836..5a8149aad 100755
--- a/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py
+++ b/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py
@@ -28,6 +28,7 @@ import torch
 import torch.nn as nn
 from asr_datamodule import LibriSpeechAsrDataModule
 from conformer import Conformer
+
 from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
 from icefall.checkpoint import average_checkpoints, load_checkpoint
 from icefall.lexicon import Lexicon
@@ -62,32 +63,36 @@ def get_parser():
         "--epoch",
         type=int,
         default=34,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=20,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
         "--chunk-size",
         type=int,
         default=8,
-        help="Frames of right context"
-        "-1 for whole right context, i.e. non-streaming decoding",
+        help=(
+            "Frames of right context"
+            "-1 for whole right context, i.e. non-streaming decoding"
+        ),
     )
 
     parser.add_argument(
         "--tailing-num-frames",
         type=int,
         default=20,
-        help="tailing dummy frames padded to the right,"
-        "only used during decoding",
+        help="tailing dummy frames padded to the right,only used during decoding",
     )
 
     parser.add_argument(
@@ -139,8 +144,7 @@ def get_parser():
         "--avg-models",
         type=str,
         default=None,
-        help="Manually select models to average, seperated by comma;"
-        "e.g. 60,62,63,72",
+        help="Manually select models to average, seperated by comma;e.g. 60,62,63,72",
     )
 
     return parser
@@ -248,13 +252,9 @@ def decode_one_batch(
     maxlen = nnet_output.size(1)
     topk_prob, topk_index = nnet_output.topk(1, dim=2)  # (B, maxlen, 1)
     topk_index = topk_index.view(batch_size, maxlen)  # (B, maxlen)
-    topk_index = topk_index.masked_fill_(
-        memory_key_padding_mask, 0
-    )  # (B, maxlen)
+    topk_index = topk_index.masked_fill_(memory_key_padding_mask, 0)  # (B, maxlen)
     token_ids = [token_id.tolist() for token_id in topk_index]
-    token_ids = [
-        remove_duplicates_and_blank(token_id) for token_id in token_ids
-    ]
+    token_ids = [remove_duplicates_and_blank(token_id) for token_id in token_ids]
     hyps = bpe_model.decode(token_ids)
     hyps = [s.split() for s in hyps]
     return {key: hyps}
@@ -337,9 +337,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
 
     return results
 
@@ -357,15 +355,18 @@ def save_results(
     test_set_wers = dict()
     if params.avg_models is not None:
         avg_models = params.avg_models.replace(",", "_")
-        result_file_prefix = f"epoch-avg-{avg_models}-chunksize \
-        -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-"
+        result_file_prefix = (
+            f"epoch-avg-{avg_models}-chunksize        "
+            f" -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-"
+        )
     else:
-        result_file_prefix = f"epoch-{params.epoch}-avg-{params.avg}-chunksize \
-        -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-"
+        result_file_prefix = (
+            f"epoch-{params.epoch}-avg-{params.avg}-chunksize        "
+            f" -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-"
+        )
     for key, results in results_dict.items():
         recog_path = (
-            params.exp_dir
-            / f"{result_file_prefix}recogs-{test_set_name}-{key}.txt"
+            params.exp_dir / f"{result_file_prefix}recogs-{test_set_name}-{key}.txt"
         )
         store_transcripts(filename=recog_path, texts=results)
         if enable_log:
@@ -374,8 +375,7 @@ def save_results(
         # The following prints out WERs, per-word error statistics and aligned
         # ref/hyp pairs.
         errs_filename = (
-            params.exp_dir
-            / f"{result_file_prefix}-errs-{test_set_name}-{key}.txt"
+            params.exp_dir / f"{result_file_prefix}-errs-{test_set_name}-{key}.txt"
         )
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
@@ -384,9 +384,7 @@ def save_results(
             test_set_wers[key] = wer
 
         if enable_log:
-            logging.info(
-                "Wrote detailed error stats to {}".format(errs_filename)
-            )
+            logging.info("Wrote detailed error stats to {}".format(errs_filename))
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
@@ -474,9 +472,7 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save(
-            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
-        )
+        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
         return
 
     model.to(device)
@@ -507,9 +503,7 @@ def main():
             simulate_streaming=params.simulate_streaming,
         )
 
-        save_results(
-            params=params, test_set_name=test_set, results_dict=results_dict
-        )
+        save_results(params=params, test_set_name=test_set, results_dict=results_dict)
 
     logging.info("Done!")
 
diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/train.py b/egs/librispeech/ASR/streaming_conformer_ctc/train.py
index e41b7ea78..553b7d092 100755
--- a/egs/librispeech/ASR/streaming_conformer_ctc/train.py
+++ b/egs/librispeech/ASR/streaming_conformer_ctc/train.py
@@ -405,9 +405,7 @@ def compute_loss(
             #
             # See https://github.com/k2-fsa/icefall/issues/97
             # for more details
-            unsorted_token_ids = graph_compiler.texts_to_ids(
-                supervisions["text"]
-            )
+            unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"])
             att_loss = mmodel.decoder_forward(
                 encoder_memory,
                 memory_mask,
@@ -436,9 +434,7 @@ def compute_loss(
     info["utt_duration"] = supervisions["num_frames"].sum().item()
     # averaged padding proportion over utterances
     info["utt_pad_proportion"] = (
-        ((feature.size(1) - supervisions["num_frames"]) / feature.size(1))
-        .sum()
-        .item()
+        ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)).sum().item()
     )
 
     return loss, info
@@ -551,9 +547,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -668,9 +662,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py b/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py
index bc78e4a41..0c87fdf1b 100644
--- a/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py
+++ b/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py
@@ -149,9 +149,7 @@ class Transformer(nn.Module):
                 norm=decoder_norm,
             )
 
-            self.decoder_output_layer = torch.nn.Linear(
-                d_model, self.decoder_num_class
-            )
+            self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class)
 
             self.decoder_criterion = LabelSmoothingLoss()
         else:
@@ -286,23 +284,17 @@ class Transformer(nn.Module):
         """
         ys_in = add_sos(token_ids, sos_id=sos_id)
         ys_in = [torch.tensor(y) for y in ys_in]
-        ys_in_pad = pad_sequence(
-            ys_in, batch_first=True, padding_value=float(eos_id)
-        )
+        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
 
         ys_out = add_eos(token_ids, eos_id=eos_id)
         ys_out = [torch.tensor(y) for y in ys_out]
-        ys_out_pad = pad_sequence(
-            ys_out, batch_first=True, padding_value=float(-1)
-        )
+        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
 
         device = memory.device
         ys_in_pad = ys_in_pad.to(device)
         ys_out_pad = ys_out_pad.to(device)
 
-        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
-            device
-        )
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
 
         tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
         # TODO: Use length information to create the decoder padding mask
@@ -363,23 +355,17 @@ class Transformer(nn.Module):
 
         ys_in = add_sos(token_ids, sos_id=sos_id)
         ys_in = [torch.tensor(y) for y in ys_in]
-        ys_in_pad = pad_sequence(
-            ys_in, batch_first=True, padding_value=float(eos_id)
-        )
+        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
 
         ys_out = add_eos(token_ids, eos_id=eos_id)
         ys_out = [torch.tensor(y) for y in ys_out]
-        ys_out_pad = pad_sequence(
-            ys_out, batch_first=True, padding_value=float(-1)
-        )
+        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
 
         device = memory.device
         ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
         ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
 
-        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
-            device
-        )
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
 
         tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
         # TODO: Use length information to create the decoder padding mask
@@ -652,9 +638,7 @@ def _get_activation_fn(activation: str):
     elif activation == "gelu":
         return nn.functional.gelu
 
-    raise RuntimeError(
-        "activation should be relu/gelu, not {}".format(activation)
-    )
+    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
 
 
 class PositionalEncoding(nn.Module):
@@ -856,9 +840,7 @@ def encoder_padding_mask(
         1,
     ).to(torch.int32)
 
-    lengths = [
-        0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
-    ]
+    lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)]
     for idx in range(supervision_segments.size(0)):
         # Note: TorchScript doesn't allow to unpack tensors as tuples
         sequence_idx = supervision_segments[idx, 0].item()
@@ -879,9 +861,7 @@ def encoder_padding_mask(
     return mask
 
 
-def decoder_padding_mask(
-    ys_pad: torch.Tensor, ignore_id: int = -1
-) -> torch.Tensor:
+def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor:
     """Generate a length mask for input.
 
     The masked position are filled with True,
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
index 355ccc99a..63afd6be2 100644
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -77,17 +77,18 @@ class LibriSpeechAsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description="These options are used for the preparation of "
-            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-            "effective batch sizes, sampling strategies, applied data "
-            "augmentations, etc.",
+            description=(
+                "These options are used for the preparation of "
+                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+                "effective batch sizes, sampling strategies, applied data "
+                "augmentations, etc."
+            ),
         )
         group.add_argument(
             "--full-libri",
             type=str2bool,
             default=True,
-            help="When enabled, use 960h LibriSpeech. "
-            "Otherwise, use 100h subset.",
+            help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.",
         )
         group.add_argument(
             "--manifest-dir",
@@ -99,59 +100,74 @@ class LibriSpeechAsrDataModule:
             "--max-duration",
             type=int,
             default=200.0,
-            help="Maximum pooled recordings duration (seconds) in a "
-            "single batch. You can reduce it if it causes CUDA OOM.",
+            help=(
+                "Maximum pooled recordings duration (seconds) in a "
+                "single batch. You can reduce it if it causes CUDA OOM."
+            ),
         )
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help="When enabled, the batches will come from buckets of "
-            "similar duration (saves padding frames).",
+            help=(
+                "When enabled, the batches will come from buckets of "
+                "similar duration (saves padding frames)."
+            ),
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=30,
-            help="The number of buckets for the DynamicBucketingSampler"
-            "(you might want to increase it for larger datasets).",
+            help=(
+                "The number of buckets for the DynamicBucketingSampler"
+                "(you might want to increase it for larger datasets)."
+            ),
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help="When enabled, utterances (cuts) will be concatenated "
-            "to minimize the amount of padding.",
+            help=(
+                "When enabled, utterances (cuts) will be concatenated "
+                "to minimize the amount of padding."
+            ),
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help="Determines the maximum duration of a concatenated cut "
-            "relative to the duration of the longest cut in a batch.",
+            help=(
+                "Determines the maximum duration of a concatenated cut "
+                "relative to the duration of the longest cut in a batch."
+            ),
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help="The amount of padding (in seconds) inserted between "
-            "concatenated cuts. This padding is filled with noise when "
-            "noise augmentation is used.",
+            help=(
+                "The amount of padding (in seconds) inserted between "
+                "concatenated cuts. This padding is filled with noise when "
+                "noise augmentation is used."
+            ),
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help="When enabled, use on-the-fly cut mixing and feature "
-            "extraction. Will drop existing precomputed feature manifests "
-            "if available.",
+            help=(
+                "When enabled, use on-the-fly cut mixing and feature "
+                "extraction. Will drop existing precomputed feature manifests "
+                "if available."
+            ),
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help="When enabled (=default), the examples will be "
-            "shuffled for each epoch.",
+            help=(
+                "When enabled (=default), the examples will be shuffled for each epoch."
+            ),
         )
         group.add_argument(
             "--drop-last",
@@ -163,17 +179,18 @@ class LibriSpeechAsrDataModule:
             "--return-cuts",
             type=str2bool,
             default=True,
-            help="When enabled, each batch will have the "
-            "field: batch['supervisions']['cut'] with the cuts that "
-            "were used to construct it.",
+            help=(
+                "When enabled, each batch will have the "
+                "field: batch['supervisions']['cut'] with the cuts that "
+                "were used to construct it."
+            ),
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that "
-            "collect the batches.",
+            help="The number of training dataloader workers that collect the batches.",
         )
 
         group.add_argument(
@@ -187,18 +204,22 @@ class LibriSpeechAsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help="Used only when --enable-spec-aug is True. "
-            "It specifies the factor for time warping in SpecAugment. "
-            "Larger values mean more warping. "
-            "A value less than 1 means to disable time warp.",
+            help=(
+                "Used only when --enable-spec-aug is True. "
+                "It specifies the factor for time warping in SpecAugment. "
+                "Larger values mean more warping. "
+                "A value less than 1 means to disable time warp."
+            ),
         )
 
         group.add_argument(
             "--enable-musan",
             type=str2bool,
             default=True,
-            help="When enabled, select noise from MUSAN and mix it"
-            "with training dataset. ",
+            help=(
+                "When enabled, select noise from MUSAN and mix it"
+                "with training dataset. "
+            ),
         )
 
         group.add_argument(
@@ -224,20 +245,16 @@ class LibriSpeechAsrDataModule:
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             logging.info("About to get Musan cuts")
-            cuts_musan = load_manifest(
-                self.args.manifest_dir / "musan_cuts.jsonl.gz"
-            )
+            cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                f"Using cut concatenation with duration factor "
+                "Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -252,9 +269,7 @@ class LibriSpeechAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -298,9 +313,7 @@ class LibriSpeechAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -356,9 +369,7 @@ class LibriSpeechAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
index 7d0cd0bf3..94ba0a4dc 100755
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
@@ -57,16 +57,19 @@ def get_parser():
         "--epoch",
         type=int,
         default=19,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=5,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
     parser.add_argument(
         "--method",
@@ -336,9 +339,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -400,9 +401,7 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(
-        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
-    )
+    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -467,9 +466,7 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save(
-            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
-        )
+        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
         return
 
     model.to(device)
@@ -498,9 +495,7 @@ def main():
             G=G,
         )
 
-        save_results(
-            params=params, test_set_name=test_set, results_dict=results_dict
-        )
+        save_results(params=params, test_set_name=test_set, results_dict=results_dict)
 
     logging.info("Done!")
 
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/model.py b/egs/librispeech/ASR/tdnn_lstm_ctc/model.py
index 5e04c11b4..1731e1ebe 100644
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/model.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/model.py
@@ -66,10 +66,7 @@ class TdnnLstm(nn.Module):
             nn.BatchNorm1d(num_features=500, affine=False),
         )
         self.lstms = nn.ModuleList(
-            [
-                nn.LSTM(input_size=500, hidden_size=500, num_layers=1)
-                for _ in range(5)
-            ]
+            [nn.LSTM(input_size=500, hidden_size=500, num_layers=1) for _ in range(5)]
         )
         self.lstm_bnorms = nn.ModuleList(
             [nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)]
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
index 2baeb6bba..722e8f003 100755
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
@@ -29,11 +29,7 @@ import torchaudio
 from model import TdnnLstm
 from torch.nn.utils.rnn import pad_sequence
 
-from icefall.decode import (
-    get_lattice,
-    one_best_decoding,
-    rescore_with_whole_lattice,
-)
+from icefall.decode import get_lattice, one_best_decoding, rescore_with_whole_lattice
 from icefall.utils import AttributeDict, get_texts
 
 
@@ -46,9 +42,11 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
+        help=(
+            "Path to the checkpoint. "
+            "The checkpoint is assumed to be saved by "
+            "icefall.checkpoint.save_checkpoint()."
+        ),
     )
 
     parser.add_argument(
@@ -58,9 +56,7 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument(
-        "--HLG", type=str, required=True, help="Path to HLG.pt."
-    )
+    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
 
     parser.add_argument(
         "--method",
@@ -103,10 +99,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     return parser
@@ -144,10 +142,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -215,9 +212,7 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
     features = features.permute(0, 2, 1)  # now features is (N, C, T)
 
     with torch.no_grad():
@@ -269,9 +264,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py
index 6b37d5c23..071ac792b 100755
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py
@@ -355,9 +355,7 @@ def compute_loss(
     info["utt_duration"] = supervisions["num_frames"].sum().item()
     # averaged padding proportion over utterances
     info["utt_pad_proportion"] = (
-        ((feature.size(2) - supervisions["num_frames"]) / feature.size(2))
-        .sum()
-        .item()
+        ((feature.size(2) - supervisions["num_frames"]) / feature.size(2)).sum().item()
     )
 
     return loss, info
@@ -470,9 +468,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/egs/librispeech/ASR/transducer/beam_search.py b/egs/librispeech/ASR/transducer/beam_search.py
index 11032f31a..b45b6a9d8 100644
--- a/egs/librispeech/ASR/transducer/beam_search.py
+++ b/egs/librispeech/ASR/transducer/beam_search.py
@@ -38,9 +38,7 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
     blank_id = model.decoder.blank_id
     device = model.device
 
-    sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(
-        1, 1
-    )
+    sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(1, 1)
     decoder_out, (h, c) = model.decoder(sos)
     T = encoder_out.size(1)
     t = 0
@@ -123,9 +121,7 @@ def beam_search(
     max_u = 20000  # terminate after this number of steps
     u = 0
 
-    cache: Dict[
-        str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
-    ] = {}
+    cache: Dict[str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = {}
 
     while t < T and u < max_u:
         # fmt: off
@@ -157,9 +153,9 @@ def beam_search(
             cached_key = "_".join(map(str, y_star.ys))
 
             if cached_key not in cache:
-                decoder_input = torch.tensor(
-                    [y_star.ys[-1]], device=device
-                ).reshape(1, 1)
+                decoder_input = torch.tensor([y_star.ys[-1]], device=device).reshape(
+                    1, 1
+                )
 
                 decoder_out, decoder_state = model.decoder(
                     decoder_input,
diff --git a/egs/librispeech/ASR/transducer/decode.py b/egs/librispeech/ASR/transducer/decode.py
index 5f233df87..f30332cea 100755
--- a/egs/librispeech/ASR/transducer/decode.py
+++ b/egs/librispeech/ASR/transducer/decode.py
@@ -71,16 +71,19 @@ def get_parser():
         "--epoch",
         type=int,
         default=34,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=11,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -228,9 +231,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
     batch_size = encoder_out.size(0)
 
@@ -245,9 +246,7 @@ def decode_one_batch(
                 model=model, encoder_out=encoder_out_i, beam=params.beam_size
             )
         else:
-            raise ValueError(
-                f"Unsupported decoding method: {params.decoding_method}"
-            )
+            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
         hyps.append(sp.decode(hyp).split())
 
     if params.decoding_method == "greedy_search":
@@ -318,9 +317,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -353,8 +350,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
diff --git a/egs/librispeech/ASR/transducer/export.py b/egs/librispeech/ASR/transducer/export.py
index 5a5db30c4..4d9f937f5 100755
--- a/egs/librispeech/ASR/transducer/export.py
+++ b/egs/librispeech/ASR/transducer/export.py
@@ -67,17 +67,20 @@ def get_parser():
         "--epoch",
         type=int,
         default=34,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=11,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -238,9 +241,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer/pretrained.py b/egs/librispeech/ASR/transducer/pretrained.py
index 1db2df648..7aadfbcd1 100755
--- a/egs/librispeech/ASR/transducer/pretrained.py
+++ b/egs/librispeech/ASR/transducer/pretrained.py
@@ -60,9 +60,11 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
+        help=(
+            "Path to the checkpoint. "
+            "The checkpoint is assumed to be saved by "
+            "icefall.checkpoint.save_checkpoint()."
+        ),
     )
 
     parser.add_argument(
@@ -87,10 +89,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     parser.add_argument(
@@ -188,10 +192,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -249,9 +252,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -287,9 +288,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer/rnn.py b/egs/librispeech/ASR/transducer/rnn.py
index 2a165b0c1..fe8732301 100644
--- a/egs/librispeech/ASR/transducer/rnn.py
+++ b/egs/librispeech/ASR/transducer/rnn.py
@@ -117,12 +117,8 @@ class LayerNormLSTMCell(nn.Module):
         )
 
         if bias:
-            self.bias_ih = nn.Parameter(
-                torch.empty(4 * hidden_size, **factory_kwargs)
-            )
-            self.bias_hh = nn.Parameter(
-                torch.empty(4 * hidden_size, **factory_kwargs)
-            )
+            self.bias_ih = nn.Parameter(torch.empty(4 * hidden_size, **factory_kwargs))
+            self.bias_hh = nn.Parameter(torch.empty(4 * hidden_size, **factory_kwargs))
         else:
             self.register_parameter("bias_ih", None)
             self.register_parameter("bias_hh", None)
@@ -348,9 +344,7 @@ class LayerNormLSTM(nn.Module):
             device=device,
             dtype=dtype,
         )
-        first_layer = LayerNormLSTMLayer(
-            input_size=input_size, **factory_kwargs
-        )
+        first_layer = LayerNormLSTMLayer(input_size=input_size, **factory_kwargs)
         layers = [first_layer]
         for i in range(1, num_layers):
             layers.append(
@@ -385,9 +379,7 @@ class LayerNormLSTM(nn.Module):
             - List[(next_h, next_c)] containing the hidden states for all layers
 
         """
-        output_states = torch.jit.annotate(
-            List[Tuple[torch.Tensor, torch.Tensor]], []
-        )
+        output_states = torch.jit.annotate(List[Tuple[torch.Tensor, torch.Tensor]], [])
         output = input
         for i, rnn_layer in enumerate(self.layers):
             state = states[i]
@@ -456,12 +448,8 @@ class LayerNormGRUCell(nn.Module):
         )
 
         if bias:
-            self.bias_ih = nn.Parameter(
-                torch.empty(3 * hidden_size, **factory_kwargs)
-            )
-            self.bias_hh = nn.Parameter(
-                torch.empty(3 * hidden_size, **factory_kwargs)
-            )
+            self.bias_ih = nn.Parameter(torch.empty(3 * hidden_size, **factory_kwargs))
+            self.bias_hh = nn.Parameter(torch.empty(3 * hidden_size, **factory_kwargs))
         else:
             self.register_parameter("bias_ih", None)
             self.register_parameter("bias_hh", None)
diff --git a/egs/librispeech/ASR/transducer/test_rnn.py b/egs/librispeech/ASR/transducer/test_rnn.py
index 8591e2d8a..74c94cc70 100755
--- a/egs/librispeech/ASR/transducer/test_rnn.py
+++ b/egs/librispeech/ASR/transducer/test_rnn.py
@@ -254,9 +254,7 @@ def test_layernorm_lstm_layer_with_projection_forward(device="cpu"):
         for name, self_param in self_layer.cell.named_parameters():
             getattr(torch_layer, f"{name}_l0").copy_(self_param)
 
-    torch_y, (torch_h, torch_c) = torch_layer(
-        x_clone, (h.unsqueeze(0), c.unsqueeze(0))
-    )
+    torch_y, (torch_h, torch_c) = torch_layer(x_clone, (h.unsqueeze(0), c.unsqueeze(0)))
     assert_allclose(self_y, torch_y)
     assert_allclose(self_h, torch_h)
     assert_allclose(self_c, torch_c)
@@ -303,9 +301,7 @@ def test_layernorm_lstm_layer_forward(device="cpu"):
         for name, self_param in self_layer.cell.named_parameters():
             getattr(torch_layer, f"{name}_l0").copy_(self_param)
 
-    torch_y, (torch_h, torch_c) = torch_layer(
-        x_clone, (h.unsqueeze(0), c.unsqueeze(0))
-    )
+    torch_y, (torch_h, torch_c) = torch_layer(x_clone, (h.unsqueeze(0), c.unsqueeze(0)))
     assert_allclose(self_y, torch_y)
     assert_allclose(self_h, torch_h)
     assert_allclose(self_c, torch_c)
@@ -594,9 +590,7 @@ def test_layernorm_gru_cell_forward(device="cpu"):
 
     assert_allclose(self_h, torch_h, atol=1e-5)
 
-    (
-        self_h.reshape(-1) * torch.arange(self_h.numel(), device=device)
-    ).sum().backward()
+    (self_h.reshape(-1) * torch.arange(self_h.numel(), device=device)).sum().backward()
     (
         torch_h.reshape(-1) * torch.arange(torch_h.numel(), device=device)
     ).sum().backward()
@@ -718,9 +712,7 @@ def test_layernorm_gru_forward(device="cpu"):
     T = torch.randint(low=2, high=100, size=(1,))
 
     x = torch.rand(N, T, input_size, device=device).requires_grad_()
-    states = [
-        torch.rand(N, hidden_size, device=device) for _ in range(num_layers)
-    ]
+    states = [torch.rand(N, hidden_size, device=device) for _ in range(num_layers)]
 
     x_clone = x.detach().clone().requires_grad_()
 
diff --git a/egs/librispeech/ASR/transducer/train.py b/egs/librispeech/ASR/transducer/train.py
index 1dd65eddb..674ea10a6 100755
--- a/egs/librispeech/ASR/transducer/train.py
+++ b/egs/librispeech/ASR/transducer/train.py
@@ -396,9 +396,7 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -520,9 +518,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -659,9 +655,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/transducer_lstm/beam_search.py b/egs/librispeech/ASR/transducer_lstm/beam_search.py
index 3531a9633..5342c3e8c 100644
--- a/egs/librispeech/ASR/transducer_lstm/beam_search.py
+++ b/egs/librispeech/ASR/transducer_lstm/beam_search.py
@@ -38,9 +38,7 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
     blank_id = model.decoder.blank_id
     device = model.device
 
-    sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(
-        1, 1
-    )
+    sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(1, 1)
     decoder_out, (h, c) = model.decoder(sos)
     T = encoder_out.size(1)
     t = 0
@@ -124,9 +122,7 @@ def beam_search(
     max_u = 20000  # terminate after this number of steps
     u = 0
 
-    cache: Dict[
-        str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
-    ] = {}
+    cache: Dict[str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = {}
 
     while t < T and u < max_u:
         # fmt: off
@@ -158,9 +154,9 @@ def beam_search(
             cached_key = "_".join(map(str, y_star.ys))
 
             if cached_key not in cache:
-                decoder_input = torch.tensor(
-                    [y_star.ys[-1]], device=device
-                ).reshape(1, 1)
+                decoder_input = torch.tensor([y_star.ys[-1]], device=device).reshape(
+                    1, 1
+                )
 
                 decoder_out, decoder_state = model.decoder(
                     decoder_input,
diff --git a/egs/librispeech/ASR/transducer_lstm/decode.py b/egs/librispeech/ASR/transducer_lstm/decode.py
index 604235e2a..61b9de504 100755
--- a/egs/librispeech/ASR/transducer_lstm/decode.py
+++ b/egs/librispeech/ASR/transducer_lstm/decode.py
@@ -71,16 +71,19 @@ def get_parser():
         "--epoch",
         type=int,
         default=77,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=55,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -225,9 +228,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
     batch_size = encoder_out.size(0)
 
@@ -242,9 +243,7 @@ def decode_one_batch(
                 model=model, encoder_out=encoder_out_i, beam=params.beam_size
             )
         else:
-            raise ValueError(
-                f"Unsupported decoding method: {params.decoding_method}"
-            )
+            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
         hyps.append(sp.decode(hyp).split())
 
     if params.decoding_method == "greedy_search":
@@ -315,9 +314,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -350,8 +347,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
diff --git a/egs/librispeech/ASR/transducer_lstm/encoder.py b/egs/librispeech/ASR/transducer_lstm/encoder.py
index 3dc992dd2..038d80077 100644
--- a/egs/librispeech/ASR/transducer_lstm/encoder.py
+++ b/egs/librispeech/ASR/transducer_lstm/encoder.py
@@ -48,9 +48,7 @@ class LstmEncoder(EncoderInterface):
         if vgg_frontend:
             self.encoder_embed = VggSubsampling(num_features, real_hidden_size)
         else:
-            self.encoder_embed = Conv2dSubsampling(
-                num_features, real_hidden_size
-            )
+            self.encoder_embed = Conv2dSubsampling(num_features, real_hidden_size)
 
         self.rnn = nn.LSTM(
             input_size=hidden_size,
diff --git a/egs/librispeech/ASR/transducer_lstm/train.py b/egs/librispeech/ASR/transducer_lstm/train.py
index cdb801e79..57bda63fd 100755
--- a/egs/librispeech/ASR/transducer_lstm/train.py
+++ b/egs/librispeech/ASR/transducer_lstm/train.py
@@ -400,9 +400,7 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -524,9 +522,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -665,9 +661,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/transducer_stateless/alignment.py b/egs/librispeech/ASR/transducer_stateless/alignment.py
index f143611ea..65f2c58d8 100644
--- a/egs/librispeech/ASR/transducer_stateless/alignment.py
+++ b/egs/librispeech/ASR/transducer_stateless/alignment.py
@@ -193,9 +193,7 @@ def force_alignment(
         decoder_out = model.decoder(decoder_input, need_pad=False)
         # decoder_output is of shape (num_active_items, 1, decoder_output_dim)
 
-        current_encoder_out = current_encoder_out.expand(
-            decoder_out.size(0), 1, -1
-        )
+        current_encoder_out = current_encoder_out.expand(decoder_out.size(0), 1, -1)
 
         logits = model.joiner(
             current_encoder_out,
diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py
index ea985f30d..1d79eef9d 100644
--- a/egs/librispeech/ASR/transducer_stateless/beam_search.py
+++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py
@@ -316,9 +316,9 @@ def greedy_search(
         y = logits.argmax().item()
         if y != blank_id:
             hyp.append(y)
-            decoder_input = torch.tensor(
-                [hyp[-context_size:]], device=device
-            ).reshape(1, context_size)
+            decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape(
+                1, context_size
+            )
 
             decoder_out = model.decoder(decoder_input, need_pad=False)
 
@@ -478,9 +478,7 @@ class HypothesisList(object):
         key = hyp.key
         if key in self:
             old_hyp = self._data[key]  # shallow copy
-            torch.logaddexp(
-                old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
-            )
+            torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob)
         else:
             self._data[key] = hyp
 
@@ -496,9 +494,7 @@ class HypothesisList(object):
           Return the hypothesis that has the largest `log_prob`.
         """
         if length_norm:
-            return max(
-                self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
-            )
+            return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
         else:
             return max(self._data.values(), key=lambda hyp: hyp.log_prob)
 
@@ -786,9 +782,7 @@ def modified_beam_search(
         log_probs_shape = k2.ragged.create_ragged_shape2(
             row_splits=row_splits, cached_tot_size=log_probs.numel()
         )
-        ragged_log_probs = k2.RaggedTensor(
-            shape=log_probs_shape, value=log_probs
-        )
+        ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
 
         for i in range(batch_size):
             topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
@@ -887,9 +881,7 @@ def _deprecated_modified_beam_search(
         decoder_out = model.decoder(decoder_input, need_pad=False)
         # decoder_output is of shape (num_hyps, 1, decoder_output_dim)
 
-        current_encoder_out = current_encoder_out.expand(
-            decoder_out.size(0), 1, -1
-        )
+        current_encoder_out = current_encoder_out.expand(decoder_out.size(0), 1, -1)
 
         logits = model.joiner(
             current_encoder_out,
@@ -959,9 +951,9 @@ def beam_search(
 
     device = model.device
 
-    decoder_input = torch.tensor(
-        [blank_id] * context_size, device=device
-    ).reshape(1, context_size)
+    decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
+        1, context_size
+    )
 
     decoder_out = model.decoder(decoder_input, need_pad=False)
 
diff --git a/egs/librispeech/ASR/transducer_stateless/compute_ali.py b/egs/librispeech/ASR/transducer_stateless/compute_ali.py
index 48769e9d1..89992856d 100755
--- a/egs/librispeech/ASR/transducer_stateless/compute_ali.py
+++ b/egs/librispeech/ASR/transducer_stateless/compute_ali.py
@@ -54,16 +54,19 @@ def get_parser():
         "--epoch",
         type=int,
         default=34,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=20,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -124,8 +127,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
@@ -162,9 +164,7 @@ def compute_alignments(
 
         feature_lens = supervisions["num_frames"].to(device)
 
-        encoder_out, encoder_out_lens = model.encoder(
-            x=feature, x_lens=feature_lens
-        )
+        encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
         batch_size = encoder_out.size(0)
 
@@ -204,9 +204,7 @@ def compute_alignments(
         if batch_idx % 2 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
 
     return CutSet.from_cuts(cuts)
 
diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py
index cde52c9fc..d279eae85 100644
--- a/egs/librispeech/ASR/transducer_stateless/conformer.py
+++ b/egs/librispeech/ASR/transducer_stateless/conformer.py
@@ -209,10 +209,7 @@ class Conformer(Transformer):
 
           NOTE: the returned tensors are on the given device.
         """
-        if (
-            len(self._init_state) == 2
-            and self._init_state[0].size(1) == left_context
-        ):
+        if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context:
             # Note: It is OK to share the init state as it is
             # not going to be modified by the model
             return self._init_state
@@ -421,9 +418,7 @@ class ConformerEncoderLayer(nn.Module):
         causal: bool = False,
     ) -> None:
         super(ConformerEncoderLayer, self).__init__()
-        self.self_attn = RelPositionMultiheadAttention(
-            d_model, nhead, dropout=0.0
-        )
+        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
 
         self.feed_forward = nn.Sequential(
             nn.Linear(d_model, dim_feedforward),
@@ -439,22 +434,16 @@ class ConformerEncoderLayer(nn.Module):
             nn.Linear(dim_feedforward, d_model),
         )
 
-        self.conv_module = ConvolutionModule(
-            d_model, cnn_module_kernel, causal=causal
-        )
+        self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal)
 
-        self.norm_ff_macaron = nn.LayerNorm(
-            d_model
-        )  # for the macaron style FNN module
+        self.norm_ff_macaron = nn.LayerNorm(d_model)  # for the macaron style FNN module
         self.norm_ff = nn.LayerNorm(d_model)  # for the FNN module
         self.norm_mha = nn.LayerNorm(d_model)  # for the MHA module
 
         self.ff_scale = 0.5
 
         self.norm_conv = nn.LayerNorm(d_model)  # for the CNN module
-        self.norm_final = nn.LayerNorm(
-            d_model
-        )  # for the final output of the block
+        self.norm_final = nn.LayerNorm(d_model)  # for the final output of the block
 
         self.dropout = nn.Dropout(dropout)
 
@@ -486,9 +475,7 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(
-            self.feed_forward_macaron(src)
-        )
+        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -514,9 +501,7 @@ class ConformerEncoderLayer(nn.Module):
         if self.normalize_before:
             src = self.norm_conv(src)
 
-        src, _ = self.conv_module(
-            src, src_key_padding_mask=src_key_padding_mask
-        )
+        src, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
         src = residual + self.dropout(src)
 
         if not self.normalize_before:
@@ -581,9 +566,7 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(
-            self.feed_forward_macaron(src)
-        )
+        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -625,9 +608,7 @@ class ConformerEncoderLayer(nn.Module):
         if self.normalize_before:
             src = self.norm_conv(src)
 
-        src, conv_cache = self.conv_module(
-            src, states[1], right_context=right_context
-        )
+        src, conv_cache = self.conv_module(src, states[1], right_context=right_context)
         states[1] = conv_cache
         src = residual + self.dropout(src)
 
@@ -779,9 +760,7 @@ class RelPositionalEncoding(torch.nn.Module):
 
     """
 
-    def __init__(
-        self, d_model: int, dropout_rate: float, max_len: int = 5000
-    ) -> None:
+    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
         """Construct an PositionalEncoding object."""
         super(RelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -798,9 +777,7 @@ class RelPositionalEncoding(torch.nn.Module):
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(1) >= x_size_1 * 2 - 1:
                 # Note: TorchScript doesn't implement operator== for torch.Device
-                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
-                    x.device
-                ):
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
                     self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                 return
         # Suppose `i` means to the position of query vector and `j` means the
@@ -826,9 +803,7 @@ class RelPositionalEncoding(torch.nn.Module):
         pe = torch.cat([pe_positive, pe_negative], dim=1)
         self.pe = pe.to(device=x.device, dtype=x.dtype)
 
-    def forward(
-        self, x: torch.Tensor, left_context: int = 0
-    ) -> Tuple[Tensor, Tensor]:
+    def forward(self, x: torch.Tensor, left_context: int = 0) -> Tuple[Tensor, Tensor]:
         """Add positional encoding.
 
         Args:
@@ -1092,9 +1067,9 @@ class RelPositionMultiheadAttention(nn.Module):
 
         if torch.equal(query, key) and torch.equal(key, value):
             # self-attention
-            q, k, v = nn.functional.linear(
-                query, in_proj_weight, in_proj_bias
-            ).chunk(3, dim=-1)
+            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
+                3, dim=-1
+            )
 
         elif torch.equal(key, value):
             # encoder-decoder attention
@@ -1163,33 +1138,25 @@ class RelPositionMultiheadAttention(nn.Module):
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0)
                 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
-                    raise RuntimeError(
-                        "The size of the 2D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
             elif attn_mask.dim() == 3:
                 if list(attn_mask.size()) != [
                     bsz * num_heads,
                     query.size(0),
                     key.size(0),
                 ]:
-                    raise RuntimeError(
-                        "The size of the 3D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
             else:
                 raise RuntimeError(
-                    "attn_mask's dimension {} is not supported".format(
-                        attn_mask.dim()
-                    )
+                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
                 )
             # attn_mask's dim is 3 now.
 
         # convert ByteTensor key_padding_mask to bool
-        if (
-            key_padding_mask is not None
-            and key_padding_mask.dtype == torch.uint8
-        ):
+        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
             warnings.warn(
-                "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
+                "Byte tensor for key_padding_mask is deprecated. Use bool tensor"
+                " instead."
             )
             key_padding_mask = key_padding_mask.to(torch.bool)
 
@@ -1228,14 +1195,10 @@ class RelPositionMultiheadAttention(nn.Module):
         # first compute matrix a and matrix c
         # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
         k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
-        matrix_ac = torch.matmul(
-            q_with_bias_u, k
-        )  # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
 
         # compute matrix b and matrix d
-        matrix_bd = torch.matmul(
-            q_with_bias_v, p
-        )  # (batch, head, time1, 2*time1-1)
+        matrix_bd = torch.matmul(q_with_bias_v, p)  # (batch, head, time1, 2*time1-1)
 
         matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
 
@@ -1243,9 +1206,7 @@ class RelPositionMultiheadAttention(nn.Module):
             matrix_ac + matrix_bd
         ) * scaling  # (batch, head, time1, time2)
 
-        attn_output_weights = attn_output_weights.view(
-            bsz * num_heads, tgt_len, -1
-        )
+        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
 
         assert list(attn_output_weights.size()) == [
             bsz * num_heads,
@@ -1290,9 +1251,7 @@ class RelPositionMultiheadAttention(nn.Module):
             attn_output_weights = attn_output_weights.view(
                 bsz, num_heads, tgt_len, src_len
             )
-            attn_output_weights = attn_output_weights.masked_fill(
-                combined_mask, 0.0
-            )
+            attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0)
             attn_output_weights = attn_output_weights.view(
                 bsz * num_heads, tgt_len, src_len
             )
@@ -1304,13 +1263,9 @@ class RelPositionMultiheadAttention(nn.Module):
         attn_output = torch.bmm(attn_output_weights, v)
         assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
         attn_output = (
-            attn_output.transpose(0, 1)
-            .contiguous()
-            .view(tgt_len, bsz, embed_dim)
-        )
-        attn_output = nn.functional.linear(
-            attn_output, out_proj_weight, out_proj_bias
+            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
         )
+        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
 
         if need_weights:
             # average attention weights over heads
@@ -1418,16 +1373,12 @@ class ConvolutionModule(nn.Module):
                 # manualy padding self.lorder zeros to the left
                 x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
             else:
-                assert (
-                    not self.training
-                ), "Cache should be None in training time"
+                assert not self.training, "Cache should be None in training time"
                 assert cache.size(0) == self.lorder
                 x = torch.cat([cache.permute(1, 2, 0), x], dim=2)
                 if right_context > 0:
                     cache = x.permute(2, 0, 1)[
-                        -(self.lorder + right_context) : (  # noqa
-                            -right_context
-                        ),
+                        -(self.lorder + right_context) : (-right_context),  # noqa
                         ...,
                     ]
                 else:
diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py
index 74bba9cad..314f49154 100755
--- a/egs/librispeech/ASR/transducer_stateless/decode.py
+++ b/egs/librispeech/ASR/transducer_stateless/decode.py
@@ -94,16 +94,19 @@ def get_parser():
         "--epoch",
         type=int,
         default=29,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=13,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -171,8 +174,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -230,9 +232,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
     hyps = []
 
@@ -248,10 +248,7 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -297,11 +294,7 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            (
-                f"beam_{params.beam}_"
-                f"max_contexts_{params.max_contexts}_"
-                f"max_states_{params.max_states}"
-            ): hyps
+            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -374,9 +367,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -409,8 +400,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -450,9 +440,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
diff --git a/egs/librispeech/ASR/transducer_stateless/decoder.py b/egs/librispeech/ASR/transducer_stateless/decoder.py
index fbc2373a9..a182d91e2 100644
--- a/egs/librispeech/ASR/transducer_stateless/decoder.py
+++ b/egs/librispeech/ASR/transducer_stateless/decoder.py
@@ -87,9 +87,7 @@ class Decoder(nn.Module):
         if self.context_size > 1:
             embedding_out = embedding_out.permute(0, 2, 1)
             if need_pad is True:
-                embedding_out = F.pad(
-                    embedding_out, pad=(self.context_size - 1, 0)
-                )
+                embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
             else:
                 # During inference time, there is no need to do extra padding
                 # as we only need one output
diff --git a/egs/librispeech/ASR/transducer_stateless/export.py b/egs/librispeech/ASR/transducer_stateless/export.py
index 8bd0bdea1..7c10b4348 100755
--- a/egs/librispeech/ASR/transducer_stateless/export.py
+++ b/egs/librispeech/ASR/transducer_stateless/export.py
@@ -68,17 +68,20 @@ def get_parser():
         "--epoch",
         type=int,
         default=20,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=10,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -109,8 +112,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
@@ -244,9 +246,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless/joiner.py b/egs/librispeech/ASR/transducer_stateless/joiner.py
index 93cccbd8c..e1625992d 100644
--- a/egs/librispeech/ASR/transducer_stateless/joiner.py
+++ b/egs/librispeech/ASR/transducer_stateless/joiner.py
@@ -60,13 +60,9 @@ class Joiner(nn.Module):
         encoder_out_len: List[int] = encoder_out_len.tolist()
         decoder_out_len: List[int] = decoder_out_len.tolist()
 
-        encoder_out_list = [
-            encoder_out[i, : encoder_out_len[i], :] for i in range(N)
-        ]
+        encoder_out_list = [encoder_out[i, : encoder_out_len[i], :] for i in range(N)]
 
-        decoder_out_list = [
-            decoder_out[i, : decoder_out_len[i], :] for i in range(N)
-        ]
+        decoder_out_list = [decoder_out[i, : decoder_out_len[i], :] for i in range(N)]
 
         x = [
             e.unsqueeze(1) + d.unsqueeze(0)
diff --git a/egs/librispeech/ASR/transducer_stateless/pretrained.py b/egs/librispeech/ASR/transducer_stateless/pretrained.py
index b64521801..bd7eeff28 100755
--- a/egs/librispeech/ASR/transducer_stateless/pretrained.py
+++ b/egs/librispeech/ASR/transducer_stateless/pretrained.py
@@ -90,9 +90,11 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
+        help=(
+            "Path to the checkpoint. "
+            "The checkpoint is assumed to be saved by "
+            "icefall.checkpoint.save_checkpoint()."
+        ),
     )
 
     parser.add_argument(
@@ -117,10 +119,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     parser.add_argument(
@@ -167,8 +171,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -197,10 +200,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -259,9 +261,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -334,9 +334,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py b/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py
index b00fc34f1..9af46846a 100755
--- a/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py
+++ b/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py
@@ -140,16 +140,13 @@ def main():
                 token_alignment[i, : token_alignment_length[i]].tolist(), sp=sp
             )
             word_starting_time = [
-                "{:.2f}".format(i * frame_shift_in_second)
-                for i in word_starting_frames
+                "{:.2f}".format(i * frame_shift_in_second) for i in word_starting_frames
             ]
 
             words = supervisions["text"][i].split()
 
             assert len(word_starting_frames) == len(words)
-            word_starting_time_dict[cuts[i].id] = list(
-                zip(words, word_starting_time)
-            )
+            word_starting_time_dict[cuts[i].id] = list(zip(words, word_starting_time))
 
         # This is a demo script and we exit here after processing
         # one batch.
@@ -160,9 +157,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless/test_conformer.py b/egs/librispeech/ASR/transducer_stateless/test_conformer.py
index d1350c8ab..65b08d425 100755
--- a/egs/librispeech/ASR/transducer_stateless/test_conformer.py
+++ b/egs/librispeech/ASR/transducer_stateless/test_conformer.py
@@ -29,9 +29,7 @@ from conformer import Conformer
 
 def test_conformer():
     feature_dim = 50
-    c = Conformer(
-        num_features=feature_dim, output_dim=256, d_model=128, nhead=4
-    )
+    c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4)
     batch_size = 5
     seq_len = 20
     # Just make sure the forward pass runs.
diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py
index ae93f3348..bcb883fa5 100755
--- a/egs/librispeech/ASR/transducer_stateless/train.py
+++ b/egs/librispeech/ASR/transducer_stateless/train.py
@@ -136,8 +136,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -422,9 +421,7 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -545,9 +542,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -664,13 +659,9 @@ def run(rank, world_size, args):
         num_removed = num_in_total - num_left
         removed_percent = num_removed / num_in_total * 100
 
-        logging.info(
-            f"Before removing short and long utterances: {num_in_total}"
-        )
+        logging.info(f"Before removing short and long utterances: {num_in_total}")
         logging.info(f"After removing short and long utterances: {num_left}")
-        logging.info(
-            f"Removed {num_removed} utterances ({removed_percent:.5f}%)"
-        )
+        logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
     except TypeError as e:
         # You can ignore this error as previous versions of Lhotse work fine
         # for the above code. In recent versions of Lhotse, it uses
@@ -698,9 +689,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/transducer_stateless/transformer.py b/egs/librispeech/ASR/transducer_stateless/transformer.py
index e851dcc32..b3ff153c1 100644
--- a/egs/librispeech/ASR/transducer_stateless/transformer.py
+++ b/egs/librispeech/ASR/transducer_stateless/transformer.py
@@ -250,9 +250,7 @@ def _get_activation_fn(activation: str):
     elif activation == "gelu":
         return nn.functional.gelu
 
-    raise RuntimeError(
-        "activation should be relu/gelu, not {}".format(activation)
-    )
+    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
 
 
 class PositionalEncoding(nn.Module):
diff --git a/egs/librispeech/ASR/transducer_stateless2/decode.py b/egs/librispeech/ASR/transducer_stateless2/decode.py
index ac2807241..86ef9e5b6 100755
--- a/egs/librispeech/ASR/transducer_stateless2/decode.py
+++ b/egs/librispeech/ASR/transducer_stateless2/decode.py
@@ -94,16 +94,19 @@ def get_parser():
         "--epoch",
         type=int,
         default=29,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=13,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -171,8 +174,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -230,9 +232,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
     hyps = []
 
@@ -248,10 +248,7 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -297,11 +294,7 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            (
-                f"beam_{params.beam}_"
-                f"max_contexts_{params.max_contexts}_"
-                f"max_states_{params.max_states}"
-            ): hyps
+            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -374,9 +367,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -409,8 +400,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -450,9 +440,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
diff --git a/egs/librispeech/ASR/transducer_stateless2/export.py b/egs/librispeech/ASR/transducer_stateless2/export.py
index 57c1a6094..d95eeb1f4 100755
--- a/egs/librispeech/ASR/transducer_stateless2/export.py
+++ b/egs/librispeech/ASR/transducer_stateless2/export.py
@@ -63,17 +63,20 @@ def get_parser():
         "--epoch",
         type=int,
         default=20,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=10,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -104,8 +107,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
@@ -176,9 +178,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless2/pretrained.py b/egs/librispeech/ASR/transducer_stateless2/pretrained.py
index 292f77f03..793931e3b 100755
--- a/egs/librispeech/ASR/transducer_stateless2/pretrained.py
+++ b/egs/librispeech/ASR/transducer_stateless2/pretrained.py
@@ -90,9 +90,11 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
+        help=(
+            "Path to the checkpoint. "
+            "The checkpoint is assumed to be saved by "
+            "icefall.checkpoint.save_checkpoint()."
+        ),
     )
 
     parser.add_argument(
@@ -117,10 +119,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     parser.add_argument(
@@ -167,8 +171,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -197,10 +200,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -259,9 +261,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -334,9 +334,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless2/train.py b/egs/librispeech/ASR/transducer_stateless2/train.py
index ea15c9040..68e247f23 100755
--- a/egs/librispeech/ASR/transducer_stateless2/train.py
+++ b/egs/librispeech/ASR/transducer_stateless2/train.py
@@ -136,8 +136,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -410,9 +409,7 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -533,9 +530,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -652,13 +647,9 @@ def run(rank, world_size, args):
         num_removed = num_in_total - num_left
         removed_percent = num_removed / num_in_total * 100
 
-        logging.info(
-            f"Before removing short and long utterances: {num_in_total}"
-        )
+        logging.info(f"Before removing short and long utterances: {num_in_total}")
         logging.info(f"After removing short and long utterances: {num_left}")
-        logging.info(
-            f"Removed {num_removed} utterances ({removed_percent:.5f}%)"
-        )
+        logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
     except TypeError as e:
         # You can ignore this error as previous versions of Lhotse work fine
         # for the above code. In recent versions of Lhotse, it uses
@@ -686,9 +677,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py
index d596e05cb..22b6ab911 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py
@@ -95,16 +95,19 @@ def get_parser():
         "--epoch",
         type=int,
         default=29,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=13,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -172,8 +175,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -231,9 +233,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
     hyps = []
 
@@ -249,10 +249,7 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -298,11 +295,7 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            (
-                f"beam_{params.beam}_"
-                f"max_contexts_{params.max_contexts}_"
-                f"max_states_{params.max_states}"
-            ): hyps
+            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -375,9 +368,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -410,8 +401,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -451,9 +441,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py
index b6b69d932..fad9a6977 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py
@@ -69,17 +69,20 @@ def get_parser():
         "--epoch",
         type=int,
         default=20,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=10,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -110,8 +113,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
@@ -247,9 +249,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py
index f297fa2b2..efd257b5d 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py
@@ -90,9 +90,11 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
+        help=(
+            "Path to the checkpoint. "
+            "The checkpoint is assumed to be saved by "
+            "icefall.checkpoint.save_checkpoint()."
+        ),
     )
 
     parser.add_argument(
@@ -117,10 +119,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     parser.add_argument(
@@ -167,8 +171,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -197,10 +200,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -259,9 +261,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -334,9 +334,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py
index ef51a7811..1e1188ca6 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py
@@ -41,9 +41,7 @@ def test_dataset():
     print(args)
 
     if args.enable_musan:
-        cuts_musan = load_manifest(
-            Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
     else:
         cuts_musan = None
 
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
index 27912738c..88987d91c 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
@@ -114,8 +114,7 @@ def get_parser():
         "--full-libri",
         type=str2bool,
         default=True,
-        help="When enabled, use 960h LibriSpeech. "
-        "Otherwise, use 100h subset.",
+        help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.",
     )
 
     parser.add_argument(
@@ -170,8 +169,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -469,9 +467,7 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -635,9 +631,7 @@ def train_one_epoch(
                     f"train/current_{prefix}_",
                     params.batch_idx_train,
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
                 libri_tot_loss.write_summary(
                     tb_writer, "train/libri_tot_", params.batch_idx_train
                 )
@@ -784,9 +778,7 @@ def run(rank, world_size, args):
     train_giga_cuts = train_giga_cuts.repeat(times=None)
 
     if args.enable_musan:
-        cuts_musan = load_manifest(
-            Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
     else:
         cuts_musan = None
 
@@ -825,9 +817,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/ptb/LM/local/sort_lm_training_data.py b/egs/ptb/LM/local/sort_lm_training_data.py
index af54dbd07..bed3856e4 100755
--- a/egs/ptb/LM/local/sort_lm_training_data.py
+++ b/egs/ptb/LM/local/sort_lm_training_data.py
@@ -135,9 +135,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/ptb/LM/local/test_prepare_lm_training_data.py b/egs/ptb/LM/local/test_prepare_lm_training_data.py
index 877720e7b..3790045fa 100755
--- a/egs/ptb/LM/local/test_prepare_lm_training_data.py
+++ b/egs/ptb/LM/local/test_prepare_lm_training_data.py
@@ -54,9 +54,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/spgispeech/ASR/local/compute_fbank_musan.py b/egs/spgispeech/ASR/local/compute_fbank_musan.py
index 6cb8b65ae..9bea28a41 100755
--- a/egs/spgispeech/ASR/local/compute_fbank_musan.py
+++ b/egs/spgispeech/ASR/local/compute_fbank_musan.py
@@ -87,9 +87,7 @@ def compute_fbank_musan():
     # create chunks of Musan with duration 5 - 10 seconds
     musan_cuts = (
         CutSet.from_manifests(
-            recordings=combine(
-                part["recordings"] for part in manifests.values()
-            )
+            recordings=combine(part["recordings"] for part in manifests.values())
         )
         .cut_into_windows(10.0)
         .filter(lambda c: c.duration > 5)
@@ -108,8 +106,6 @@ def compute_fbank_musan():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
     logging.basicConfig(format=formatter, level=logging.INFO)
     compute_fbank_musan()
diff --git a/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py b/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py
index 8116e7605..20ff6d7ab 100755
--- a/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py
+++ b/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py
@@ -103,11 +103,7 @@ def compute_fbank_spgispeech(args):
             chunk_size=chunk_size,
         )
         start = args.start
-        stop = (
-            min(args.stop, args.num_splits)
-            if args.stop > 0
-            else args.num_splits
-        )
+        stop = min(args.stop, args.num_splits) if args.stop > 0 else args.num_splits
         num_digits = len(str(args.num_splits))
         for i in range(start, stop):
             idx = f"{i + 1}".zfill(num_digits)
@@ -129,9 +125,7 @@ def compute_fbank_spgispeech(args):
                 logging.info(f"{partition} already exists - skipping.")
                 continue
             logging.info(f"Processing {partition}")
-            cut_set = load_manifest_lazy(
-                src_dir / f"cuts_{partition}_raw.jsonl.gz"
-            )
+            cut_set = load_manifest_lazy(src_dir / f"cuts_{partition}_raw.jsonl.gz")
             cut_set = cut_set.compute_and_store_features_batch(
                 extractor=extractor,
                 storage_path=output_dir / f"feats_{partition}",
@@ -144,9 +138,7 @@ def compute_fbank_spgispeech(args):
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
     logging.basicConfig(format=formatter, level=logging.INFO)
 
     args = get_args()
diff --git a/egs/spgispeech/ASR/local/prepare_splits.py b/egs/spgispeech/ASR/local/prepare_splits.py
index 8c8f1c133..508d4acd8 100755
--- a/egs/spgispeech/ASR/local/prepare_splits.py
+++ b/egs/spgispeech/ASR/local/prepare_splits.py
@@ -55,9 +55,7 @@ def split_spgispeech_train():
 
     # Add speed perturbation
     train_cuts = (
-        train_cuts
-        + train_cuts.perturb_speed(0.9)
-        + train_cuts.perturb_speed(1.1)
+        train_cuts + train_cuts.perturb_speed(0.9) + train_cuts.perturb_speed(1.1)
     )
 
     # Write the manifests to disk.
@@ -73,9 +71,7 @@ def split_spgispeech_train():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
     logging.basicConfig(format=formatter, level=logging.INFO)
 
     split_spgispeech_train()
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
index f165f6e60..83f95d123 100644
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -70,10 +70,12 @@ class SPGISpeechAsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description="These options are used for the preparation of "
-            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-            "effective batch sizes, sampling strategies, applied data "
-            "augmentations, etc.",
+            description=(
+                "These options are used for the preparation of "
+                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+                "effective batch sizes, sampling strategies, applied data "
+                "augmentations, etc."
+            ),
         )
         group.add_argument(
             "--manifest-dir",
@@ -85,67 +87,81 @@ class SPGISpeechAsrDataModule:
             "--enable-musan",
             type=str2bool,
             default=True,
-            help="When enabled, select noise from MUSAN and mix it "
-            "with training dataset. ",
+            help=(
+                "When enabled, select noise from MUSAN and mix it "
+                "with training dataset. "
+            ),
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help="When enabled, utterances (cuts) will be concatenated "
-            "to minimize the amount of padding.",
+            help=(
+                "When enabled, utterances (cuts) will be concatenated "
+                "to minimize the amount of padding."
+            ),
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help="Determines the maximum duration of a concatenated cut "
-            "relative to the duration of the longest cut in a batch.",
+            help=(
+                "Determines the maximum duration of a concatenated cut "
+                "relative to the duration of the longest cut in a batch."
+            ),
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help="The amount of padding (in seconds) inserted between "
-            "concatenated cuts. This padding is filled with noise when "
-            "noise augmentation is used.",
+            help=(
+                "The amount of padding (in seconds) inserted between "
+                "concatenated cuts. This padding is filled with noise when "
+                "noise augmentation is used."
+            ),
         )
         group.add_argument(
             "--max-duration",
             type=int,
             default=100.0,
-            help="Maximum pooled recordings duration (seconds) in a "
-            "single batch. You can reduce it if it causes CUDA OOM.",
+            help=(
+                "Maximum pooled recordings duration (seconds) in a "
+                "single batch. You can reduce it if it causes CUDA OOM."
+            ),
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=30,
-            help="The number of buckets for the BucketingSampler"
-            "(you might want to increase it for larger datasets).",
+            help=(
+                "The number of buckets for the BucketingSampler"
+                "(you might want to increase it for larger datasets)."
+            ),
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help="When enabled, use on-the-fly cut mixing and feature "
-            "extraction. Will drop existing precomputed feature manifests "
-            "if available.",
+            help=(
+                "When enabled, use on-the-fly cut mixing and feature "
+                "extraction. Will drop existing precomputed feature manifests "
+                "if available."
+            ),
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help="When enabled (=default), the examples will be "
-            "shuffled for each epoch.",
+            help=(
+                "When enabled (=default), the examples will be shuffled for each epoch."
+            ),
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=8,
-            help="The number of training dataloader workers that "
-            "collect the batches.",
+            help="The number of training dataloader workers that collect the batches.",
         )
         group.add_argument(
             "--enable-spec-aug",
@@ -157,10 +173,12 @@ class SPGISpeechAsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help="Used only when --enable-spec-aug is True. "
-            "It specifies the factor for time warping in SpecAugment. "
-            "Larger values mean more warping. "
-            "A value less than 1 means to disable time warp.",
+            help=(
+                "Used only when --enable-spec-aug is True. "
+                "It specifies the factor for time warping in SpecAugment. "
+                "Larger values mean more warping. "
+                "A value less than 1 means to disable time warp."
+            ),
         )
 
     def train_dataloaders(
@@ -176,24 +194,20 @@ class SPGISpeechAsrDataModule:
             The state dict for the training sampler.
         """
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(
-            self.args.manifest_dir / "cuts_musan.jsonl.gz"
-        )
+        cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.jsonl.gz")
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                f"Using cut concatenation with duration factor "
+                "Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -208,9 +222,7 @@ class SPGISpeechAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
             input_transforms.append(
                 SpecAugment(
                     time_warp_factor=self.args.spec_aug_time_warp_factor,
@@ -227,9 +239,7 @@ class SPGISpeechAsrDataModule:
         if self.args.on_the_fly_feats:
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
             )
         else:
@@ -282,9 +292,7 @@ class SPGISpeechAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
             )
         else:
             validate = K2SpeechRecognitionDataset(
@@ -328,9 +336,7 @@ class SPGISpeechAsrDataModule:
     @lru_cache()
     def train_cuts(self) -> CutSet:
         logging.info("About to get SPGISpeech train cuts")
-        return load_manifest_lazy(
-            self.args.manifest_dir / "cuts_train_shuf.jsonl.gz"
-        )
+        return load_manifest_lazy(self.args.manifest_dir / "cuts_train_shuf.jsonl.gz")
 
     @lru_cache()
     def dev_cuts(self) -> CutSet:
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py
index c39bd0530..72a7cd1c1 100755
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py
@@ -76,11 +76,7 @@ from beam_search import (
 )
 from train import get_params, get_transducer_model
 
-from icefall.checkpoint import (
-    average_checkpoints,
-    find_checkpoints,
-    load_checkpoint,
-)
+from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
 from icefall.utils import (
     AttributeDict,
     setup_logger,
@@ -117,9 +113,11 @@ def get_parser():
         "--avg",
         type=int,
         default=10,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch' and '--iter'",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch' and '--iter'"
+        ),
     )
 
     parser.add_argument(
@@ -187,8 +185,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -246,9 +243,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -263,10 +258,7 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -312,11 +304,7 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            (
-                f"beam_{params.beam}_"
-                f"max_contexts_{params.max_contexts}_"
-                f"max_states_{params.max_states}"
-            ): hyps
+            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -389,9 +377,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -424,9 +410,7 @@ def save_results(
         # we also compute CER for spgispeech dataset.
         results_char = []
         for res in results:
-            results_char.append(
-                (res[0], list("".join(res[1])), list("".join(res[2])))
-            )
+            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
         cers_filename = (
             params.res_dir / f"cers-{test_set_name}-{key}-{params.suffix}.txt"
         )
@@ -438,32 +422,23 @@ def save_results(
 
         logging.info("Wrote detailed error stats to {}".format(wers_filename))
 
-    test_set_wers = {
-        k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])
-    }
-    test_set_cers = {
-        k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])
-    }
+    test_set_wers = {k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])}
+    test_set_cers = {k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])}
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER\tCER", file=f)
         for key in test_set_wers:
             print(
-                "{}\t{}\t{}".format(
-                    key, test_set_wers[key], test_set_cers[key]
-                ),
+                "{}\t{}\t{}".format(key, test_set_wers[key], test_set_cers[key]),
                 file=f,
             )
 
     s = "\nFor {}, WER/CER of different settings are:\n".format(test_set_name)
     note = "\tbest for {}".format(test_set_name)
     for key in test_set_wers:
-        s += "{}\t{}\t{}{}\n".format(
-            key, test_set_wers[key], test_set_cers[key], note
-        )
+        s += "{}\t{}\t{}{}\n".format(key, test_set_wers[key], test_set_cers[key], note)
         note = ""
     logging.info(s)
 
@@ -496,9 +471,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@@ -530,8 +503,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for"
-                f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py
index 77faa3c0e..1f18ae2f3 100755
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py
@@ -50,11 +50,7 @@ import sentencepiece as spm
 import torch
 from train import get_params, get_transducer_model
 
-from icefall.checkpoint import (
-    average_checkpoints,
-    find_checkpoints,
-    load_checkpoint,
-)
+from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
 from icefall.utils import str2bool
 
 
@@ -67,17 +63,20 @@ def get_parser():
         "--epoch",
         type=int,
         default=28,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=15,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -119,8 +118,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
@@ -196,9 +194,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py
index dda29b3e5..cd835a7b4 100755
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py
@@ -77,9 +77,7 @@ from icefall.dist import cleanup_dist, setup_dist
 from icefall.env import get_env_info
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[
-    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
 
 
 def get_parser():
@@ -155,8 +153,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need to be "
-        "changed.",
+        help="The initial learning rate.  This value should not need to be changed.",
     )
 
     parser.add_argument(
@@ -179,42 +176,45 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help="The prune range for rnnt loss, it means how many symbols(context)"
-        "we are using to compute the loss",
+        help=(
+            "The prune range for rnnt loss, it means how many symbols(context)"
+            "we are using to compute the loss"
+        ),
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help="The scale to smooth the loss with lm "
-        "(output of prediction network) part.",
+        help=(
+            "The scale to smooth the loss with lm (output of prediction network) part."
+        ),
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network)part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help="To get pruning ranges, we will calculate a simple version"
-        "loss(joiner is just addition), this simple loss also uses for"
-        "training (as a regularization item). We will scale the simple loss"
-        "with this parameter before adding to the final loss.",
+        help=(
+            "To get pruning ranges, we will calculate a simple version"
+            "loss(joiner is just addition), this simple loss also uses for"
+            "training (as a regularization item). We will scale the simple loss"
+            "with this parameter before adding to the final loss."
+        ),
     )
 
     parser.add_argument(
@@ -554,23 +554,16 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0
-            if warmup < 1.0
-            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
-        )
-        loss = (
-            params.simple_loss_scale * simple_loss
-            + pruned_loss_scale * pruned_loss
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
         )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
 
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -733,9 +726,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
diff --git a/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py b/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py
index 4582609ac..602e50d29 100755
--- a/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py
+++ b/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py
@@ -84,9 +84,7 @@ def compute_fbank_tal_csasr(num_mel_bins: int = 80):
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set
-                    + cut_set.perturb_speed(0.9)
-                    + cut_set.perturb_speed(1.1)
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -112,9 +110,7 @@ def get_args():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/tal_csasr/ASR/local/prepare_char.py b/egs/tal_csasr/ASR/local/prepare_char.py
index 2c5b8b8b3..1262baf63 100755
--- a/egs/tal_csasr/ASR/local/prepare_char.py
+++ b/egs/tal_csasr/ASR/local/prepare_char.py
@@ -87,9 +87,7 @@ def lexicon_to_fst_no_sil(
         cur_state = loop_state
 
         word = word2id[word]
-        pieces = [
-            token2id[i] if i in token2id else token2id[""] for i in pieces
-        ]
+        pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces]
 
         for i in range(len(pieces) - 1):
             w = word if i == 0 else eps
diff --git a/egs/tal_csasr/ASR/local/prepare_lang.py b/egs/tal_csasr/ASR/local/prepare_lang.py
index e5ae89ec4..c8cf9b881 100755
--- a/egs/tal_csasr/ASR/local/prepare_lang.py
+++ b/egs/tal_csasr/ASR/local/prepare_lang.py
@@ -317,9 +317,7 @@ def lexicon_to_fst(
 
 def get_args():
     parser = argparse.ArgumentParser()
-    parser.add_argument(
-        "--lang-dir", type=str, help="The lang dir, data/lang_phone"
-    )
+    parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
     return parser.parse_args()
 
 
diff --git a/egs/tal_csasr/ASR/local/test_prepare_lang.py b/egs/tal_csasr/ASR/local/test_prepare_lang.py
index d4cf62bba..74e025ad7 100755
--- a/egs/tal_csasr/ASR/local/test_prepare_lang.py
+++ b/egs/tal_csasr/ASR/local/test_prepare_lang.py
@@ -88,9 +88,7 @@ def test_read_lexicon(filename: str):
     fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa.draw("L.pdf", title="L")
 
-    fsa_disambig = lexicon_to_fst(
-        lexicon_disambig, phone2id=phone2id, word2id=word2id
-    )
+    fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
     fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
     fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
diff --git a/egs/tal_csasr/ASR/local/text2token.py b/egs/tal_csasr/ASR/local/text2token.py
index 71be2a613..2be639b7a 100755
--- a/egs/tal_csasr/ASR/local/text2token.py
+++ b/egs/tal_csasr/ASR/local/text2token.py
@@ -50,15 +50,15 @@ def get_parser():
         "-n",
         default=1,
         type=int,
-        help="number of characters to split, i.e., \
-                        aabb -> a a b b with -n 1 and aa bb with -n 2",
+        help=(
+            "number of characters to split, i.e.,                         aabb -> a a b"
+            " b with -n 1 and aa bb with -n 2"
+        ),
     )
     parser.add_argument(
         "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
     )
-    parser.add_argument(
-        "--space", default="", type=str, help="space symbol"
-    )
+    parser.add_argument("--space", default="", type=str, help="space symbol")
     parser.add_argument(
         "--non-lang-syms",
         "-l",
@@ -66,9 +66,7 @@ def get_parser():
         type=str,
         help="list of non-linguistic symobles, e.g.,  etc.",
     )
-    parser.add_argument(
-        "text", type=str, default=False, nargs="?", help="input text"
-    )
+    parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
     parser.add_argument(
         "--trans_type",
         "-t",
@@ -108,8 +106,7 @@ def token2id(
             if token_type == "lazy_pinyin":
                 text = lazy_pinyin(chars_list)
                 sub_ids = [
-                    token_table[txt] if txt in token_table else oov_id
-                    for txt in text
+                    token_table[txt] if txt in token_table else oov_id for txt in text
                 ]
                 ids.append(sub_ids)
             else:  # token_type = "pinyin"
@@ -135,9 +132,7 @@ def main():
     if args.text:
         f = codecs.open(args.text, encoding="utf-8")
     else:
-        f = codecs.getreader("utf-8")(
-            sys.stdin if is_python2 else sys.stdin.buffer
-        )
+        f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
 
     sys.stdout = codecs.getwriter("utf-8")(
         sys.stdout if is_python2 else sys.stdout.buffer
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py
index 49bfb148b..02bd6e2cc 100644
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py
@@ -74,10 +74,12 @@ class TAL_CSASRAsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description="These options are used for the preparation of "
-            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-            "effective batch sizes, sampling strategies, applied data "
-            "augmentations, etc.",
+            description=(
+                "These options are used for the preparation of "
+                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+                "effective batch sizes, sampling strategies, applied data "
+                "augmentations, etc."
+            ),
         )
 
         group.add_argument(
@@ -91,66 +93,81 @@ class TAL_CSASRAsrDataModule:
             "--max-duration",
             type=int,
             default=200.0,
-            help="Maximum pooled recordings duration (seconds) in a "
-            "single batch. You can reduce it if it causes CUDA OOM.",
+            help=(
+                "Maximum pooled recordings duration (seconds) in a "
+                "single batch. You can reduce it if it causes CUDA OOM."
+            ),
         )
 
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help="When enabled, the batches will come from buckets of "
-            "similar duration (saves padding frames).",
+            help=(
+                "When enabled, the batches will come from buckets of "
+                "similar duration (saves padding frames)."
+            ),
         )
 
         group.add_argument(
             "--num-buckets",
             type=int,
             default=300,
-            help="The number of buckets for the DynamicBucketingSampler"
-            "(you might want to increase it for larger datasets).",
+            help=(
+                "The number of buckets for the DynamicBucketingSampler"
+                "(you might want to increase it for larger datasets)."
+            ),
         )
 
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help="When enabled, utterances (cuts) will be concatenated "
-            "to minimize the amount of padding.",
+            help=(
+                "When enabled, utterances (cuts) will be concatenated "
+                "to minimize the amount of padding."
+            ),
         )
 
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help="Determines the maximum duration of a concatenated cut "
-            "relative to the duration of the longest cut in a batch.",
+            help=(
+                "Determines the maximum duration of a concatenated cut "
+                "relative to the duration of the longest cut in a batch."
+            ),
         )
 
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help="The amount of padding (in seconds) inserted between "
-            "concatenated cuts. This padding is filled with noise when "
-            "noise augmentation is used.",
+            help=(
+                "The amount of padding (in seconds) inserted between "
+                "concatenated cuts. This padding is filled with noise when "
+                "noise augmentation is used."
+            ),
         )
 
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help="When enabled, use on-the-fly cut mixing and feature "
-            "extraction. Will drop existing precomputed feature manifests "
-            "if available.",
+            help=(
+                "When enabled, use on-the-fly cut mixing and feature "
+                "extraction. Will drop existing precomputed feature manifests "
+                "if available."
+            ),
         )
 
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help="When enabled (=default), the examples will be "
-            "shuffled for each epoch.",
+            help=(
+                "When enabled (=default), the examples will be shuffled for each epoch."
+            ),
         )
 
         group.add_argument(
@@ -164,17 +181,18 @@ class TAL_CSASRAsrDataModule:
             "--return-cuts",
             type=str2bool,
             default=True,
-            help="When enabled, each batch will have the "
-            "field: batch['supervisions']['cut'] with the cuts that "
-            "were used to construct it.",
+            help=(
+                "When enabled, each batch will have the "
+                "field: batch['supervisions']['cut'] with the cuts that "
+                "were used to construct it."
+            ),
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that "
-            "collect the batches.",
+            help="The number of training dataloader workers that collect the batches.",
         )
 
         group.add_argument(
@@ -188,18 +206,22 @@ class TAL_CSASRAsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help="Used only when --enable-spec-aug is True. "
-            "It specifies the factor for time warping in SpecAugment. "
-            "Larger values mean more warping. "
-            "A value less than 1 means to disable time warp.",
+            help=(
+                "Used only when --enable-spec-aug is True. "
+                "It specifies the factor for time warping in SpecAugment. "
+                "Larger values mean more warping. "
+                "A value less than 1 means to disable time warp."
+            ),
         )
 
         group.add_argument(
             "--enable-musan",
             type=str2bool,
             default=True,
-            help="When enabled, select noise from MUSAN and mix it"
-            "with training dataset. ",
+            help=(
+                "When enabled, select noise from MUSAN and mix it"
+                "with training dataset. "
+            ),
         )
 
         group.add_argument(
@@ -222,24 +244,20 @@ class TAL_CSASRAsrDataModule:
             The state dict for the training sampler.
         """
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(
-            self.args.manifest_dir / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                f"Using cut concatenation with duration factor "
+                "Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -254,9 +272,7 @@ class TAL_CSASRAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -300,9 +316,7 @@ class TAL_CSASRAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -360,9 +374,7 @@ class TAL_CSASRAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py
index b624913f5..b2aef7e86 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py
@@ -124,20 +124,24 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch' and '--iter'",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch' and '--iter'"
+        ),
     )
 
     parser.add_argument(
         "--use-averaged-model",
         type=str2bool,
         default=False,
-        help="Whether to load averaged model. Currently it only supports "
-        "using --epoch. If True, it would decode with the averaged model "
-        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
-        "Actually only the models with epoch number of `epoch-avg` and "
-        "`epoch` are loaded for averaging. ",
+        help=(
+            "Whether to load averaged model. Currently it only supports "
+            "using --epoch. If True, it would decode with the averaged model "
+            "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+            "Actually only the models with epoch number of `epoch-avg` and "
+            "`epoch` are loaded for averaging. "
+        ),
     )
 
     parser.add_argument(
@@ -208,8 +212,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -268,9 +271,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
     zh_hyps = []
     en_hyps = []
@@ -303,10 +304,7 @@ def decode_one_batch(
             hyps.append(chars_new)
             zh_hyps.append(zh_text)
             en_hyps.append(en_text)
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -375,9 +373,7 @@ def decode_one_batch(
                     f"Unsupported decoding method: {params.decoding_method}"
                 )
             for i in range(encoder_out.size(0)):
-                hyp = sp.decode(
-                    [lexicon.token_table[idx] for idx in hyp_tokens[i]]
-                )
+                hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]])
                 chars = pattern.split(hyp.upper())
                 chars_new = []
                 zh_text = []
@@ -396,11 +392,11 @@ def decode_one_batch(
         return {"greedy_search": (hyps, zh_hyps, en_hyps)}
     elif params.decoding_method == "fast_beam_search":
         return {
-            (
-                f"beam_{params.beam}_"
-                f"max_contexts_{params.max_contexts}_"
-                f"max_states_{params.max_states}"
-            ): (hyps, zh_hyps, en_hyps)
+            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": (
+                hyps,
+                zh_hyps,
+                en_hyps,
+            )
         }
     else:
         return {f"beam_size_{params.beam_size}": (hyps, zh_hyps, en_hyps)}
@@ -506,9 +502,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results, zh_results, en_results
 
 
@@ -541,8 +535,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -585,9 +578,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@@ -619,13 +610,12 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for"
-                    f" --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg:
                 raise ValueError(
@@ -648,13 +638,12 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg + 1]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for"
-                    f" --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg + 1:
                 raise ValueError(
@@ -682,7 +671,7 @@ def main():
             filename_start = f"{params.exp_dir}/epoch-{start}.pt"
             filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
             logging.info(
-                f"Calculating the averaged model over epoch range from "
+                "Calculating the averaged model over epoch range from "
                 f"{start} (excluded) to {params.epoch}"
             )
             model.to(device)
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
index 8f900208a..94a4c7a2e 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
@@ -92,20 +92,24 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch' and '--iter'",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch' and '--iter'"
+        ),
     )
 
     parser.add_argument(
         "--use-averaged-model",
         type=str2bool,
         default=False,
-        help="Whether to load averaged model. Currently it only supports "
-        "using --epoch. If True, it would decode with the averaged model "
-        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
-        "Actually only the models with epoch number of `epoch-avg` and "
-        "`epoch` are loaded for averaging. ",
+        help=(
+            "Whether to load averaged model. Currently it only supports "
+            "using --epoch. If True, it would decode with the averaged model "
+            "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+            "Actually only the models with epoch number of `epoch-avg` and "
+            "`epoch` are loaded for averaging. "
+        ),
     )
 
     parser.add_argument(
@@ -139,8 +143,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     add_model_arguments(parser)
@@ -176,13 +179,12 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for"
-                    f" --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg:
                 raise ValueError(
@@ -205,13 +207,12 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg + 1]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for"
-                    f" --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg + 1:
                 raise ValueError(
@@ -239,7 +240,7 @@ def main():
             filename_start = f"{params.exp_dir}/epoch-{start}.pt"
             filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
             logging.info(
-                f"Calculating the averaged model over epoch range from "
+                "Calculating the averaged model over epoch range from "
                 f"{start} (excluded) to {params.epoch}"
             )
             model.to(device)
@@ -277,9 +278,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py
index dbe213b24..198242129 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py
@@ -84,9 +84,11 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
+        help=(
+            "Path to the checkpoint. "
+            "The checkpoint is assumed to be saved by "
+            "icefall.checkpoint.save_checkpoint()."
+        ),
     )
 
     parser.add_argument(
@@ -115,10 +117,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     parser.add_argument(
@@ -165,8 +169,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -197,10 +200,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -263,15 +265,11 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=features, x_lens=feature_lengths
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
 
     num_waves = encoder_out.size(0)
     hyps = []
@@ -367,9 +365,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py
index ca35eba45..676e8c904 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py
@@ -86,9 +86,7 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[
-    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
 
 
 def add_model_arguments(parser: argparse.ArgumentParser):
@@ -214,8 +212,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need "
-        "to be changed.",
+        help="The initial learning rate.  This value should not need to be changed.",
     )
 
     parser.add_argument(
@@ -238,42 +235,45 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help="The prune range for rnnt loss, it means how many symbols(context)"
-        "we are using to compute the loss",
+        help=(
+            "The prune range for rnnt loss, it means how many symbols(context)"
+            "we are using to compute the loss"
+        ),
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help="The scale to smooth the loss with lm "
-        "(output of prediction network) part.",
+        help=(
+            "The scale to smooth the loss with lm (output of prediction network) part."
+        ),
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network)part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help="To get pruning ranges, we will calculate a simple version"
-        "loss(joiner is just addition), this simple loss also uses for"
-        "training (as a regularization item). We will scale the simple loss"
-        "with this parameter before adding to the final loss.",
+        help=(
+            "To get pruning ranges, we will calculate a simple version"
+            "loss(joiner is just addition), this simple loss also uses for"
+            "training (as a regularization item). We will scale the simple loss"
+            "with this parameter before adding to the final loss."
+        ),
     )
 
     parser.add_argument(
@@ -600,11 +600,7 @@ def compute_loss(
      warmup: a floating point value which increases throughout training;
         values >= 1.0 are fully warmed up and have all modules present.
     """
-    device = (
-        model.device
-        if isinstance(model, DDP)
-        else next(model.parameters()).device
-    )
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
     feature = batch["inputs"]
     # at entry, feature is (N, T, C)
     assert feature.ndim == 3
@@ -634,22 +630,15 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0
-            if warmup < 1.0
-            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
-        )
-        loss = (
-            params.simple_loss_scale * simple_loss
-            + pruned_loss_scale * pruned_loss
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
         )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -828,9 +817,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -944,7 +931,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2 ** 22
+            2**22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
diff --git a/egs/tedlium3/ASR/local/compute_fbank_tedlium.py b/egs/tedlium3/ASR/local/compute_fbank_tedlium.py
index 327962a79..733ebf235 100755
--- a/egs/tedlium3/ASR/local/compute_fbank_tedlium.py
+++ b/egs/tedlium3/ASR/local/compute_fbank_tedlium.py
@@ -83,9 +83,7 @@ def compute_fbank_tedlium():
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set
-                    + cut_set.perturb_speed(0.9)
-                    + cut_set.perturb_speed(1.1)
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
                 )
             cur_num_jobs = num_jobs if ex is None else 80
             cur_num_jobs = min(cur_num_jobs, len(cut_set))
@@ -104,9 +102,7 @@ def compute_fbank_tedlium():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py b/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py
index 49544ccb3..9dbcc9d9e 100644
--- a/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py
+++ b/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py
@@ -25,9 +25,7 @@ import sentencepiece as spm
 
 def get_args():
     parser = argparse.ArgumentParser()
-    parser.add_argument(
-        "--texts", type=List[str], help="The input transcripts list."
-    )
+    parser.add_argument("--texts", type=List[str], help="The input transcripts list.")
     parser.add_argument(
         "--bpe-model",
         type=str,
diff --git a/egs/tedlium3/ASR/local/prepare_lexicon.py b/egs/tedlium3/ASR/local/prepare_lexicon.py
index 35dd332e8..b9160b6d4 100755
--- a/egs/tedlium3/ASR/local/prepare_lexicon.py
+++ b/egs/tedlium3/ASR/local/prepare_lexicon.py
@@ -23,11 +23,12 @@ consisting of supervisions_train.json and does the following:
 1. Generate lexicon_words.txt.
 
 """
-import lhotse
 import argparse
 import logging
 from pathlib import Path
 
+import lhotse
+
 
 def get_args():
     parser = argparse.ArgumentParser()
@@ -61,9 +62,7 @@ def prepare_lexicon(manifests_dir: str, lang_dir: str):
     words = set()
 
     lexicon = Path(lang_dir) / "lexicon_words.txt"
-    sups = lhotse.load_manifest(
-        f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz"
-    )
+    sups = lhotse.load_manifest(f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz")
     for s in sups:
         # list the words units and filter the empty item
         words_list = list(filter(None, s.text.split()))
@@ -88,9 +87,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/tedlium3/ASR/local/prepare_transcripts.py b/egs/tedlium3/ASR/local/prepare_transcripts.py
index 1039ac5bb..7ea4e89a4 100755
--- a/egs/tedlium3/ASR/local/prepare_transcripts.py
+++ b/egs/tedlium3/ASR/local/prepare_transcripts.py
@@ -23,11 +23,12 @@ consisting of supervisions_train.json and does the following:
 1. Generate train.text.
 
 """
-import lhotse
 import argparse
 import logging
 from pathlib import Path
 
+import lhotse
+
 
 def get_args():
     parser = argparse.ArgumentParser()
@@ -61,9 +62,7 @@ def prepare_transcripts(manifests_dir: str, lang_dir: str):
     texts = []
 
     train_text = Path(lang_dir) / "train.text"
-    sups = lhotse.load_manifest(
-        f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz"
-    )
+    sups = lhotse.load_manifest(f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz")
     for s in sups:
         texts.append(s.text)
 
@@ -83,9 +82,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py
index 2b294e601..6bae33e65 100755
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py
@@ -94,17 +94,20 @@ def get_parser():
         "--epoch",
         type=int,
         default=29,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=13,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -172,8 +175,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -231,9 +233,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -248,10 +248,7 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -297,11 +294,7 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            (
-                f"beam_{params.beam}_"
-                f"max_contexts_{params.max_contexts}_"
-                f"max_states_{params.max_states}"
-            ): hyps
+            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -374,9 +367,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -409,8 +400,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/export.py b/egs/tedlium3/ASR/pruned_transducer_stateless/export.py
index a1c3bcea3..244740932 100644
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/export.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/export.py
@@ -65,17 +65,20 @@ def get_parser():
         "--epoch",
         type=int,
         default=30,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=13,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -106,8 +109,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
@@ -179,9 +181,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py
index 8480ac029..00545f107 100644
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py
@@ -93,9 +93,11 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
+        help=(
+            "Path to the checkpoint. "
+            "The checkpoint is assumed to be saved by "
+            "icefall.checkpoint.save_checkpoint()."
+        ),
     )
 
     parser.add_argument(
@@ -122,10 +124,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     parser.add_argument(
@@ -165,8 +169,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -203,10 +206,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -271,9 +273,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -298,10 +298,7 @@ def main():
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -353,9 +350,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/train.py b/egs/tedlium3/ASR/pruned_transducer_stateless/train.py
index 8d5cdf683..70c5e290f 100755
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/train.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/train.py
@@ -133,42 +133,45 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help="The prune range for rnnt loss, it means how many symbols(context)"
-        "we are using to compute the loss",
+        help=(
+            "The prune range for rnnt loss, it means how many symbols(context)"
+            "we are using to compute the loss"
+        ),
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help="The scale to smooth the loss with lm "
-        "(output of prediction network) part.",
+        help=(
+            "The scale to smooth the loss with lm (output of prediction network) part."
+        ),
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network)part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help="To get pruning ranges, we will calculate a simple version"
-        "loss(joiner is just addition), this simple loss also uses for"
-        "training (as a regularization item). We will scale the simple loss"
-        "with this parameter before adding to the final loss.",
+        help=(
+            "To get pruning ranges, we will calculate a simple version"
+            "loss(joiner is just addition), this simple loss also uses for"
+            "training (as a regularization item). We will scale the simple loss"
+            "with this parameter before adding to the final loss."
+        ),
     )
 
     parser.add_argument(
@@ -556,9 +559,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -678,9 +679,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py
index 51de46ae8..86ac2fea3 100644
--- a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py
+++ b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py
@@ -63,10 +63,12 @@ class TedLiumAsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description="These options are used for the preparation of "
-            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-            "effective batch sizes, sampling strategies, applied data "
-            "augmentations, etc.",
+            description=(
+                "These options are used for the preparation of "
+                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+                "effective batch sizes, sampling strategies, applied data "
+                "augmentations, etc."
+            ),
         )
         group.add_argument(
             "--manifest-dir",
@@ -78,75 +80,91 @@ class TedLiumAsrDataModule:
             "--max-duration",
             type=int,
             default=200.0,
-            help="Maximum pooled recordings duration (seconds) in a "
-            "single batch. You can reduce it if it causes CUDA OOM.",
+            help=(
+                "Maximum pooled recordings duration (seconds) in a "
+                "single batch. You can reduce it if it causes CUDA OOM."
+            ),
         )
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help="When enabled, the batches will come from buckets of "
-            "similar duration (saves padding frames).",
+            help=(
+                "When enabled, the batches will come from buckets of "
+                "similar duration (saves padding frames)."
+            ),
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=30,
-            help="The number of buckets for the DynamicBucketingSampler"
-            "(you might want to increase it for larger datasets).",
+            help=(
+                "The number of buckets for the DynamicBucketingSampler"
+                "(you might want to increase it for larger datasets)."
+            ),
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help="When enabled, utterances (cuts) will be concatenated "
-            "to minimize the amount of padding.",
+            help=(
+                "When enabled, utterances (cuts) will be concatenated "
+                "to minimize the amount of padding."
+            ),
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help="Determines the maximum duration of a concatenated cut "
-            "relative to the duration of the longest cut in a batch.",
+            help=(
+                "Determines the maximum duration of a concatenated cut "
+                "relative to the duration of the longest cut in a batch."
+            ),
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help="The amount of padding (in seconds) inserted between "
-            "concatenated cuts. This padding is filled with noise when "
-            "noise augmentation is used.",
+            help=(
+                "The amount of padding (in seconds) inserted between "
+                "concatenated cuts. This padding is filled with noise when "
+                "noise augmentation is used."
+            ),
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help="When enabled, use on-the-fly cut mixing and feature "
-            "extraction. Will drop existing precomputed feature manifests "
-            "if available.",
+            help=(
+                "When enabled, use on-the-fly cut mixing and feature "
+                "extraction. Will drop existing precomputed feature manifests "
+                "if available."
+            ),
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help="When enabled (=default), the examples will be "
-            "shuffled for each epoch.",
+            help=(
+                "When enabled (=default), the examples will be shuffled for each epoch."
+            ),
         )
         group.add_argument(
             "--return-cuts",
             type=str2bool,
             default=True,
-            help="When enabled, each batch will have the "
-            "field: batch['supervisions']['cut'] with the cuts that "
-            "were used to construct it.",
+            help=(
+                "When enabled, each batch will have the "
+                "field: batch['supervisions']['cut'] with the cuts that "
+                "were used to construct it."
+            ),
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that "
-            "collect the batches.",
+            help="The number of training dataloader workers that collect the batches.",
         )
 
         group.add_argument(
@@ -160,18 +178,22 @@ class TedLiumAsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help="Used only when --enable-spec-aug is True. "
-            "It specifies the factor for time warping in SpecAugment. "
-            "Larger values mean more warping. "
-            "A value less than 1 means to disable time warp.",
+            help=(
+                "Used only when --enable-spec-aug is True. "
+                "It specifies the factor for time warping in SpecAugment. "
+                "Larger values mean more warping. "
+                "A value less than 1 means to disable time warp."
+            ),
         )
 
         group.add_argument(
             "--enable-musan",
             type=str2bool,
             default=True,
-            help="When enabled, select noise from MUSAN and mix it"
-            "with training dataset. ",
+            help=(
+                "When enabled, select noise from MUSAN and mix it"
+                "with training dataset. "
+            ),
         )
 
     def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
@@ -179,20 +201,16 @@ class TedLiumAsrDataModule:
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
-            cuts_musan = load_manifest(
-                self.args.manifest_dir / "musan_cuts.jsonl.gz"
-            )
+            cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                f"Using cut concatenation with duration factor "
+                "Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -207,9 +225,7 @@ class TedLiumAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -253,9 +269,7 @@ class TedLiumAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -300,9 +314,7 @@ class TedLiumAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -358,13 +370,9 @@ class TedLiumAsrDataModule:
     @lru_cache()
     def dev_cuts(self) -> CutSet:
         logging.info("About to get dev cuts")
-        return load_manifest_lazy(
-            self.args.manifest_dir / "tedlium_cuts_dev.jsonl.gz"
-        )
+        return load_manifest_lazy(self.args.manifest_dir / "tedlium_cuts_dev.jsonl.gz")
 
     @lru_cache()
     def test_cuts(self) -> CutSet:
         logging.info("About to get test cuts")
-        return load_manifest_lazy(
-            self.args.manifest_dir / "tedlium_cuts_test.jsonl.gz"
-        )
+        return load_manifest_lazy(self.args.manifest_dir / "tedlium_cuts_test.jsonl.gz")
diff --git a/egs/tedlium3/ASR/transducer_stateless/beam_search.py b/egs/tedlium3/ASR/transducer_stateless/beam_search.py
index 77caf6460..1f99edaf3 100644
--- a/egs/tedlium3/ASR/transducer_stateless/beam_search.py
+++ b/egs/tedlium3/ASR/transducer_stateless/beam_search.py
@@ -87,9 +87,9 @@ def greedy_search(
         y = logits.argmax().item()
         if y != blank_id and y != unk_id:
             hyp.append(y)
-            decoder_input = torch.tensor(
-                [hyp[-context_size:]], device=device
-            ).reshape(1, context_size)
+            decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape(
+                1, context_size
+            )
 
             decoder_out = model.decoder(decoder_input, need_pad=False)
 
@@ -148,9 +148,7 @@ class HypothesisList(object):
         key = hyp.key
         if key in self:
             old_hyp = self._data[key]  # shallow copy
-            torch.logaddexp(
-                old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
-            )
+            torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob)
         else:
             self._data[key] = hyp
 
@@ -166,9 +164,7 @@ class HypothesisList(object):
           Return the hypothesis that has the largest `log_prob`.
         """
         if length_norm:
-            return max(
-                self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
-            )
+            return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
         else:
             return max(self._data.values(), key=lambda hyp: hyp.log_prob)
 
@@ -344,9 +340,9 @@ def modified_beam_search(
 
     device = model.device
 
-    decoder_input = torch.tensor(
-        [blank_id] * context_size, device=device
-    ).reshape(1, context_size)
+    decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
+        1, context_size
+    )
 
     decoder_out = model.decoder(decoder_input, need_pad=False)
 
@@ -383,9 +379,7 @@ def modified_beam_search(
         decoder_out = model.decoder(decoder_input, need_pad=False)
         # decoder_output is of shape (num_hyps, 1, decoder_output_dim)
 
-        current_encoder_out = current_encoder_out.expand(
-            decoder_out.size(0), 1, -1
-        )
+        current_encoder_out = current_encoder_out.expand(decoder_out.size(0), 1, -1)
 
         logits = model.joiner(
             current_encoder_out,
@@ -454,9 +448,9 @@ def beam_search(
 
     device = model.device
 
-    decoder_input = torch.tensor(
-        [blank_id] * context_size, device=device
-    ).reshape(1, context_size)
+    decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
+        1, context_size
+    )
 
     decoder_out = model.decoder(decoder_input, need_pad=False)
 
diff --git a/egs/tedlium3/ASR/transducer_stateless/decode.py b/egs/tedlium3/ASR/transducer_stateless/decode.py
index d3e9e55e7..12d0e2652 100755
--- a/egs/tedlium3/ASR/transducer_stateless/decode.py
+++ b/egs/tedlium3/ASR/transducer_stateless/decode.py
@@ -81,16 +81,19 @@ def get_parser():
         "--epoch",
         type=int,
         default=29,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=13,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -130,8 +133,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -250,9 +252,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
     batch_size = encoder_out.size(0)
 
@@ -275,9 +275,7 @@ def decode_one_batch(
                 model=model, encoder_out=encoder_out_i, beam=params.beam_size
             )
         else:
-            raise ValueError(
-                f"Unsupported decoding method: {params.decoding_method}"
-            )
+            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
         hyps.append(sp.decode(hyp).split())
 
     if params.decoding_method == "greedy_search":
@@ -348,9 +346,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -383,8 +379,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
diff --git a/egs/tedlium3/ASR/transducer_stateless/decoder.py b/egs/tedlium3/ASR/transducer_stateless/decoder.py
index f0c6f32b6..f9a3814c6 100644
--- a/egs/tedlium3/ASR/transducer_stateless/decoder.py
+++ b/egs/tedlium3/ASR/transducer_stateless/decoder.py
@@ -90,9 +90,7 @@ class Decoder(nn.Module):
         if self.context_size > 1:
             embedding_out = embedding_out.permute(0, 2, 1)
             if need_pad is True:
-                embedding_out = F.pad(
-                    embedding_out, pad=(self.context_size - 1, 0)
-                )
+                embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
             else:
                 # During inference time, there is no need to do extra padding
                 # as we only need one output
diff --git a/egs/tedlium3/ASR/transducer_stateless/export.py b/egs/tedlium3/ASR/transducer_stateless/export.py
index c32b1d002..0b2ae970b 100644
--- a/egs/tedlium3/ASR/transducer_stateless/export.py
+++ b/egs/tedlium3/ASR/transducer_stateless/export.py
@@ -69,17 +69,20 @@ def get_parser():
         "--epoch",
         type=int,
         default=20,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=10,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -110,8 +113,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
@@ -247,9 +249,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tedlium3/ASR/transducer_stateless/pretrained.py b/egs/tedlium3/ASR/transducer_stateless/pretrained.py
index c0e3bb844..912d65497 100644
--- a/egs/tedlium3/ASR/transducer_stateless/pretrained.py
+++ b/egs/tedlium3/ASR/transducer_stateless/pretrained.py
@@ -82,9 +82,11 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
+        help=(
+            "Path to the checkpoint. "
+            "The checkpoint is assumed to be saved by "
+            "icefall.checkpoint.save_checkpoint()."
+        ),
     )
 
     parser.add_argument(
@@ -110,10 +112,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     parser.add_argument(
@@ -127,8 +131,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -222,10 +225,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -285,9 +287,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -335,9 +335,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tedlium3/ASR/transducer_stateless/train.py b/egs/tedlium3/ASR/transducer_stateless/train.py
index 09cbf4a00..6fed32e81 100755
--- a/egs/tedlium3/ASR/transducer_stateless/train.py
+++ b/egs/tedlium3/ASR/transducer_stateless/train.py
@@ -133,8 +133,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -525,9 +524,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -647,9 +644,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/timit/ASR/RESULTS.md b/egs/timit/ASR/RESULTS.md
index b78c16b88..d8ceb82b6 100644
--- a/egs/timit/ASR/RESULTS.md
+++ b/egs/timit/ASR/RESULTS.md
@@ -71,4 +71,4 @@ python tdnn_ligru_ctc/decode.py --epoch 25 \
                                --avg 17 \
                                --max-duration 20 \
                                --lang-dir data/lang_phone
-```
\ No newline at end of file
+```
diff --git a/egs/timit/ASR/local/compile_hlg.py b/egs/timit/ASR/local/compile_hlg.py
index 58cab4cf2..32c248d7e 100644
--- a/egs/timit/ASR/local/compile_hlg.py
+++ b/egs/timit/ASR/local/compile_hlg.py
@@ -146,9 +146,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/timit/ASR/local/compute_fbank_timit.py b/egs/timit/ASR/local/compute_fbank_timit.py
index f25786a0c..ecdf10ba9 100644
--- a/egs/timit/ASR/local/compute_fbank_timit.py
+++ b/egs/timit/ASR/local/compute_fbank_timit.py
@@ -85,9 +85,7 @@ def compute_fbank_timit():
             )
             if partition == "TRAIN":
                 cut_set = (
-                    cut_set
-                    + cut_set.perturb_speed(0.9)
-                    + cut_set.perturb_speed(1.1)
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -101,9 +99,7 @@ def compute_fbank_timit():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/timit/ASR/local/prepare_lexicon.py b/egs/timit/ASR/local/prepare_lexicon.py
index 04023a9ab..0cf0f0deb 100644
--- a/egs/timit/ASR/local/prepare_lexicon.py
+++ b/egs/timit/ASR/local/prepare_lexicon.py
@@ -62,9 +62,7 @@ def prepare_lexicon(manifests_dir: str, lang_dir: str):
 
     phones = set()
 
-    supervisions_train = (
-        Path(manifests_dir) / "timit_supervisions_TRAIN.jsonl.gz"
-    )
+    supervisions_train = Path(manifests_dir) / "timit_supervisions_TRAIN.jsonl.gz"
     lexicon = Path(lang_dir) / "lexicon.txt"
 
     logging.info(f"Loading {supervisions_train}!")
@@ -97,9 +95,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/timit/ASR/prepare.sh b/egs/timit/ASR/prepare.sh
index ae1b96a68..d11cd3a05 100644
--- a/egs/timit/ASR/prepare.sh
+++ b/egs/timit/ASR/prepare.sh
@@ -20,9 +20,9 @@ stop_stage=100
 #  - $dl_dir/lm
 #      This directory contains the language model(LM) downloaded from
 #      https://huggingface.co/luomingshuang/timit_lm, and the LM is based
-#	     on 39 phones. About how to get these LM files, you can know it 
+#	     on 39 phones. About how to get these LM files, you can know it
 #      from https://github.com/luomingshuang/Train_LM_with_kaldilm.
-#	
+#
 #	    - lm_3_gram.arpa
 #     - lm_4_gram.arpa
 #
diff --git a/egs/timit/ASR/tdnn_ligru_ctc/decode.py b/egs/timit/ASR/tdnn_ligru_ctc/decode.py
index 4f2aa2340..5a59a13ce 100644
--- a/egs/timit/ASR/tdnn_ligru_ctc/decode.py
+++ b/egs/timit/ASR/tdnn_ligru_ctc/decode.py
@@ -57,16 +57,19 @@ def get_parser():
         "--epoch",
         type=int,
         default=19,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=5,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
     parser.add_argument(
         "--method",
@@ -336,9 +339,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -400,9 +401,7 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(
-        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
-    )
+    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -462,9 +461,7 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save(
-            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
-        )
+        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
         return
 
     model.to(device)
@@ -485,9 +482,7 @@ def main():
         G=G,
     )
 
-    save_results(
-        params=params, test_set_name=test_set, results_dict=results_dict
-    )
+    save_results(params=params, test_set_name=test_set, results_dict=results_dict)
 
     logging.info("Done!")
 
diff --git a/egs/timit/ASR/tdnn_ligru_ctc/model.py b/egs/timit/ASR/tdnn_ligru_ctc/model.py
index 4d2199ace..9a594a969 100644
--- a/egs/timit/ASR/tdnn_ligru_ctc/model.py
+++ b/egs/timit/ASR/tdnn_ligru_ctc/model.py
@@ -16,11 +16,11 @@
 # limitations under the License.
 
 
+from typing import Optional
+
 import torch
 import torch.nn as nn
-
 from torch import Tensor
-from typing import Optional
 
 
 class TdnnLiGRU(nn.Module):
@@ -261,9 +261,7 @@ class LiGRU(torch.nn.Module):
         h = []
         if hx is not None:
             if self.bidirectional:
-                hx = hx.reshape(
-                    self.num_layers, self.batch_size * 2, self.hidden_size
-                )
+                hx = hx.reshape(self.num_layers, self.batch_size * 2, self.hidden_size)
         # Processing the different layers
         for i, ligru_lay in enumerate(self.rnn):
             if hx is not None:
@@ -445,9 +443,7 @@ class LiGRU_Layer(torch.nn.Module):
             if self.drop_mask_cnt + self.batch_size > self.N_drop_masks:
                 self.drop_mask_cnt = 0
                 self.drop_masks = self.drop(
-                    torch.ones(
-                        self.N_drop_masks, self.hidden_size, device=w.device
-                    )
+                    torch.ones(self.N_drop_masks, self.hidden_size, device=w.device)
                 ).data
 
             # Sampling the mask
diff --git a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py b/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py
index 7da285944..da669bc39 100644
--- a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py
+++ b/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py
@@ -29,11 +29,7 @@ import torchaudio
 from model import TdnnLiGRU
 from torch.nn.utils.rnn import pad_sequence
 
-from icefall.decode import (
-    get_lattice,
-    one_best_decoding,
-    rescore_with_whole_lattice,
-)
+from icefall.decode import get_lattice, one_best_decoding, rescore_with_whole_lattice
 from icefall.utils import AttributeDict, get_texts
 
 
@@ -46,9 +42,11 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
+        help=(
+            "Path to the checkpoint. "
+            "The checkpoint is assumed to be saved by "
+            "icefall.checkpoint.save_checkpoint()."
+        ),
     )
 
     parser.add_argument(
@@ -58,9 +56,7 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument(
-        "--HLG", type=str, required=True, help="Path to HLG.pt."
-    )
+    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
 
     parser.add_argument(
         "--method",
@@ -103,10 +99,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     return parser
@@ -144,10 +142,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -215,9 +212,7 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
     features = features.permute(0, 2, 1)  # now features is (N, C, T)
 
     with torch.no_grad():
@@ -269,9 +264,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/timit/ASR/tdnn_ligru_ctc/train.py b/egs/timit/ASR/tdnn_ligru_ctc/train.py
index 452c2a7cb..48b7feda0 100644
--- a/egs/timit/ASR/tdnn_ligru_ctc/train.py
+++ b/egs/timit/ASR/tdnn_ligru_ctc/train.py
@@ -449,9 +449,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py
index 1554e987f..d957c22e1 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -63,10 +63,12 @@ class TimitAsrDataModule(DataModule):
         super().add_arguments(parser)
         group = parser.add_argument_group(
             title="ASR data related options",
-            description="These options are used for the preparation of "
-            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-            "effective batch sizes, sampling strategies, applied data "
-            "augmentations, etc.",
+            description=(
+                "These options are used for the preparation of "
+                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+                "effective batch sizes, sampling strategies, applied data "
+                "augmentations, etc."
+            ),
         )
         group.add_argument(
             "--feature-dir",
@@ -78,75 +80,91 @@ class TimitAsrDataModule(DataModule):
             "--max-duration",
             type=int,
             default=200.0,
-            help="Maximum pooled recordings duration (seconds) in a "
-            "single batch. You can reduce it if it causes CUDA OOM.",
+            help=(
+                "Maximum pooled recordings duration (seconds) in a "
+                "single batch. You can reduce it if it causes CUDA OOM."
+            ),
         )
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help="When enabled, the batches will come from buckets of "
-            "similar duration (saves padding frames).",
+            help=(
+                "When enabled, the batches will come from buckets of "
+                "similar duration (saves padding frames)."
+            ),
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=30,
-            help="The number of buckets for the DynamicBucketingSampler"
-            "(you might want to increase it for larger datasets).",
+            help=(
+                "The number of buckets for the DynamicBucketingSampler"
+                "(you might want to increase it for larger datasets)."
+            ),
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help="When enabled, utterances (cuts) will be concatenated "
-            "to minimize the amount of padding.",
+            help=(
+                "When enabled, utterances (cuts) will be concatenated "
+                "to minimize the amount of padding."
+            ),
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help="Determines the maximum duration of a concatenated cut "
-            "relative to the duration of the longest cut in a batch.",
+            help=(
+                "Determines the maximum duration of a concatenated cut "
+                "relative to the duration of the longest cut in a batch."
+            ),
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help="The amount of padding (in seconds) inserted between "
-            "concatenated cuts. This padding is filled with noise when "
-            "noise augmentation is used.",
+            help=(
+                "The amount of padding (in seconds) inserted between "
+                "concatenated cuts. This padding is filled with noise when "
+                "noise augmentation is used."
+            ),
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help="When enabled, use on-the-fly cut mixing and feature "
-            "extraction. Will drop existing precomputed feature manifests "
-            "if available.",
+            help=(
+                "When enabled, use on-the-fly cut mixing and feature "
+                "extraction. Will drop existing precomputed feature manifests "
+                "if available."
+            ),
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help="When enabled (=default), the examples will be "
-            "shuffled for each epoch.",
+            help=(
+                "When enabled (=default), the examples will be shuffled for each epoch."
+            ),
         )
         group.add_argument(
             "--return-cuts",
             type=str2bool,
             default=True,
-            help="When enabled, each batch will have the "
-            "field: batch['supervisions']['cut'] with the cuts that "
-            "were used to construct it.",
+            help=(
+                "When enabled, each batch will have the "
+                "field: batch['supervisions']['cut'] with the cuts that "
+                "were used to construct it."
+            ),
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that "
-            "collect the batches.",
+            help="The number of training dataloader workers that collect the batches.",
         )
 
     def train_dataloaders(self) -> DataLoader:
@@ -154,15 +172,13 @@ class TimitAsrDataModule(DataModule):
         cuts_train = self.train_cuts()
 
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(
-            self.args.feature_dir / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(self.args.feature_dir / "musan_cuts.jsonl.gz")
 
         logging.info("About to create train dataset")
         transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))]
         if self.args.concatenate_cuts:
             logging.info(
-                f"Using cut concatenation with duration factor "
+                "Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -178,9 +194,9 @@ class TimitAsrDataModule(DataModule):
         # In different Lhotse's versions, the default of num_frame_masks is
         # different.
         num_frame_masks = 10
-        num_frame_masks_parameter = inspect.signature(
-            SpecAugment.__init__
-        ).parameters["num_frame_masks"]
+        num_frame_masks_parameter = inspect.signature(SpecAugment.__init__).parameters[
+            "num_frame_masks"
+        ]
         if num_frame_masks_parameter.default == 1:
             num_frame_masks = 2
         logging.info(f"Num frame mask: {num_frame_masks}")
@@ -212,9 +228,7 @@ class TimitAsrDataModule(DataModule):
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -263,9 +277,7 @@ class TimitAsrDataModule(DataModule):
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -299,20 +311,14 @@ class TimitAsrDataModule(DataModule):
         for cuts_test in cuts:
             logging.debug("About to create test dataset")
             test = K2SpeechRecognitionDataset(
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                )
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
                 if self.args.on_the_fly_feats
                 else PrecomputedFeatures(),
                 return_cuts=self.args.return_cuts,
             )
-            sampler = SingleCutSampler(
-                cuts_test, max_duration=self.args.max_duration
-            )
+            sampler = SingleCutSampler(cuts_test, max_duration=self.args.max_duration)
             logging.debug("About to create test dataloader")
-            test_dl = DataLoader(
-                test, batch_size=None, sampler=sampler, num_workers=1
-            )
+            test_dl = DataLoader(test, batch_size=None, sampler=sampler, num_workers=1)
             test_loaders.append(test_dl)
 
         if is_list:
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/decode.py b/egs/timit/ASR/tdnn_lstm_ctc/decode.py
index 5e7300cf2..319ee5515 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/decode.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/decode.py
@@ -56,16 +56,19 @@ def get_parser():
         "--epoch",
         type=int,
         default=25,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=5,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
     parser.add_argument(
         "--method",
@@ -335,9 +338,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -399,9 +400,7 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(
-        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
-    )
+    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -461,9 +460,7 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save(
-            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
-        )
+        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
         return
 
     model.to(device)
@@ -483,9 +480,7 @@ def main():
         G=G,
     )
 
-    save_results(
-        params=params, test_set_name=test_set, results_dict=results_dict
-    )
+    save_results(params=params, test_set_name=test_set, results_dict=results_dict)
 
     logging.info("Done!")
 
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/model.py b/egs/timit/ASR/tdnn_lstm_ctc/model.py
index 51edb97e2..e211ad80d 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/model.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/model.py
@@ -74,10 +74,7 @@ class TdnnLstm(nn.Module):
             nn.BatchNorm1d(num_features=512, affine=False),
         )
         self.lstms = nn.ModuleList(
-            [
-                nn.LSTM(input_size=512, hidden_size=512, num_layers=1)
-                for _ in range(4)
-            ]
+            [nn.LSTM(input_size=512, hidden_size=512, num_layers=1) for _ in range(4)]
         )
         self.lstm_bnorms = nn.ModuleList(
             [nn.BatchNorm1d(num_features=512, affine=False) for _ in range(5)]
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py b/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py
index 5f478da1c..0c72c973b 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py
@@ -29,11 +29,7 @@ import torchaudio
 from model import TdnnLstm
 from torch.nn.utils.rnn import pad_sequence
 
-from icefall.decode import (
-    get_lattice,
-    one_best_decoding,
-    rescore_with_whole_lattice,
-)
+from icefall.decode import get_lattice, one_best_decoding, rescore_with_whole_lattice
 from icefall.utils import AttributeDict, get_texts
 
 
@@ -46,9 +42,11 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
+        help=(
+            "Path to the checkpoint. "
+            "The checkpoint is assumed to be saved by "
+            "icefall.checkpoint.save_checkpoint()."
+        ),
     )
 
     parser.add_argument(
@@ -58,9 +56,7 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument(
-        "--HLG", type=str, required=True, help="Path to HLG.pt."
-    )
+    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
 
     parser.add_argument(
         "--method",
@@ -103,10 +99,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     return parser
@@ -144,10 +142,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -215,9 +212,7 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
     features = features.permute(0, 2, 1)  # now features is (N, C, T)
 
     with torch.no_grad():
@@ -269,9 +264,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/train.py b/egs/timit/ASR/tdnn_lstm_ctc/train.py
index 849256b98..be1ecffaa 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/train.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/train.py
@@ -449,9 +449,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py
index 8a9f6ed30..bd73e520e 100755
--- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py
+++ b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py
@@ -20,12 +20,7 @@ import logging
 from pathlib import Path
 
 import torch
-from lhotse import (
-    CutSet,
-    KaldifeatFbank,
-    KaldifeatFbankConfig,
-    LilcomHdf5Writer,
-)
+from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomHdf5Writer
 
 # Torch's multithreaded behavior needs to be disabled or
 # it wastes a lot of CPU and slow things down.
@@ -83,9 +78,7 @@ def compute_fbank_wenetspeech_dev_test():
 
 
 def main():
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
     logging.basicConfig(format=formatter, level=logging.INFO)
 
     compute_fbank_wenetspeech_dev_test()
diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py
index a882b6113..c228597b8 100755
--- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py
+++ b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py
@@ -62,8 +62,10 @@ def get_parser():
         "--batch-duration",
         type=float,
         default=600.0,
-        help="The maximum number of audio seconds in a batch."
-        "Determines batch size dynamically.",
+        help=(
+            "The maximum number of audio seconds in a batch."
+            "Determines batch size dynamically."
+        ),
     )
 
     parser.add_argument(
@@ -152,9 +154,7 @@ def main():
     date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
 
     log_filename = "log-compute_fbank_wenetspeech_splits"
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
     log_filename = f"{log_filename}-{date_time}"
 
     logging.basicConfig(
diff --git a/egs/wenetspeech/ASR/local/prepare_char.py b/egs/wenetspeech/ASR/local/prepare_char.py
index 8bc073c75..d8622842f 100755
--- a/egs/wenetspeech/ASR/local/prepare_char.py
+++ b/egs/wenetspeech/ASR/local/prepare_char.py
@@ -83,9 +83,7 @@ def lexicon_to_fst_no_sil(
         cur_state = loop_state
 
         word = word2id[word]
-        pieces = [
-            token2id[i] if i in token2id else token2id[""] for i in pieces
-        ]
+        pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces]
 
         for i in range(len(pieces) - 1):
             w = word if i == 0 else eps
@@ -138,9 +136,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
     return False
 
 
-def generate_lexicon(
-    token_sym_table: Dict[str, int], words: List[str]
-) -> Lexicon:
+def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
     """Generate a lexicon from a word list and token_sym_table.
     Args:
       token_sym_table:
diff --git a/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py b/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py
index 817969c47..93ce750f8 100755
--- a/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py
+++ b/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py
@@ -115,11 +115,7 @@ def preprocess_wenet_speech():
                 f"Speed perturb for {partition} with factors 0.9 and 1.1 "
                 "(Perturbing may take 8 minutes and saving may take 20 minutes)"
             )
-            cut_set = (
-                cut_set
-                + cut_set.perturb_speed(0.9)
-                + cut_set.perturb_speed(1.1)
-            )
+            cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
         logging.info(f"Saving to {raw_cuts_path}")
         cut_set.to_file(raw_cuts_path)
 
diff --git a/egs/wenetspeech/ASR/local/text2token.py b/egs/wenetspeech/ASR/local/text2token.py
index 1c463cf1c..e121d842c 100755
--- a/egs/wenetspeech/ASR/local/text2token.py
+++ b/egs/wenetspeech/ASR/local/text2token.py
@@ -50,15 +50,15 @@ def get_parser():
         "-n",
         default=1,
         type=int,
-        help="number of characters to split, i.e., \
-                        aabb -> a a b b with -n 1 and aa bb with -n 2",
+        help=(
+            "number of characters to split, i.e.,                         aabb -> a a b"
+            " b with -n 1 and aa bb with -n 2"
+        ),
     )
     parser.add_argument(
         "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
     )
-    parser.add_argument(
-        "--space", default="", type=str, help="space symbol"
-    )
+    parser.add_argument("--space", default="", type=str, help="space symbol")
     parser.add_argument(
         "--non-lang-syms",
         "-l",
@@ -66,9 +66,7 @@ def get_parser():
         type=str,
         help="list of non-linguistic symobles, e.g.,  etc.",
     )
-    parser.add_argument(
-        "text", type=str, default=False, nargs="?", help="input text"
-    )
+    parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
     parser.add_argument(
         "--trans_type",
         "-t",
@@ -108,8 +106,7 @@ def token2id(
             if token_type == "lazy_pinyin":
                 text = lazy_pinyin(chars_list)
                 sub_ids = [
-                    token_table[txt] if txt in token_table else oov_id
-                    for txt in text
+                    token_table[txt] if txt in token_table else oov_id for txt in text
                 ]
                 ids.append(sub_ids)
             else:  # token_type = "pinyin"
@@ -135,9 +132,7 @@ def main():
     if args.text:
         f = codecs.open(args.text, encoding="utf-8")
     else:
-        f = codecs.getreader("utf-8")(
-            sys.stdin if is_python2 else sys.stdin.buffer
-        )
+        f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
 
     sys.stdout = codecs.getwriter("utf-8")(
         sys.stdout if is_python2 else sys.stdout.buffer
diff --git a/egs/wenetspeech/ASR/prepare.sh b/egs/wenetspeech/ASR/prepare.sh
index 755fbb2d7..da7d7e061 100755
--- a/egs/wenetspeech/ASR/prepare.sh
+++ b/egs/wenetspeech/ASR/prepare.sh
@@ -190,7 +190,7 @@ if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then
   mkdir -p $lang_char_dir
 
   if ! which jq; then
-      echo "This script is intended to be used with jq but you have not installed jq 
+      echo "This script is intended to be used with jq but you have not installed jq
       Note: in Linux, you can install jq with the following command:
       1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
       2. chmod +x ./jq
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
index 10c953e3b..bd92ac115 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -81,10 +81,12 @@ class WenetSpeechAsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description="These options are used for the preparation of "
-            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-            "effective batch sizes, sampling strategies, applied data "
-            "augmentations, etc.",
+            description=(
+                "These options are used for the preparation of "
+                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+                "effective batch sizes, sampling strategies, applied data "
+                "augmentations, etc."
+            ),
         )
         group.add_argument(
             "--manifest-dir",
@@ -96,75 +98,91 @@ class WenetSpeechAsrDataModule:
             "--max-duration",
             type=int,
             default=200.0,
-            help="Maximum pooled recordings duration (seconds) in a "
-            "single batch. You can reduce it if it causes CUDA OOM.",
+            help=(
+                "Maximum pooled recordings duration (seconds) in a "
+                "single batch. You can reduce it if it causes CUDA OOM."
+            ),
         )
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help="When enabled, the batches will come from buckets of "
-            "similar duration (saves padding frames).",
+            help=(
+                "When enabled, the batches will come from buckets of "
+                "similar duration (saves padding frames)."
+            ),
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=300,
-            help="The number of buckets for the DynamicBucketingSampler"
-            "(you might want to increase it for larger datasets).",
+            help=(
+                "The number of buckets for the DynamicBucketingSampler"
+                "(you might want to increase it for larger datasets)."
+            ),
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help="When enabled, utterances (cuts) will be concatenated "
-            "to minimize the amount of padding.",
+            help=(
+                "When enabled, utterances (cuts) will be concatenated "
+                "to minimize the amount of padding."
+            ),
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help="Determines the maximum duration of a concatenated cut "
-            "relative to the duration of the longest cut in a batch.",
+            help=(
+                "Determines the maximum duration of a concatenated cut "
+                "relative to the duration of the longest cut in a batch."
+            ),
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help="The amount of padding (in seconds) inserted between "
-            "concatenated cuts. This padding is filled with noise when "
-            "noise augmentation is used.",
+            help=(
+                "The amount of padding (in seconds) inserted between "
+                "concatenated cuts. This padding is filled with noise when "
+                "noise augmentation is used."
+            ),
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help="When enabled, use on-the-fly cut mixing and feature "
-            "extraction. Will drop existing precomputed feature manifests "
-            "if available.",
+            help=(
+                "When enabled, use on-the-fly cut mixing and feature "
+                "extraction. Will drop existing precomputed feature manifests "
+                "if available."
+            ),
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help="When enabled (=default), the examples will be "
-            "shuffled for each epoch.",
+            help=(
+                "When enabled (=default), the examples will be shuffled for each epoch."
+            ),
         )
         group.add_argument(
             "--return-cuts",
             type=str2bool,
             default=True,
-            help="When enabled, each batch will have the "
-            "field: batch['supervisions']['cut'] with the cuts that "
-            "were used to construct it.",
+            help=(
+                "When enabled, each batch will have the "
+                "field: batch['supervisions']['cut'] with the cuts that "
+                "were used to construct it."
+            ),
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that "
-            "collect the batches.",
+            help="The number of training dataloader workers that collect the batches.",
         )
 
         group.add_argument(
@@ -178,18 +196,22 @@ class WenetSpeechAsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help="Used only when --enable-spec-aug is True. "
-            "It specifies the factor for time warping in SpecAugment. "
-            "Larger values mean more warping. "
-            "A value less than 1 means to disable time warp.",
+            help=(
+                "Used only when --enable-spec-aug is True. "
+                "It specifies the factor for time warping in SpecAugment. "
+                "Larger values mean more warping. "
+                "A value less than 1 means to disable time warp."
+            ),
         )
 
         group.add_argument(
             "--enable-musan",
             type=str2bool,
             default=True,
-            help="When enabled, select noise from MUSAN and mix it"
-            "with training dataset. ",
+            help=(
+                "When enabled, select noise from MUSAN and mix it"
+                "with training dataset. "
+            ),
         )
 
         group.add_argument(
@@ -212,24 +234,20 @@ class WenetSpeechAsrDataModule:
             The state dict for the training sampler.
         """
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(
-            self.args.manifest_dir / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                f"Using cut concatenation with duration factor "
+                "Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -244,9 +262,7 @@ class WenetSpeechAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -289,9 +305,7 @@ class WenetSpeechAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -348,9 +362,7 @@ class WenetSpeechAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -414,8 +426,7 @@ class WenetSpeechAsrDataModule:
     def train_cuts(self) -> CutSet:
         logging.info("About to get train cuts")
         cuts_train = load_manifest_lazy(
-            self.args.manifest_dir
-            / f"cuts_{self.args.training_subset}.jsonl.gz"
+            self.args.manifest_dir / f"cuts_{self.args.training_subset}.jsonl.gz"
         )
         return cuts_train
 
@@ -427,13 +438,9 @@ class WenetSpeechAsrDataModule:
     @lru_cache()
     def test_net_cuts(self) -> List[CutSet]:
         logging.info("About to get TEST_NET cuts")
-        return load_manifest_lazy(
-            self.args.manifest_dir / "cuts_TEST_NET.jsonl.gz"
-        )
+        return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_NET.jsonl.gz")
 
     @lru_cache()
     def test_meeting_cuts(self) -> List[CutSet]:
         logging.info("About to get TEST_MEETING cuts")
-        return load_manifest_lazy(
-            self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz"
-        )
+        return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz")
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py
index f0c9bebec..6e856248c 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py
@@ -114,11 +114,7 @@ from beam_search import (
 from train import get_params, get_transducer_model
 
 from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
-from icefall.checkpoint import (
-    average_checkpoints,
-    find_checkpoints,
-    load_checkpoint,
-)
+from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
 from icefall.lexicon import Lexicon
 from icefall.utils import (
     AttributeDict,
@@ -137,25 +133,30 @@ def get_parser():
         "--epoch",
         type=int,
         default=28,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
 
     parser.add_argument(
         "--batch",
         type=int,
         default=None,
-        help="It specifies the batch checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the batch checkpoint to use for decoding."
+            "Note: Epoch counts from 0."
+        ),
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=15,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -252,8 +253,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -328,9 +328,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -389,10 +387,7 @@ def decode_one_batch(
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -438,11 +433,7 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            (
-                f"beam_{params.beam}_"
-                f"max_contexts_{params.max_contexts}_"
-                f"max_states_{params.max_states}"
-            ): hyps
+            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -515,9 +506,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -550,8 +539,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -663,9 +651,7 @@ def main():
             )
             decoding_graph.scores *= params.ngram_lm_scale
         else:
-            decoding_graph = k2.trivial_graph(
-                params.vocab_size - 1, device=device
-            )
+            decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
     else:
         decoding_graph = None
 
@@ -716,8 +702,7 @@ def main():
         )
 
     dev_shards = [
-        str(path)
-        for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
+        str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
     ]
     cuts_dev_webdataset = CutSet.from_webdataset(
         dev_shards,
@@ -727,8 +712,7 @@ def main():
     )
 
     test_net_shards = [
-        str(path)
-        for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
+        str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
     ]
     cuts_test_net_webdataset = CutSet.from_webdataset(
         test_net_shards,
@@ -739,9 +723,7 @@ def main():
 
     test_meeting_shards = [
         str(path)
-        for path in sorted(
-            glob.glob(os.path.join(test_meeting, "shared-*.tar"))
-        )
+        for path in sorted(glob.glob(os.path.join(test_meeting, "shared-*.tar")))
     ]
     cuts_test_meeting_webdataset = CutSet.from_webdataset(
         test_meeting_shards,
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
index 933642a0f..c742593df 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
@@ -126,17 +126,20 @@ def get_parser():
         "--epoch",
         type=int,
         default=28,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=15,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -205,8 +208,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
@@ -468,13 +470,9 @@ def export_joiner_model_onnx(
 
         - projected_decoder_out: a tensor of shape (N, joiner_dim)
     """
-    encoder_proj_filename = str(joiner_filename).replace(
-        ".onnx", "_encoder_proj.onnx"
-    )
+    encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx")
 
-    decoder_proj_filename = str(joiner_filename).replace(
-        ".onnx", "_decoder_proj.onnx"
-    )
+    decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx")
 
     encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
     decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
@@ -645,9 +643,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
index e5cc47bfe..ed9020c67 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
@@ -107,10 +107,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     parser.add_argument(
@@ -145,10 +147,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -331,9 +332,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py
index c396c50ef..a46ff5a07 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py
@@ -219,9 +219,7 @@ def test_joiner(
         )
 
         # Now test encoder_proj
-        joiner_encoder_proj_inputs = {
-            encoder_proj_input_name: encoder_out.numpy()
-        }
+        joiner_encoder_proj_inputs = {encoder_proj_input_name: encoder_out.numpy()}
         joiner_encoder_proj_out = joiner_encoder_proj_session.run(
             [encoder_proj_output_name], joiner_encoder_proj_inputs
         )[0]
@@ -230,16 +228,10 @@ def test_joiner(
         torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out)
         assert torch.allclose(
             joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5
-        ), (
-            (joiner_encoder_proj_out - torch_joiner_encoder_proj_out)
-            .abs()
-            .max()
-        )
+        ), ((joiner_encoder_proj_out - torch_joiner_encoder_proj_out).abs().max())
 
         # Now test decoder_proj
-        joiner_decoder_proj_inputs = {
-            decoder_proj_input_name: decoder_out.numpy()
-        }
+        joiner_decoder_proj_inputs = {decoder_proj_input_name: decoder_out.numpy()}
         joiner_decoder_proj_out = joiner_decoder_proj_session.run(
             [decoder_proj_output_name], joiner_decoder_proj_inputs
         )[0]
@@ -248,11 +240,7 @@ def test_joiner(
         torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out)
         assert torch.allclose(
             joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5
-        ), (
-            (joiner_decoder_proj_out - torch_joiner_decoder_proj_out)
-            .abs()
-            .max()
-        )
+        ), ((joiner_decoder_proj_out - torch_joiner_decoder_proj_out).abs().max())
 
 
 @torch.no_grad()
@@ -304,9 +292,7 @@ def main():
 
 if __name__ == "__main__":
     torch.manual_seed(20220727)
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py
index 3770fbbb4..f7d962008 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py
@@ -111,10 +111,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     parser.add_argument(
@@ -149,10 +151,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -200,11 +201,7 @@ def greedy_search(
 
     projected_encoder_out = joiner_encoder_proj.run(
         [joiner_encoder_proj.get_outputs()[0].name],
-        {
-            joiner_encoder_proj.get_inputs()[
-                0
-            ].name: packed_encoder_out.data.numpy()
-        },
+        {joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()},
     )[0]
 
     blank_id = 0  # hard-code to 0
@@ -389,9 +386,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
index 9a549efd9..26c9c2b8c 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
@@ -80,9 +80,11 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
+        help=(
+            "Path to the checkpoint. "
+            "The checkpoint is assumed to be saved by "
+            "icefall.checkpoint.save_checkpoint()."
+        ),
     )
 
     parser.add_argument(
@@ -107,10 +109,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     parser.add_argument(
@@ -158,8 +162,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -189,10 +192,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -253,9 +255,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -280,10 +280,7 @@ def main():
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -335,9 +332,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
index d3cc7c9c9..e020c4c05 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
@@ -115,9 +115,7 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[
-    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
 
 
 def get_parser():
@@ -219,42 +217,45 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help="The prune range for rnnt loss, it means how many symbols(context)"
-        "we are using to compute the loss",
+        help=(
+            "The prune range for rnnt loss, it means how many symbols(context)"
+            "we are using to compute the loss"
+        ),
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help="The scale to smooth the loss with lm "
-        "(output of prediction network) part.",
+        help=(
+            "The scale to smooth the loss with lm (output of prediction network) part."
+        ),
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network)part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help="To get pruning ranges, we will calculate a simple version"
-        "loss(joiner is just addition), this simple loss also uses for"
-        "training (as a regularization item). We will scale the simple loss"
-        "with this parameter before adding to the final loss.",
+        help=(
+            "To get pruning ranges, we will calculate a simple version"
+            "loss(joiner is just addition), this simple loss also uses for"
+            "training (as a regularization item). We will scale the simple loss"
+            "with this parameter before adding to the final loss."
+        ),
     )
 
     parser.add_argument(
@@ -590,22 +591,15 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0
-            if warmup < 1.0
-            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
-        )
-        loss = (
-            params.simple_loss_scale * simple_loss
-            + pruned_loss_scale * pruned_loss
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
         )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -762,9 +756,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -864,7 +856,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2 ** 22
+            2**22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py
index dd27c17f0..1023c931a 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py
@@ -210,10 +210,7 @@ class Conformer(EncoderInterface):
           (num_encoder_layers, cnn_module_kernel - 1, encoder_dim).
           NOTE: the returned tensors are on the given device.
         """
-        if (
-            len(self._init_state) == 2
-            and self._init_state[0].size(1) == left_context
-        ):
+        if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context:
             # Note: It is OK to share the init state as it is
             # not going to be modified by the model
             return self._init_state
@@ -433,9 +430,7 @@ class ConformerEncoderLayer(nn.Module):
 
         self.d_model = d_model
 
-        self.self_attn = RelPositionMultiheadAttention(
-            d_model, nhead, dropout=0.0
-        )
+        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
 
         self.feed_forward = nn.Sequential(
             ScaledLinear(d_model, dim_feedforward),
@@ -453,9 +448,7 @@ class ConformerEncoderLayer(nn.Module):
             ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
         )
 
-        self.conv_module = ConvolutionModule(
-            d_model, cnn_module_kernel, causal=causal
-        )
+        self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal)
 
         self.norm_final = BasicNorm(d_model)
 
@@ -520,9 +513,7 @@ class ConformerEncoderLayer(nn.Module):
         src = src + self.dropout(src_att)
 
         # convolution module
-        conv, _ = self.conv_module(
-            src, src_key_padding_mask=src_key_padding_mask
-        )
+        conv, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
         src = src + self.dropout(conv)
 
         # feed forward module
@@ -766,9 +757,7 @@ class RelPositionalEncoding(torch.nn.Module):
         max_len: Maximum input length.
     """
 
-    def __init__(
-        self, d_model: int, dropout_rate: float, max_len: int = 5000
-    ) -> None:
+    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
         """Construct an PositionalEncoding object."""
         super(RelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -784,9 +773,7 @@ class RelPositionalEncoding(torch.nn.Module):
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(1) >= x_size_1 * 2 - 1:
                 # Note: TorchScript doesn't implement operator== for torch.Device
-                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
-                    x.device
-                ):
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
                     self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                 return
         # Suppose `i` means to the position of query vector and `j` means the
@@ -1073,9 +1060,9 @@ class RelPositionMultiheadAttention(nn.Module):
 
         if torch.equal(query, key) and torch.equal(key, value):
             # self-attention
-            q, k, v = nn.functional.linear(
-                query, in_proj_weight, in_proj_bias
-            ).chunk(3, dim=-1)
+            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
+                3, dim=-1
+            )
 
         elif torch.equal(key, value):
             # encoder-decoder attention
@@ -1144,33 +1131,25 @@ class RelPositionMultiheadAttention(nn.Module):
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0)
                 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
-                    raise RuntimeError(
-                        "The size of the 2D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
             elif attn_mask.dim() == 3:
                 if list(attn_mask.size()) != [
                     bsz * num_heads,
                     query.size(0),
                     key.size(0),
                 ]:
-                    raise RuntimeError(
-                        "The size of the 3D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
             else:
                 raise RuntimeError(
-                    "attn_mask's dimension {} is not supported".format(
-                        attn_mask.dim()
-                    )
+                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
                 )
             # attn_mask's dim is 3 now.
 
         # convert ByteTensor key_padding_mask to bool
-        if (
-            key_padding_mask is not None
-            and key_padding_mask.dtype == torch.uint8
-        ):
+        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
             warnings.warn(
-                "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
+                "Byte tensor for key_padding_mask is deprecated. Use bool tensor"
+                " instead."
             )
             key_padding_mask = key_padding_mask.to(torch.bool)
 
@@ -1208,23 +1187,15 @@ class RelPositionMultiheadAttention(nn.Module):
         # first compute matrix a and matrix c
         # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
         k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
-        matrix_ac = torch.matmul(
-            q_with_bias_u, k
-        )  # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
 
         # compute matrix b and matrix d
-        matrix_bd = torch.matmul(
-            q_with_bias_v, p
-        )  # (batch, head, time1, 2*time1-1)
+        matrix_bd = torch.matmul(q_with_bias_v, p)  # (batch, head, time1, 2*time1-1)
         matrix_bd = self.rel_shift(matrix_bd, left_context)
 
-        attn_output_weights = (
-            matrix_ac + matrix_bd
-        )  # (batch, head, time1, time2)
+        attn_output_weights = matrix_ac + matrix_bd  # (batch, head, time1, time2)
 
-        attn_output_weights = attn_output_weights.view(
-            bsz * num_heads, tgt_len, -1
-        )
+        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
 
         assert list(attn_output_weights.size()) == [
             bsz * num_heads,
@@ -1265,21 +1236,17 @@ class RelPositionMultiheadAttention(nn.Module):
         ):
             if attn_mask.size(0) != 1:
                 attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len)
-                combined_mask = attn_mask | key_padding_mask.unsqueeze(
-                    1
-                ).unsqueeze(2)
+                combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2)
             else:
                 # attn_mask.shape == (1, tgt_len, src_len)
-                combined_mask = attn_mask.unsqueeze(
-                    0
-                ) | key_padding_mask.unsqueeze(1).unsqueeze(2)
+                combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze(
+                    1
+                ).unsqueeze(2)
 
             attn_output_weights = attn_output_weights.view(
                 bsz, num_heads, tgt_len, src_len
             )
-            attn_output_weights = attn_output_weights.masked_fill(
-                combined_mask, 0.0
-            )
+            attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0)
             attn_output_weights = attn_output_weights.view(
                 bsz * num_heads, tgt_len, src_len
             )
@@ -1291,13 +1258,9 @@ class RelPositionMultiheadAttention(nn.Module):
         attn_output = torch.bmm(attn_output_weights, v)
         assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
         attn_output = (
-            attn_output.transpose(0, 1)
-            .contiguous()
-            .view(tgt_len, bsz, embed_dim)
-        )
-        attn_output = nn.functional.linear(
-            attn_output, out_proj_weight, out_proj_bias
+            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
         )
+        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
 
         if need_weights:
             # average attention weights over heads
@@ -1430,16 +1393,12 @@ class ConvolutionModule(nn.Module):
                 # manualy padding self.lorder zeros to the left
                 x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
             else:
-                assert (
-                    not self.training
-                ), "Cache should be None in training time"
+                assert not self.training, "Cache should be None in training time"
                 assert cache.size(0) == self.lorder
                 x = torch.cat([cache.permute(1, 2, 0), x], dim=2)
                 if right_context > 0:
                     cache = x.permute(2, 0, 1)[
-                        -(self.lorder + right_context) : (  # noqa
-                            -right_context
-                        ),
+                        -(self.lorder + right_context) : (-right_context),  # noqa
                         ...,
                     ]
                 else:
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
index 344e31283..3d66f9dc9 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
@@ -160,20 +160,24 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch' and '--iter'",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch' and '--iter'"
+        ),
     )
 
     parser.add_argument(
         "--use-averaged-model",
         type=str2bool,
         default=True,
-        help="Whether to load averaged model. Currently it only supports "
-        "using --epoch. If True, it would decode with the averaged model "
-        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
-        "Actually only the models with epoch number of `epoch-avg` and "
-        "`epoch` are loaded for averaging. ",
+        help=(
+            "Whether to load averaged model. Currently it only supports "
+            "using --epoch. If True, it would decode with the averaged model "
+            "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+            "Actually only the models with epoch number of `epoch-avg` and "
+            "`epoch` are loaded for averaging. "
+        ),
     )
 
     parser.add_argument(
@@ -244,8 +248,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -342,9 +345,7 @@ def decode_one_batch(
             simulate_streaming=True,
         )
     else:
-        encoder_out, encoder_out_lens = model.encoder(
-            x=feature, x_lens=feature_lens
-        )
+        encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
     hyps = []
 
@@ -360,10 +361,7 @@ def decode_one_batch(
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -409,11 +407,7 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            (
-                f"beam_{params.beam}_"
-                f"max_contexts_{params.max_contexts}_"
-                f"max_states_{params.max_states}"
-            ): hyps
+            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -484,9 +478,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -519,8 +511,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -589,13 +580,12 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for"
-                    f" --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg:
                 raise ValueError(
@@ -618,13 +608,12 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg + 1]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for"
-                    f" --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg + 1:
                 raise ValueError(
@@ -652,7 +641,7 @@ def main():
             filename_start = f"{params.exp_dir}/epoch-{start}.pt"
             filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
             logging.info(
-                f"Calculating the averaged model over epoch range from "
+                "Calculating the averaged model over epoch range from "
                 f"{start} (excluded) to {params.epoch}"
             )
             model.to(device)
@@ -720,8 +709,7 @@ def main():
         )
 
     dev_shards = [
-        str(path)
-        for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
+        str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
     ]
     cuts_dev_webdataset = CutSet.from_webdataset(
         dev_shards,
@@ -731,8 +719,7 @@ def main():
     )
 
     test_net_shards = [
-        str(path)
-        for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
+        str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
     ]
     cuts_test_net_webdataset = CutSet.from_webdataset(
         test_net_shards,
@@ -743,9 +730,7 @@ def main():
 
     test_meeting_shards = [
         str(path)
-        for path in sorted(
-            glob.glob(os.path.join(test_meeting, "shared-*.tar"))
-        )
+        for path in sorted(glob.glob(os.path.join(test_meeting, "shared-*.tar")))
     ]
     cuts_test_meeting_webdataset = CutSet.from_webdataset(
         test_meeting_shards,
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py
index 386248554..e522943c0 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py
@@ -75,9 +75,7 @@ class DecodeStream(object):
         # encoder.streaming_forward
         self.done_frames: int = 0
 
-        self.pad_length = (
-            params.right_context + 2
-        ) * params.subsampling_factor + 3
+        self.pad_length = (params.right_context + 2) * params.subsampling_factor + 3
 
         if params.decoding_method == "greedy_search":
             self.hyp = [params.blank_id] * params.context_size
@@ -91,13 +89,11 @@ class DecodeStream(object):
             )
         elif params.decoding_method == "fast_beam_search":
             # The rnnt_decoding_stream for fast_beam_search.
-            self.rnnt_decoding_stream: k2.RnntDecodingStream = (
-                k2.RnntDecodingStream(decoding_graph)
+            self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream(
+                decoding_graph
             )
         else:
-            raise ValueError(
-                f"Unsupported decoding method: {params.decoding_method}"
-            )
+            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
 
     @property
     def done(self) -> bool:
@@ -126,13 +122,10 @@ class DecodeStream(object):
         """Consume chunk_size frames of features"""
         chunk_length = chunk_size + self.pad_length
 
-        ret_length = min(
-            self.num_frames - self.num_processed_frames, chunk_length
-        )
+        ret_length = min(self.num_frames - self.num_processed_frames, chunk_length)
 
         ret_features = self.features[
-            self.num_processed_frames : self.num_processed_frames  # noqa
-            + ret_length
+            self.num_processed_frames : self.num_processed_frames + ret_length  # noqa
         ]
 
         self.num_processed_frames += chunk_size
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
index d0a7fd69f..fb53f70ab 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
@@ -90,17 +90,20 @@ def get_parser():
         "--epoch",
         type=int,
         default=28,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=15,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -131,8 +134,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     add_model_arguments(parser)
 
@@ -201,9 +203,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py
index 1b064c874..9834189d8 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py
@@ -80,9 +80,11 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
+        help=(
+            "Path to the checkpoint. "
+            "The checkpoint is assumed to be saved by "
+            "icefall.checkpoint.save_checkpoint()."
+        ),
     )
 
     parser.add_argument(
@@ -107,10 +109,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     parser.add_argument(
@@ -157,8 +161,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -189,10 +192,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -253,9 +255,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -280,10 +280,7 @@ def main():
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -335,9 +332,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py
index 651aff6c9..810d94135 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py
@@ -173,14 +173,10 @@ def modified_beam_search(
         log_probs_shape = k2.ragged.create_ragged_shape2(
             row_splits=row_splits, cached_tot_size=log_probs.numel()
         )
-        ragged_log_probs = k2.RaggedTensor(
-            shape=log_probs_shape, value=log_probs
-        )
+        ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
 
         for i in range(batch_size):
-            topk_log_probs, topk_indexes = ragged_log_probs[i].topk(
-                num_active_paths
-            )
+            topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths)
 
             with warnings.catch_warnings():
                 warnings.simplefilter("ignore")
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py
index ff96c6487..31a7fe605 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py
@@ -119,20 +119,24 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch' and '--iter'",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch' and '--iter'"
+        ),
     )
 
     parser.add_argument(
         "--use-averaged-model",
         type=str2bool,
         default=True,
-        help="Whether to load averaged model. Currently it only supports "
-        "using --epoch. If True, it would decode with the averaged model "
-        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
-        "Actually only the models with epoch number of `epoch-avg` and "
-        "`epoch` are loaded for averaging. ",
+        help=(
+            "Whether to load averaged model. Currently it only supports "
+            "using --epoch. If True, it would decode with the averaged model "
+            "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+            "Actually only the models with epoch number of `epoch-avg` and "
+            "`epoch` are loaded for averaging. "
+        ),
     )
 
     parser.add_argument(
@@ -201,8 +205,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -311,9 +314,7 @@ def decode_one_chunk(
     encoder_out = model.joiner.encoder_proj(encoder_out)
 
     if params.decoding_method == "greedy_search":
-        greedy_search(
-            model=model, encoder_out=encoder_out, streams=decode_streams
-        )
+        greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
     elif params.decoding_method == "fast_beam_search":
         processed_lens = processed_lens + encoder_out_lens
         fast_beam_search_one_best(
@@ -333,9 +334,7 @@ def decode_one_chunk(
             num_active_paths=params.num_active_paths,
         )
     else:
-        raise ValueError(
-            f"Unsupported decoding method: {params.decoding_method}"
-        )
+        raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
 
     states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)]
 
@@ -389,9 +388,7 @@ def decode_dataset(
     decode_results = []
     # Contain decode streams currently running.
     decode_streams = []
-    initial_states = model.encoder.get_init_state(
-        params.left_context, device=device
-    )
+    initial_states = model.encoder.get_init_state(params.left_context, device=device)
     for num, cut in enumerate(cuts):
         # each utterance has a DecodeStream.
         decode_stream = DecodeStream(
@@ -461,9 +458,7 @@ def decode_dataset(
     elif params.decoding_method == "modified_beam_search":
         key = f"num_active_paths_{params.num_active_paths}"
     else:
-        raise ValueError(
-            f"Unsupported decoding method: {params.decoding_method}"
-        )
+        raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
 
     return {key: decode_results}
 
@@ -499,8 +494,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -565,13 +559,12 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for"
-                    f" --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg:
                 raise ValueError(
@@ -594,13 +587,12 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg + 1]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for"
-                    f" --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg + 1:
                 raise ValueError(
@@ -628,7 +620,7 @@ def main():
             filename_start = f"{params.exp_dir}/epoch-{start}.pt"
             filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
             logging.info(
-                f"Calculating the averaged model over epoch range from "
+                "Calculating the averaged model over epoch range from "
                 f"{start} (excluded) to {params.epoch}"
             )
             model.to(device)
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py
index 2052e9da7..40c9665f7 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py
@@ -98,9 +98,7 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[
-    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
 
 
 def add_model_arguments(parser: argparse.ArgumentParser):
@@ -260,8 +258,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need "
-        "to be changed.",
+        help="The initial learning rate.  This value should not need to be changed.",
     )
 
     parser.add_argument(
@@ -284,42 +281,45 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help="The prune range for rnnt loss, it means how many symbols(context)"
-        "we are using to compute the loss",
+        help=(
+            "The prune range for rnnt loss, it means how many symbols(context)"
+            "we are using to compute the loss"
+        ),
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help="The scale to smooth the loss with lm "
-        "(output of prediction network) part.",
+        help=(
+            "The scale to smooth the loss with lm (output of prediction network) part."
+        ),
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network)part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help="To get pruning ranges, we will calculate a simple version"
-        "loss(joiner is just addition), this simple loss also uses for"
-        "training (as a regularization item). We will scale the simple loss"
-        "with this parameter before adding to the final loss.",
+        help=(
+            "To get pruning ranges, we will calculate a simple version"
+            "loss(joiner is just addition), this simple loss also uses for"
+            "training (as a regularization item). We will scale the simple loss"
+            "with this parameter before adding to the final loss."
+        ),
     )
 
     parser.add_argument(
@@ -665,11 +665,7 @@ def compute_loss(
      warmup: a floating point value which increases throughout training;
         values >= 1.0 are fully warmed up and have all modules present.
     """
-    device = (
-        model.device
-        if isinstance(model, DDP)
-        else next(model.parameters()).device
-    )
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
     feature = batch["inputs"]
     # at entry, feature is (N, T, C)
     assert feature.ndim == 3
@@ -701,23 +697,16 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0
-            if warmup < 1.0
-            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
-        )
-        loss = (
-            params.simple_loss_scale * simple_loss
-            + pruned_loss_scale * pruned_loss
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
         )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
 
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -841,9 +830,7 @@ def train_one_epoch(
             scaler.update()
             optimizer.zero_grad()
         except:  # noqa
-            display_and_save_batch(
-                batch, params=params, graph_compiler=graph_compiler
-            )
+            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
             raise
 
         if params.print_diagnostics and batch_idx == 5:
@@ -901,9 +888,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -1016,7 +1001,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2 ** 22
+            2**22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
@@ -1184,9 +1169,7 @@ def scan_pessimistic_batches_for_oom(
                     f"Failing criterion: {criterion} "
                     f"(={crit_values[criterion]}) ..."
                 )
-            display_and_save_batch(
-                batch, params=params, graph_compiler=graph_compiler
-            )
+            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
             raise
 
 
diff --git a/egs/yesno/ASR/local/compile_hlg.py b/egs/yesno/ASR/local/compile_hlg.py
index f83be05cf..7234ca929 100755
--- a/egs/yesno/ASR/local/compile_hlg.py
+++ b/egs/yesno/ASR/local/compile_hlg.py
@@ -128,9 +128,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/yesno/ASR/local/compute_fbank_yesno.py b/egs/yesno/ASR/local/compute_fbank_yesno.py
index 9a4e8a36f..75d95df68 100755
--- a/egs/yesno/ASR/local/compute_fbank_yesno.py
+++ b/egs/yesno/ASR/local/compute_fbank_yesno.py
@@ -54,9 +54,7 @@ def compute_fbank_yesno():
         dataset_parts,
     )
 
-    extractor = Fbank(
-        FbankConfig(sampling_rate=8000, num_mel_bins=num_mel_bins)
-    )
+    extractor = Fbank(FbankConfig(sampling_rate=8000, num_mel_bins=num_mel_bins))
 
     with get_executor() as ex:  # Initialize the executor only once.
         for partition, m in manifests.items():
@@ -71,9 +69,7 @@ def compute_fbank_yesno():
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set
-                    + cut_set.perturb_speed(0.9)
-                    + cut_set.perturb_speed(1.1)
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -87,9 +83,7 @@ def compute_fbank_yesno():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/yesno/ASR/tdnn/asr_datamodule.py b/egs/yesno/ASR/tdnn/asr_datamodule.py
index 85e5f1358..21860d2f5 100644
--- a/egs/yesno/ASR/tdnn/asr_datamodule.py
+++ b/egs/yesno/ASR/tdnn/asr_datamodule.py
@@ -56,10 +56,12 @@ class YesNoAsrDataModule(DataModule):
         super().add_arguments(parser)
         group = parser.add_argument_group(
             title="ASR data related options",
-            description="These options are used for the preparation of "
-            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-            "effective batch sizes, sampling strategies, applied data "
-            "augmentations, etc.",
+            description=(
+                "These options are used for the preparation of "
+                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+                "effective batch sizes, sampling strategies, applied data "
+                "augmentations, etc."
+            ),
         )
         group.add_argument(
             "--feature-dir",
@@ -71,75 +73,91 @@ class YesNoAsrDataModule(DataModule):
             "--max-duration",
             type=int,
             default=30.0,
-            help="Maximum pooled recordings duration (seconds) in a "
-            "single batch. You can reduce it if it causes CUDA OOM.",
+            help=(
+                "Maximum pooled recordings duration (seconds) in a "
+                "single batch. You can reduce it if it causes CUDA OOM."
+            ),
         )
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=False,
-            help="When enabled, the batches will come from buckets of "
-            "similar duration (saves padding frames).",
+            help=(
+                "When enabled, the batches will come from buckets of "
+                "similar duration (saves padding frames)."
+            ),
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=10,
-            help="The number of buckets for the DynamicBucketingSampler"
-            "(you might want to increase it for larger datasets).",
+            help=(
+                "The number of buckets for the DynamicBucketingSampler"
+                "(you might want to increase it for larger datasets)."
+            ),
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help="When enabled, utterances (cuts) will be concatenated "
-            "to minimize the amount of padding.",
+            help=(
+                "When enabled, utterances (cuts) will be concatenated "
+                "to minimize the amount of padding."
+            ),
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help="Determines the maximum duration of a concatenated cut "
-            "relative to the duration of the longest cut in a batch.",
+            help=(
+                "Determines the maximum duration of a concatenated cut "
+                "relative to the duration of the longest cut in a batch."
+            ),
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help="The amount of padding (in seconds) inserted between "
-            "concatenated cuts. This padding is filled with noise when "
-            "noise augmentation is used.",
+            help=(
+                "The amount of padding (in seconds) inserted between "
+                "concatenated cuts. This padding is filled with noise when "
+                "noise augmentation is used."
+            ),
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help="When enabled, use on-the-fly cut mixing and feature "
-            "extraction. Will drop existing precomputed feature manifests "
-            "if available.",
+            help=(
+                "When enabled, use on-the-fly cut mixing and feature "
+                "extraction. Will drop existing precomputed feature manifests "
+                "if available."
+            ),
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help="When enabled (=default), the examples will be "
-            "shuffled for each epoch.",
+            help=(
+                "When enabled (=default), the examples will be shuffled for each epoch."
+            ),
         )
         group.add_argument(
             "--return-cuts",
             type=str2bool,
             default=True,
-            help="When enabled, each batch will have the "
-            "field: batch['supervisions']['cut'] with the cuts that "
-            "were used to construct it.",
+            help=(
+                "When enabled, each batch will have the "
+                "field: batch['supervisions']['cut'] with the cuts that "
+                "were used to construct it."
+            ),
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that "
-            "collect the batches.",
+            help="The number of training dataloader workers that collect the batches.",
         )
 
     def train_dataloaders(self) -> DataLoader:
@@ -150,7 +168,7 @@ class YesNoAsrDataModule(DataModule):
         transforms = []
         if self.args.concatenate_cuts:
             logging.info(
-                f"Using cut concatenation with duration factor "
+                "Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
diff --git a/egs/yesno/ASR/tdnn/decode.py b/egs/yesno/ASR/tdnn/decode.py
index 9d4ab4b61..41afe0404 100755
--- a/egs/yesno/ASR/tdnn/decode.py
+++ b/egs/yesno/ASR/tdnn/decode.py
@@ -35,16 +35,19 @@ def get_parser():
         "--epoch",
         type=int,
         default=14,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=2,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -201,9 +204,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -274,9 +275,7 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(
-        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
-    )
+    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -297,9 +296,7 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save(
-            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
-        )
+        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
         return
 
     model.to(device)
@@ -317,9 +314,7 @@ def main():
         word_table=lexicon.word_table,
     )
 
-    save_results(
-        exp_dir=params.exp_dir, test_set_name="test_set", results=results
-    )
+    save_results(exp_dir=params.exp_dir, test_set_name="test_set", results=results)
 
     logging.info("Done!")
 
diff --git a/egs/yesno/ASR/tdnn/pretrained.py b/egs/yesno/ASR/tdnn/pretrained.py
index 14220be19..09a8672ae 100755
--- a/egs/yesno/ASR/tdnn/pretrained.py
+++ b/egs/yesno/ASR/tdnn/pretrained.py
@@ -41,9 +41,11 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
+        help=(
+            "Path to the checkpoint. "
+            "The checkpoint is assumed to be saved by "
+            "icefall.checkpoint.save_checkpoint()."
+        ),
     )
 
     parser.add_argument(
@@ -53,18 +55,18 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument(
-        "--HLG", type=str, required=True, help="Path to HLG.pt."
-    )
+    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
 
     parser.add_argument(
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     return parser
@@ -101,10 +103,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -159,9 +160,7 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     # Note: We don't use key padding mask for attention during decoding
     with torch.no_grad():
@@ -201,9 +200,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/yesno/ASR/tdnn/train.py b/egs/yesno/ASR/tdnn/train.py
index f32a27f35..335493491 100755
--- a/egs/yesno/ASR/tdnn/train.py
+++ b/egs/yesno/ASR/tdnn/train.py
@@ -430,9 +430,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/egs/yesno/ASR/transducer/decode.py b/egs/yesno/ASR/transducer/decode.py
index 6714180db..de478334e 100755
--- a/egs/yesno/ASR/transducer/decode.py
+++ b/egs/yesno/ASR/transducer/decode.py
@@ -48,16 +48,19 @@ def get_parser():
         "--epoch",
         type=int,
         default=125,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=20,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
     parser.add_argument(
         "--exp-dir",
@@ -116,9 +119,7 @@ def decode_one_batch(
     # at entry, feature is (N, T, C)
     feature_lens = batch["supervisions"]["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
     hyps = []
     batch_size = encoder_out.size(0)
@@ -186,9 +187,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -303,9 +302,7 @@ def main():
         model=model,
     )
 
-    save_results(
-        exp_dir=params.exp_dir, test_set_name="test_set", results=results
-    )
+    save_results(exp_dir=params.exp_dir, test_set_name="test_set", results=results)
 
     logging.info("Done!")
 
diff --git a/egs/yesno/ASR/transducer/train.py b/egs/yesno/ASR/transducer/train.py
index deb92107d..88866ae81 100755
--- a/egs/yesno/ASR/transducer/train.py
+++ b/egs/yesno/ASR/transducer/train.py
@@ -430,9 +430,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/icefall/char_graph_compiler.py b/icefall/char_graph_compiler.py
index 235160e14..c31db6e4c 100644
--- a/icefall/char_graph_compiler.py
+++ b/icefall/char_graph_compiler.py
@@ -71,9 +71,7 @@ class CharCtcTrainingGraphCompiler(object):
         for text in texts:
             text = re.sub(whitespace, "", text)
             sub_ids = [
-                self.token_table[txt]
-                if txt in self.token_table
-                else self.oov_id
+                self.token_table[txt] if txt in self.token_table else self.oov_id
                 for txt in text
             ]
             ids.append(sub_ids)
@@ -96,9 +94,7 @@ class CharCtcTrainingGraphCompiler(object):
         for text in texts:
             text = text.split("/")
             sub_ids = [
-                self.token_table[txt]
-                if txt in self.token_table
-                else self.oov_id
+                self.token_table[txt] if txt in self.token_table else self.oov_id
                 for txt in text
             ]
             ids.append(sub_ids)
diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py
index 5069b78e8..8aa0a8eeb 100644
--- a/icefall/checkpoint.py
+++ b/icefall/checkpoint.py
@@ -292,15 +292,11 @@ def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
     """
     checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt"))
     pattern = re.compile(r"checkpoint-([0-9]+).pt")
-    iter_checkpoints = [
-        (int(pattern.search(c).group(1)), c) for c in checkpoints
-    ]
+    iter_checkpoints = [(int(pattern.search(c).group(1)), c) for c in checkpoints]
     # iter_checkpoints is a list of tuples. Each tuple contains
     # two elements: (iteration_number, checkpoint-iteration_number.pt)
 
-    iter_checkpoints = sorted(
-        iter_checkpoints, reverse=True, key=lambda x: x[0]
-    )
+    iter_checkpoints = sorted(iter_checkpoints, reverse=True, key=lambda x: x[0])
     if iteration >= 0:
         ans = [ic[1] for ic in iter_checkpoints if ic[0] >= iteration]
     else:
@@ -469,7 +465,5 @@ def average_state_dict(
         v = state_dict_1[k]
         if torch.is_floating_point(v):
             v *= weight_1
-            v += (
-                state_dict_2[k].to(device=state_dict_1[k].device) * weight_2
-            )
+            v += state_dict_2[k].to(device=state_dict_1[k].device) * weight_2
             v *= scaling_factor
diff --git a/icefall/decode.py b/icefall/decode.py
index f04ee368c..6cd87bdc0 100644
--- a/icefall/decode.py
+++ b/icefall/decode.py
@@ -334,13 +334,9 @@ class Nbest(object):
         if hasattr(lattice, "aux_labels"):
             # delete token IDs as it is not needed
             del word_fsa.aux_labels
-            word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(
-                word_fsa
-            )
+            word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa)
         else:
-            word_fsa_with_epsilon_loops = k2.linear_fst_with_self_loops(
-                word_fsa
-            )
+            word_fsa_with_epsilon_loops = k2.linear_fst_with_self_loops(word_fsa)
 
         path_to_utt_map = self.shape.row_ids(1)
 
@@ -370,9 +366,7 @@ class Nbest(object):
         # path_lattice has word IDs as labels and token IDs as aux_labels
         path_lattice = k2.top_sort(k2.connect(path_lattice))
 
-        one_best = k2.shortest_path(
-            path_lattice, use_double_scores=use_double_scores
-        )
+        one_best = k2.shortest_path(path_lattice, use_double_scores=use_double_scores)
 
         one_best = k2.invert(one_best)
         # Now one_best has token IDs as labels and word IDs as aux_labels
@@ -442,9 +436,7 @@ class Nbest(object):
         scores_shape = self.fsa.arcs.shape().remove_axis(1)
         # scores_shape has axes [path][arc]
 
-        ragged_scores = k2.RaggedTensor(
-            scores_shape, self.fsa.scores.contiguous()
-        )
+        ragged_scores = k2.RaggedTensor(scores_shape, self.fsa.scores.contiguous())
 
         tot_scores = ragged_scores.sum()
 
@@ -678,9 +670,7 @@ def rescore_with_n_best_list(
             logging.info(f"num_paths before decreasing: {num_paths}")
             num_paths = int(num_paths / 2)
             if loop_count >= max_loop_count or num_paths <= 0:
-                logging.info(
-                    "Return None as the resulting lattice is too large."
-                )
+                logging.info("Return None as the resulting lattice is too large.")
                 return None
             logging.info(
                 "This OOM is not an error. You can ignore it. "
@@ -787,13 +777,9 @@ def rescore_with_whole_lattice(
         except RuntimeError as e:
             logging.info(f"Caught exception:\n{e}\n")
             if loop_count >= max_loop_count:
-                logging.info(
-                    "Return None as the resulting lattice is too large."
-                )
+                logging.info("Return None as the resulting lattice is too large.")
                 return None
-            logging.info(
-                f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}"
-            )
+            logging.info(f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}")
             logging.info(
                 "This OOM is not an error. You can ignore it. "
                 "If your model does not converge well, or --max-duration "
@@ -805,9 +791,7 @@ def rescore_with_whole_lattice(
                 prune_th_list[loop_count],
                 True,
             )
-            logging.info(
-                f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}"
-            )
+            logging.info(f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}")
         loop_count += 1
 
     # lat has token IDs as labels
@@ -894,9 +878,7 @@ def rescore_with_attention_decoder(
             logging.info(f"num_paths before decreasing: {num_paths}")
             num_paths = int(num_paths / 2)
             if loop_count >= max_loop_count or num_paths <= 0:
-                logging.info(
-                    "Return None as the resulting lattice is too large."
-                )
+                logging.info("Return None as the resulting lattice is too large.")
                 return None
             logging.info(
                 "This OOM is not an error. You can ignore it. "
diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py
index b075aceac..7b58ffbd4 100644
--- a/icefall/diagnostics.py
+++ b/icefall/diagnostics.py
@@ -19,7 +19,7 @@
 
 import random
 from dataclasses import dataclass
-from typing import Optional, Tuple, List
+from typing import List, Optional, Tuple
 
 import torch
 from torch import Tensor, nn
@@ -78,11 +78,11 @@ def get_tensor_stats(
     elif stats_type == "abs":
         x = x.abs()
     elif stats_type == "rms":
-        x = x ** 2
+        x = x**2
     elif stats_type == "positive":
         x = (x > 0).to(dtype=torch.float)
     else:
-        assert stats_type in [ "value", "max", "min" ]
+        assert stats_type in ["value", "max", "min"]
 
     sum_dims = [d for d in range(x.ndim) if d != dim]
     if len(sum_dims) > 0:
@@ -121,7 +121,9 @@ class TensorDiagnostic(object):
         self.name = name
         self.class_name = None  # will assign in accumulate()
 
-        self.stats = None  # we'll later assign a list to this data member.  It's a list of dict.
+        self.stats = (
+            None  # we'll later assign a list to this data member.  It's a list of dict.
+        )
 
         # the keys into self.stats[dim] are strings, whose values can be
         # "abs", "max", "min" ,"value", "positive", "rms", "value".
@@ -133,7 +135,6 @@ class TensorDiagnostic(object):
         # only adding a new element to the list if there was a different dim.
         # if the string in the key is "eigs", if we detect a length mismatch we put None as the value.
 
-
     def accumulate(self, x, class_name: Optional[str] = None):
         """
         Accumulate tensors.
@@ -185,17 +186,12 @@ class TensorDiagnostic(object):
                         done = True
                         break
                 if not done:
-                    if (
-                        this_dim_stats[stats_type] != []
-                        and stats_type == "eigs"
-                    ):
+                    if this_dim_stats[stats_type] != [] and stats_type == "eigs":
                         # >1 size encountered on this dim, e.g. it's a batch or time dimension,
                         # don't accumulat "eigs" stats type, it uses too much memory
                         this_dim_stats[stats_type] = None
                     else:
-                        this_dim_stats[stats_type].append(
-                            TensorAndCount(stats, count)
-                        )
+                        this_dim_stats[stats_type].append(TensorAndCount(stats, count))
 
     def print_diagnostics(self):
         """Print diagnostics for each dimension of the tensor."""
@@ -211,7 +207,6 @@ class TensorDiagnostic(object):
                     assert stats_type == "eigs"
                     continue
 
-
                 def get_count(count):
                     return 1 if stats_type in ["max", "min"] else count
 
@@ -221,7 +216,8 @@ class TensorDiagnostic(object):
                     # a dimension that has variable size in different nnet
                     # forwards, e.g. a time dimension in an ASR model.
                     stats = torch.cat(
-                        [x.tensor / get_count(x.count) for x in stats_list], dim=0
+                        [x.tensor / get_count(x.count) for x in stats_list],
+                        dim=0,
                     )
 
                 if stats_type == "eigs":
@@ -229,9 +225,7 @@ class TensorDiagnostic(object):
                         eigs, _ = torch.symeig(stats)
                         stats = eigs.abs().sqrt()
                     except:  # noqa
-                        print(
-                            "Error getting eigenvalues, trying another method."
-                        )
+                        print("Error getting eigenvalues, trying another method.")
                         eigs, _ = torch.eig(stats)
                         stats = eigs.abs().sqrt()
                         # sqrt so it reflects data magnitude, like stddev- not variance
@@ -242,9 +236,9 @@ class TensorDiagnostic(object):
 
                 # if `summarize` we print percentiles of the stats; else,
                 # we print out individual elements.
-                summarize = (
-                    len(stats_list) > 1
-                ) or self.opts.dim_is_summarized(stats.numel())
+                summarize = (len(stats_list) > 1) or self.opts.dim_is_summarized(
+                    stats.numel()
+                )
                 if summarize:  # usually `summarize` will be true
                     # print out percentiles.
                     stats = stats.sort()[0]
@@ -261,15 +255,15 @@ class TensorDiagnostic(object):
                     ans = stats.tolist()
                     ans = ["%.2g" % x for x in ans]
                     ans = "[" + " ".join(ans) + "]"
-                if stats_type in [ "value", "rms", "eigs" ]:
+                if stats_type in ["value", "rms", "eigs"]:
                     # This norm is useful because it is strictly less than the largest
                     # sqrt(eigenvalue) of the variance, which we print out, and shows,
                     # speaking in an approximate way, how much of that largest eigenvalue
                     # can be attributed to the mean of the distribution.
-                    norm = (stats ** 2).sum().sqrt().item()
+                    norm = (stats**2).sum().sqrt().item()
                     ans += f", norm={norm:.2g}"
                 mean = stats.mean().item()
-                rms = (stats ** 2).mean().sqrt().item()
+                rms = (stats**2).mean().sqrt().item()
                 ans += f", mean={mean:.3g}, rms={rms:.3g}"
 
                 # OK, "ans" contains the actual stats, e.g.
@@ -277,17 +271,17 @@ class TensorDiagnostic(object):
 
                 sizes = [x.tensor.shape[0] for x in stats_list]
                 size_str = (
-                    f"{sizes[0]}"
-                    if len(sizes) == 1
-                    else f"{min(sizes)}..{max(sizes)}"
+                    f"{sizes[0]}" if len(sizes) == 1 else f"{min(sizes)}..{max(sizes)}"
+                )
+                maybe_class_name = (
+                    f" type={self.class_name}," if self.class_name is not None else ""
                 )
-                maybe_class_name = f" type={self.class_name}," if self.class_name is not None else ""
                 print(
-                    f"module={self.name},{maybe_class_name} dim={dim}, size={size_str}, {stats_type} {ans}"
+                    f"module={self.name},{maybe_class_name} dim={dim}, size={size_str},"
+                    f" {stats_type} {ans}"
                 )
 
 
-
 class ModelDiagnostic(object):
     """This class stores diagnostics for all tensors in the torch.nn.Module.
 
@@ -345,32 +339,32 @@ def attach_diagnostics(
         # (matters for name, since the variable gets overwritten).
         # These closures don't really capture by value, only by
         # "the final value the variable got in the function" :-(
-        def forward_hook(
-            _module, _input, _output, _model_diagnostic=ans, _name=name
-        ):
+        def forward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
             if isinstance(_output, tuple) and len(_output) == 1:
                 _output = _output[0]
 
             if isinstance(_output, Tensor):
-                _model_diagnostic[f"{_name}.output"].accumulate(_output,
-                                                                class_name=type(_module).__name__)
+                _model_diagnostic[f"{_name}.output"].accumulate(
+                    _output, class_name=type(_module).__name__
+                )
             elif isinstance(_output, tuple):
                 for i, o in enumerate(_output):
-                    _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o,
-                                                                         class_name=type(_module).__name__)
+                    _model_diagnostic[f"{_name}.output[{i}]"].accumulate(
+                        o, class_name=type(_module).__name__
+                    )
 
-        def backward_hook(
-            _module, _input, _output, _model_diagnostic=ans, _name=name
-        ):
+        def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
             if isinstance(_output, tuple) and len(_output) == 1:
                 _output = _output[0]
             if isinstance(_output, Tensor):
-                _model_diagnostic[f"{_name}.grad"].accumulate(_output,
-                                                              class_name=type(_module).__name__)
+                _model_diagnostic[f"{_name}.grad"].accumulate(
+                    _output, class_name=type(_module).__name__
+                )
             elif isinstance(_output, tuple):
                 for i, o in enumerate(_output):
-                    _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o,
-                                                                       class_name=type(_module).__name__)
+                    _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(
+                        o, class_name=type(_module).__name__
+                    )
 
         module.register_forward_hook(forward_hook)
         module.register_backward_hook(backward_hook)
diff --git a/icefall/dist.py b/icefall/dist.py
index 7016beafb..9df1c5bd1 100644
--- a/icefall/dist.py
+++ b/icefall/dist.py
@@ -29,9 +29,7 @@ def setup_dist(rank, world_size, master_port=None, use_ddp_launch=False):
         os.environ["MASTER_ADDR"] = "localhost"
 
     if "MASTER_PORT" not in os.environ:
-        os.environ["MASTER_PORT"] = (
-            "12354" if master_port is None else str(master_port)
-        )
+        os.environ["MASTER_PORT"] = "12354" if master_port is None else str(master_port)
 
     if use_ddp_launch is False:
         dist.init_process_group("nccl", rank=rank, world_size=world_size)
diff --git a/icefall/env.py b/icefall/env.py
index 8aeda6be2..373e9a9ff 100644
--- a/icefall/env.py
+++ b/icefall/env.py
@@ -53,9 +53,7 @@ def get_git_sha1():
             )
             > 0
         )
-        git_commit = (
-            git_commit + "-dirty" if dirty_commit else git_commit + "-clean"
-        )
+        git_commit = git_commit + "-dirty" if dirty_commit else git_commit + "-clean"
     except:  # noqa
         return None
 
diff --git a/icefall/graph_compiler.py b/icefall/graph_compiler.py
index 570ed7d7a..e2ff03f61 100644
--- a/icefall/graph_compiler.py
+++ b/icefall/graph_compiler.py
@@ -75,9 +75,7 @@ class CtcTrainingGraphCompiler(object):
 
         # NOTE: k2.compose runs on CUDA only when treat_epsilons_specially
         # is False, so we add epsilon self-loops here
-        fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(
-            transcript_fsa
-        )
+        fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(transcript_fsa)
 
         fsa_with_self_loops = k2.arc_sort(fsa_with_self_loops)
 
diff --git a/icefall/hooks.py b/icefall/hooks.py
index fbcf5e148..398a5f689 100644
--- a/icefall/hooks.py
+++ b/icefall/hooks.py
@@ -14,10 +14,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import logging
 import random
+
 import torch
 from torch import Tensor, nn
-import logging
 
 
 def register_inf_check_hooks(model: nn.Module) -> None:
@@ -56,7 +57,7 @@ def register_inf_check_hooks(model: nn.Module) -> None:
             if isinstance(_output, Tensor):
                 if not torch.isfinite(_output.to(torch.float32).sum()):
                     logging.warning(
-                        f"The sum of {_name}.grad is not finite" # ": {_output}"
+                        f"The sum of {_name}.grad is not finite"  # ": {_output}"
                     )
             elif isinstance(_output, tuple):
                 for i, o in enumerate(_output):
@@ -65,28 +66,20 @@ def register_inf_check_hooks(model: nn.Module) -> None:
                     if not isinstance(o, Tensor):
                         continue
                     if not torch.isfinite(o.to(torch.float32).sum()):
-                        logging.warning(
-                            f"The sum of {_name}.grad[{i}] is not finite"
-                        )
+                        logging.warning(f"The sum of {_name}.grad[{i}] is not finite")
 
         module.register_forward_hook(forward_hook)
         module.register_backward_hook(backward_hook)
 
-
     for name, parameter in model.named_parameters():
 
-        def param_backward_hook(
-                grad, _name=name
-        ):
+        def param_backward_hook(grad, _name=name):
             if not torch.isfinite(grad.to(torch.float32).sum()):
-                logging.warning(
-                    f"The sum of {_name}.param_grad is not finite"
-                )
+                logging.warning(f"The sum of {_name}.param_grad is not finite")
 
         parameter.register_hook(param_backward_hook)
 
 
-
 def _test_inf_check_hooks():
     model = nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 80))
 
diff --git a/icefall/lexicon.py b/icefall/lexicon.py
index 80bd7c1ee..22e1b78bb 100644
--- a/icefall/lexicon.py
+++ b/icefall/lexicon.py
@@ -49,18 +49,12 @@ def read_lexicon(filename: str) -> List[Tuple[str, List[str]]]:
                 continue
 
             if len(a) < 2:
-                logging.info(
-                    f"Found bad line {line} in lexicon file {filename}"
-                )
-                logging.info(
-                    "Every line is expected to contain at least 2 fields"
-                )
+                logging.info(f"Found bad line {line} in lexicon file {filename}")
+                logging.info("Every line is expected to contain at least 2 fields")
                 sys.exit(1)
             word = a[0]
             if word == "":
-                logging.info(
-                    f"Found bad line {line} in lexicon file {filename}"
-                )
+                logging.info(f"Found bad line {line} in lexicon file {filename}")
                 logging.info(" should not be a valid word")
                 sys.exit(1)
 
@@ -119,9 +113,7 @@ def convert_lexicon_to_ragged(
     lexicon_tmp = read_lexicon(filename)
     lexicon = dict(lexicon_tmp)
     if len(lexicon_tmp) != len(lexicon):
-        raise RuntimeError(
-            "It's assumed that each word has a unique pronunciation"
-        )
+        raise RuntimeError("It's assumed that each word has a unique pronunciation")
 
     for i in range(disambig_id):
         w = word_table[i]
diff --git a/icefall/mmi.py b/icefall/mmi.py
index 2c479fc2c..16ed6e032 100644
--- a/icefall/mmi.py
+++ b/icefall/mmi.py
@@ -63,10 +63,7 @@ def _compute_mmi_loss_exact_optimized(
 
     # [0, num_fsas, 1, num_fsas, 2, num_fsas, ... ]
     num_den_graphs_indexes = (
-        torch.stack([num_graphs_indexes, den_graphs_indexes])
-        .t()
-        .reshape(-1)
-        .to(device)
+        torch.stack([num_graphs_indexes, den_graphs_indexes]).t().reshape(-1).to(device)
     )
 
     num_den_reordered_graphs = k2.index(num_den_graphs, num_den_graphs_indexes)
@@ -115,20 +112,12 @@ def _compute_mmi_loss_exact_non_optimized(
     num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True)
 
     # TODO: pass output_beam as function argument
-    num_lats = k2.intersect_dense(
-        num_graphs, dense_fsa_vec, output_beam=beam_size
-    )
-    den_lats = k2.intersect_dense(
-        den_graphs, dense_fsa_vec, output_beam=beam_size
-    )
+    num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=beam_size)
+    den_lats = k2.intersect_dense(den_graphs, dense_fsa_vec, output_beam=beam_size)
 
-    num_tot_scores = num_lats.get_tot_scores(
-        log_semiring=True, use_double_scores=True
-    )
+    num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
 
-    den_tot_scores = den_lats.get_tot_scores(
-        log_semiring=True, use_double_scores=True
-    )
+    den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
 
     tot_scores = num_tot_scores - den_scale * den_tot_scores
 
@@ -168,13 +157,9 @@ def _compute_mmi_loss_pruned(
         max_active_states=10000,
     )
 
-    num_tot_scores = num_lats.get_tot_scores(
-        log_semiring=True, use_double_scores=True
-    )
+    num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
 
-    den_tot_scores = den_lats.get_tot_scores(
-        log_semiring=True, use_double_scores=True
-    )
+    den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
 
     tot_scores = num_tot_scores - den_scale * den_tot_scores
 
diff --git a/icefall/mmi_graph_compiler.py b/icefall/mmi_graph_compiler.py
index 0d901227d..9f680f83d 100644
--- a/icefall/mmi_graph_compiler.py
+++ b/icefall/mmi_graph_compiler.py
@@ -137,9 +137,7 @@ class MmiTrainingGraphCompiler(object):
             transcript_fsa
         )
 
-        transcript_fsa_with_self_loops = k2.arc_sort(
-            transcript_fsa_with_self_loops
-        )
+        transcript_fsa_with_self_loops = k2.arc_sort(transcript_fsa_with_self_loops)
 
         num = k2.compose(
             self.ctc_topo_P,
@@ -155,9 +153,7 @@ class MmiTrainingGraphCompiler(object):
 
         ctc_topo_P_vec = k2.create_fsa_vec([self.ctc_topo_P])
         if replicate_den:
-            indexes = torch.zeros(
-                len(texts), dtype=torch.int32, device=self.device
-            )
+            indexes = torch.zeros(len(texts), dtype=torch.int32, device=self.device)
             den = k2.index_fsa(ctc_topo_P_vec, indexes)
         else:
             den = ctc_topo_P_vec
diff --git a/icefall/rnn_lm/compute_perplexity.py b/icefall/rnn_lm/compute_perplexity.py
index 550801a8f..9a275bf28 100755
--- a/icefall/rnn_lm/compute_perplexity.py
+++ b/icefall/rnn_lm/compute_perplexity.py
@@ -46,16 +46,19 @@ def get_parser():
         "--epoch",
         type=int,
         default=49,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=20,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -194,7 +197,7 @@ def main():
 
     logging.info(f"Number of model parameters: {num_param}")
     logging.info(
-        f"Number of model parameters (requires_grad): "
+        "Number of model parameters (requires_grad): "
         f"{num_param_requires_grad} "
         f"({num_param_requires_grad/num_param_requires_grad*100}%)"
     )
diff --git a/icefall/rnn_lm/dataset.py b/icefall/rnn_lm/dataset.py
index 598e329c4..4bf982503 100644
--- a/icefall/rnn_lm/dataset.py
+++ b/icefall/rnn_lm/dataset.py
@@ -155,12 +155,8 @@ class LmDatasetCollate:
         sentence_tokens_with_sos = add_sos(sentence_tokens, self.sos_id)
         sentence_tokens_with_eos = add_eos(sentence_tokens, self.eos_id)
 
-        x = sentence_tokens_with_sos.pad(
-            mode="constant", padding_value=self.blank_id
-        )
-        y = sentence_tokens_with_eos.pad(
-            mode="constant", padding_value=self.blank_id
-        )
+        x = sentence_tokens_with_sos.pad(mode="constant", padding_value=self.blank_id)
+        y = sentence_tokens_with_eos.pad(mode="constant", padding_value=self.blank_id)
         sentence_token_lengths += 1  # plus 1 since we added a SOS
 
         return x.to(torch.int64), y.to(torch.int64), sentence_token_lengths
diff --git a/icefall/rnn_lm/export.py b/icefall/rnn_lm/export.py
index 094035fce..2e878f5c8 100644
--- a/icefall/rnn_lm/export.py
+++ b/icefall/rnn_lm/export.py
@@ -38,17 +38,20 @@ def get_parser():
         "--epoch",
         type=int,
         default=29,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=5,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -159,9 +162,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/icefall/rnn_lm/model.py b/icefall/rnn_lm/model.py
index a6144727a..9eef88840 100644
--- a/icefall/rnn_lm/model.py
+++ b/icefall/rnn_lm/model.py
@@ -129,9 +129,7 @@ class RnnLmModel(torch.nn.Module):
         tokens_eos = add_eos(tokens, eos_id)
         sos_tokens_row_splits = sos_tokens.shape.row_splits(1)
 
-        sentence_lengths = (
-            sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
-        )
+        sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
 
         x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id)
         y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id)
@@ -161,12 +159,12 @@ class RnnLmModel(torch.nn.Module):
         if state:
             h, c = state
         else:
-            h = torch.zeros(
-                self.rnn.num_layers, batch_size, self.rnn.input_size
-            ).to(device)
-            c = torch.zeros(
-                self.rnn.num_layers, batch_size, self.rnn.input_size
-            ).to(device)
+            h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(
+                device
+            )
+            c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(
+                device
+            )
 
         embedding = self.input_embedding(tokens)
         rnn_out, states = self.rnn(embedding, (h, c))
@@ -181,12 +179,8 @@ class RnnLmModel(torch.nn.Module):
         if state:
             h, c = state
         else:
-            h = torch.zeros(
-                self.rnn.num_layers, batch_size, self.rnn.input_size
-            )
-            c = torch.zeros(
-                self.rnn.num_layers, batch_size, self.rnn.input_size
-            )
+            h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size)
+            c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size)
 
         device = next(self.parameters()).device
 
@@ -194,9 +188,7 @@ class RnnLmModel(torch.nn.Module):
         tokens_eos = add_eos(tokens, eos_id)
         sos_tokens_row_splits = sos_tokens.shape.row_splits(1)
 
-        sentence_lengths = (
-            sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
-        )
+        sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
 
         x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id)
         y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id)
diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py
index bb5f03fb9..e17b50332 100755
--- a/icefall/rnn_lm/train.py
+++ b/icefall/rnn_lm/train.py
@@ -446,17 +446,13 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
                 tb_writer.add_scalar(
                     "train/current_ppl", this_batch_ppl, params.batch_idx_train
                 )
 
-                tb_writer.add_scalar(
-                    "train/tot_ppl", tot_ppl, params.batch_idx_train
-                )
+                tb_writer.add_scalar("train/tot_ppl", tot_ppl, params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -471,8 +467,7 @@ def train_one_epoch(
 
             valid_ppl = math.exp(valid_info["loss"] / valid_info["frames"])
             logging.info(
-                f"Epoch {params.cur_epoch}, validation: {valid_info}, "
-                f"ppl: {valid_ppl}"
+                f"Epoch {params.cur_epoch}, validation: {valid_info}, ppl: {valid_ppl}"
             )
 
             if tb_writer is not None:
diff --git a/icefall/shared/make_kn_lm.py b/icefall/shared/make_kn_lm.py
index c2edd823e..a3bf1ef4c 100755
--- a/icefall/shared/make_kn_lm.py
+++ b/icefall/shared/make_kn_lm.py
@@ -15,30 +15,50 @@
 # The data structure is based on: kaldi/egs/wsj/s5/utils/lang/make_phone_lm.py
 # The smoothing algorithm is based on: http://www.speech.sri.com/projects/srilm/manpages/ngram-discount.7.html
 
-import sys
-import os
-import re
+import argparse
 import io
 import math
-import argparse
+import os
+import re
+import sys
 from collections import Counter, defaultdict
 
-
-parser = argparse.ArgumentParser(description="""
+parser = argparse.ArgumentParser(
+    description="""
     Generate kneser-ney language model as arpa format. By default,
     it will read the corpus from standard input, and output to standard output.
-    """)
-parser.add_argument("-ngram-order", type=int, default=4, choices=[2, 3, 4, 5, 6, 7], help="Order of n-gram")
+    """
+)
+parser.add_argument(
+    "-ngram-order",
+    type=int,
+    default=4,
+    choices=[2, 3, 4, 5, 6, 7],
+    help="Order of n-gram",
+)
 parser.add_argument("-text", type=str, default=None, help="Path to the corpus file")
-parser.add_argument("-lm", type=str, default=None, help="Path to output arpa file for language models")
-parser.add_argument("-verbose", type=int, default=0, choices=[0, 1, 2, 3, 4, 5], help="Verbose level")
+parser.add_argument(
+    "-lm",
+    type=str,
+    default=None,
+    help="Path to output arpa file for language models",
+)
+parser.add_argument(
+    "-verbose",
+    type=int,
+    default=0,
+    choices=[0, 1, 2, 3, 4, 5],
+    help="Verbose level",
+)
 args = parser.parse_args()
 
-default_encoding = "latin-1"  # For encoding-agnostic scripts, we assume byte stream as input.
-                              # Need to be very careful about the use of strip() and split()
-                              # in this case, because there is a latin-1 whitespace character
-                              # (nbsp) which is part of the unicode encoding range.
-                              # Ref: kaldi/egs/wsj/s5/utils/lang/bpe/prepend_words.py @ 69cd717
+default_encoding = (
+    "latin-1"  # For encoding-agnostic scripts, we assume byte stream as input.
+)
+# Need to be very careful about the use of strip() and split()
+# in this case, because there is a latin-1 whitespace character
+# (nbsp) which is part of the unicode encoding range.
+# Ref: kaldi/egs/wsj/s5/utils/lang/bpe/prepend_words.py @ 69cd717
 strip_chars = " \t\r\n"
 whitespace = re.compile("[ \t]+")
 
@@ -52,7 +72,9 @@ class CountsForHistory:
         # The 'lambda: defaultdict(float)' is an anonymous function taking no
         # arguments that returns a new defaultdict(float).
         self.word_to_count = defaultdict(int)
-        self.word_to_context = defaultdict(set)  # using a set to count the number of unique contexts
+        self.word_to_context = defaultdict(
+            set
+        )  # using a set to count the number of unique contexts
         self.word_to_f = dict()  # discounted probability
         self.word_to_bow = dict()  # back-off weight
         self.total_count = 0
@@ -62,10 +84,15 @@ class CountsForHistory:
 
     def __str__(self):
         # e.g. returns ' total=12: 3->4, 4->6, -1->2'
-        return ' total={0}: {1}'.format(
+        return " total={0}: {1}".format(
             str(self.total_count),
-            ', '.join(['{0} -> {1}'.format(word, count)
-                      for word, count in self.word_to_count.items()]))
+            ", ".join(
+                [
+                    "{0} -> {1}".format(word, count)
+                    for word, count in self.word_to_count.items()
+                ]
+            ),
+        )
 
     def add_count(self, predicted_word, context_word, count):
         assert count >= 0
@@ -85,7 +112,7 @@ class NgramCounts:
     # accumulating the 4-gram count for the '8' in the sequence '5 6 7 8', we'd
     # do as follows: self.counts[3][[5,6,7]][8] += 1.0 where the [3] indexes an
     # array, the [[5,6,7]] indexes a dict, and the [8] indexes a dict.
-    def __init__(self, ngram_order, bos_symbol='', eos_symbol=''):
+    def __init__(self, ngram_order, bos_symbol="", eos_symbol=""):
         assert ngram_order >= 2
 
         self.ngram_order = ngram_order
@@ -103,39 +130,48 @@ class NgramCounts:
     # would be (6,7,8) and 'predicted_word' would be 9; 'count' would be
     # 1.
     def add_count(self, history, predicted_word, context_word, count):
-        self.counts[len(history)][history].add_count(predicted_word, context_word, count)
+        self.counts[len(history)][history].add_count(
+            predicted_word, context_word, count
+        )
 
     # 'line' is a string containing a sequence of integer word-ids.
     # This function adds the un-smoothed counts from this line of text.
     def add_raw_counts_from_line(self, line):
-        if line == '':
+        if line == "":
             words = [self.bos_symbol, self.eos_symbol]
         else:
             words = [self.bos_symbol] + whitespace.split(line) + [self.eos_symbol]
 
         for i in range(len(words)):
-            for n in range(1, self.ngram_order+1):
+            for n in range(1, self.ngram_order + 1):
                 if i + n > len(words):
                     break
-                ngram = words[i: i + n]
+                ngram = words[i : i + n]
                 predicted_word = ngram[-1]
-                history = tuple(ngram[: -1])
+                history = tuple(ngram[:-1])
                 if i == 0 or n == self.ngram_order:
                     context_word = None
                 else:
-                    context_word = words[i-1]
+                    context_word = words[i - 1]
 
                 self.add_count(history, predicted_word, context_word, 1)
 
     def add_raw_counts_from_standard_input(self):
         lines_processed = 0
-        infile = io.TextIOWrapper(sys.stdin.buffer, encoding=default_encoding)  # byte stream as input
+        infile = io.TextIOWrapper(
+            sys.stdin.buffer, encoding=default_encoding
+        )  # byte stream as input
         for line in infile:
             line = line.strip(strip_chars)
             self.add_raw_counts_from_line(line)
             lines_processed += 1
         if lines_processed == 0 or args.verbose > 0:
-            print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr)
+            print(
+                "make_phone_lm.py: processed {0} lines of input".format(
+                    lines_processed
+                ),
+                file=sys.stderr,
+            )
 
     def add_raw_counts_from_file(self, filename):
         lines_processed = 0
@@ -145,7 +181,12 @@ class NgramCounts:
                 self.add_raw_counts_from_line(line)
                 lines_processed += 1
         if lines_processed == 0 or args.verbose > 0:
-            print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr)
+            print(
+                "make_phone_lm.py: processed {0} lines of input".format(
+                    lines_processed
+                ),
+                file=sys.stderr,
+            )
 
     def cal_discounting_constants(self):
         # For each order N of N-grams, we calculate discounting constant D_N = n1_N / (n1_N + 2 * n2_N),
@@ -153,9 +194,11 @@ class NgramCounts:
         # This constant is used similarly to absolute discounting.
         # Return value: d is a list of floats, where d[N+1] = D_N
 
-        self.d = [0]  # for the lowest order, i.e., 1-gram, we do not need to discount, thus the constant is 0
-                      # This is a special case: as we currently assumed having seen all vocabularies in the dictionary,
-                      # but perhaps this is not the case for some other scenarios.
+        self.d = [
+            0
+        ]  # for the lowest order, i.e., 1-gram, we do not need to discount, thus the constant is 0
+        # This is a special case: as we currently assumed having seen all vocabularies in the dictionary,
+        # but perhaps this is not the case for some other scenarios.
         for n in range(1, self.ngram_order):
             this_order_counts = self.counts[n]
             n1 = 0
@@ -165,9 +208,11 @@ class NgramCounts:
                 n1 += stat[1]
                 n2 += stat[2]
             assert n1 + 2 * n2 > 0
-            self.d.append(max(0.1, n1 * 1.0) / (n1 + 2 * n2))   # We are doing this max(0.001, xxx) to avoid zero discounting constant D due to n1=0, 
-                                                                # which could happen if the number of symbols is small.
-                                                                # Otherwise, zero discounting constant can cause division by zero in computing BOW.
+            self.d.append(
+                max(0.1, n1 * 1.0) / (n1 + 2 * n2)
+            )  # We are doing this max(0.001, xxx) to avoid zero discounting constant D due to n1=0,
+            # which could happen if the number of symbols is small.
+            # Otherwise, zero discounting constant can cause division by zero in computing BOW.
 
     def cal_f(self):
         # f(a_z) is a probability distribution of word sequence a_z.
@@ -182,7 +227,9 @@ class NgramCounts:
         this_order_counts = self.counts[n]
         for hist, counts_for_hist in this_order_counts.items():
             for w, c in counts_for_hist.word_to_count.items():
-                counts_for_hist.word_to_f[w] = max((c - self.d[n]), 0) * 1.0 / counts_for_hist.total_count
+                counts_for_hist.word_to_f[w] = (
+                    max((c - self.d[n]), 0) * 1.0 / counts_for_hist.total_count
+                )
 
         # lower order N-grams
         for n in range(0, self.ngram_order - 1):
@@ -196,11 +243,17 @@ class NgramCounts:
                 if n_star_star != 0:
                     for w in counts_for_hist.word_to_count.keys():
                         n_star_z = len(counts_for_hist.word_to_context[w])
-                        counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / n_star_star
+                        counts_for_hist.word_to_f[w] = (
+                            max((n_star_z - self.d[n]), 0) * 1.0 / n_star_star
+                        )
                 else:  # patterns begin with , they do not have "modified count", so use raw count instead
                     for w in counts_for_hist.word_to_count.keys():
                         n_star_z = counts_for_hist.word_to_count[w]
-                        counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / counts_for_hist.total_count
+                        counts_for_hist.word_to_f[w] = (
+                            max((n_star_z - self.d[n]), 0)
+                            * 1.0
+                            / counts_for_hist.total_count
+                        )
 
     def cal_bow(self):
         # Backoff weights are only necessary for ngrams which form a prefix of a longer ngram.
@@ -240,12 +293,18 @@ class NgramCounts:
                         sum_z1_f_z = 0
                         _ = a_[1:]
                         _counts_for_hist = self.counts[len(_)][_]
-                        for u in a_counts_for_hist.word_to_count.keys():  # Should be careful here: what is Z1
+                        for (
+                            u
+                        ) in (
+                            a_counts_for_hist.word_to_count.keys()
+                        ):  # Should be careful here: what is Z1
                             sum_z1_f_z += _counts_for_hist.word_to_f[u]
 
                         if sum_z1_f_z < 1:
                             # assert sum_z1_f_a_z < 1
-                            counts_for_hist.word_to_bow[w] = (1.0 - sum_z1_f_a_z) / (1.0 - sum_z1_f_z)
+                            counts_for_hist.word_to_bow[w] = (1.0 - sum_z1_f_a_z) / (
+                                1.0 - sum_z1_f_z
+                            )
                         else:
                             counts_for_hist.word_to_bow[w] = None
 
@@ -259,7 +318,9 @@ class NgramCounts:
                     ngram = " ".join(hist) + " " + w
                     ngram = ngram.strip(strip_chars)
 
-                    res.append("{0}\t{1}".format(ngram, counts_for_hist.word_to_count[w]))
+                    res.append(
+                        "{0}\t{1}".format(ngram, counts_for_hist.word_to_count[w])
+                    )
         res.sort(reverse=True)
         for r in res:
             print(r)
@@ -322,27 +383,40 @@ class NgramCounts:
                     if bow is None:
                         res.append("{1}\t{0}".format(ngram, math.log(f, 10)))
                     else:
-                        res.append("{1}\t{0}\t{2}".format(ngram, math.log(f, 10), math.log(bow, 10)))
+                        res.append(
+                            "{1}\t{0}\t{2}".format(
+                                ngram, math.log(f, 10), math.log(bow, 10)
+                            )
+                        )
         res.sort(reverse=True)
         for r in res:
             print(r)
 
-    def print_as_arpa(self, fout=io.TextIOWrapper(sys.stdout.buffer, encoding='latin-1')):
+    def print_as_arpa(
+        self, fout=io.TextIOWrapper(sys.stdout.buffer, encoding="latin-1")
+    ):
         # print as ARPA format.
 
-        print('\\data\\', file=fout)
+        print("\\data\\", file=fout)
         for hist_len in range(self.ngram_order):
             # print the number of n-grams.
-            print('ngram {0}={1}'.format(
-                hist_len + 1,
-                sum([len(counts_for_hist.word_to_f) for counts_for_hist in self.counts[hist_len].values()])),
-                file=fout
+            print(
+                "ngram {0}={1}".format(
+                    hist_len + 1,
+                    sum(
+                        [
+                            len(counts_for_hist.word_to_f)
+                            for counts_for_hist in self.counts[hist_len].values()
+                        ]
+                    ),
+                ),
+                file=fout,
             )
 
-        print('', file=fout)
+        print("", file=fout)
 
         for hist_len in range(self.ngram_order):
-            print('\\{0}-grams:'.format(hist_len + 1), file=fout)
+            print("\\{0}-grams:".format(hist_len + 1), file=fout)
 
             this_order_counts = self.counts[hist_len]
             for hist, counts_for_hist in this_order_counts.items():
@@ -354,12 +428,12 @@ class NgramCounts:
                     if prob == 0:  # f() is always 0
                         prob = 1e-99
 
-                    line = '{0}\t{1}'.format('%.7f' % math.log10(prob), ' '.join(ngram))
+                    line = "{0}\t{1}".format("%.7f" % math.log10(prob), " ".join(ngram))
                     if bow is not None:
-                        line += '\t{0}'.format('%.7f' % math.log10(bow))
+                        line += "\t{0}".format("%.7f" % math.log10(bow))
                     print(line, file=fout)
-            print('', file=fout)
-        print('\\end\\', file=fout)
+            print("", file=fout)
+        print("\\end\\", file=fout)
 
 
 if __name__ == "__main__":
@@ -379,5 +453,5 @@ if __name__ == "__main__":
     if args.lm is None:
         ngram_counts.print_as_arpa()
     else:
-        with open(args.lm, 'w', encoding=default_encoding) as f:
+        with open(args.lm, "w", encoding=default_encoding) as f:
             ngram_counts.print_as_arpa(fout=f)
diff --git a/icefall/utils.py b/icefall/utils.py
index c502cb4d8..0beb94b2e 100644
--- a/icefall/utils.py
+++ b/icefall/utils.py
@@ -130,9 +130,7 @@ def setup_logger(
         formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s"  # noqa
         log_filename = f"{log_filename}-{date_time}-{rank}"
     else:
-        formatter = (
-            "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-        )
+        formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
         log_filename = f"{log_filename}-{date_time}"
 
     os.makedirs(os.path.dirname(log_filename), exist_ok=True)
@@ -280,13 +278,9 @@ def get_texts_with_timestamp(
     """
     if isinstance(best_paths.aux_labels, k2.RaggedTensor):
         all_aux_shape = (
-            best_paths.arcs.shape()
-            .remove_axis(1)
-            .compose(best_paths.aux_labels.shape)
-        )
-        all_aux_labels = k2.RaggedTensor(
-            all_aux_shape, best_paths.aux_labels.values
+            best_paths.arcs.shape().remove_axis(1).compose(best_paths.aux_labels.shape)
         )
+        all_aux_labels = k2.RaggedTensor(all_aux_shape, best_paths.aux_labels.values)
         # remove 0's and -1's.
         aux_labels = best_paths.aux_labels.remove_values_leq(0)
         # TODO: change arcs.shape() to arcs.shape
@@ -355,9 +349,7 @@ def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]:
     # arc.shape() has axes [fsa][state][arc], we remove "state"-axis here
     token_shape = best_paths.arcs.shape().remove_axis(1)
     # token_shape has axes [fsa][arc]
-    tokens = k2.RaggedTensor(
-        token_shape, getattr(best_paths, kind).contiguous()
-    )
+    tokens = k2.RaggedTensor(token_shape, getattr(best_paths, kind).contiguous())
     tokens = tokens.remove_values_eq(-1)
     return tokens.tolist()
 
@@ -578,9 +570,7 @@ def write_error_stats(
             f"{cut_id}:\t"
             + " ".join(
                 (
-                    ref_word
-                    if ref_word == hyp_word
-                    else f"({ref_word}->{hyp_word})"
+                    ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
                     for ref_word, hyp_word in ali
                 )
             ),
@@ -590,9 +580,7 @@ def write_error_stats(
     print("", file=f)
     print("SUBSTITUTIONS: count ref -> hyp", file=f)
 
-    for count, (ref, hyp) in sorted(
-        [(v, k) for k, v in subs.items()], reverse=True
-    ):
+    for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
         print(f"{count}   {ref} -> {hyp}", file=f)
 
     print("", file=f)
@@ -606,9 +594,7 @@ def write_error_stats(
         print(f"{count}   {hyp}", file=f)
 
     print("", file=f)
-    print(
-        "PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp", file=f
-    )
+    print("PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp", file=f)
     for _, word, counts in sorted(
         [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
     ):
@@ -783,9 +769,7 @@ def write_error_stats_with_timestamps(
             f"{cut_id}:\t"
             + " ".join(
                 (
-                    ref_word
-                    if ref_word == hyp_word
-                    else f"({ref_word}->{hyp_word})"
+                    ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
                     for ref_word, hyp_word in ali
                 )
             ),
@@ -795,9 +779,7 @@ def write_error_stats_with_timestamps(
     print("", file=f)
     print("SUBSTITUTIONS: count ref -> hyp", file=f)
 
-    for count, (ref, hyp) in sorted(
-        [(v, k) for k, v in subs.items()], reverse=True
-    ):
+    for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
         print(f"{count}   {ref} -> {hyp}", file=f)
 
     print("", file=f)
@@ -811,9 +793,7 @@ def write_error_stats_with_timestamps(
         print(f"{count}   {hyp}", file=f)
 
     print("", file=f)
-    print(
-        "PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp", file=f
-    )
+    print("PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp", file=f)
     for _, word, counts in sorted(
         [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
     ):
@@ -883,9 +863,7 @@ class MetricsTracker(collections.defaultdict):
             if k == "frames" or k == "utterances":
                 continue
             norm_value = (
-                float(v) / num_frames
-                if "utt_" not in k
-                else float(v) / num_utterances
+                float(v) / num_frames if "utt_" not in k else float(v) / num_utterances
             )
             ans.append((k, norm_value))
         return ans
@@ -919,9 +897,7 @@ class MetricsTracker(collections.defaultdict):
             tb_writer.add_scalar(prefix + k, v, batch_idx)
 
 
-def concat(
-    ragged: k2.RaggedTensor, value: int, direction: str
-) -> k2.RaggedTensor:
+def concat(ragged: k2.RaggedTensor, value: int, direction: str) -> k2.RaggedTensor:
     """Prepend a value to the beginning of each sublist or append a value.
     to the end of each sublist.
 
@@ -967,8 +943,8 @@ def concat(
         ans = k2.ragged.cat([ragged, pad], axis=1)
     else:
         raise ValueError(
-            f'Unsupported direction: {direction}. " \
-            "Expect either "left" or "right"'
+            f'Unsupported direction: {direction}. "             "Expect either "left"'
+            ' or "right"'
         )
     return ans
 
@@ -1093,9 +1069,7 @@ def linf_norm(x):
     return torch.max(torch.abs(x))
 
 
-def measure_weight_norms(
-    model: nn.Module, norm: str = "l2"
-) -> Dict[str, float]:
+def measure_weight_norms(model: nn.Module, norm: str = "l2") -> Dict[str, float]:
     """
     Compute the norms of the model's parameters.
 
@@ -1118,9 +1092,7 @@ def measure_weight_norms(
         return norms
 
 
-def measure_gradient_norms(
-    model: nn.Module, norm: str = "l1"
-) -> Dict[str, float]:
+def measure_gradient_norms(model: nn.Module, norm: str = "l1") -> Dict[str, float]:
     """
     Compute the norms of the gradients for each of model's parameters.
 
@@ -1405,9 +1377,7 @@ def parse_hyp_and_timestamp(
         use_word_table = True
 
     for i in range(N):
-        time = convert_timestamp(
-            res.timestamps[i], subsampling_factor, frame_shift_ms
-        )
+        time = convert_timestamp(res.timestamps[i], subsampling_factor, frame_shift_ms)
         if use_word_table:
             words = [word_table[i] for i in res.hyps[i]]
         else:
diff --git a/pyproject.toml b/pyproject.toml
index b4f8c3377..3183055d4 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -3,7 +3,7 @@ profile = "black"
 skip = ["icefall/__init__.py"]
 
 [tool.black]
-line-length = 80
+line-length = 88
 exclude = '''
 /(
     \.git
diff --git a/setup.py b/setup.py
index 6c720e121..ccd2503ff 100644
--- a/setup.py
+++ b/setup.py
@@ -1,8 +1,9 @@
 #!/usr/bin/env python3
 
-from setuptools import find_packages, setup
 from pathlib import Path
 
+from setuptools import find_packages, setup
+
 icefall_dir = Path(__file__).parent
 install_requires = (icefall_dir / "requirements.txt").read_text().splitlines()
 
diff --git a/test/test_checkpoint.py b/test/test_checkpoint.py
index 511a11c23..34e829642 100644
--- a/test/test_checkpoint.py
+++ b/test/test_checkpoint.py
@@ -20,11 +20,7 @@ import pytest
 import torch
 import torch.nn as nn
 
-from icefall.checkpoint import (
-    average_checkpoints,
-    load_checkpoint,
-    save_checkpoint,
-)
+from icefall.checkpoint import average_checkpoints, load_checkpoint, save_checkpoint
 
 
 @pytest.fixture
diff --git a/test/test_decode.py b/test/test_decode.py
index 97964ac67..4c2e192a7 100644
--- a/test/test_decode.py
+++ b/test/test_decode.py
@@ -23,6 +23,7 @@ You can run this file in one of the two ways:
 """
 
 import k2
+
 from icefall.decode import Nbest
 
 
diff --git a/test/test_graph_compiler.py b/test/test_graph_compiler.py
index ccfb57d49..10443cf22 100644
--- a/test/test_graph_compiler.py
+++ b/test/test_graph_compiler.py
@@ -154,9 +154,7 @@ class TestCtcTrainingGraphCompiler(object):
         fsas = k2.Fsa.from_fsas([fsa1, fsa2])
 
         decoding_graph = k2.arc_sort(decoding_graph)
-        lattice = k2.intersect(
-            decoding_graph, fsas, treat_epsilons_specially=False
-        )
+        lattice = k2.intersect(decoding_graph, fsas, treat_epsilons_specially=False)
         lattice = k2.connect(lattice)
 
         aux_labels0 = lattice[0].aux_labels[:-1]
diff --git a/test/test_utils.py b/test/test_utils.py
index 6a9ce7853..31f06bd51 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -50,9 +50,7 @@ def test_encode_supervisions(sup):
     assert torch.all(
         torch.eq(
             supervision_segments,
-            torch.tensor(
-                [[1, 0, 30 // 4], [0, 0, 20 // 4], [2, 9 // 4, 10 // 4]]
-            ),
+            torch.tensor([[1, 0, 30 // 4], [0, 0, 20 // 4], [2, 9 // 4, 10 // 4]]),
         )
     )
     assert texts == ["two", "one", "three"]

From d89766d85dbb023a8fcb47221545d00b1a015c69 Mon Sep 17 00:00:00 2001
From: Desh Raj 
Date: Wed, 16 Nov 2022 13:10:55 -0500
Subject: [PATCH 034/120] add git blame ignore revs file

---
 .git-blame-ignore-revs | 2 ++
 1 file changed, 2 insertions(+)
 create mode 100644 .git-blame-ignore-revs

diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs
new file mode 100644
index 000000000..c5908fc89
--- /dev/null
+++ b/.git-blame-ignore-revs
@@ -0,0 +1,2 @@
+# Migrate to 88 characters per line (see: https://github.com/lhotse-speech/lhotse/issues/890)
+d110b04ad389134c82fa314e3aafc7b40043efb0

From 7a8e8e735d21bd9d992e291e6c39c922154168b5 Mon Sep 17 00:00:00 2001
From: Desh Raj 
Date: Wed, 16 Nov 2022 14:43:21 -0500
Subject: [PATCH 035/120] change click version in pre-commit

---
 .pre-commit-config.yaml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index e2055801b..5cb213327 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -4,7 +4,7 @@ repos:
     hooks:
       - id: black
         args: ["--line-length=88"]
-        additional_dependencies: ['click==8.0.1']
+        additional_dependencies: ['click==8.1.0']
         exclude: icefall\/__init__\.py
 
   - repo: https://github.com/PyCQA/flake8

From fca796cc2c9f4e86cad6ea0f5fe4305c071a2293 Mon Sep 17 00:00:00 2001
From: Daniil 
Date: Wed, 16 Nov 2022 17:55:53 -0500
Subject: [PATCH 036/120] Small code refactoring (#687)

---
 egs/librispeech/ASR/conformer_ctc2/train.py   |  15 ---
 .../ASR/conformer_ctc2/transformer.py         |  27 +---
 egs/librispeech/ASR/local/compile_hlg.py      |  25 ++--
 .../transducer_stateless/asr_datamodule.py    | 123 ++++++++++--------
 icefall/decode.py                             |  24 +++-
 icefall/utils.py                              |  12 +-
 6 files changed, 120 insertions(+), 106 deletions(-)

diff --git a/egs/librispeech/ASR/conformer_ctc2/train.py b/egs/librispeech/ASR/conformer_ctc2/train.py
index 9d9c2af1f..18fa3e69f 100755
--- a/egs/librispeech/ASR/conformer_ctc2/train.py
+++ b/egs/librispeech/ASR/conformer_ctc2/train.py
@@ -166,13 +166,6 @@ def get_parser():
         """,
     )
 
-    parser.add_argument(
-        "--bpe-model",
-        type=str,
-        default="data/lang_bpe_500/bpe.model",
-        help="Path to the BPE model",
-    )
-
     parser.add_argument(
         "--initial-lr",
         type=float,
@@ -522,14 +515,6 @@ def compute_loss(
         nnet_output, encoder_memory, memory_mask = model(
             feature, supervisions, warmup=warmup
         )
-        # logging.info('feature shape: {}'.format(feature.shape))
-        # logging.info('nnet_output shape: {}'.format(nnet_output.shape))
-        # logging.info('encoder_memory shape: {}'.format(encoder_memory.shape))
-        # logging.info('memory_mask shape: {}'.format(memory_mask.shape))
-        # after the main warmup step, we keep pruned_loss_scale small
-        # for the same amount of time (model_warm_step), to avoid
-        # overwhelming the simple_loss and causing it to diverge,
-        # in case it had not fully learned the alignment yet.
 
     # NOTE: We need `encode_supervisions` to sort sequences with
     # different duration in decreasing order, required by
diff --git a/egs/librispeech/ASR/conformer_ctc2/transformer.py b/egs/librispeech/ASR/conformer_ctc2/transformer.py
index fa179acc0..3ef7edc23 100644
--- a/egs/librispeech/ASR/conformer_ctc2/transformer.py
+++ b/egs/librispeech/ASR/conformer_ctc2/transformer.py
@@ -417,7 +417,6 @@ class TransformerEncoderLayer(nn.Module):
         dim_feedforward: int = 2048,
         dropout: float = 0.1,
         layer_dropout: float = 0.075,
-        activation: str = "relu",
     ) -> None:
         super(TransformerEncoderLayer, self).__init__()
 
@@ -443,11 +442,6 @@ class TransformerEncoderLayer(nn.Module):
 
         self.dropout = nn.Dropout(dropout)
 
-    # def __setstate__(self, state):
-    #     if "activation" not in state:
-    #         state["activation"] = nn.functional.relu
-    #     super(TransformerEncoderLayer, self).__setstate__(state)
-
     def forward(
         self,
         src: torch.Tensor,
@@ -539,7 +533,6 @@ class TransformerDecoderLayer(nn.Module):
         dim_feedforward: int = 2048,
         dropout: float = 0.1,
         layer_dropout: float = 0.075,
-        # activation: str = "relu",
         normalize_before: bool = True,
     ) -> None:
         super(TransformerDecoderLayer, self).__init__()
@@ -564,11 +557,6 @@ class TransformerDecoderLayer(nn.Module):
 
         self.dropout = nn.Dropout(dropout)
 
-    # def __setstate__(self, state):
-    #     if "activation" not in state:
-    #         state["activation"] = nn.functional.relu
-    #     super(TransformerDecoderLayer, self).__setstate__(state)
-
     def forward(
         self,
         tgt: torch.Tensor,
@@ -653,17 +641,6 @@ class TransformerDecoderLayer(nn.Module):
         return tgt
 
 
-def _get_activation_fn(activation: str):
-    if activation == "relu":
-        return nn.functional.relu
-    elif activation == "gelu":
-        return nn.functional.gelu
-
-    raise RuntimeError(
-        "activation should be relu/gelu, not {}".format(activation)
-    )
-
-
 class TransformerEncoder(nn.Module):
     r"""TransformerEncoder is a stack of N encoder layers
 
@@ -708,7 +685,7 @@ class TransformerEncoder(nn.Module):
         """
         output = src
 
-        for i, mod in enumerate(self.layers):
+        for mod in self.layers:
             output = mod(
                 output,
                 src_mask=mask,
@@ -769,7 +746,7 @@ class TransformerDecoder(nn.Module):
         """
         output = tgt
 
-        for i, mod in enumerate(self.layers):
+        for mod in self.layers:
             output = mod(
                 output,
                 memory,
diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py
index 9a35750e0..c628dfd53 100755
--- a/egs/librispeech/ASR/local/compile_hlg.py
+++ b/egs/librispeech/ASR/local/compile_hlg.py
@@ -40,6 +40,13 @@ from icefall.lexicon import Lexicon
 
 def get_args():
     parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--lm",
+        type=str,
+        default="G_3_gram",
+        help="""Stem name for LM used in HLG compiling.
+        """,
+    )
     parser.add_argument(
         "--lang-dir",
         type=str,
@@ -50,11 +57,13 @@ def get_args():
     return parser.parse_args()
 
 
-def compile_HLG(lang_dir: str) -> k2.Fsa:
+def compile_HLG(lang_dir: str, lm: str="G_3_gram") -> k2.Fsa:
     """
     Args:
       lang_dir:
         The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
+      lm:
+        The language stem base name.
 
     Return:
       An FSA representing HLG.
@@ -65,15 +74,15 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
     H = k2.ctc_topo(max_token_id)
     L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
 
-    if Path("data/lm/G_3_gram.pt").is_file():
-        logging.info("Loading pre-compiled G_3_gram")
-        d = torch.load("data/lm/G_3_gram.pt")
+    if Path(f"data/lm/{lm}.pt").is_file():
+        logging.info(f"Loading pre-compiled {lm}")
+        d = torch.load(f"data/lm/{lm}.pt")
         G = k2.Fsa.from_dict(d)
     else:
-        logging.info("Loading G_3_gram.fst.txt")
-        with open("data/lm/G_3_gram.fst.txt") as f:
+        logging.info(f"Loading {lm}.fst.txt")
+        with open(f"data/lm/{lm}.fst.txt") as f:
             G = k2.Fsa.from_openfst(f.read(), acceptor=False)
-            torch.save(G.as_dict(), "data/lm/G_3_gram.pt")
+            torch.save(G.as_dict(), f"data/lm/{lm}.pt")
 
     first_token_disambig_id = lexicon.token_table["#0"]
     first_word_disambig_id = lexicon.word_table["#0"]
@@ -144,7 +153,7 @@ def main():
 
     logging.info(f"Processing {lang_dir}")
 
-    HLG = compile_HLG(lang_dir)
+    HLG = compile_HLG(lang_dir, args.lm)
     logging.info(f"Saving HLG.pt to {lang_dir}")
     torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")
 
diff --git a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py
index 51de46ae8..94784c4c4 100644
--- a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py
+++ b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py
@@ -17,10 +17,11 @@
 
 
 import argparse
-import inspect
 import logging
+
 from functools import lru_cache
 from pathlib import Path
+from typing import Any, Dict, Optional
 
 from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
 from lhotse.dataset import (
@@ -28,7 +29,6 @@ from lhotse.dataset import (
     CutMix,
     DynamicBucketingSampler,
     K2SpeechRecognitionDataset,
-    PrecomputedFeatures,
     SingleCutSampler,
     SpecAugment,
 )
@@ -140,7 +140,6 @@ class TedLiumAsrDataModule:
             "field: batch['supervisions']['cut'] with the cuts that "
             "were used to construct it.",
         )
-
         group.add_argument(
             "--num-workers",
             type=int,
@@ -148,14 +147,12 @@ class TedLiumAsrDataModule:
             help="The number of training dataloader workers that "
             "collect the batches.",
         )
-
         group.add_argument(
             "--enable-spec-aug",
             type=str2bool,
             default=True,
             help="When enabled, use SpecAugment for training dataset.",
         )
-
         group.add_argument(
             "--spec-aug-time-warp-factor",
             type=int,
@@ -165,16 +162,48 @@ class TedLiumAsrDataModule:
             "Larger values mean more warping. "
             "A value less than 1 means to disable time warp.",
         )
-
         group.add_argument(
             "--enable-musan",
             type=str2bool,
             default=True,
             help="When enabled, select noise from MUSAN and mix it"
-            "with training dataset. ",
+            "with training dataset.",
         )
 
-    def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
+    def train_dataloaders(
+        self,
+        cuts_train: CutSet,
+        sampler_state_dict: Optional[Dict[str, Any]] = None
+    ) -> DataLoader:
+        """
+        Args:
+          cuts_train:
+            CutSet for training.
+          sampler_state_dict:
+            The state dict for the training sampler.
+        """
+
+        input_transforms = []
+        if self.args.enable_spec_aug:
+            logging.info("Enable SpecAugment")
+            logging.info(
+                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
+            )
+
+            input_transforms.append(
+                SpecAugment(
+                    time_warp_factor=self.args.spec_aug_time_warp_factor,
+                    num_frame_masks=10,
+                    features_mask_size=27,
+                    num_feature_masks=2,
+                    frames_mask_size=100,
+                    max_frames_mask_fraction=0.15,
+                    p=0.9,
+                )
+            )
+        else:
+            logging.info("Disable SpecAugment")
+
         logging.info("About to get Musan cuts")
         transforms = []
         if self.args.enable_musan:
@@ -204,42 +233,7 @@ class TedLiumAsrDataModule:
                 )
             ] + transforms
 
-        input_transforms = []
-        if self.args.enable_spec_aug:
-            logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
-            # Set the value of num_frame_masks according to Lhotse's version.
-            # In different Lhotse's versions, the default of num_frame_masks is
-            # different.
-            num_frame_masks = 10
-            num_frame_masks_parameter = inspect.signature(
-                SpecAugment.__init__
-            ).parameters["num_frame_masks"]
-            if num_frame_masks_parameter.default == 1:
-                num_frame_masks = 2
-            logging.info(f"Num frame mask: {num_frame_masks}")
-            input_transforms.append(
-                SpecAugment(
-                    time_warp_factor=self.args.spec_aug_time_warp_factor,
-                    num_frame_masks=num_frame_masks,
-                    features_mask_size=27,
-                    num_feature_masks=2,
-                    frames_mask_size=100,
-                    max_frames_mask_fraction=0.15,
-                    p=0.9,
-                )
-            )
-        else:
-            logging.info("Disable SpecAugment")
-
         logging.info("About to create train dataset")
-        train = K2SpeechRecognitionDataset(
-            cut_transforms=transforms,
-            input_transforms=input_transforms,
-            return_cuts=self.args.return_cuts,
-        )
         if self.args.on_the_fly_feats:
             # NOTE: the PerturbSpeed transform should be added only if we
             # remove it from data prep stage.
@@ -259,6 +253,12 @@ class TedLiumAsrDataModule:
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
+        else:
+            train = K2SpeechRecognitionDataset(
+                cut_transforms=transforms,
+                input_transforms=input_transforms,
+                return_cuts=self.args.return_cuts,
+            )
 
         if self.args.bucketing_sampler:
             logging.info("Using DynamicBucketingSampler.")
@@ -276,6 +276,11 @@ class TedLiumAsrDataModule:
                 max_duration=self.args.max_duration,
                 shuffle=self.args.shuffle,
             )
+
+        if sampler_state_dict is not None:
+            logging.info("Loading sampler state dict")
+            train_sampler.load_state_dict(sampler_state_dict)
+
         logging.info("About to create train dataloader")
         train_dl = DataLoader(
             train,
@@ -288,6 +293,7 @@ class TedLiumAsrDataModule:
         return train_dl
 
     def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+
         transforms = []
         if self.args.concatenate_cuts:
             transforms = [
@@ -310,11 +316,13 @@ class TedLiumAsrDataModule:
                 cut_transforms=transforms,
                 return_cuts=self.args.return_cuts,
             )
+
         valid_sampler = DynamicBucketingSampler(
             cuts_valid,
             max_duration=self.args.max_duration,
             shuffle=False,
         )
+
         logging.info("About to create dev dataloader")
         valid_dl = DataLoader(
             validate,
@@ -326,25 +334,34 @@ class TedLiumAsrDataModule:
 
         return valid_dl
 
-    def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+    def test_dataloaders(self, cuts_test: CutSet) -> DataLoader:
+
         logging.debug("About to create test dataset")
-        test = K2SpeechRecognitionDataset(
-            input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
-            if self.args.on_the_fly_feats
-            else PrecomputedFeatures(),
-            return_cuts=self.args.return_cuts,
-        )
-        sampler = DynamicBucketingSampler(
-            cuts,
+        if self.args.on_the_fly_feats:
+            test = K2SpeechRecognitionDataset(
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
+                return_cuts=self.args.return_cuts,
+            )
+        else:
+            test = K2SpeechRecognitionDataset(
+                return_cuts=self.args.return_cuts,
+            )
+
+        test_sampler = DynamicBucketingSampler(
+            cuts_test,
             max_duration=self.args.max_duration,
             shuffle=False,
         )
+
         logging.debug("About to create test dataloader")
         test_dl = DataLoader(
             test,
             batch_size=None,
-            sampler=sampler,
+            sampler=test_sampler,
             num_workers=self.args.num_workers,
+            persistent_workers=False,
         )
         return test_dl
 
diff --git a/icefall/decode.py b/icefall/decode.py
index f04ee368c..099e2d171 100644
--- a/icefall/decode.py
+++ b/icefall/decode.py
@@ -459,7 +459,8 @@ class Nbest(object):
 def one_best_decoding(
     lattice: k2.Fsa,
     use_double_scores: bool = True,
-) -> k2.Fsa:
+    lm_scale_list: Optional[List[float]] = None,
+) -> Union[k2.Fsa, Dict[str, k2.Fsa]]:
     """Get the best path from a lattice.
 
     Args:
@@ -468,11 +469,28 @@ def one_best_decoding(
       use_double_scores:
         True to use double precision floating point in the computation.
         False to use single precision.
+      lm_scale_list:
+        A list of floats representing LM score scales.
     Return:
       An FsaVec containing linear paths.
     """
-    best_path = k2.shortest_path(lattice, use_double_scores=use_double_scores)
-    return best_path
+
+    if lm_scale_list is not None:
+
+        ans = dict()
+        saved_am_scores = lattice.scores - lattice.lm_scores
+        for lm_scale in lm_scale_list:
+            am_scores = saved_am_scores / lm_scale
+            lattice.scores = am_scores + lattice.lm_scores
+
+            best_path = k2.shortest_path(
+                lattice, use_double_scores=use_double_scores
+            )
+            key = f"lm_scale_{lm_scale}"
+            ans[key] = best_path
+        return ans
+
+    return k2.shortest_path(lattice, use_double_scores=use_double_scores)
 
 
 def nbest_decoding(
diff --git a/icefall/utils.py b/icefall/utils.py
index c502cb4d8..143c79497 100644
--- a/icefall/utils.py
+++ b/icefall/utils.py
@@ -194,8 +194,16 @@ def encode_supervisions(
     supervision_segments = torch.stack(
         (
             supervisions["sequence_idx"],
-            supervisions["start_frame"] // subsampling_factor,
-            supervisions["num_frames"] // subsampling_factor,
+            torch.div(
+                supervisions["start_frame"],
+                subsampling_factor,
+                rounding_mode="floor",
+            ),
+            torch.div(
+                supervisions["num_frames"],
+                subsampling_factor,
+                rounding_mode="floor",
+            )
         ),
         1,
     ).to(torch.int32)

From 60317120caef28a43ed904a26263c57536ca95ab Mon Sep 17 00:00:00 2001
From: Fangjun Kuang 
Date: Thu, 17 Nov 2022 20:19:32 +0800
Subject: [PATCH 037/120] Revert "Apply new Black style changes"

---
 .git-blame-ignore-revs                        |    2 -
 .github/workflows/style_check.yml             |   11 +-
 .pre-commit-config.yaml                       |   28 +-
 docker/README.md                              |   24 +-
 .../Dockerfile                                |   14 +-
 .../Dockerfile                                |   17 +-
 .../images/k2-gt-v1.9-blueviolet.svg          |    2 +-
 .../images/python-gt-v3.6-blue.svg            |    2 +-
 .../images/torch-gt-v1.6.0-green.svg          |    2 +-
 docs/source/recipes/aishell/index.rst         |    1 +
 docs/source/recipes/timit/index.rst           |    1 +
 docs/source/recipes/timit/tdnn_ligru_ctc.rst  |   28 +-
 docs/source/recipes/timit/tdnn_lstm_ctc.rst   |   24 +-
 .../local/compute_fbank_aidatatang_200zh.py   |    8 +-
 .../ASR/local/prepare_char.py                 |    8 +-
 .../ASR/local/prepare_lang.py                 |    4 +-
 .../ASR/local/test_prepare_lang.py            |    4 +-
 egs/aidatatang_200zh/ASR/local/text2token.py  |   21 +-
 egs/aidatatang_200zh/ASR/prepare.sh           |    3 +-
 .../asr_datamodule.py                         |  110 +-
 .../pruned_transducer_stateless2/decode.py    |   50 +-
 .../pruned_transducer_stateless2/export.py    |   20 +-
 .../pretrained.py                             |   41 +-
 .../ASR/pruned_transducer_stateless2/train.py |   50 +-
 egs/aishell/ASR/conformer_ctc/conformer.py    |   70 +-
 egs/aishell/ASR/conformer_ctc/decode.py       |   29 +-
 egs/aishell/ASR/conformer_ctc/export.py       |   17 +-
 egs/aishell/ASR/conformer_ctc/pretrained.py   |   39 +-
 egs/aishell/ASR/conformer_ctc/subsampling.py  |   16 +-
 .../ASR/conformer_ctc/test_subsampling.py     |    3 +-
 egs/aishell/ASR/conformer_ctc/train.py        |   12 +-
 egs/aishell/ASR/conformer_ctc/transformer.py  |   44 +-
 egs/aishell/ASR/conformer_mmi/conformer.py    |   70 +-
 egs/aishell/ASR/conformer_mmi/decode.py       |   33 +-
 egs/aishell/ASR/conformer_mmi/subsampling.py  |   16 +-
 egs/aishell/ASR/conformer_mmi/train.py        |    8 +-
 egs/aishell/ASR/conformer_mmi/transformer.py  |   44 +-
 .../local/compute_fbank_aidatatang_200zh.py   |    8 +-
 .../ASR/local/compute_fbank_aishell.py        |    8 +-
 egs/aishell/ASR/local/prepare_char.py         |    8 +-
 egs/aishell/ASR/local/prepare_lang.py         |    4 +-
 egs/aishell/ASR/local/test_prepare_lang.py    |    4 +-
 .../pruned_transducer_stateless2/decode.py    |   50 +-
 .../pruned_transducer_stateless2/export.py    |   31 +-
 .../pretrained.py                             |   50 +-
 .../ASR/pruned_transducer_stateless2/train.py |   64 +-
 .../pruned_transducer_stateless3/decode.py    |   73 +-
 .../pruned_transducer_stateless3/export.py    |   54 +-
 .../ASR/pruned_transducer_stateless3/model.py |    8 +-
 .../pretrained.py                             |   50 +-
 .../ASR/pruned_transducer_stateless3/train.py |   79 +-
 .../ASR/tdnn_lstm_ctc/asr_datamodule.py       |  118 +-
 egs/aishell/ASR/tdnn_lstm_ctc/decode.py       |   33 +-
 egs/aishell/ASR/tdnn_lstm_ctc/model.py        |    5 +-
 egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py   |   37 +-
 egs/aishell/ASR/tdnn_lstm_ctc/train.py        |    7 +-
 .../ASR/transducer_stateless/beam_search.py   |   22 +-
 .../ASR/transducer_stateless/conformer.py     |   70 +-
 .../ASR/transducer_stateless/decode.py        |   39 +-
 .../ASR/transducer_stateless/decoder.py       |    4 +-
 .../ASR/transducer_stateless/export.py        |   20 +-
 egs/aishell/ASR/transducer_stateless/model.py |    4 +-
 .../ASR/transducer_stateless/pretrained.py    |   36 +-
 egs/aishell/ASR/transducer_stateless/train.py |   15 +-
 .../ASR/transducer_stateless/transformer.py   |    4 +-
 .../asr_datamodule.py                         |   85 +-
 .../transducer_stateless_modified-2/decode.py |   46 +-
 .../transducer_stateless_modified-2/export.py |   20 +-
 .../pretrained.py                             |   50 +-
 .../transducer_stateless_modified-2/train.py  |   22 +-
 .../transducer_stateless_modified/decode.py   |   46 +-
 .../transducer_stateless_modified/export.py   |   20 +-
 .../pretrained.py                             |   50 +-
 .../transducer_stateless_modified/train.py    |   15 +-
 egs/aishell2/ASR/local/__init__.py            |    0
 .../ASR/local/compute_fbank_aishell2.py       |    8 +-
 .../pruned_transducer_stateless5/__init__.py  |    0
 .../asr_datamodule.py                         |  114 +-
 .../pruned_transducer_stateless5/decode.py    |   67 +-
 .../pruned_transducer_stateless5/export.py    |   47 +-
 .../pretrained.py                             |   40 +-
 .../ASR/pruned_transducer_stateless5/train.py |   67 +-
 .../ASR/local/compute_fbank_aishell4.py       |    8 +-
 egs/aishell4/ASR/local/prepare_char.py        |    8 +-
 egs/aishell4/ASR/local/prepare_lang.py        |    4 +-
 egs/aishell4/ASR/local/test_prepare_lang.py   |    4 +-
 egs/aishell4/ASR/local/text2token.py          |   21 +-
 .../asr_datamodule.py                         |  110 +-
 .../pruned_transducer_stateless5/decode.py    |   69 +-
 .../pruned_transducer_stateless5/export.py    |   47 +-
 .../pretrained.py                             |   45 +-
 .../ASR/pruned_transducer_stateless5/train.py |   59 +-
 .../ASR/local/compute_fbank_alimeeting.py     |    8 +-
 egs/alimeeting/ASR/local/prepare_char.py      |    8 +-
 egs/alimeeting/ASR/local/prepare_lang.py      |    4 +-
 egs/alimeeting/ASR/local/test_prepare_lang.py |    4 +-
 egs/alimeeting/ASR/local/text2segments.py     |    2 +-
 egs/alimeeting/ASR/local/text2token.py        |   21 +-
 .../asr_datamodule.py                         |  110 +-
 .../pruned_transducer_stateless2/decode.py    |   60 +-
 .../pruned_transducer_stateless2/export.py    |   20 +-
 .../pretrained.py                             |   41 +-
 .../ASR/pruned_transducer_stateless2/train.py |   50 +-
 egs/csj/ASR/.gitignore                        |    2 +-
 egs/csj/ASR/local/compute_fbank_csj.py        |   38 +-
 egs/csj/ASR/local/compute_fbank_musan.py      |   17 +-
 egs/csj/ASR/local/conf/disfluent.ini          |   55 +-
 egs/csj/ASR/local/conf/fluent.ini             |   55 +-
 egs/csj/ASR/local/conf/number.ini             |   55 +-
 egs/csj/ASR/local/conf/symbol.ini             |   55 +-
 .../ASR/local/display_manifest_statistics.py  |    4 +-
 egs/csj/ASR/local/prepare_lang_char.py        |   17 +-
 egs/csj/ASR/local/validate_manifest.py        |    7 +-
 .../ASR/conformer_ctc/asr_datamodule.py       |  117 +-
 egs/gigaspeech/ASR/conformer_ctc/conformer.py |   66 +-
 egs/gigaspeech/ASR/conformer_ctc/decode.py    |   29 +-
 .../ASR/conformer_ctc/gigaspeech_scoring.py   |    3 +-
 .../ASR/conformer_ctc/label_smoothing.py      |    7 +-
 .../ASR/conformer_ctc/subsampling.py          |   16 +-
 egs/gigaspeech/ASR/conformer_ctc/train.py     |   12 +-
 .../ASR/conformer_ctc/transformer.py          |   49 +-
 .../compute_fbank_gigaspeech_dev_test.py      |    4 +-
 .../local/compute_fbank_gigaspeech_splits.py  |   10 +-
 .../ASR/local/preprocess_gigaspeech.py        |   10 +-
 .../asr_datamodule.py                         |  117 +-
 .../pruned_transducer_stateless2/decode.py    |   42 +-
 .../pruned_transducer_stateless2/export.py    |   24 +-
 .../ASR/pruned_transducer_stateless2/train.py |   48 +-
 egs/librispeech/ASR/conformer_ctc/ali.py      |   25 +-
 .../ASR/conformer_ctc/conformer.py            |   66 +-
 egs/librispeech/ASR/conformer_ctc/decode.py   |   29 +-
 egs/librispeech/ASR/conformer_ctc/export.py   |   17 +-
 .../ASR/conformer_ctc/label_smoothing.py      |    7 +-
 .../ASR/conformer_ctc/pretrained.py           |   33 +-
 .../ASR/conformer_ctc/subsampling.py          |   16 +-
 egs/librispeech/ASR/conformer_ctc/train.py    |   22 +-
 .../ASR/conformer_ctc/transformer.py          |   49 +-
 .../ASR/conformer_ctc2/attention.py           |   19 +-
 .../ASR/conformer_ctc2/conformer.py           |   65 +-
 egs/librispeech/ASR/conformer_ctc2/decode.py  |   56 +-
 egs/librispeech/ASR/conformer_ctc2/export.py  |   49 +-
 egs/librispeech/ASR/conformer_ctc2/train.py   |   39 +-
 .../ASR/conformer_ctc2/transformer.py         |   46 +-
 .../ASR/conformer_mmi/conformer.py            |   70 +-
 egs/librispeech/ASR/conformer_mmi/decode.py   |   29 +-
 .../ASR/conformer_mmi/subsampling.py          |   16 +-
 .../ASR/conformer_mmi/test_subsampling.py     |    3 +-
 .../ASR/conformer_mmi/test_transformer.py     |    9 +-
 .../ASR/conformer_mmi/train-with-attention.py |   27 +-
 egs/librispeech/ASR/conformer_mmi/train.py    |   27 +-
 .../ASR/conformer_mmi/transformer.py          |   28 +-
 .../decode.py                                 |   69 +-
 .../emformer.py                               |  119 +-
 .../export.py                                 |   47 +-
 .../stream.py                                 |    8 +-
 .../streaming_decode.py                       |   75 +-
 .../train.py                                  |   56 +-
 .../decode.py                                 |   69 +-
 .../emformer.py                               |  108 +-
 .../export.py                                 |   47 +-
 .../streaming_decode.py                       |   75 +-
 .../train.py                                  |   56 +-
 .../ASR/local/add_alignment_librispeech.py    |   12 +-
 egs/librispeech/ASR/local/compile_hlg.py      |    6 +-
 egs/librispeech/ASR/local/compile_lg.py       |    4 +-
 .../compute_fbank_gigaspeech_dev_test.py      |    4 +-
 .../local/compute_fbank_gigaspeech_splits.py  |   10 +-
 .../ASR/local/compute_fbank_librispeech.py    |    8 +-
 .../ASR/local/compute_fbank_musan.py          |    8 +-
 .../convert_transcript_words_to_tokens.py     |   16 +-
 egs/librispeech/ASR/local/download_lm.py      |    4 +-
 egs/librispeech/ASR/local/filter_cuts.py      |   10 +-
 .../ASR/local/generate_unique_lexicon.py      |    4 +-
 egs/librispeech/ASR/local/prepare_lang_bpe.py |    4 +-
 .../ASR/local/prepare_lm_training_data.py     |   11 +-
 .../ASR/local/preprocess_gigaspeech.py        |    4 +-
 .../ASR/local/test_prepare_lang.py            |    4 +-
 .../ASR/local/validate_manifest.py            |    7 +-
 .../ASR/lstm_transducer_stateless/decode.py   |  818 ++++++++++++
 .../ASR/lstm_transducer_stateless/export.py   |  388 ++++++
 .../jit_pretrained.py                         |  322 +++++
 .../ASR/lstm_transducer_stateless/lstm.py     |  871 +++++++++++++
 .../ASR/lstm_transducer_stateless/model.py    |  210 +++
 .../lstm_transducer_stateless/pretrained.py   |  352 +++++
 .../ASR/lstm_transducer_stateless/stream.py   |  148 +++
 .../streaming_decode.py                       |  968 ++++++++++++++
 .../ASR/lstm_transducer_stateless/train.py    | 1157 +++++++++++++++++
 .../ASR/lstm_transducer_stateless2/decode.py  |   67 +-
 .../ASR/lstm_transducer_stateless2/export.py  |   59 +-
 .../jit_pretrained.py                         |   21 +-
 .../ASR/lstm_transducer_stateless2/model.py   |    8 +-
 .../lstm_transducer_stateless2/ncnn-decode.py |   15 +-
 .../lstm_transducer_stateless2/pretrained.py  |   40 +-
 .../streaming-ncnn-decode.py                  |   27 +-
 .../streaming-onnx-decode.py                  |   45 +-
 .../ASR/lstm_transducer_stateless2/train.py   |   68 +-
 .../ASR/lstm_transducer_stateless3/decode.py  |   79 +-
 .../ASR/lstm_transducer_stateless3/export.py  |   47 +-
 .../jit_pretrained.py                         |   21 +-
 .../ASR/lstm_transducer_stateless3/lstm.py    |   14 +-
 .../lstm_transducer_stateless3/pretrained.py  |   40 +-
 .../streaming_decode.py                       |   74 +-
 .../ASR/lstm_transducer_stateless3/train.py   |   66 +-
 .../ASR/pruned2_knowledge/asr_datamodule.py   |  125 +-
 .../ASR/pruned2_knowledge/beam_search.py      |   18 +-
 .../ASR/pruned2_knowledge/conformer.py        |   90 +-
 .../ASR/pruned2_knowledge/decode.py           |   44 +-
 .../ASR/pruned2_knowledge/decoder.py          |    4 +-
 .../ASR/pruned2_knowledge/decoder2.py         |   84 +-
 .../ASR/pruned2_knowledge/export.py           |   20 +-
 .../ASR/pruned2_knowledge/joiner.py           |    4 +-
 .../ASR/pruned2_knowledge/model.py            |    8 +-
 .../ASR/pruned2_knowledge/optim.py            |   35 +-
 .../ASR/pruned2_knowledge/sampling.py         |  180 ++-
 .../ASR/pruned2_knowledge/scaling.py          |   51 +-
 .../ASR/pruned2_knowledge/scaling_tmp.py      |  355 ++---
 .../ASR/pruned2_knowledge/train.py            |   50 +-
 .../pruned_stateless_emformer_rnnt2/decode.py |   69 +-
 .../emformer.py                               |    8 +-
 .../pruned_stateless_emformer_rnnt2/export.py |   47 +-
 .../pruned_stateless_emformer_rnnt2/model.py  |    4 +-
 .../pruned_stateless_emformer_rnnt2/train.py  |   44 +-
 .../beam_search.py                            |   26 +-
 .../ASR/pruned_transducer_stateless/decode.py |   44 +-
 .../decode_stream.py                          |   19 +-
 .../pruned_transducer_stateless/decoder.py    |    4 +-
 .../ASR/pruned_transducer_stateless/export.py |   20 +-
 .../ASR/pruned_transducer_stateless/model.py  |    4 +-
 .../pruned_transducer_stateless/pretrained.py |   36 +-
 .../streaming_beam_search.py                  |    8 +-
 .../streaming_decode.py                       |   39 +-
 .../ASR/pruned_transducer_stateless/train.py  |   46 +-
 .../beam_search.py                            |   51 +-
 .../pruned_transducer_stateless2/conformer.py |   97 +-
 .../pruned_transducer_stateless2/decode.py    |   50 +-
 .../pruned_transducer_stateless2/decoder.py   |    8 +-
 .../pruned_transducer_stateless2/export.py    |   24 +-
 .../pruned_transducer_stateless2/joiner.py    |    4 +-
 .../ASR/pruned_transducer_stateless2/model.py |    8 +-
 .../ASR/pruned_transducer_stateless2/optim.py |   35 +-
 .../pretrained.py                             |   36 +-
 .../pruned_transducer_stateless2/scaling.py   |   56 +-
 .../streaming_beam_search.py                  |   12 +-
 .../streaming_decode.py                       |   39 +-
 .../ASR/pruned_transducer_stateless2/train.py |   58 +-
 .../asr_datamodule.py                         |   85 +-
 .../decode-giga.py                            |   54 +-
 .../pruned_transducer_stateless3/decode.py    |   74 +-
 .../pruned_transducer_stateless3/export.py    |   32 +-
 .../gigaspeech.py                             |    8 +-
 .../jit_pretrained.py                         |   21 +-
 .../ASR/pruned_transducer_stateless3/model.py |    8 +-
 .../onnx_check.py                             |   24 +-
 .../onnx_pretrained.py                        |   27 +-
 .../pretrained.py                             |   36 +-
 .../scaling_converter.py                      |   10 +-
 .../streaming_decode.py                       |   39 +-
 .../pruned_transducer_stateless3/test_onnx.py |   24 +-
 .../ASR/pruned_transducer_stateless3/train.py |   65 +-
 .../pruned_transducer_stateless4/decode.py    |   79 +-
 .../pruned_transducer_stateless4/export.py    |   47 +-
 .../streaming_decode.py                       |   62 +-
 .../ASR/pruned_transducer_stateless4/train.py |   61 +-
 .../pruned_transducer_stateless5/conformer.py |  118 +-
 .../pruned_transducer_stateless5/decode.py    |   67 +-
 .../pruned_transducer_stateless5/export.py    |   47 +-
 .../pretrained.py                             |   40 +-
 .../streaming_decode.py                       |   62 +-
 .../ASR/pruned_transducer_stateless5/train.py |   66 +-
 .../pruned_transducer_stateless6/conformer.py |   67 +-
 .../pruned_transducer_stateless6/decode.py    |   69 +-
 .../pruned_transducer_stateless6/export.py    |   24 +-
 .../extract_codebook_index.py                 |    3 +-
 .../hubert_decode.py                          |   17 +-
 .../hubert_xlarge.py                          |   22 +-
 .../ASR/pruned_transducer_stateless6/model.py |   12 +-
 .../ASR/pruned_transducer_stateless6/train.py |   65 +-
 .../pruned_transducer_stateless6/vq_utils.py  |   31 +-
 .../pruned_transducer_stateless7/decode.py    |   67 +-
 .../pruned_transducer_stateless7/decoder.py   |    6 +-
 .../pruned_transducer_stateless7/export.py    |   47 +-
 .../jit_pretrained.py                         |   21 +-
 .../pruned_transducer_stateless7/joiner.py    |    4 +-
 .../ASR/pruned_transducer_stateless7/model.py |   16 +-
 .../ASR/pruned_transducer_stateless7/optim.py |  435 +++----
 .../pretrained.py                             |   40 +-
 .../pruned_transducer_stateless7/scaling.py   |  487 ++++---
 .../scaling_converter.py                      |   12 +-
 .../ASR/pruned_transducer_stateless7/train.py |   88 +-
 .../pruned_transducer_stateless7/zipformer.py |  654 +++++-----
 .../pruned_transducer_stateless8/decode.py    |   67 +-
 .../pruned_transducer_stateless8/export.py    |   47 +-
 .../jit_pretrained.py                         |   21 +-
 .../ASR/pruned_transducer_stateless8/model.py |    4 +-
 .../pretrained.py                             |   40 +-
 .../ASR/pruned_transducer_stateless8/train.py |   99 +-
 .../ASR/streaming_conformer_ctc/README.md     |   16 +-
 .../ASR/streaming_conformer_ctc/conformer.py  |  116 +-
 .../streaming_decode.py                       |   68 +-
 .../ASR/streaming_conformer_ctc/train.py      |   16 +-
 .../streaming_conformer_ctc/transformer.py    |   40 +-
 .../ASR/tdnn_lstm_ctc/asr_datamodule.py       |  113 +-
 egs/librispeech/ASR/tdnn_lstm_ctc/decode.py   |   29 +-
 egs/librispeech/ASR/tdnn_lstm_ctc/model.py    |    5 +-
 .../ASR/tdnn_lstm_ctc/pretrained.py           |   43 +-
 egs/librispeech/ASR/tdnn_lstm_ctc/train.py    |    8 +-
 egs/librispeech/ASR/transducer/beam_search.py |   14 +-
 egs/librispeech/ASR/transducer/decode.py      |   28 +-
 egs/librispeech/ASR/transducer/export.py      |   17 +-
 egs/librispeech/ASR/transducer/pretrained.py  |   33 +-
 egs/librispeech/ASR/transducer/rnn.py         |   24 +-
 egs/librispeech/ASR/transducer/test_rnn.py    |   16 +-
 egs/librispeech/ASR/transducer/train.py       |   12 +-
 .../ASR/transducer_lstm/beam_search.py        |   14 +-
 egs/librispeech/ASR/transducer_lstm/decode.py |   28 +-
 .../ASR/transducer_lstm/encoder.py            |    4 +-
 egs/librispeech/ASR/transducer_lstm/train.py  |   12 +-
 .../ASR/transducer_stateless/alignment.py     |    4 +-
 .../ASR/transducer_stateless/beam_search.py   |   28 +-
 .../ASR/transducer_stateless/compute_ali.py   |   24 +-
 .../ASR/transducer_stateless/conformer.py     |  107 +-
 .../ASR/transducer_stateless/decode.py        |   42 +-
 .../ASR/transducer_stateless/decoder.py       |    4 +-
 .../ASR/transducer_stateless/export.py        |   20 +-
 .../ASR/transducer_stateless/joiner.py        |    8 +-
 .../ASR/transducer_stateless/pretrained.py    |   36 +-
 .../transducer_stateless/test_compute_ali.py  |   11 +-
 .../transducer_stateless/test_conformer.py    |    4 +-
 .../ASR/transducer_stateless/train.py         |   23 +-
 .../ASR/transducer_stateless/transformer.py   |    4 +-
 .../ASR/transducer_stateless2/decode.py       |   42 +-
 .../ASR/transducer_stateless2/export.py       |   20 +-
 .../ASR/transducer_stateless2/pretrained.py   |   36 +-
 .../ASR/transducer_stateless2/train.py        |   23 +-
 .../decode.py                                 |   42 +-
 .../export.py                                 |   20 +-
 .../pretrained.py                             |   36 +-
 .../test_asr_datamodule.py                    |    4 +-
 .../train.py                                  |   22 +-
 egs/ptb/LM/local/sort_lm_training_data.py     |    4 +-
 .../LM/local/test_prepare_lm_training_data.py |    4 +-
 .../ASR/local/compute_fbank_musan.py          |    8 +-
 .../ASR/local/compute_fbank_spgispeech.py     |   14 +-
 egs/spgispeech/ASR/local/prepare_splits.py    |    8 +-
 .../asr_datamodule.py                         |  100 +-
 .../pruned_transducer_stateless2/decode.py    |   66 +-
 .../pruned_transducer_stateless2/export.py    |   26 +-
 .../ASR/pruned_transducer_stateless2/train.py |   51 +-
 .../ASR/local/compute_fbank_tal_csasr.py      |    8 +-
 egs/tal_csasr/ASR/local/prepare_char.py       |    4 +-
 egs/tal_csasr/ASR/local/prepare_lang.py       |    4 +-
 egs/tal_csasr/ASR/local/test_prepare_lang.py  |    4 +-
 egs/tal_csasr/ASR/local/text2token.py         |   21 +-
 .../asr_datamodule.py                         |  110 +-
 .../pruned_transducer_stateless5/decode.py    |   77 +-
 .../pruned_transducer_stateless5/export.py    |   47 +-
 .../pretrained.py                             |   40 +-
 .../ASR/pruned_transducer_stateless5/train.py |   59 +-
 .../ASR/local/compute_fbank_tedlium.py        |    8 +-
 .../convert_transcript_words_to_bpe_ids.py    |    4 +-
 egs/tedlium3/ASR/local/prepare_lexicon.py     |   11 +-
 egs/tedlium3/ASR/local/prepare_transcripts.py |   11 +-
 .../ASR/pruned_transducer_stateless/decode.py |   38 +-
 .../ASR/pruned_transducer_stateless/export.py |   20 +-
 .../pruned_transducer_stateless/pretrained.py |   41 +-
 .../ASR/pruned_transducer_stateless/train.py  |   35 +-
 .../transducer_stateless/asr_datamodule.py    |  127 +-
 .../ASR/transducer_stateless/beam_search.py   |   30 +-
 .../ASR/transducer_stateless/decode.py        |   31 +-
 .../ASR/transducer_stateless/decoder.py       |    4 +-
 .../ASR/transducer_stateless/export.py        |   20 +-
 .../ASR/transducer_stateless/pretrained.py    |   36 +-
 .../ASR/transducer_stateless/train.py         |   11 +-
 egs/timit/ASR/RESULTS.md                      |    2 +-
 egs/timit/ASR/local/compile_hlg.py            |    4 +-
 egs/timit/ASR/local/compute_fbank_timit.py    |    8 +-
 egs/timit/ASR/local/prepare_lexicon.py        |    8 +-
 egs/timit/ASR/prepare.sh                      |    4 +-
 egs/timit/ASR/tdnn_ligru_ctc/decode.py        |   29 +-
 egs/timit/ASR/tdnn_ligru_ctc/model.py         |   12 +-
 egs/timit/ASR/tdnn_ligru_ctc/pretrained.py    |   43 +-
 egs/timit/ASR/tdnn_ligru_ctc/train.py         |    4 +-
 egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py |  104 +-
 egs/timit/ASR/tdnn_lstm_ctc/decode.py         |   29 +-
 egs/timit/ASR/tdnn_lstm_ctc/model.py          |    5 +-
 egs/timit/ASR/tdnn_lstm_ctc/pretrained.py     |   43 +-
 egs/timit/ASR/tdnn_lstm_ctc/train.py          |    4 +-
 .../compute_fbank_wenetspeech_dev_test.py     |   11 +-
 .../local/compute_fbank_wenetspeech_splits.py |   10 +-
 egs/wenetspeech/ASR/local/prepare_char.py     |    8 +-
 .../ASR/local/preprocess_wenetspeech.py       |    6 +-
 egs/wenetspeech/ASR/local/text2token.py       |   21 +-
 egs/wenetspeech/ASR/prepare.sh                |    2 +-
 .../asr_datamodule.py                         |  121 +-
 .../pruned_transducer_stateless2/decode.py    |   64 +-
 .../pruned_transducer_stateless2/export.py    |   28 +-
 .../jit_pretrained.py                         |   21 +-
 .../onnx_check.py                             |   24 +-
 .../onnx_pretrained.py                        |   27 +-
 .../pretrained.py                             |   41 +-
 .../ASR/pruned_transducer_stateless2/train.py |   50 +-
 .../pruned_transducer_stateless5/conformer.py |   97 +-
 .../pruned_transducer_stateless5/decode.py    |   75 +-
 .../decode_stream.py                          |   19 +-
 .../pruned_transducer_stateless5/export.py    |   20 +-
 .../pretrained.py                             |   41 +-
 .../streaming_beam_search.py                  |    8 +-
 .../streaming_decode.py                       |   62 +-
 .../ASR/pruned_transducer_stateless5/train.py |   67 +-
 egs/yesno/ASR/local/compile_hlg.py            |    4 +-
 egs/yesno/ASR/local/compute_fbank_yesno.py    |   12 +-
 egs/yesno/ASR/tdnn/asr_datamodule.py          |   74 +-
 egs/yesno/ASR/tdnn/decode.py                  |   29 +-
 egs/yesno/ASR/tdnn/pretrained.py              |   37 +-
 egs/yesno/ASR/tdnn/train.py                   |    4 +-
 egs/yesno/ASR/transducer/decode.py            |   25 +-
 egs/yesno/ASR/transducer/train.py             |    4 +-
 icefall/char_graph_compiler.py                |    8 +-
 icefall/checkpoint.py                         |   12 +-
 icefall/decode.py                             |   40 +-
 icefall/diagnostics.py                        |   80 +-
 icefall/dist.py                               |    4 +-
 icefall/env.py                                |    4 +-
 icefall/graph_compiler.py                     |    4 +-
 icefall/hooks.py                              |   19 +-
 icefall/lexicon.py                            |   16 +-
 icefall/mmi.py                                |   29 +-
 icefall/mmi_graph_compiler.py                 |    8 +-
 icefall/rnn_lm/compute_perplexity.py          |   15 +-
 icefall/rnn_lm/dataset.py                     |    8 +-
 icefall/rnn_lm/export.py                      |   17 +-
 icefall/rnn_lm/model.py                       |   28 +-
 icefall/rnn_lm/train.py                       |   11 +-
 icefall/shared/make_kn_lm.py                  |  184 +--
 icefall/utils.py                              |   66 +-
 pyproject.toml                                |    2 +-
 setup.py                                      |    3 +-
 test/test_checkpoint.py                       |    6 +-
 test/test_decode.py                           |    1 -
 test/test_graph_compiler.py                   |    4 +-
 test/test_utils.py                            |    4 +-
 441 files changed, 14535 insertions(+), 6789 deletions(-)
 delete mode 100644 .git-blame-ignore-revs
 mode change 100644 => 100755 egs/aishell2/ASR/local/__init__.py
 mode change 100644 => 100755 egs/aishell2/ASR/pruned_transducer_stateless5/__init__.py
 mode change 100644 => 100755 egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py
 mode change 100644 => 100755 egs/librispeech/ASR/lstm_transducer_stateless/decode.py
 mode change 100644 => 100755 egs/librispeech/ASR/lstm_transducer_stateless/export.py
 mode change 100644 => 100755 egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py
 mode change 100644 => 100755 egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py
 mode change 100644 => 100755 egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py
 mode change 100644 => 100755 egs/librispeech/ASR/lstm_transducer_stateless/train.py

diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs
deleted file mode 100644
index c5908fc89..000000000
--- a/.git-blame-ignore-revs
+++ /dev/null
@@ -1,2 +0,0 @@
-# Migrate to 88 characters per line (see: https://github.com/lhotse-speech/lhotse/issues/890)
-d110b04ad389134c82fa314e3aafc7b40043efb0
diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml
index 45d261ccc..90459bc1c 100644
--- a/.github/workflows/style_check.yml
+++ b/.github/workflows/style_check.yml
@@ -45,18 +45,17 @@ jobs:
 
       - name: Install Python dependencies
         run: |
-          python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0
-          # Click issue fixed in https://github.com/psf/black/pull/2966
+          python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 click==8.0.4
+          # See https://github.com/psf/black/issues/2964
+          # The version of click should be selected from 8.0.0, 8.0.1, 8.0.2, 8.0.3, and 8.0.4
 
       - name: Run flake8
         shell: bash
         working-directory: ${{github.workspace}}
         run: |
           # stop the build if there are Python syntax errors or undefined names
-          flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
-          # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
-          flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 \
-            --statistics --extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503
+          flake8 . --count --show-source --statistics
+          flake8 .
 
       - name: Run black
         shell: bash
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 5cb213327..446ba0fe7 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,38 +1,26 @@
 repos:
   - repo: https://github.com/psf/black
-    rev: 22.3.0
+    rev: 21.6b0
     hooks:
       - id: black
-        args: ["--line-length=88"]
-        additional_dependencies: ['click==8.1.0']
+        args: [--line-length=80]
+        additional_dependencies: ['click==8.0.1']
         exclude: icefall\/__init__\.py
 
   - repo: https://github.com/PyCQA/flake8
-    rev: 5.0.4
+    rev: 3.9.2
     hooks:
       - id: flake8
-        args: ["--max-line-length=88", "--extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503"]
-
-      # What are we ignoring here?
-      # E203: whitespace before ':'
-      # E266: too many leading '#' for block comment
-      # E501: line too long
-      # F401: module imported but unused
-      # E402: module level import not at top of file
-      # F403: 'from module import *' used; unable to detect undefined names
-      # F841: local variable is assigned to but never used
-      # W503: line break before binary operator
-      # In addition, the default ignore list is:
-      # E121,E123,E126,E226,E24,E704,W503,W504
+        args: [--max-line-length=80]
 
   - repo: https://github.com/pycqa/isort
-    rev: 5.10.1
+    rev: 5.9.2
     hooks:
       - id: isort
-        args: ["--profile=black"]
+        args: [--profile=black, --line-length=80]
 
   - repo: https://github.com/pre-commit/pre-commit-hooks
-    rev: v4.2.0
+    rev: v4.0.1
     hooks:
       - id: check-executables-have-shebangs
       - id: end-of-file-fixer
diff --git a/docker/README.md b/docker/README.md
index c14b9bf75..6f2314e96 100644
--- a/docker/README.md
+++ b/docker/README.md
@@ -2,7 +2,7 @@
 
 2 sets of configuration are provided - (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8, and (b) Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8.
 
-If your NVIDIA driver supports CUDA Version: 11.3, please go for case (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8.
+If your NVIDIA driver supports CUDA Version: 11.3, please go for case (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8. 
 
 Otherwise, since the older PyTorch images are not updated with the [apt-key rotation by NVIDIA](https://developer.nvidia.com/blog/updating-the-cuda-linux-gpg-repository-key), you have to go for case (b) Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8. Ensure that your NVDIA driver supports at least CUDA 11.0.
 
@@ -10,7 +10,7 @@ You can check the highest CUDA version within your NVIDIA driver's support with
 
 ```bash
 $ nvidia-smi
-Tue Sep 20 00:26:13 2022
+Tue Sep 20 00:26:13 2022       
 +-----------------------------------------------------------------------------+
 | NVIDIA-SMI 450.119.03   Driver Version: 450.119.03   CUDA Version: 11.0     |
 |-------------------------------+----------------------+----------------------+
@@ -26,7 +26,7 @@ Tue Sep 20 00:26:13 2022
 | 41%   30C    P8    11W / 280W |      6MiB / 24220MiB |      0%      Default |
 |                               |                      |                  N/A |
 +-------------------------------+----------------------+----------------------+
-
+                                                                               
 +-----------------------------------------------------------------------------+
 | Processes:                                                                  |
 |  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
@@ -40,15 +40,15 @@ Tue Sep 20 00:26:13 2022
 ```
 
 ## Building images locally
-If your environment requires a proxy to access the Internet, remember to add those information into the Dockerfile directly.
-For most cases, you can uncomment these lines in the Dockerfile and add in your proxy details.
+If your environment requires a proxy to access the Internet, remember to add those information into the Dockerfile directly. 
+For most cases, you can uncomment these lines in the Dockerfile and add in your proxy details. 
 
 ```dockerfile
 ENV http_proxy=http://aaa.bb.cc.net:8080 \
     https_proxy=http://aaa.bb.cc.net:8080
 ```
 
-Then, proceed with these commands.
+Then, proceed with these commands. 
 
 ### If you are case (a), i.e. your NVIDIA driver supports CUDA version >= 11.3:
 
@@ -72,11 +72,11 @@ docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all icefall
 ```
 
 ### Tips:
-1. Since your data and models most probably won't be in the docker, you must use the -v flag to access the host machine. Do this by specifying `-v {/path/in/host/machine}:{/path/in/docker}`.
+1. Since your data and models most probably won't be in the docker, you must use the -v flag to access the host machine. Do this by specifying `-v {/path/in/host/machine}:{/path/in/docker}`. 
 
 2. Also, if your environment requires a proxy, this would be a good time to add it in too: `-e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080`.
 
-Overall, your docker run command should look like this.
+Overall, your docker run command should look like this. 
 
 ```bash
 docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all -v {/path/in/host/machine}:{/path/in/docker} -e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080 icefall/pytorch1.12.1
@@ -86,9 +86,9 @@ You can explore more docker run options [here](https://docs.docker.com/engine/re
 
 ### Linking to icefall in your host machine
 
-If you already have icefall downloaded onto your host machine, you can use that repository instead so that changes in your code are visible inside and outside of the container.
+If you already have icefall downloaded onto your host machine, you can use that repository instead so that changes in your code are visible inside and outside of the container. 
 
-Note: Remember to set the -v flag above during the first run of the container, as that is the only way for your container to access your host machine.
+Note: Remember to set the -v flag above during the first run of the container, as that is the only way for your container to access your host machine. 
 Warning: Check that the icefall in your host machine is visible from within your container before proceeding to the commands below.
 
 Use these commands once you are inside the container.
@@ -103,7 +103,7 @@ ln -s {/path/in/docker/to/icefall} /workspace/icefall
 docker exec -it icefall /bin/bash
 ```
 
-## Restarting a killed container that has been run before.
+## Restarting a killed container that has been run before. 
 ```bash
 docker start -ai icefall
 ```
@@ -111,4 +111,4 @@ docker start -ai icefall
 ## Sample usage of the CPU based images:
 ```bash
 docker run -it icefall /bin/bash
-```
+``` 
diff --git a/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile b/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile
index ff9e40604..3637d2f11 100644
--- a/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile
+++ b/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile
@@ -1,7 +1,7 @@
 FROM pytorch/pytorch:1.12.1-cuda11.3-cudnn8-devel
 
 # ENV http_proxy=http://aaa.bbb.cc.net:8080 \
-#	https_proxy=http://aaa.bbb.cc.net:8080
+#	https_proxy=http://aaa.bbb.cc.net:8080 
 
 # install normal source
 RUN apt-get update && \
@@ -38,10 +38,10 @@ RUN wget -P /opt https://cmake.org/files/v3.18/cmake-3.18.0.tar.gz && \
     rm -rf cmake-3.18.0.tar.gz && \
     find /opt/cmake-3.18.0 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
     cd -
-
-# flac
+	
+# flac 
 RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz  && \
-    cd /opt && \
+    cd /opt && \ 
     xz -d flac-1.3.2.tar.xz && \
     tar -xvf flac-1.3.2.tar && \
     cd flac-1.3.2 && \
@@ -49,11 +49,11 @@ RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz  &&
     make && make install && \
     rm -rf flac-1.3.2.tar && \
     find /opt/flac-1.3.2  -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
-    cd -
+    cd - 
 
 RUN conda install -y -c pytorch torchaudio=0.12 && \
     pip install graphviz
-
+	
 
 #install k2 from source
 RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \
@@ -68,7 +68,7 @@ RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
 	cd /workspace/icefall && \
 	pip install -r requirements.txt
 
-RUN pip install kaldifeat
+RUN pip install kaldifeat 
 ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
 
 WORKDIR /workspace/icefall
diff --git a/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile b/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile
index 5c7423fa5..17a8215f9 100644
--- a/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile
+++ b/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile
@@ -1,12 +1,12 @@
 FROM pytorch/pytorch:1.7.1-cuda11.0-cudnn8-devel
 
 # ENV http_proxy=http://aaa.bbb.cc.net:8080 \
-#	https_proxy=http://aaa.bbb.cc.net:8080
+#	https_proxy=http://aaa.bbb.cc.net:8080 
 
 RUN rm /etc/apt/sources.list.d/cuda.list && \
 	rm /etc/apt/sources.list.d/nvidia-ml.list && \
 	apt-key del 7fa2af80
-
+	
 # install normal source
 RUN apt-get update && \
     apt-get install -y --no-install-recommends \
@@ -36,7 +36,7 @@ RUN curl -fsSL https://developer.download.nvidia.com/compute/cuda/repos/ubuntu18
 	curl -fsSL https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub | apt-key add - && \
 	echo "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/cuda.list && \
 	echo "deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/nvidia-ml.list && \
-	rm -rf /var/lib/apt/lists/* && \
+	rm -rf /var/lib/apt/lists/* && \ 
 	mv /opt/conda/lib/libcufft.so.10 /opt/libcufft.so.10.bak && \
     mv /opt/conda/lib/libcurand.so.10 /opt/libcurand.so.10.bak && \
     mv /opt/conda/lib/libcublas.so.11 /opt/libcublas.so.11.bak && \
@@ -56,10 +56,10 @@ RUN wget -P /opt https://cmake.org/files/v3.18/cmake-3.18.0.tar.gz && \
     rm -rf cmake-3.18.0.tar.gz && \
     find /opt/cmake-3.18.0 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
     cd -
-
-# flac
+	
+# flac 
 RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz  && \
-    cd /opt && \
+    cd /opt && \ 
     xz -d flac-1.3.2.tar.xz && \
     tar -xvf flac-1.3.2.tar && \
     cd flac-1.3.2 && \
@@ -67,7 +67,7 @@ RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz  &&
     make && make install && \
     rm -rf flac-1.3.2.tar && \
     find /opt/flac-1.3.2  -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
-    cd -
+    cd - 
 
 RUN conda install -y -c pytorch torchaudio=0.7.1 && \
     pip install graphviz
@@ -79,7 +79,7 @@ RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \
     cd -
 
 # install  lhotse
-RUN pip install git+https://github.com/lhotse-speech/lhotse
+RUN pip install git+https://github.com/lhotse-speech/lhotse 
 
 RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
 	cd /workspace/icefall && \
@@ -88,3 +88,4 @@ RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
 ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
 
 WORKDIR /workspace/icefall
+
diff --git a/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg b/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg
index 3019ff03d..534b2e534 100644
--- a/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg
+++ b/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg
@@ -1 +1 @@
-k2: >= v1.9k2>= v1.9
+k2: >= v1.9k2>= v1.9
\ No newline at end of file
diff --git a/docs/source/installation/images/python-gt-v3.6-blue.svg b/docs/source/installation/images/python-gt-v3.6-blue.svg
index df677ad09..4254dc58a 100644
--- a/docs/source/installation/images/python-gt-v3.6-blue.svg
+++ b/docs/source/installation/images/python-gt-v3.6-blue.svg
@@ -1 +1 @@
-python: >= 3.6python>= 3.6
+python: >= 3.6python>= 3.6
\ No newline at end of file
diff --git a/docs/source/installation/images/torch-gt-v1.6.0-green.svg b/docs/source/installation/images/torch-gt-v1.6.0-green.svg
index d7007d742..d3ece9a17 100644
--- a/docs/source/installation/images/torch-gt-v1.6.0-green.svg
+++ b/docs/source/installation/images/torch-gt-v1.6.0-green.svg
@@ -1 +1 @@
-torch: >= 1.6.0torch>= 1.6.0
+torch: >= 1.6.0torch>= 1.6.0
\ No newline at end of file
diff --git a/docs/source/recipes/aishell/index.rst b/docs/source/recipes/aishell/index.rst
index b77d59bca..d072d6e9c 100644
--- a/docs/source/recipes/aishell/index.rst
+++ b/docs/source/recipes/aishell/index.rst
@@ -19,3 +19,4 @@ It can be downloaded from ``_
    tdnn_lstm_ctc
    conformer_ctc
    stateless_transducer
+
diff --git a/docs/source/recipes/timit/index.rst b/docs/source/recipes/timit/index.rst
index 5ee147be7..17f40cdb7 100644
--- a/docs/source/recipes/timit/index.rst
+++ b/docs/source/recipes/timit/index.rst
@@ -6,3 +6,4 @@ TIMIT
 
    tdnn_ligru_ctc
    tdnn_lstm_ctc
+
diff --git a/docs/source/recipes/timit/tdnn_ligru_ctc.rst b/docs/source/recipes/timit/tdnn_ligru_ctc.rst
index 3d7aefe02..186420ee7 100644
--- a/docs/source/recipes/timit/tdnn_ligru_ctc.rst
+++ b/docs/source/recipes/timit/tdnn_ligru_ctc.rst
@@ -148,10 +148,10 @@ Some commonly used options are:
 
         $ ./tdnn_ligru_ctc/decode.py --epoch 25 --avg 17
 
-    uses the average of ``epoch-9.pt``, ``epoch-10.pt``, ``epoch-11.pt``,
-    ``epoch-12.pt``, ``epoch-13.pt``, ``epoch-14.pt``, ``epoch-15.pt``,
-    ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, ``epoch-19.pt``,
-    ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, ``epoch-23.pt``,
+    uses the average of ``epoch-9.pt``, ``epoch-10.pt``, ``epoch-11.pt``, 
+    ``epoch-12.pt``, ``epoch-13.pt``, ``epoch-14.pt``, ``epoch-15.pt``, 
+    ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, ``epoch-19.pt``, 
+    ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, ``epoch-23.pt``, 
     ``epoch-24.pt`` and ``epoch-25.pt``
     for decoding.
 
@@ -317,13 +317,13 @@ To decode with ``1best`` method, we can use:
 
 .. code-block:: bash
 
-  ./tdnn_ligru_ctc/pretrained.py
+  ./tdnn_ligru_ctc/pretrained.py 
     --method 1best
-    --checkpoint ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/exp/pretrained_average_9_25.pt
-    --words-file ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/words.txt
-    --HLG ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt
-    ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV
-    ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV
+    --checkpoint ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/exp/pretrained_average_9_25.pt 
+    --words-file ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/words.txt 
+    --HLG ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt 
+    ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV 
+    ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV 
     ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV
 
 The output is:
@@ -337,7 +337,7 @@ The output is:
   2021-11-08 20:41:38,697 INFO [pretrained.py:210] Reading sound files: ['./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV']
   2021-11-08 20:41:38,704 INFO [pretrained.py:216] Decoding started
   2021-11-08 20:41:39,819 INFO [pretrained.py:246] Use HLG decoding
-  2021-11-08 20:41:39,829 INFO [pretrained.py:267]
+  2021-11-08 20:41:39,829 INFO [pretrained.py:267] 
   ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV:
   sil dh ih sh uw ah l iy v iy z ih sil p r aa sil k s ih m ey dx ih sil d w uh dx ih w ih s f iy l ih ng w ih th ih n ih m s eh l f sil jh
 
@@ -362,8 +362,8 @@ To decode with ``whole-lattice-rescoring`` methond, you can use
     --HLG ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt \
     --G ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lm/G_4_gram.pt \
     --ngram-lm-scale 0.1 \
-    ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV
-    ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV
+    ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV 
+    ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV 
     ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV
 
 The decoding output is:
@@ -378,7 +378,7 @@ The decoding output is:
   2021-11-08 20:37:54,715 INFO [pretrained.py:210] Reading sound files: ['./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV']
   2021-11-08 20:37:54,720 INFO [pretrained.py:216] Decoding started
   2021-11-08 20:37:55,808 INFO [pretrained.py:251] Use HLG decoding + LM rescoring
-  2021-11-08 20:37:56,348 INFO [pretrained.py:267]
+  2021-11-08 20:37:56,348 INFO [pretrained.py:267] 
   ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV:
   sil dh ih sh uw ah l iy v iy z ah sil p r aa sil k s ih m ey dx ih sil d w uh dx iy w ih s f iy l iy ng w ih th ih n ih m s eh l f sil jh
 
diff --git a/docs/source/recipes/timit/tdnn_lstm_ctc.rst b/docs/source/recipes/timit/tdnn_lstm_ctc.rst
index ee67a6edc..6f760a9ce 100644
--- a/docs/source/recipes/timit/tdnn_lstm_ctc.rst
+++ b/docs/source/recipes/timit/tdnn_lstm_ctc.rst
@@ -148,8 +148,8 @@ Some commonly used options are:
 
         $ ./tdnn_lstm_ctc/decode.py --epoch 25 --avg 10
 
-    uses the average of ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``,
-    ``epoch-19.pt``, ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``,
+    uses the average of ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, 
+    ``epoch-19.pt``, ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, 
     ``epoch-23.pt``, ``epoch-24.pt`` and ``epoch-25.pt``
     for decoding.
 
@@ -315,13 +315,13 @@ To decode with ``1best`` method, we can use:
 
 .. code-block:: bash
 
-  ./tdnn_lstm_ctc/pretrained.py
+  ./tdnn_lstm_ctc/pretrained.py 
     --method 1best
-    --checkpoint ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/exp/pretrained_average_16_25.pt
-    --words-file ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/words.txt
-    --HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt
-    ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV
-    ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV
+    --checkpoint ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/exp/pretrained_average_16_25.pt 
+    --words-file ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/words.txt 
+    --HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt 
+    ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV 
+    ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV 
     ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV
 
 The output is:
@@ -335,7 +335,7 @@ The output is:
   2021-11-08 21:02:53,827 INFO [pretrained.py:210] Reading sound files: ['./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV']
   2021-11-08 21:02:53,831 INFO [pretrained.py:216] Decoding started
   2021-11-08 21:02:54,380 INFO [pretrained.py:246] Use HLG decoding
-  2021-11-08 21:02:54,387 INFO [pretrained.py:267]
+  2021-11-08 21:02:54,387 INFO [pretrained.py:267] 
   ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV:
   sil dh ih sh uw ah l iy v iy z ih sil p r aa sil k s ih m ey dx ih sil d w uh dx iy w ih s f iy l iy w ih th ih n ih m s eh l f sil jh
 
@@ -360,8 +360,8 @@ To decode with ``whole-lattice-rescoring`` methond, you can use
     --HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt \
     --G ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lm/G_4_gram.pt \
     --ngram-lm-scale 0.08 \
-    ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV
-    ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV
+    ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV 
+    ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV 
     ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV
 
 The decoding output is:
@@ -376,7 +376,7 @@ The decoding output is:
   2021-11-08 20:05:26,978 INFO [pretrained.py:210] Reading sound files: ['./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV']
   2021-11-08 20:05:26,981 INFO [pretrained.py:216] Decoding started
   2021-11-08 20:05:27,519 INFO [pretrained.py:251] Use HLG decoding + LM rescoring
-  2021-11-08 20:05:27,878 INFO [pretrained.py:267]
+  2021-11-08 20:05:27,878 INFO [pretrained.py:267] 
   ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV:
   sil dh ih sh uw l iy v iy z ih sil p r aa sil k s ah m ey dx ih sil w uh dx iy w ih s f iy l ih ng w ih th ih n ih m s eh l f sil jh
 
diff --git a/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py b/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py
index 387c14acf..fb2751c0f 100755
--- a/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py
+++ b/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py
@@ -87,7 +87,9 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+                    cut_set
+                    + cut_set.perturb_speed(0.9)
+                    + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -114,7 +116,9 @@ def get_args():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/aidatatang_200zh/ASR/local/prepare_char.py b/egs/aidatatang_200zh/ASR/local/prepare_char.py
index 6b440dfb3..d9e47d17a 100755
--- a/egs/aidatatang_200zh/ASR/local/prepare_char.py
+++ b/egs/aidatatang_200zh/ASR/local/prepare_char.py
@@ -86,7 +86,9 @@ def lexicon_to_fst_no_sil(
         cur_state = loop_state
 
         word = word2id[word]
-        pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces]
+        pieces = [
+            token2id[i] if i in token2id else token2id[""] for i in pieces
+        ]
 
         for i in range(len(pieces) - 1):
             w = word if i == 0 else eps
@@ -140,7 +142,9 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
     return False
 
 
-def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
+def generate_lexicon(
+    token_sym_table: Dict[str, int], words: List[str]
+) -> Lexicon:
     """Generate a lexicon from a word list and token_sym_table.
 
     Args:
diff --git a/egs/aidatatang_200zh/ASR/local/prepare_lang.py b/egs/aidatatang_200zh/ASR/local/prepare_lang.py
index c8cf9b881..e5ae89ec4 100755
--- a/egs/aidatatang_200zh/ASR/local/prepare_lang.py
+++ b/egs/aidatatang_200zh/ASR/local/prepare_lang.py
@@ -317,7 +317,9 @@ def lexicon_to_fst(
 
 def get_args():
     parser = argparse.ArgumentParser()
-    parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
+    parser.add_argument(
+        "--lang-dir", type=str, help="The lang dir, data/lang_phone"
+    )
     return parser.parse_args()
 
 
diff --git a/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py b/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py
index 74e025ad7..d4cf62bba 100755
--- a/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py
+++ b/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py
@@ -88,7 +88,9 @@ def test_read_lexicon(filename: str):
     fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa.draw("L.pdf", title="L")
 
-    fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
+    fsa_disambig = lexicon_to_fst(
+        lexicon_disambig, phone2id=phone2id, word2id=word2id
+    )
     fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
     fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
diff --git a/egs/aidatatang_200zh/ASR/local/text2token.py b/egs/aidatatang_200zh/ASR/local/text2token.py
index 2be639b7a..71be2a613 100755
--- a/egs/aidatatang_200zh/ASR/local/text2token.py
+++ b/egs/aidatatang_200zh/ASR/local/text2token.py
@@ -50,15 +50,15 @@ def get_parser():
         "-n",
         default=1,
         type=int,
-        help=(
-            "number of characters to split, i.e.,                         aabb -> a a b"
-            " b with -n 1 and aa bb with -n 2"
-        ),
+        help="number of characters to split, i.e., \
+                        aabb -> a a b b with -n 1 and aa bb with -n 2",
     )
     parser.add_argument(
         "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
     )
-    parser.add_argument("--space", default="", type=str, help="space symbol")
+    parser.add_argument(
+        "--space", default="", type=str, help="space symbol"
+    )
     parser.add_argument(
         "--non-lang-syms",
         "-l",
@@ -66,7 +66,9 @@ def get_parser():
         type=str,
         help="list of non-linguistic symobles, e.g.,  etc.",
     )
-    parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
+    parser.add_argument(
+        "text", type=str, default=False, nargs="?", help="input text"
+    )
     parser.add_argument(
         "--trans_type",
         "-t",
@@ -106,7 +108,8 @@ def token2id(
             if token_type == "lazy_pinyin":
                 text = lazy_pinyin(chars_list)
                 sub_ids = [
-                    token_table[txt] if txt in token_table else oov_id for txt in text
+                    token_table[txt] if txt in token_table else oov_id
+                    for txt in text
                 ]
                 ids.append(sub_ids)
             else:  # token_type = "pinyin"
@@ -132,7 +135,9 @@ def main():
     if args.text:
         f = codecs.open(args.text, encoding="utf-8")
     else:
-        f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
+        f = codecs.getreader("utf-8")(
+            sys.stdin if is_python2 else sys.stdin.buffer
+        )
 
     sys.stdout = codecs.getwriter("utf-8")(
         sys.stdout if is_python2 else sys.stdout.buffer
diff --git a/egs/aidatatang_200zh/ASR/prepare.sh b/egs/aidatatang_200zh/ASR/prepare.sh
index 4749e1b7f..039951354 100755
--- a/egs/aidatatang_200zh/ASR/prepare.sh
+++ b/egs/aidatatang_200zh/ASR/prepare.sh
@@ -106,10 +106,11 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
   if [ ! -f $lang_char_dir/words.txt ]; then
     ./local/prepare_words.py \
       --input-file $lang_char_dir/words_no_ids.txt \
-      --output-file $lang_char_dir/words.txt
+      --output-file $lang_char_dir/words.txt 
   fi
 
   if [ ! -f $lang_char_dir/L_disambig.pt ]; then
     ./local/prepare_char.py
   fi
 fi
+
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py
index 8c94f5bea..6a5b57e24 100644
--- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -81,12 +81,10 @@ class Aidatatang_200zhAsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description=(
-                "These options are used for the preparation of "
-                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-                "effective batch sizes, sampling strategies, applied data "
-                "augmentations, etc."
-            ),
+            description="These options are used for the preparation of "
+            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+            "effective batch sizes, sampling strategies, applied data "
+            "augmentations, etc.",
         )
         group.add_argument(
             "--manifest-dir",
@@ -98,91 +96,75 @@ class Aidatatang_200zhAsrDataModule:
             "--max-duration",
             type=int,
             default=200.0,
-            help=(
-                "Maximum pooled recordings duration (seconds) in a "
-                "single batch. You can reduce it if it causes CUDA OOM."
-            ),
+            help="Maximum pooled recordings duration (seconds) in a "
+            "single batch. You can reduce it if it causes CUDA OOM.",
         )
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, the batches will come from buckets of "
-                "similar duration (saves padding frames)."
-            ),
+            help="When enabled, the batches will come from buckets of "
+            "similar duration (saves padding frames).",
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=300,
-            help=(
-                "The number of buckets for the DynamicBucketingSampler"
-                "(you might want to increase it for larger datasets)."
-            ),
+            help="The number of buckets for the DynamicBucketingSampler"
+            "(you might want to increase it for larger datasets).",
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, utterances (cuts) will be concatenated "
-                "to minimize the amount of padding."
-            ),
+            help="When enabled, utterances (cuts) will be concatenated "
+            "to minimize the amount of padding.",
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help=(
-                "Determines the maximum duration of a concatenated cut "
-                "relative to the duration of the longest cut in a batch."
-            ),
+            help="Determines the maximum duration of a concatenated cut "
+            "relative to the duration of the longest cut in a batch.",
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help=(
-                "The amount of padding (in seconds) inserted between "
-                "concatenated cuts. This padding is filled with noise when "
-                "noise augmentation is used."
-            ),
+            help="The amount of padding (in seconds) inserted between "
+            "concatenated cuts. This padding is filled with noise when "
+            "noise augmentation is used.",
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, use on-the-fly cut mixing and feature "
-                "extraction. Will drop existing precomputed feature manifests "
-                "if available."
-            ),
+            help="When enabled, use on-the-fly cut mixing and feature "
+            "extraction. Will drop existing precomputed feature manifests "
+            "if available.",
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled (=default), the examples will be shuffled for each epoch."
-            ),
+            help="When enabled (=default), the examples will be "
+            "shuffled for each epoch.",
         )
         group.add_argument(
             "--return-cuts",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, each batch will have the "
-                "field: batch['supervisions']['cut'] with the cuts that "
-                "were used to construct it."
-            ),
+            help="When enabled, each batch will have the "
+            "field: batch['supervisions']['cut'] with the cuts that "
+            "were used to construct it.",
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that collect the batches.",
+            help="The number of training dataloader workers that "
+            "collect the batches.",
         )
 
         group.add_argument(
@@ -196,22 +178,18 @@ class Aidatatang_200zhAsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help=(
-                "Used only when --enable-spec-aug is True. "
-                "It specifies the factor for time warping in SpecAugment. "
-                "Larger values mean more warping. "
-                "A value less than 1 means to disable time warp."
-            ),
+            help="Used only when --enable-spec-aug is True. "
+            "It specifies the factor for time warping in SpecAugment. "
+            "Larger values mean more warping. "
+            "A value less than 1 means to disable time warp.",
         )
 
         group.add_argument(
             "--enable-musan",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, select noise from MUSAN and mix it"
-                "with training dataset. "
-            ),
+            help="When enabled, select noise from MUSAN and mix it"
+            "with training dataset. ",
         )
 
     def train_dataloaders(
@@ -227,20 +205,24 @@ class Aidatatang_200zhAsrDataModule:
             The state dict for the training sampler.
         """
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+        cuts_musan = load_manifest(
+            self.args.manifest_dir / "musan_cuts.jsonl.gz"
+        )
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+                CutMix(
+                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
+                )
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                "Using cut concatenation with duration factor "
+                f"Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -255,7 +237,9 @@ class Aidatatang_200zhAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+            logging.info(
+                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
+            )
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -298,7 +282,9 @@ class Aidatatang_200zhAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -354,7 +340,9 @@ class Aidatatang_200zhAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 return_cuts=self.args.return_cuts,
             )
         else:
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py
index 3f582ef04..f0407f429 100755
--- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py
@@ -69,7 +69,11 @@ from beam_search import (
 )
 from train import get_params, get_transducer_model
 
-from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
+from icefall.checkpoint import (
+    average_checkpoints,
+    find_checkpoints,
+    load_checkpoint,
+)
 from icefall.lexicon import Lexicon
 from icefall.utils import (
     AttributeDict,
@@ -88,30 +92,25 @@ def get_parser():
         "--epoch",
         type=int,
         default=28,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--batch",
         type=int,
         default=None,
-        help=(
-            "It specifies the batch checkpoint to use for decoding."
-            "Note: Epoch counts from 0."
-        ),
+        help="It specifies the batch checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -193,7 +192,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -249,7 +249,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -264,7 +266,10 @@ def decode_one_batch(
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -310,7 +315,11 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -381,7 +390,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -414,7 +425,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py
index 34f4d3ddf..00b54c39f 100644
--- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py
@@ -62,20 +62,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=28,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -106,7 +103,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     return parser
@@ -175,7 +173,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py
index 3c96ed07b..eb5e6b0d4 100644
--- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py
@@ -85,11 +85,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -114,12 +112,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -166,7 +162,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -196,9 +193,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -259,7 +257,9 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -284,7 +284,10 @@ def main():
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -336,7 +339,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py
index c7b1a4266..d46838b68 100644
--- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py
@@ -81,7 +81,9 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+LRSchedulerType = Union[
+    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
+]
 
 os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
 
@@ -185,45 +187,42 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help=(
-            "The prune range for rnnt loss, it means how many symbols(context)"
-            "we are using to compute the loss"
-        ),
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help=(
-            "The scale to smooth the loss with lm (output of prediction network) part."
-        ),
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)part.",
+        help="The scale to smooth the loss with am (output of encoder network)"
+        "part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help=(
-            "To get pruning ranges, we will calculate a simple version"
-            "loss(joiner is just addition), this simple loss also uses for"
-            "training (as a regularization item). We will scale the simple loss"
-            "with this parameter before adding to the final loss."
-        ),
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
     )
 
     parser.add_argument(
@@ -543,15 +542,22 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+            0.0
+            if warmup < 1.0
+            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+        )
+        loss = (
+            params.simple_loss_scale * simple_loss
+            + pruned_loss_scale * pruned_loss
         )
-        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -705,7 +711,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -805,7 +813,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2**22
+            2 ** 22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
diff --git a/egs/aishell/ASR/conformer_ctc/conformer.py b/egs/aishell/ASR/conformer_ctc/conformer.py
index f5b5873b4..cb7205e51 100644
--- a/egs/aishell/ASR/conformer_ctc/conformer.py
+++ b/egs/aishell/ASR/conformer_ctc/conformer.py
@@ -157,7 +157,9 @@ class ConformerEncoderLayer(nn.Module):
         normalize_before: bool = True,
     ) -> None:
         super(ConformerEncoderLayer, self).__init__()
-        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
+        self.self_attn = RelPositionMultiheadAttention(
+            d_model, nhead, dropout=0.0
+        )
 
         self.feed_forward = nn.Sequential(
             nn.Linear(d_model, dim_feedforward),
@@ -175,14 +177,18 @@ class ConformerEncoderLayer(nn.Module):
 
         self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
 
-        self.norm_ff_macaron = nn.LayerNorm(d_model)  # for the macaron style FNN module
+        self.norm_ff_macaron = nn.LayerNorm(
+            d_model
+        )  # for the macaron style FNN module
         self.norm_ff = nn.LayerNorm(d_model)  # for the FNN module
         self.norm_mha = nn.LayerNorm(d_model)  # for the MHA module
 
         self.ff_scale = 0.5
 
         self.norm_conv = nn.LayerNorm(d_model)  # for the CNN module
-        self.norm_final = nn.LayerNorm(d_model)  # for the final output of the block
+        self.norm_final = nn.LayerNorm(
+            d_model
+        )  # for the final output of the block
 
         self.dropout = nn.Dropout(dropout)
 
@@ -216,7 +222,9 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
+        src = residual + self.ff_scale * self.dropout(
+            self.feed_forward_macaron(src)
+        )
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -335,7 +343,9 @@ class RelPositionalEncoding(torch.nn.Module):
 
     """
 
-    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
+    def __init__(
+        self, d_model: int, dropout_rate: float, max_len: int = 5000
+    ) -> None:
         """Construct an PositionalEncoding object."""
         super(RelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -351,7 +361,9 @@ class RelPositionalEncoding(torch.nn.Module):
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(1) >= x.size(1) * 2 - 1:
                 # Note: TorchScript doesn't implement operator== for torch.Device
-                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
+                    x.device
+                ):
                     self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                 return
         # Suppose `i` means to the position of query vector and `j` means the
@@ -621,9 +633,9 @@ class RelPositionMultiheadAttention(nn.Module):
 
         if torch.equal(query, key) and torch.equal(key, value):
             # self-attention
-            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
-                3, dim=-1
-            )
+            q, k, v = nn.functional.linear(
+                query, in_proj_weight, in_proj_bias
+            ).chunk(3, dim=-1)
 
         elif torch.equal(key, value):
             # encoder-decoder attention
@@ -691,25 +703,33 @@ class RelPositionMultiheadAttention(nn.Module):
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0)
                 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
-                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
+                    raise RuntimeError(
+                        "The size of the 2D attn_mask is not correct."
+                    )
             elif attn_mask.dim() == 3:
                 if list(attn_mask.size()) != [
                     bsz * num_heads,
                     query.size(0),
                     key.size(0),
                 ]:
-                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
+                    raise RuntimeError(
+                        "The size of the 3D attn_mask is not correct."
+                    )
             else:
                 raise RuntimeError(
-                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
+                    "attn_mask's dimension {} is not supported".format(
+                        attn_mask.dim()
+                    )
                 )
             # attn_mask's dim is 3 now.
 
         # convert ByteTensor key_padding_mask to bool
-        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
+        if (
+            key_padding_mask is not None
+            and key_padding_mask.dtype == torch.uint8
+        ):
             warnings.warn(
-                "Byte tensor for key_padding_mask is deprecated. Use bool tensor"
-                " instead."
+                "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
             )
             key_padding_mask = key_padding_mask.to(torch.bool)
 
@@ -746,7 +766,9 @@ class RelPositionMultiheadAttention(nn.Module):
         # first compute matrix a and matrix c
         # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
         k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
-        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(
+            q_with_bias_u, k
+        )  # (batch, head, time1, time2)
 
         # compute matrix b and matrix d
         matrix_bd = torch.matmul(
@@ -758,7 +780,9 @@ class RelPositionMultiheadAttention(nn.Module):
             matrix_ac + matrix_bd
         ) * scaling  # (batch, head, time1, time2)
 
-        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
+        attn_output_weights = attn_output_weights.view(
+            bsz * num_heads, tgt_len, -1
+        )
 
         assert list(attn_output_weights.size()) == [
             bsz * num_heads,
@@ -792,9 +816,13 @@ class RelPositionMultiheadAttention(nn.Module):
         attn_output = torch.bmm(attn_output_weights, v)
         assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
         attn_output = (
-            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+            attn_output.transpose(0, 1)
+            .contiguous()
+            .view(tgt_len, bsz, embed_dim)
+        )
+        attn_output = nn.functional.linear(
+            attn_output, out_proj_weight, out_proj_bias
         )
-        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
 
         if need_weights:
             # average attention weights over heads
@@ -817,7 +845,9 @@ class ConvolutionModule(nn.Module):
 
     """
 
-    def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None:
+    def __init__(
+        self, channels: int, kernel_size: int, bias: bool = True
+    ) -> None:
         """Construct an ConvolutionModule object."""
         super(ConvolutionModule, self).__init__()
         # kernerl_size should be a odd number for 'SAME' padding
diff --git a/egs/aishell/ASR/conformer_ctc/decode.py b/egs/aishell/ASR/conformer_ctc/decode.py
index a30fa52df..751b7d5b5 100755
--- a/egs/aishell/ASR/conformer_ctc/decode.py
+++ b/egs/aishell/ASR/conformer_ctc/decode.py
@@ -58,19 +58,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=49,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=20,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -404,7 +401,9 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -432,7 +431,9 @@ def save_results(
         # we compute CER for aishell dataset.
         results_char = []
         for res in results:
-            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
+            results_char.append(
+                (res[0], list("".join(res[1])), list("".join(res[2])))
+            )
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
@@ -440,7 +441,9 @@ def save_results(
             test_set_wers[key] = wer
 
         if enable_log:
-            logging.info("Wrote detailed error stats to {}".format(errs_filename))
+            logging.info(
+                "Wrote detailed error stats to {}".format(errs_filename)
+            )
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt"
@@ -559,7 +562,9 @@ def main():
             eos_id=eos_id,
         )
 
-        save_results(params=params, test_set_name=test_set, results_dict=results_dict)
+        save_results(
+            params=params, test_set_name=test_set, results_dict=results_dict
+        )
 
     logging.info("Done!")
 
diff --git a/egs/aishell/ASR/conformer_ctc/export.py b/egs/aishell/ASR/conformer_ctc/export.py
index 9ee405e8b..42b8c29e7 100644
--- a/egs/aishell/ASR/conformer_ctc/export.py
+++ b/egs/aishell/ASR/conformer_ctc/export.py
@@ -40,20 +40,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=84,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=25,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -160,7 +157,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/conformer_ctc/pretrained.py b/egs/aishell/ASR/conformer_ctc/pretrained.py
index e3d5a20e3..27776bc24 100755
--- a/egs/aishell/ASR/conformer_ctc/pretrained.py
+++ b/egs/aishell/ASR/conformer_ctc/pretrained.py
@@ -46,29 +46,27 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
         "--tokens-file",
         type=str,
-        help="Path to tokens.txtUsed only when method is ctc-decoding",
+        help="Path to tokens.txt" "Used only when method is ctc-decoding",
     )
 
     parser.add_argument(
         "--words-file",
         type=str,
-        help="Path to words.txtUsed when method is NOT ctc-decoding",
+        help="Path to words.txt" "Used when method is NOT ctc-decoding",
     )
 
     parser.add_argument(
         "--HLG",
         type=str,
-        help="Path to HLG.pt.Used when method is NOT ctc-decoding",
+        help="Path to HLG.pt." "Used when method is NOT ctc-decoding",
     )
 
     parser.add_argument(
@@ -165,12 +163,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     return parser
@@ -214,9 +210,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -277,7 +274,9 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     # Note: We don't use key padding mask for attention during decoding
     with torch.no_grad():
@@ -372,7 +371,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/conformer_ctc/subsampling.py b/egs/aishell/ASR/conformer_ctc/subsampling.py
index 8e0f73d05..542fb0364 100644
--- a/egs/aishell/ASR/conformer_ctc/subsampling.py
+++ b/egs/aishell/ASR/conformer_ctc/subsampling.py
@@ -42,9 +42,13 @@ class Conv2dSubsampling(nn.Module):
         assert idim >= 7
         super().__init__()
         self.conv = nn.Sequential(
-            nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2),
+            nn.Conv2d(
+                in_channels=1, out_channels=odim, kernel_size=3, stride=2
+            ),
             nn.ReLU(),
-            nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2),
+            nn.Conv2d(
+                in_channels=odim, out_channels=odim, kernel_size=3, stride=2
+            ),
             nn.ReLU(),
         )
         self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
@@ -128,13 +132,17 @@ class VggSubsampling(nn.Module):
                 )
             )
             layers.append(
-                torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)
+                torch.nn.MaxPool2d(
+                    kernel_size=2, stride=2, padding=0, ceil_mode=True
+                )
             )
             cur_channels = block_dim
 
         self.layers = nn.Sequential(*layers)
 
-        self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim)
+        self.out = nn.Linear(
+            block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
+        )
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
         """Subsample x.
diff --git a/egs/aishell/ASR/conformer_ctc/test_subsampling.py b/egs/aishell/ASR/conformer_ctc/test_subsampling.py
index 81fa234dd..e3361d0c9 100755
--- a/egs/aishell/ASR/conformer_ctc/test_subsampling.py
+++ b/egs/aishell/ASR/conformer_ctc/test_subsampling.py
@@ -16,8 +16,9 @@
 # limitations under the License.
 
 
+from subsampling import Conv2dSubsampling
+from subsampling import VggSubsampling
 import torch
-from subsampling import Conv2dSubsampling, VggSubsampling
 
 
 def test_conv2d_subsampling():
diff --git a/egs/aishell/ASR/conformer_ctc/train.py b/egs/aishell/ASR/conformer_ctc/train.py
index c2cbe6e3b..a228cc1fe 100755
--- a/egs/aishell/ASR/conformer_ctc/train.py
+++ b/egs/aishell/ASR/conformer_ctc/train.py
@@ -382,7 +382,9 @@ def compute_loss(
             #
             # See https://github.com/k2-fsa/icefall/issues/97
             # for more details
-            unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"])
+            unsorted_token_ids = graph_compiler.texts_to_ids(
+                supervisions["text"]
+            )
             att_loss = mmodel.decoder_forward(
                 encoder_memory,
                 memory_mask,
@@ -518,7 +520,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -626,7 +630,9 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
+            tb_writer.add_scalar(
+                "train/learning_rate", cur_lr, params.batch_idx_train
+            )
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/aishell/ASR/conformer_ctc/transformer.py b/egs/aishell/ASR/conformer_ctc/transformer.py
index a3e50e385..f93914aaa 100644
--- a/egs/aishell/ASR/conformer_ctc/transformer.py
+++ b/egs/aishell/ASR/conformer_ctc/transformer.py
@@ -149,7 +149,9 @@ class Transformer(nn.Module):
                 norm=decoder_norm,
             )
 
-            self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class)
+            self.decoder_output_layer = torch.nn.Linear(
+                d_model, self.decoder_num_class
+            )
 
             self.decoder_criterion = LabelSmoothingLoss()
         else:
@@ -181,7 +183,9 @@ class Transformer(nn.Module):
             x = x.permute(0, 2, 1)  # (N, T, C) -> (N, C, T)
             x = self.feat_batchnorm(x)
             x = x.permute(0, 2, 1)  # (N, C, T) -> (N, T, C)
-        encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision)
+        encoder_memory, memory_key_padding_mask = self.run_encoder(
+            x, supervision
+        )
         x = self.ctc_output(encoder_memory)
         return x, encoder_memory, memory_key_padding_mask
 
@@ -262,17 +266,23 @@ class Transformer(nn.Module):
         """
         ys_in = add_sos(token_ids, sos_id=sos_id)
         ys_in = [torch.tensor(y) for y in ys_in]
-        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
+        ys_in_pad = pad_sequence(
+            ys_in, batch_first=True, padding_value=float(eos_id)
+        )
 
         ys_out = add_eos(token_ids, eos_id=eos_id)
         ys_out = [torch.tensor(y) for y in ys_out]
-        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
+        ys_out_pad = pad_sequence(
+            ys_out, batch_first=True, padding_value=float(-1)
+        )
 
         device = memory.device
         ys_in_pad = ys_in_pad.to(device)
         ys_out_pad = ys_out_pad.to(device)
 
-        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
+            device
+        )
 
         tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
         # TODO: Use length information to create the decoder padding mask
@@ -333,17 +343,23 @@ class Transformer(nn.Module):
 
         ys_in = add_sos(token_ids, sos_id=sos_id)
         ys_in = [torch.tensor(y) for y in ys_in]
-        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
+        ys_in_pad = pad_sequence(
+            ys_in, batch_first=True, padding_value=float(eos_id)
+        )
 
         ys_out = add_eos(token_ids, eos_id=eos_id)
         ys_out = [torch.tensor(y) for y in ys_out]
-        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
+        ys_out_pad = pad_sequence(
+            ys_out, batch_first=True, padding_value=float(-1)
+        )
 
         device = memory.device
         ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
         ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
 
-        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
+            device
+        )
 
         tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
         # TODO: Use length information to create the decoder padding mask
@@ -616,7 +632,9 @@ def _get_activation_fn(activation: str):
     elif activation == "gelu":
         return nn.functional.gelu
 
-    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
+    raise RuntimeError(
+        "activation should be relu/gelu, not {}".format(activation)
+    )
 
 
 class PositionalEncoding(nn.Module):
@@ -818,7 +836,9 @@ def encoder_padding_mask(
         1,
     ).to(torch.int32)
 
-    lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)]
+    lengths = [
+        0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
+    ]
     for idx in range(supervision_segments.size(0)):
         # Note: TorchScript doesn't allow to unpack tensors as tuples
         sequence_idx = supervision_segments[idx, 0].item()
@@ -839,7 +859,9 @@ def encoder_padding_mask(
     return mask
 
 
-def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor:
+def decoder_padding_mask(
+    ys_pad: torch.Tensor, ignore_id: int = -1
+) -> torch.Tensor:
     """Generate a length mask for input.
 
     The masked position are filled with True,
diff --git a/egs/aishell/ASR/conformer_mmi/conformer.py b/egs/aishell/ASR/conformer_mmi/conformer.py
index f5b5873b4..cb7205e51 100644
--- a/egs/aishell/ASR/conformer_mmi/conformer.py
+++ b/egs/aishell/ASR/conformer_mmi/conformer.py
@@ -157,7 +157,9 @@ class ConformerEncoderLayer(nn.Module):
         normalize_before: bool = True,
     ) -> None:
         super(ConformerEncoderLayer, self).__init__()
-        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
+        self.self_attn = RelPositionMultiheadAttention(
+            d_model, nhead, dropout=0.0
+        )
 
         self.feed_forward = nn.Sequential(
             nn.Linear(d_model, dim_feedforward),
@@ -175,14 +177,18 @@ class ConformerEncoderLayer(nn.Module):
 
         self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
 
-        self.norm_ff_macaron = nn.LayerNorm(d_model)  # for the macaron style FNN module
+        self.norm_ff_macaron = nn.LayerNorm(
+            d_model
+        )  # for the macaron style FNN module
         self.norm_ff = nn.LayerNorm(d_model)  # for the FNN module
         self.norm_mha = nn.LayerNorm(d_model)  # for the MHA module
 
         self.ff_scale = 0.5
 
         self.norm_conv = nn.LayerNorm(d_model)  # for the CNN module
-        self.norm_final = nn.LayerNorm(d_model)  # for the final output of the block
+        self.norm_final = nn.LayerNorm(
+            d_model
+        )  # for the final output of the block
 
         self.dropout = nn.Dropout(dropout)
 
@@ -216,7 +222,9 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
+        src = residual + self.ff_scale * self.dropout(
+            self.feed_forward_macaron(src)
+        )
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -335,7 +343,9 @@ class RelPositionalEncoding(torch.nn.Module):
 
     """
 
-    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
+    def __init__(
+        self, d_model: int, dropout_rate: float, max_len: int = 5000
+    ) -> None:
         """Construct an PositionalEncoding object."""
         super(RelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -351,7 +361,9 @@ class RelPositionalEncoding(torch.nn.Module):
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(1) >= x.size(1) * 2 - 1:
                 # Note: TorchScript doesn't implement operator== for torch.Device
-                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
+                    x.device
+                ):
                     self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                 return
         # Suppose `i` means to the position of query vector and `j` means the
@@ -621,9 +633,9 @@ class RelPositionMultiheadAttention(nn.Module):
 
         if torch.equal(query, key) and torch.equal(key, value):
             # self-attention
-            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
-                3, dim=-1
-            )
+            q, k, v = nn.functional.linear(
+                query, in_proj_weight, in_proj_bias
+            ).chunk(3, dim=-1)
 
         elif torch.equal(key, value):
             # encoder-decoder attention
@@ -691,25 +703,33 @@ class RelPositionMultiheadAttention(nn.Module):
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0)
                 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
-                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
+                    raise RuntimeError(
+                        "The size of the 2D attn_mask is not correct."
+                    )
             elif attn_mask.dim() == 3:
                 if list(attn_mask.size()) != [
                     bsz * num_heads,
                     query.size(0),
                     key.size(0),
                 ]:
-                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
+                    raise RuntimeError(
+                        "The size of the 3D attn_mask is not correct."
+                    )
             else:
                 raise RuntimeError(
-                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
+                    "attn_mask's dimension {} is not supported".format(
+                        attn_mask.dim()
+                    )
                 )
             # attn_mask's dim is 3 now.
 
         # convert ByteTensor key_padding_mask to bool
-        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
+        if (
+            key_padding_mask is not None
+            and key_padding_mask.dtype == torch.uint8
+        ):
             warnings.warn(
-                "Byte tensor for key_padding_mask is deprecated. Use bool tensor"
-                " instead."
+                "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
             )
             key_padding_mask = key_padding_mask.to(torch.bool)
 
@@ -746,7 +766,9 @@ class RelPositionMultiheadAttention(nn.Module):
         # first compute matrix a and matrix c
         # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
         k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
-        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(
+            q_with_bias_u, k
+        )  # (batch, head, time1, time2)
 
         # compute matrix b and matrix d
         matrix_bd = torch.matmul(
@@ -758,7 +780,9 @@ class RelPositionMultiheadAttention(nn.Module):
             matrix_ac + matrix_bd
         ) * scaling  # (batch, head, time1, time2)
 
-        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
+        attn_output_weights = attn_output_weights.view(
+            bsz * num_heads, tgt_len, -1
+        )
 
         assert list(attn_output_weights.size()) == [
             bsz * num_heads,
@@ -792,9 +816,13 @@ class RelPositionMultiheadAttention(nn.Module):
         attn_output = torch.bmm(attn_output_weights, v)
         assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
         attn_output = (
-            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+            attn_output.transpose(0, 1)
+            .contiguous()
+            .view(tgt_len, bsz, embed_dim)
+        )
+        attn_output = nn.functional.linear(
+            attn_output, out_proj_weight, out_proj_bias
         )
-        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
 
         if need_weights:
             # average attention weights over heads
@@ -817,7 +845,9 @@ class ConvolutionModule(nn.Module):
 
     """
 
-    def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None:
+    def __init__(
+        self, channels: int, kernel_size: int, bias: bool = True
+    ) -> None:
         """Construct an ConvolutionModule object."""
         super(ConvolutionModule, self).__init__()
         # kernerl_size should be a odd number for 'SAME' padding
diff --git a/egs/aishell/ASR/conformer_mmi/decode.py b/egs/aishell/ASR/conformer_mmi/decode.py
index a43183063..4db367e36 100755
--- a/egs/aishell/ASR/conformer_mmi/decode.py
+++ b/egs/aishell/ASR/conformer_mmi/decode.py
@@ -59,19 +59,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=49,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=20,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -416,7 +413,9 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -444,7 +443,9 @@ def save_results(
         # we compute CER for aishell dataset.
         results_char = []
         for res in results:
-            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
+            results_char.append(
+                (res[0], list("".join(res[1])), list("".join(res[2])))
+            )
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
@@ -452,7 +453,9 @@ def save_results(
             test_set_wers[key] = wer
 
         if enable_log:
-            logging.info("Wrote detailed error stats to {}".format(errs_filename))
+            logging.info(
+                "Wrote detailed error stats to {}".format(errs_filename)
+            )
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt"
@@ -547,7 +550,9 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
+        torch.save(
+            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
+        )
         return
 
     model.to(device)
@@ -576,7 +581,9 @@ def main():
             eos_id=eos_id,
         )
 
-        save_results(params=params, test_set_name=test_set, results_dict=results_dict)
+        save_results(
+            params=params, test_set_name=test_set, results_dict=results_dict
+        )
 
     logging.info("Done!")
 
diff --git a/egs/aishell/ASR/conformer_mmi/subsampling.py b/egs/aishell/ASR/conformer_mmi/subsampling.py
index 398837a46..720ed6c22 100644
--- a/egs/aishell/ASR/conformer_mmi/subsampling.py
+++ b/egs/aishell/ASR/conformer_mmi/subsampling.py
@@ -42,9 +42,13 @@ class Conv2dSubsampling(nn.Module):
         assert idim >= 7
         super().__init__()
         self.conv = nn.Sequential(
-            nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2),
+            nn.Conv2d(
+                in_channels=1, out_channels=odim, kernel_size=3, stride=2
+            ),
             nn.ReLU(),
-            nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2),
+            nn.Conv2d(
+                in_channels=odim, out_channels=odim, kernel_size=3, stride=2
+            ),
             nn.ReLU(),
         )
         self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
@@ -128,13 +132,17 @@ class VggSubsampling(nn.Module):
                 )
             )
             layers.append(
-                torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)
+                torch.nn.MaxPool2d(
+                    kernel_size=2, stride=2, padding=0, ceil_mode=True
+                )
             )
             cur_channels = block_dim
 
         self.layers = nn.Sequential(*layers)
 
-        self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim)
+        self.out = nn.Linear(
+            block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
+        )
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
         """Subsample x.
diff --git a/egs/aishell/ASR/conformer_mmi/train.py b/egs/aishell/ASR/conformer_mmi/train.py
index 09cd6e60c..685831d09 100755
--- a/egs/aishell/ASR/conformer_mmi/train.py
+++ b/egs/aishell/ASR/conformer_mmi/train.py
@@ -511,7 +511,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -623,7 +625,9 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
+            tb_writer.add_scalar(
+                "train/learning_rate", cur_lr, params.batch_idx_train
+            )
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/aishell/ASR/conformer_mmi/transformer.py b/egs/aishell/ASR/conformer_mmi/transformer.py
index a3e50e385..f93914aaa 100644
--- a/egs/aishell/ASR/conformer_mmi/transformer.py
+++ b/egs/aishell/ASR/conformer_mmi/transformer.py
@@ -149,7 +149,9 @@ class Transformer(nn.Module):
                 norm=decoder_norm,
             )
 
-            self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class)
+            self.decoder_output_layer = torch.nn.Linear(
+                d_model, self.decoder_num_class
+            )
 
             self.decoder_criterion = LabelSmoothingLoss()
         else:
@@ -181,7 +183,9 @@ class Transformer(nn.Module):
             x = x.permute(0, 2, 1)  # (N, T, C) -> (N, C, T)
             x = self.feat_batchnorm(x)
             x = x.permute(0, 2, 1)  # (N, C, T) -> (N, T, C)
-        encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision)
+        encoder_memory, memory_key_padding_mask = self.run_encoder(
+            x, supervision
+        )
         x = self.ctc_output(encoder_memory)
         return x, encoder_memory, memory_key_padding_mask
 
@@ -262,17 +266,23 @@ class Transformer(nn.Module):
         """
         ys_in = add_sos(token_ids, sos_id=sos_id)
         ys_in = [torch.tensor(y) for y in ys_in]
-        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
+        ys_in_pad = pad_sequence(
+            ys_in, batch_first=True, padding_value=float(eos_id)
+        )
 
         ys_out = add_eos(token_ids, eos_id=eos_id)
         ys_out = [torch.tensor(y) for y in ys_out]
-        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
+        ys_out_pad = pad_sequence(
+            ys_out, batch_first=True, padding_value=float(-1)
+        )
 
         device = memory.device
         ys_in_pad = ys_in_pad.to(device)
         ys_out_pad = ys_out_pad.to(device)
 
-        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
+            device
+        )
 
         tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
         # TODO: Use length information to create the decoder padding mask
@@ -333,17 +343,23 @@ class Transformer(nn.Module):
 
         ys_in = add_sos(token_ids, sos_id=sos_id)
         ys_in = [torch.tensor(y) for y in ys_in]
-        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
+        ys_in_pad = pad_sequence(
+            ys_in, batch_first=True, padding_value=float(eos_id)
+        )
 
         ys_out = add_eos(token_ids, eos_id=eos_id)
         ys_out = [torch.tensor(y) for y in ys_out]
-        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
+        ys_out_pad = pad_sequence(
+            ys_out, batch_first=True, padding_value=float(-1)
+        )
 
         device = memory.device
         ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
         ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
 
-        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
+            device
+        )
 
         tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
         # TODO: Use length information to create the decoder padding mask
@@ -616,7 +632,9 @@ def _get_activation_fn(activation: str):
     elif activation == "gelu":
         return nn.functional.gelu
 
-    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
+    raise RuntimeError(
+        "activation should be relu/gelu, not {}".format(activation)
+    )
 
 
 class PositionalEncoding(nn.Module):
@@ -818,7 +836,9 @@ def encoder_padding_mask(
         1,
     ).to(torch.int32)
 
-    lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)]
+    lengths = [
+        0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
+    ]
     for idx in range(supervision_segments.size(0)):
         # Note: TorchScript doesn't allow to unpack tensors as tuples
         sequence_idx = supervision_segments[idx, 0].item()
@@ -839,7 +859,9 @@ def encoder_padding_mask(
     return mask
 
 
-def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor:
+def decoder_padding_mask(
+    ys_pad: torch.Tensor, ignore_id: int = -1
+) -> torch.Tensor:
     """Generate a length mask for input.
 
     The masked position are filled with True,
diff --git a/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py b/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py
index 037971927..42700a972 100755
--- a/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py
+++ b/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py
@@ -87,7 +87,9 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+                    cut_set
+                    + cut_set.perturb_speed(0.9)
+                    + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -114,7 +116,9 @@ def get_args():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/aishell/ASR/local/compute_fbank_aishell.py b/egs/aishell/ASR/local/compute_fbank_aishell.py
index 115ca1031..deab6c809 100755
--- a/egs/aishell/ASR/local/compute_fbank_aishell.py
+++ b/egs/aishell/ASR/local/compute_fbank_aishell.py
@@ -83,7 +83,9 @@ def compute_fbank_aishell(num_mel_bins: int = 80):
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+                    cut_set
+                    + cut_set.perturb_speed(0.9)
+                    + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -109,7 +111,9 @@ def get_args():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/aishell/ASR/local/prepare_char.py b/egs/aishell/ASR/local/prepare_char.py
index 6b440dfb3..d9e47d17a 100755
--- a/egs/aishell/ASR/local/prepare_char.py
+++ b/egs/aishell/ASR/local/prepare_char.py
@@ -86,7 +86,9 @@ def lexicon_to_fst_no_sil(
         cur_state = loop_state
 
         word = word2id[word]
-        pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces]
+        pieces = [
+            token2id[i] if i in token2id else token2id[""] for i in pieces
+        ]
 
         for i in range(len(pieces) - 1):
             w = word if i == 0 else eps
@@ -140,7 +142,9 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
     return False
 
 
-def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
+def generate_lexicon(
+    token_sym_table: Dict[str, int], words: List[str]
+) -> Lexicon:
     """Generate a lexicon from a word list and token_sym_table.
 
     Args:
diff --git a/egs/aishell/ASR/local/prepare_lang.py b/egs/aishell/ASR/local/prepare_lang.py
index c8cf9b881..e5ae89ec4 100755
--- a/egs/aishell/ASR/local/prepare_lang.py
+++ b/egs/aishell/ASR/local/prepare_lang.py
@@ -317,7 +317,9 @@ def lexicon_to_fst(
 
 def get_args():
     parser = argparse.ArgumentParser()
-    parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
+    parser.add_argument(
+        "--lang-dir", type=str, help="The lang dir, data/lang_phone"
+    )
     return parser.parse_args()
 
 
diff --git a/egs/aishell/ASR/local/test_prepare_lang.py b/egs/aishell/ASR/local/test_prepare_lang.py
index 74e025ad7..d4cf62bba 100755
--- a/egs/aishell/ASR/local/test_prepare_lang.py
+++ b/egs/aishell/ASR/local/test_prepare_lang.py
@@ -88,7 +88,9 @@ def test_read_lexicon(filename: str):
     fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa.draw("L.pdf", title="L")
 
-    fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
+    fsa_disambig = lexicon_to_fst(
+        lexicon_disambig, phone2id=phone2id, word2id=word2id
+    )
     fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
     fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py
index ae926ec66..a12934d55 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py
@@ -76,7 +76,11 @@ from beam_search import (
 )
 from train import add_model_arguments, get_params, get_transducer_model
 
-from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
+from icefall.checkpoint import (
+    average_checkpoints,
+    find_checkpoints,
+    load_checkpoint,
+)
 from icefall.lexicon import Lexicon
 from icefall.utils import (
     AttributeDict,
@@ -114,11 +118,9 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch' and '--iter'"
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
     )
 
     parser.add_argument(
@@ -186,7 +188,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -246,7 +249,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
 
     if params.decoding_method == "fast_beam_search":
         hyp_tokens = fast_beam_search_one_best(
@@ -258,7 +263,10 @@ def decode_one_batch(
             max_contexts=params.max_contexts,
             max_states=params.max_states,
         )
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -302,7 +310,11 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -375,7 +387,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -401,7 +415,9 @@ def save_results(
         # we compute CER for aishell dataset.
         results_char = []
         for res in results:
-            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
+            results_char.append(
+                (res[0], list("".join(res[1])), list("".join(res[2])))
+            )
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results_char, enable_log=True
@@ -412,7 +428,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -456,7 +473,9 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        params.suffix += (
+            f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        )
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@@ -485,7 +504,8 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for"
+                f" --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/export.py b/egs/aishell/ASR/pruned_transducer_stateless2/export.py
index 5f6888db4..feababdd2 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless2/export.py
@@ -50,7 +50,11 @@ from pathlib import Path
 import torch
 from train import add_model_arguments, get_params, get_transducer_model
 
-from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
+from icefall.checkpoint import (
+    average_checkpoints,
+    find_checkpoints,
+    load_checkpoint,
+)
 from icefall.lexicon import Lexicon
 from icefall.utils import str2bool
 
@@ -83,11 +87,9 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch' and '--iter'"
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
     )
 
     parser.add_argument(
@@ -118,7 +120,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     add_model_arguments(parser)
@@ -154,7 +157,8 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for"
+                f" --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
@@ -187,7 +191,9 @@ def main():
         model.__class__.forward = torch.jit.ignore(model.__class__.forward)
         logging.info("Using torch.jit.script")
         model = torch.jit.script(model)
-        filename = params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
+        filename = (
+            params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
+        )
         model.save(str(filename))
         logging.info(f"Saved to {filename}")
     else:
@@ -195,14 +201,17 @@ def main():
         # Save it using a format so that it can be loaded
         # by :func:`load_checkpoint`
         filename = (
-            params.exp_dir / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
+            params.exp_dir
+            / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
         )
         torch.save({"model": model.state_dict()}, str(filename))
         logging.info(f"Saved to {filename}")
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py
index f754a7b9e..3c38e5db7 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py
@@ -87,11 +87,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -117,12 +115,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -169,16 +165,15 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
         type=int,
         default=1,
-        help=(
-            "Maximum number of symbols per frame. "
-            "Use only when --method is greedy_search"
-        ),
+        help="Maximum number of symbols per frame. "
+        "Use only when --method is greedy_search",
     )
 
     add_model_arguments(parser)
@@ -201,9 +196,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -260,9 +256,13 @@ def main():
     feature_lens = [f.size(0) for f in features]
     feature_lens = torch.tensor(feature_lens, device=device)
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
-    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=features, x_lens=feature_lens
+    )
 
     num_waves = encoder_out.size(0)
     hyp_list = []
@@ -310,7 +310,9 @@ def main():
                     beam=params.beam_size,
                 )
             else:
-                raise ValueError(f"Unsupported decoding method: {params.method}")
+                raise ValueError(
+                    f"Unsupported decoding method: {params.method}"
+                )
             hyp_list.append(hyp)
 
     hyps = []
@@ -327,7 +329,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/train.py b/egs/aishell/ASR/pruned_transducer_stateless2/train.py
index 66ca23035..97d892754 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless2/train.py
@@ -49,6 +49,7 @@ import optim
 import torch
 import torch.multiprocessing as mp
 import torch.nn as nn
+
 from asr_datamodule import AishellAsrDataModule
 from conformer import Conformer
 from decoder import Decoder
@@ -74,7 +75,9 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+LRSchedulerType = Union[
+    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
+]
 
 
 def add_model_arguments(parser: argparse.ArgumentParser):
@@ -200,7 +203,8 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need to be changed.",
+        help="The initial learning rate.  This value should not need "
+        "to be changed.",
     )
 
     parser.add_argument(
@@ -223,45 +227,42 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help=(
-            "The prune range for rnnt loss, it means how many symbols(context)"
-            "we are using to compute the loss"
-        ),
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help=(
-            "The scale to smooth the loss with lm (output of prediction network) part."
-        ),
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)part.",
+        help="The scale to smooth the loss with am (output of encoder network)"
+        "part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help=(
-            "To get pruning ranges, we will calculate a simple version"
-            "loss(joiner is just addition), this simple loss also uses for"
-            "training (as a regularization item). We will scale the simple loss"
-            "with this parameter before adding to the final loss."
-        ),
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
     )
 
     parser.add_argument(
@@ -560,7 +561,11 @@ def compute_loss(
      warmup: a floating point value which increases throughout training;
         values >= 1.0 are fully warmed up and have all modules present.
     """
-    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+    device = (
+        model.device
+        if isinstance(model, DDP)
+        else next(model.parameters()).device
+    )
     feature = batch["inputs"]
     # at entry, feature is (N, T, C)
     assert feature.ndim == 3
@@ -588,16 +593,23 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+            0.0
+            if warmup < 1.0
+            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+        )
+        loss = (
+            params.simple_loss_scale * simple_loss
+            + pruned_loss_scale * pruned_loss
         )
-        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
 
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -713,7 +725,9 @@ def train_one_epoch(
             scaler.update()
             optimizer.zero_grad()
         except:  # noqa
-            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
+            display_and_save_batch(
+                batch, params=params, graph_compiler=graph_compiler
+            )
             raise
 
         if params.print_diagnostics and batch_idx == 5:
@@ -877,7 +891,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2**22
+            2 ** 22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
@@ -1015,7 +1029,9 @@ def scan_pessimistic_batches_for_oom(
                     f"Failing criterion: {criterion} "
                     f"(={crit_values[criterion]}) ..."
                 )
-            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
+            display_and_save_batch(
+                batch, params=params, graph_compiler=graph_compiler
+            )
             raise
 
 
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py
index 6c505940d..d159e420b 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py
@@ -121,24 +121,20 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch' and '--iter'"
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
     )
 
     parser.add_argument(
         "--use-averaged-model",
         type=str2bool,
         default=False,
-        help=(
-            "Whether to load averaged model. Currently it only supports "
-            "using --epoch. If True, it would decode with the averaged model "
-            "over the epoch range from `epoch-avg` (excluded) to `epoch`."
-            "Actually only the models with epoch number of `epoch-avg` and "
-            "`epoch` are loaded for averaging. "
-        ),
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
     )
 
     parser.add_argument(
@@ -206,7 +202,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -266,7 +263,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
 
     if params.decoding_method == "fast_beam_search":
         hyp_tokens = fast_beam_search_one_best(
@@ -278,7 +277,10 @@ def decode_one_batch(
             max_contexts=params.max_contexts,
             max_states=params.max_states,
         )
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -322,7 +324,11 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -395,7 +401,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -421,7 +429,9 @@ def save_results(
         # we compute CER for aishell dataset.
         results_char = []
         for res in results:
-            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
+            results_char.append(
+                (res[0], list("".join(res[1])), list("".join(res[2])))
+            )
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results_char, enable_log=True
@@ -432,7 +442,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tCER", file=f)
@@ -477,7 +488,9 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        params.suffix += (
+            f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        )
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@@ -505,12 +518,13 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg:
                 raise ValueError(
@@ -537,12 +551,13 @@ def main():
             )
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg + 1
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg + 1]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg + 1:
                 raise ValueError(
@@ -571,7 +586,7 @@ def main():
             filename_start = f"{params.exp_dir}/epoch-{start}.pt"
             filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
             logging.info(
-                "Calculating the averaged model over epoch range from "
+                f"Calculating the averaged model over epoch range from "
                 f"{start} (excluded) to {params.epoch}"
             )
             model.to(device)
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/export.py b/egs/aishell/ASR/pruned_transducer_stateless3/export.py
index e5a5d7c77..566902a85 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless3/export.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/export.py
@@ -88,24 +88,20 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch' and '--iter'"
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
     )
 
     parser.add_argument(
         "--use-averaged-model",
         type=str2bool,
         default=True,
-        help=(
-            "Whether to load averaged model. Currently it only supports "
-            "using --epoch. If True, it would decode with the averaged model "
-            "over the epoch range from `epoch-avg` (excluded) to `epoch`."
-            "Actually only the models with epoch number of `epoch-avg` and "
-            "`epoch` are loaded for averaging. "
-        ),
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
     )
 
     parser.add_argument(
@@ -136,7 +132,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     add_model_arguments(parser)
@@ -169,12 +166,13 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg:
                 raise ValueError(
@@ -197,12 +195,13 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg + 1
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg + 1]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg + 1:
                 raise ValueError(
@@ -230,7 +229,7 @@ def main():
             filename_start = f"{params.exp_dir}/epoch-{start}.pt"
             filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
             logging.info(
-                "Calculating the averaged model over epoch range from "
+                f"Calculating the averaged model over epoch range from "
                 f"{start} (excluded) to {params.epoch}"
             )
             model.to(device)
@@ -253,7 +252,9 @@ def main():
         model.__class__.forward = torch.jit.ignore(model.__class__.forward)
         logging.info("Using torch.jit.script")
         model = torch.jit.script(model)
-        filename = params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
+        filename = (
+            params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
+        )
         model.save(str(filename))
         logging.info(f"Saved to {filename}")
     else:
@@ -261,14 +262,17 @@ def main():
         # Save it using a format so that it can be loaded
         # by :func:`load_checkpoint`
         filename = (
-            params.exp_dir / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
+            params.exp_dir
+            / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
         )
         torch.save({"model": model.state_dict()}, str(filename))
         logging.info(f"Saved to {filename}")
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/model.py b/egs/aishell/ASR/pruned_transducer_stateless3/model.py
index a4dda0d6d..e150e8230 100644
--- a/egs/aishell/ASR/pruned_transducer_stateless3/model.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/model.py
@@ -84,7 +84,9 @@ class Transducer(nn.Module):
         self.decoder_datatang = decoder_datatang
         self.joiner_datatang = joiner_datatang
 
-        self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5)
+        self.simple_am_proj = ScaledLinear(
+            encoder_dim, vocab_size, initial_speed=0.5
+        )
         self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
 
         if decoder_datatang is not None:
@@ -177,7 +179,9 @@ class Transducer(nn.Module):
         y_padded = y.pad(mode="constant", padding_value=0)
 
         y_padded = y_padded.to(torch.int64)
-        boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
+        boundary = torch.zeros(
+            (x.size(0), 4), dtype=torch.int64, device=x.device
+        )
         boundary[:, 2] = y_lens
         boundary[:, 3] = encoder_out_lens
 
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py
index 109879952..04a0a882a 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py
@@ -87,11 +87,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -117,12 +115,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -169,16 +165,15 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
         type=int,
         default=1,
-        help=(
-            "Maximum number of symbols per frame. "
-            "Use only when --method is greedy_search"
-        ),
+        help="Maximum number of symbols per frame. "
+        "Use only when --method is greedy_search",
     )
 
     add_model_arguments(parser)
@@ -201,9 +196,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -261,9 +257,13 @@ def main():
     feature_lens = [f.size(0) for f in features]
     feature_lens = torch.tensor(feature_lens, device=device)
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
-    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=features, x_lens=feature_lens
+    )
 
     num_waves = encoder_out.size(0)
     hyp_list = []
@@ -311,7 +311,9 @@ def main():
                     beam=params.beam_size,
                 )
             else:
-                raise ValueError(f"Unsupported decoding method: {params.method}")
+                raise ValueError(
+                    f"Unsupported decoding method: {params.method}"
+                )
             hyp_list.append(hyp)
 
     hyps = []
@@ -328,7 +330,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/train.py b/egs/aishell/ASR/pruned_transducer_stateless3/train.py
index b24f533ff..feaef5cf6 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless3/train.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/train.py
@@ -96,7 +96,9 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+LRSchedulerType = Union[
+    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
+]
 
 
 def add_model_arguments(parser: argparse.ArgumentParser):
@@ -222,7 +224,8 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need to be changed.",
+        help="The initial learning rate.  This value should not need "
+        "to be changed.",
     )
 
     parser.add_argument(
@@ -245,45 +248,42 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help=(
-            "The prune range for rnnt loss, it means how many symbols(context)"
-            "we are using to compute the loss"
-        ),
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help=(
-            "The scale to smooth the loss with lm (output of prediction network) part."
-        ),
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)part.",
+        help="The scale to smooth the loss with am (output of encoder network)"
+        "part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help=(
-            "To get pruning ranges, we will calculate a simple version"
-            "loss(joiner is just addition), this simple loss also uses for"
-            "training (as a regularization item). We will scale the simple loss"
-            "with this parameter before adding to the final loss."
-        ),
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
     )
 
     parser.add_argument(
@@ -635,7 +635,11 @@ def compute_loss(
      warmup: a floating point value which increases throughout training;
         values >= 1.0 are fully warmed up and have all modules present.
     """
-    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+    device = (
+        model.device
+        if isinstance(model, DDP)
+        else next(model.parameters()).device
+    )
     feature = batch["inputs"]
     # at entry, feature is (N, T, C)
     assert feature.ndim == 3
@@ -666,16 +670,23 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+            0.0
+            if warmup < 1.0
+            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+        )
+        loss = (
+            params.simple_loss_scale * simple_loss
+            + pruned_loss_scale * pruned_loss
         )
-        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
 
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -813,7 +824,9 @@ def train_one_epoch(
                 )
             # summary stats
             if datatang_train_dl is not None:
-                tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+                tot_loss = (
+                    tot_loss * (1 - 1 / params.reset_interval)
+                ) + loss_info
 
             if aishell:
                 aishell_tot_loss = (
@@ -834,7 +847,9 @@ def train_one_epoch(
             scaler.update()
             optimizer.zero_grad()
         except:  # noqa
-            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
+            display_and_save_batch(
+                batch, params=params, graph_compiler=graph_compiler
+            )
             raise
 
         if params.print_diagnostics and batch_idx == 5:
@@ -877,7 +892,9 @@ def train_one_epoch(
             cur_lr = scheduler.get_last_lr()[0]
             if datatang_train_dl is not None:
                 datatang_str = f"datatang_tot_loss[{datatang_tot_loss}], "
-                tot_loss_str = f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+                tot_loss_str = (
+                    f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+                )
             else:
                 tot_loss_str = ""
                 datatang_str = ""
@@ -1050,7 +1067,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2**22
+            2 ** 22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
@@ -1059,7 +1076,9 @@ def run(rank, world_size, args):
     train_cuts = filter_short_and_long_utterances(train_cuts)
 
     if args.enable_musan:
-        cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
+        cuts_musan = load_manifest(
+            Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
+        )
     else:
         cuts_musan = None
 
@@ -1074,7 +1093,9 @@ def run(rank, world_size, args):
     if params.datatang_prob > 0:
         datatang = AIDatatang200zh(manifest_dir=args.manifest_dir)
         train_datatang_cuts = datatang.train_cuts()
-        train_datatang_cuts = filter_short_and_long_utterances(train_datatang_cuts)
+        train_datatang_cuts = filter_short_and_long_utterances(
+            train_datatang_cuts
+        )
         train_datatang_cuts = train_datatang_cuts.repeat(times=None)
         datatang_train_dl = asr_datamodule.train_dataloaders(
             train_datatang_cuts,
@@ -1228,7 +1249,9 @@ def scan_pessimistic_batches_for_oom(
                     f"Failing criterion: {criterion} "
                     f"(={crit_values[criterion]}) ..."
                 )
-            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
+            display_and_save_batch(
+                batch, params=params, graph_compiler=graph_compiler
+            )
             raise
 
 
diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py
index 12ae6e7d4..d24ba6bb7 100644
--- a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py
+++ b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -64,12 +64,10 @@ class AishellAsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description=(
-                "These options are used for the preparation of "
-                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-                "effective batch sizes, sampling strategies, applied data "
-                "augmentations, etc."
-            ),
+            description="These options are used for the preparation of "
+            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+            "effective batch sizes, sampling strategies, applied data "
+            "augmentations, etc.",
         )
         group.add_argument(
             "--manifest-dir",
@@ -81,74 +79,59 @@ class AishellAsrDataModule:
             "--max-duration",
             type=int,
             default=200.0,
-            help=(
-                "Maximum pooled recordings duration (seconds) in a "
-                "single batch. You can reduce it if it causes CUDA OOM."
-            ),
+            help="Maximum pooled recordings duration (seconds) in a "
+            "single batch. You can reduce it if it causes CUDA OOM.",
         )
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, the batches will come from buckets of "
-                "similar duration (saves padding frames)."
-            ),
+            help="When enabled, the batches will come from buckets of "
+            "similar duration (saves padding frames).",
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=30,
-            help=(
-                "The number of buckets for the DynamicBucketingSampler"
-                "(you might want to increase it for larger datasets)."
-            ),
+            help="The number of buckets for the DynamicBucketingSampler"
+            "(you might want to increase it for larger datasets).",
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, utterances (cuts) will be concatenated "
-                "to minimize the amount of padding."
-            ),
+            help="When enabled, utterances (cuts) will be concatenated "
+            "to minimize the amount of padding.",
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help=(
-                "Determines the maximum duration of a concatenated cut "
-                "relative to the duration of the longest cut in a batch."
-            ),
+            help="Determines the maximum duration of a concatenated cut "
+            "relative to the duration of the longest cut in a batch.",
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help=(
-                "The amount of padding (in seconds) inserted between "
-                "concatenated cuts. This padding is filled with noise when "
-                "noise augmentation is used."
-            ),
+            help="The amount of padding (in seconds) inserted between "
+            "concatenated cuts. This padding is filled with noise when "
+            "noise augmentation is used.",
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, use on-the-fly cut mixing and feature "
-                "extraction. Will drop existing precomputed feature manifests "
-                "if available."
-            ),
+            help="When enabled, use on-the-fly cut mixing and feature "
+            "extraction. Will drop existing precomputed feature manifests "
+            "if available.",
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled (=default), the examples will be shuffled for each epoch."
-            ),
+            help="When enabled (=default), the examples will be "
+            "shuffled for each epoch.",
         )
         group.add_argument(
             "--drop-last",
@@ -160,18 +143,17 @@ class AishellAsrDataModule:
             "--return-cuts",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, each batch will have the "
-                "field: batch['supervisions']['cut'] with the cuts that "
-                "were used to construct it."
-            ),
+            help="When enabled, each batch will have the "
+            "field: batch['supervisions']['cut'] with the cuts that "
+            "were used to construct it.",
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that collect the batches.",
+            help="The number of training dataloader workers that "
+            "collect the batches.",
         )
 
         group.add_argument(
@@ -185,40 +167,40 @@ class AishellAsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help=(
-                "Used only when --enable-spec-aug is True. "
-                "It specifies the factor for time warping in SpecAugment. "
-                "Larger values mean more warping. "
-                "A value less than 1 means to disable time warp."
-            ),
+            help="Used only when --enable-spec-aug is True. "
+            "It specifies the factor for time warping in SpecAugment. "
+            "Larger values mean more warping. "
+            "A value less than 1 means to disable time warp.",
         )
 
         group.add_argument(
             "--enable-musan",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, select noise from MUSAN and mix it"
-                "with training dataset. "
-            ),
+            help="When enabled, select noise from MUSAN and mix it"
+            "with training dataset. ",
         )
 
     def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+        cuts_musan = load_manifest(
+            self.args.manifest_dir / "musan_cuts.jsonl.gz"
+        )
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+                CutMix(
+                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
+                )
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                "Using cut concatenation with duration factor "
+                f"Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -233,7 +215,9 @@ class AishellAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+            logging.info(
+                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
+            )
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -276,7 +260,9 @@ class AishellAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -322,7 +308,9 @@ class AishellAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -378,9 +366,13 @@ class AishellAsrDataModule:
     @lru_cache()
     def valid_cuts(self) -> CutSet:
         logging.info("About to get dev cuts")
-        return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz")
+        return load_manifest_lazy(
+            self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz"
+        )
 
     @lru_cache()
     def test_cuts(self) -> List[CutSet]:
         logging.info("About to get test cuts")
-        return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_test.jsonl.gz")
+        return load_manifest_lazy(
+            self.args.manifest_dir / "aishell_cuts_test.jsonl.gz"
+        )
diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py
index 8ef247438..66b734fc4 100755
--- a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py
+++ b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py
@@ -49,19 +49,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=19,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=5,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
     parser.add_argument(
         "--method",
@@ -268,7 +265,9 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -290,7 +289,9 @@ def save_results(
         # We compute CER for aishell dataset.
         results_char = []
         for res in results:
-            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
+            results_char.append(
+                (res[0], list("".join(res[1])), list("".join(res[2])))
+            )
         with open(errs_filename, "w") as f:
             wer = write_error_stats(f, f"{test_set_name}-{key}", results_char)
             test_set_wers[key] = wer
@@ -334,7 +335,9 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
+    HLG = k2.Fsa.from_dict(
+        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
+    )
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -359,7 +362,9 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
+        torch.save(
+            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
+        )
 
     model.to(device)
     model.eval()
@@ -387,7 +392,9 @@ def main():
             lexicon=lexicon,
         )
 
-        save_results(params=params, test_set_name=test_set, results_dict=results_dict)
+        save_results(
+            params=params, test_set_name=test_set, results_dict=results_dict
+        )
 
     logging.info("Done!")
 
diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/model.py b/egs/aishell/ASR/tdnn_lstm_ctc/model.py
index 1731e1ebe..5e04c11b4 100644
--- a/egs/aishell/ASR/tdnn_lstm_ctc/model.py
+++ b/egs/aishell/ASR/tdnn_lstm_ctc/model.py
@@ -66,7 +66,10 @@ class TdnnLstm(nn.Module):
             nn.BatchNorm1d(num_features=500, affine=False),
         )
         self.lstms = nn.ModuleList(
-            [nn.LSTM(input_size=500, hidden_size=500, num_layers=1) for _ in range(5)]
+            [
+                nn.LSTM(input_size=500, hidden_size=500, num_layers=1)
+                for _ in range(5)
+            ]
         )
         self.lstm_bnorms = nn.ModuleList(
             [nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)]
diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py
index 52f9410cf..9bd810809 100644
--- a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py
+++ b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py
@@ -41,11 +41,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -55,7 +53,9 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
+    parser.add_argument(
+        "--HLG", type=str, required=True, help="Path to HLG.pt."
+    )
 
     parser.add_argument(
         "--method",
@@ -71,12 +71,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     return parser
@@ -114,9 +112,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -174,7 +173,9 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
     features = features.permute(0, 2, 1)  # now features is [N, C, T]
 
     with torch.no_grad():
@@ -218,7 +219,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/train.py b/egs/aishell/ASR/tdnn_lstm_ctc/train.py
index e574cf89b..7619b0551 100755
--- a/egs/aishell/ASR/tdnn_lstm_ctc/train.py
+++ b/egs/aishell/ASR/tdnn_lstm_ctc/train.py
@@ -49,7 +49,12 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
 from icefall.dist import cleanup_dist, setup_dist
 from icefall.graph_compiler import CtcTrainingGraphCompiler
 from icefall.lexicon import Lexicon
-from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool
+from icefall.utils import (
+    AttributeDict,
+    encode_supervisions,
+    setup_logger,
+    str2bool,
+)
 
 
 def get_parser():
diff --git a/egs/aishell/ASR/transducer_stateless/beam_search.py b/egs/aishell/ASR/transducer_stateless/beam_search.py
index de0a8d0f5..9ed9b2ad1 100644
--- a/egs/aishell/ASR/transducer_stateless/beam_search.py
+++ b/egs/aishell/ASR/transducer_stateless/beam_search.py
@@ -47,9 +47,9 @@ def greedy_search(
 
     device = model.device
 
-    decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
-        1, context_size
-    )
+    decoder_input = torch.tensor(
+        [blank_id] * context_size, device=device
+    ).reshape(1, context_size)
 
     decoder_out = model.decoder(decoder_input, need_pad=False)
 
@@ -81,9 +81,9 @@ def greedy_search(
         y = logits.argmax().item()
         if y != blank_id:
             hyp.append(y)
-            decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape(
-                1, context_size
-            )
+            decoder_input = torch.tensor(
+                [hyp[-context_size:]], device=device
+            ).reshape(1, context_size)
 
             decoder_out = model.decoder(decoder_input, need_pad=False)
 
@@ -157,7 +157,9 @@ class HypothesisList(object):
 
         """
         if length_norm:
-            return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
+            return max(
+                self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
+            )
         else:
             return max(self._data.values(), key=lambda hyp: hyp.log_prob)
 
@@ -244,9 +246,9 @@ def beam_search(
 
     device = model.device
 
-    decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
-        1, context_size
-    )
+    decoder_input = torch.tensor(
+        [blank_id] * context_size, device=device
+    ).reshape(1, context_size)
 
     decoder_out = model.decoder(decoder_input, need_pad=False)
 
diff --git a/egs/aishell/ASR/transducer_stateless/conformer.py b/egs/aishell/ASR/transducer_stateless/conformer.py
index e26c6c385..64114253d 100644
--- a/egs/aishell/ASR/transducer_stateless/conformer.py
+++ b/egs/aishell/ASR/transducer_stateless/conformer.py
@@ -155,7 +155,9 @@ class ConformerEncoderLayer(nn.Module):
         normalize_before: bool = True,
     ) -> None:
         super(ConformerEncoderLayer, self).__init__()
-        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
+        self.self_attn = RelPositionMultiheadAttention(
+            d_model, nhead, dropout=0.0
+        )
 
         self.feed_forward = nn.Sequential(
             nn.Linear(d_model, dim_feedforward),
@@ -173,14 +175,18 @@ class ConformerEncoderLayer(nn.Module):
 
         self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
 
-        self.norm_ff_macaron = nn.LayerNorm(d_model)  # for the macaron style FNN module
+        self.norm_ff_macaron = nn.LayerNorm(
+            d_model
+        )  # for the macaron style FNN module
         self.norm_ff = nn.LayerNorm(d_model)  # for the FNN module
         self.norm_mha = nn.LayerNorm(d_model)  # for the MHA module
 
         self.ff_scale = 0.5
 
         self.norm_conv = nn.LayerNorm(d_model)  # for the CNN module
-        self.norm_final = nn.LayerNorm(d_model)  # for the final output of the block
+        self.norm_final = nn.LayerNorm(
+            d_model
+        )  # for the final output of the block
 
         self.dropout = nn.Dropout(dropout)
 
@@ -214,7 +220,9 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
+        src = residual + self.ff_scale * self.dropout(
+            self.feed_forward_macaron(src)
+        )
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -333,7 +341,9 @@ class RelPositionalEncoding(torch.nn.Module):
 
     """
 
-    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
+    def __init__(
+        self, d_model: int, dropout_rate: float, max_len: int = 5000
+    ) -> None:
         """Construct an PositionalEncoding object."""
         super(RelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -349,7 +359,9 @@ class RelPositionalEncoding(torch.nn.Module):
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(1) >= x.size(1) * 2 - 1:
                 # Note: TorchScript doesn't implement operator== for torch.Device
-                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
+                    x.device
+                ):
                     self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                 return
         # Suppose `i` means to the position of query vector and `j` means the
@@ -619,9 +631,9 @@ class RelPositionMultiheadAttention(nn.Module):
 
         if torch.equal(query, key) and torch.equal(key, value):
             # self-attention
-            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
-                3, dim=-1
-            )
+            q, k, v = nn.functional.linear(
+                query, in_proj_weight, in_proj_bias
+            ).chunk(3, dim=-1)
 
         elif torch.equal(key, value):
             # encoder-decoder attention
@@ -689,25 +701,33 @@ class RelPositionMultiheadAttention(nn.Module):
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0)
                 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
-                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
+                    raise RuntimeError(
+                        "The size of the 2D attn_mask is not correct."
+                    )
             elif attn_mask.dim() == 3:
                 if list(attn_mask.size()) != [
                     bsz * num_heads,
                     query.size(0),
                     key.size(0),
                 ]:
-                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
+                    raise RuntimeError(
+                        "The size of the 3D attn_mask is not correct."
+                    )
             else:
                 raise RuntimeError(
-                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
+                    "attn_mask's dimension {} is not supported".format(
+                        attn_mask.dim()
+                    )
                 )
             # attn_mask's dim is 3 now.
 
         # convert ByteTensor key_padding_mask to bool
-        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
+        if (
+            key_padding_mask is not None
+            and key_padding_mask.dtype == torch.uint8
+        ):
             warnings.warn(
-                "Byte tensor for key_padding_mask is deprecated. Use bool tensor"
-                " instead."
+                "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
             )
             key_padding_mask = key_padding_mask.to(torch.bool)
 
@@ -744,7 +764,9 @@ class RelPositionMultiheadAttention(nn.Module):
         # first compute matrix a and matrix c
         # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
         k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
-        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(
+            q_with_bias_u, k
+        )  # (batch, head, time1, time2)
 
         # compute matrix b and matrix d
         matrix_bd = torch.matmul(
@@ -756,7 +778,9 @@ class RelPositionMultiheadAttention(nn.Module):
             matrix_ac + matrix_bd
         ) * scaling  # (batch, head, time1, time2)
 
-        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
+        attn_output_weights = attn_output_weights.view(
+            bsz * num_heads, tgt_len, -1
+        )
 
         assert list(attn_output_weights.size()) == [
             bsz * num_heads,
@@ -790,9 +814,13 @@ class RelPositionMultiheadAttention(nn.Module):
         attn_output = torch.bmm(attn_output_weights, v)
         assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
         attn_output = (
-            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+            attn_output.transpose(0, 1)
+            .contiguous()
+            .view(tgt_len, bsz, embed_dim)
+        )
+        attn_output = nn.functional.linear(
+            attn_output, out_proj_weight, out_proj_bias
         )
-        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
 
         if need_weights:
             # average attention weights over heads
@@ -815,7 +843,9 @@ class ConvolutionModule(nn.Module):
 
     """
 
-    def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None:
+    def __init__(
+        self, channels: int, kernel_size: int, bias: bool = True
+    ) -> None:
         """Construct an ConvolutionModule object."""
         super(ConvolutionModule, self).__init__()
         # kernerl_size should be a odd number for 'SAME' padding
diff --git a/egs/aishell/ASR/transducer_stateless/decode.py b/egs/aishell/ASR/transducer_stateless/decode.py
index 1f7bb14e1..780b0c4bb 100755
--- a/egs/aishell/ASR/transducer_stateless/decode.py
+++ b/egs/aishell/ASR/transducer_stateless/decode.py
@@ -52,19 +52,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=30,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=10,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -102,7 +99,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -229,7 +227,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
     hyps = []
     batch_size = encoder_out.size(0)
 
@@ -248,7 +248,9 @@ def decode_one_batch(
                 model=model, encoder_out=encoder_out_i, beam=params.beam_size
             )
         else:
-            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
+            raise ValueError(
+                f"Unsupported decoding method: {params.decoding_method}"
+            )
         hyps.append([lexicon.token_table[i] for i in hyp])
 
     if params.decoding_method == "greedy_search":
@@ -317,7 +319,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -342,7 +346,9 @@ def save_results(
         # we compute CER for aishell dataset.
         results_char = []
         for res in results:
-            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
+            results_char.append(
+                (res[0], list("".join(res[1])), list("".join(res[2])))
+            )
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results_char, enable_log=True
@@ -353,7 +359,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tCER", file=f)
@@ -423,7 +430,9 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
+        torch.save(
+            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
+        )
         return
 
     model.to(device)
diff --git a/egs/aishell/ASR/transducer_stateless/decoder.py b/egs/aishell/ASR/transducer_stateless/decoder.py
index 70e9e6c96..c2c6552a9 100644
--- a/egs/aishell/ASR/transducer_stateless/decoder.py
+++ b/egs/aishell/ASR/transducer_stateless/decoder.py
@@ -86,7 +86,9 @@ class Decoder(nn.Module):
         if self.context_size > 1:
             embedding_out = embedding_out.permute(0, 2, 1)
             if need_pad is True:
-                embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
+                embedding_out = F.pad(
+                    embedding_out, pad=(self.context_size - 1, 0)
+                )
             else:
                 # During inference time, there is no need to do extra padding
                 # as we only need one output
diff --git a/egs/aishell/ASR/transducer_stateless/export.py b/egs/aishell/ASR/transducer_stateless/export.py
index e35b26fe0..4c6519b96 100755
--- a/egs/aishell/ASR/transducer_stateless/export.py
+++ b/egs/aishell/ASR/transducer_stateless/export.py
@@ -69,20 +69,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=20,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=10,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -113,7 +110,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     return parser
@@ -245,7 +243,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/transducer_stateless/model.py b/egs/aishell/ASR/transducer_stateless/model.py
index 591bbe44f..994305fc1 100644
--- a/egs/aishell/ASR/transducer_stateless/model.py
+++ b/egs/aishell/ASR/transducer_stateless/model.py
@@ -103,7 +103,9 @@ class Transducer(nn.Module):
         y_padded = y.pad(mode="constant", padding_value=0)
 
         y_padded = y_padded.to(torch.int64)
-        boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
+        boundary = torch.zeros(
+            (x.size(0), 4), dtype=torch.int64, device=x.device
+        )
         boundary[:, 2] = y_lens
         boundary[:, 3] = x_lens
 
diff --git a/egs/aishell/ASR/transducer_stateless/pretrained.py b/egs/aishell/ASR/transducer_stateless/pretrained.py
index 8effc9815..db89c4d67 100755
--- a/egs/aishell/ASR/transducer_stateless/pretrained.py
+++ b/egs/aishell/ASR/transducer_stateless/pretrained.py
@@ -73,11 +73,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -102,12 +100,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -121,7 +117,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -214,9 +211,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -275,7 +273,9 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -319,7 +319,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/transducer_stateless/train.py b/egs/aishell/ASR/transducer_stateless/train.py
index 62ffff473..d54157709 100755
--- a/egs/aishell/ASR/transducer_stateless/train.py
+++ b/egs/aishell/ASR/transducer_stateless/train.py
@@ -126,7 +126,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -388,7 +389,9 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -501,7 +504,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -620,7 +625,9 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
+            tb_writer.add_scalar(
+                "train/learning_rate", cur_lr, params.batch_idx_train
+            )
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/aishell/ASR/transducer_stateless/transformer.py b/egs/aishell/ASR/transducer_stateless/transformer.py
index b3ff153c1..e851dcc32 100644
--- a/egs/aishell/ASR/transducer_stateless/transformer.py
+++ b/egs/aishell/ASR/transducer_stateless/transformer.py
@@ -250,7 +250,9 @@ def _get_activation_fn(activation: str):
     elif activation == "gelu":
         return nn.functional.gelu
 
-    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
+    raise RuntimeError(
+        "activation should be relu/gelu, not {}".format(activation)
+    )
 
 
 class PositionalEncoding(nn.Module):
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py b/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py
index 76e209f06..838e53658 100644
--- a/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py
@@ -29,7 +29,10 @@ from lhotse.dataset import (
     K2SpeechRecognitionDataset,
     SpecAugment,
 )
-from lhotse.dataset.input_strategies import OnTheFlyFeatures, PrecomputedFeatures
+from lhotse.dataset.input_strategies import (
+    OnTheFlyFeatures,
+    PrecomputedFeatures,
+)
 from torch.utils.data import DataLoader
 
 from icefall.utils import str2bool
@@ -43,69 +46,59 @@ class AsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description=(
-                "These options are used for the preparation of "
-                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-                "effective batch sizes, sampling strategies, applied data "
-                "augmentations, etc."
-            ),
+            description="These options are used for the preparation of "
+            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+            "effective batch sizes, sampling strategies, applied data "
+            "augmentations, etc.",
         )
 
         group.add_argument(
             "--max-duration",
             type=int,
             default=200.0,
-            help=(
-                "Maximum pooled recordings duration (seconds) in a "
-                "single batch. You can reduce it if it causes CUDA OOM."
-            ),
+            help="Maximum pooled recordings duration (seconds) in a "
+            "single batch. You can reduce it if it causes CUDA OOM.",
         )
 
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, the batches will come from buckets of "
-                "similar duration (saves padding frames)."
-            ),
+            help="When enabled, the batches will come from buckets of "
+            "similar duration (saves padding frames).",
         )
 
         group.add_argument(
             "--num-buckets",
             type=int,
             default=30,
-            help=(
-                "The number of buckets for the DynamicBucketingSampler "
-                "(you might want to increase it for larger datasets)."
-            ),
+            help="The number of buckets for the DynamicBucketingSampler "
+            "(you might want to increase it for larger datasets).",
         )
 
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled (=default), the examples will be shuffled for each epoch."
-            ),
+            help="When enabled (=default), the examples will be "
+            "shuffled for each epoch.",
         )
 
         group.add_argument(
             "--return-cuts",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, each batch will have the "
-                "field: batch['supervisions']['cut'] with the cuts that "
-                "were used to construct it."
-            ),
+            help="When enabled, each batch will have the "
+            "field: batch['supervisions']['cut'] with the cuts that "
+            "were used to construct it.",
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that collect the batches.",
+            help="The number of training dataloader workers that "
+            "collect the batches.",
         )
 
         group.add_argument(
@@ -119,22 +112,18 @@ class AsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help=(
-                "Used only when --enable-spec-aug is True. "
-                "It specifies the factor for time warping in SpecAugment. "
-                "Larger values mean more warping. "
-                "A value less than 1 means to disable time warp."
-            ),
+            help="Used only when --enable-spec-aug is True. "
+            "It specifies the factor for time warping in SpecAugment. "
+            "Larger values mean more warping. "
+            "A value less than 1 means to disable time warp.",
         )
 
         group.add_argument(
             "--enable-musan",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, select noise from MUSAN and mix it"
-                "with training dataset. "
-            ),
+            help="When enabled, select noise from MUSAN and mix it"
+            "with training dataset. ",
         )
 
         group.add_argument(
@@ -148,11 +137,9 @@ class AsrDataModule:
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, use on-the-fly cut mixing and feature "
-                "extraction. Will drop existing precomputed feature manifests "
-                "if available. Used only in dev/test CutSet"
-            ),
+            help="When enabled, use on-the-fly cut mixing and feature "
+            "extraction. Will drop existing precomputed feature manifests "
+            "if available. Used only in dev/test CutSet",
         )
 
     def train_dataloaders(
@@ -175,7 +162,9 @@ class AsrDataModule:
         if cuts_musan is not None:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+                CutMix(
+                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
+                )
             )
         else:
             logging.info("Disable MUSAN")
@@ -184,7 +173,9 @@ class AsrDataModule:
 
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+            logging.info(
+                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
+            )
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -261,7 +252,9 @@ class AsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 return_cuts=self.args.return_cuts,
             )
         else:
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py
index fd4cb8385..ea3f94fd8 100755
--- a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py
@@ -93,19 +93,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=30,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=10,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -173,7 +170,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -229,7 +227,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
 
     if params.decoding_method == "fast_beam_search":
         hyp_tokens = fast_beam_search_one_best(
@@ -241,7 +241,10 @@ def decode_one_batch(
             max_contexts=params.max_contexts,
             max_states=params.max_states,
         )
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -285,7 +288,11 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -358,7 +365,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -384,7 +393,9 @@ def save_results(
         # we compute CER for aishell dataset.
         results_char = []
         for res in results:
-            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
+            results_char.append(
+                (res[0], list("".join(res[1])), list("".join(res[2])))
+            )
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results_char, enable_log=True
@@ -395,7 +406,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tCER", file=f)
@@ -436,7 +448,9 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        params.suffix += (
+            f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        )
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/export.py b/egs/aishell/ASR/transducer_stateless_modified-2/export.py
index 32481829c..3bd2ceb11 100755
--- a/egs/aishell/ASR/transducer_stateless_modified-2/export.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/export.py
@@ -68,20 +68,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=20,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=10,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -112,7 +109,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     return parser
@@ -243,7 +241,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py
index 55701a007..a95a4bc52 100755
--- a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py
@@ -87,11 +87,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -117,12 +115,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -169,16 +165,15 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
         type=int,
         default=1,
-        help=(
-            "Maximum number of symbols per frame. "
-            "Use only when --method is greedy_search"
-        ),
+        help="Maximum number of symbols per frame. "
+        "Use only when --method is greedy_search",
     )
 
     return parser
@@ -199,9 +194,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -258,9 +254,13 @@ def main():
     feature_lens = [f.size(0) for f in features]
     feature_lens = torch.tensor(feature_lens, device=device)
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
-    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=features, x_lens=feature_lens
+    )
 
     num_waves = encoder_out.size(0)
     hyp_list = []
@@ -308,7 +308,9 @@ def main():
                     beam=params.beam_size,
                 )
             else:
-                raise ValueError(f"Unsupported decoding method: {params.method}")
+                raise ValueError(
+                    f"Unsupported decoding method: {params.method}"
+                )
             hyp_list.append(hyp)
 
     hyps = []
@@ -325,7 +327,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/train.py b/egs/aishell/ASR/transducer_stateless_modified-2/train.py
index 8fb7d1e49..225d0d709 100755
--- a/egs/aishell/ASR/transducer_stateless_modified-2/train.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/train.py
@@ -149,7 +149,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -167,7 +168,8 @@ def get_parser():
         "--datatang-prob",
         type=float,
         default=0.2,
-        help="The probability to select a batch from the aidatatang_200zh dataset",
+        help="The probability to select a batch from the "
+        "aidatatang_200zh dataset",
     )
 
     return parser
@@ -447,7 +449,9 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -601,7 +605,9 @@ def train_one_epoch(
                     f"train/current_{prefix}_",
                     params.batch_idx_train,
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
                 aishell_tot_loss.write_summary(
                     tb_writer, "train/aishell_tot_", params.batch_idx_train
                 )
@@ -729,7 +735,9 @@ def run(rank, world_size, args):
     train_datatang_cuts = train_datatang_cuts.repeat(times=None)
 
     if args.enable_musan:
-        cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
+        cuts_musan = load_manifest(
+            Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
+        )
     else:
         cuts_musan = None
 
@@ -768,7 +776,9 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
+            tb_writer.add_scalar(
+                "train/learning_rate", cur_lr, params.batch_idx_train
+            )
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/aishell/ASR/transducer_stateless_modified/decode.py b/egs/aishell/ASR/transducer_stateless_modified/decode.py
index 1e41942da..65fcda873 100755
--- a/egs/aishell/ASR/transducer_stateless_modified/decode.py
+++ b/egs/aishell/ASR/transducer_stateless_modified/decode.py
@@ -94,19 +94,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=30,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=10,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -174,7 +171,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -233,7 +231,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
 
     if params.decoding_method == "fast_beam_search":
         hyp_tokens = fast_beam_search_one_best(
@@ -245,7 +245,10 @@ def decode_one_batch(
             max_contexts=params.max_contexts,
             max_states=params.max_states,
         )
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -289,7 +292,11 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -362,7 +369,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -388,7 +397,9 @@ def save_results(
         # we compute CER for aishell dataset.
         results_char = []
         for res in results:
-            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
+            results_char.append(
+                (res[0], list("".join(res[1])), list("".join(res[2])))
+            )
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results_char, enable_log=True
@@ -399,7 +410,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tCER", file=f)
@@ -440,7 +452,9 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        params.suffix += (
+            f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        )
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
diff --git a/egs/aishell/ASR/transducer_stateless_modified/export.py b/egs/aishell/ASR/transducer_stateless_modified/export.py
index ca1d4bd4a..11335a834 100755
--- a/egs/aishell/ASR/transducer_stateless_modified/export.py
+++ b/egs/aishell/ASR/transducer_stateless_modified/export.py
@@ -68,20 +68,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=20,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=10,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -112,7 +109,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     return parser
@@ -243,7 +241,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py
index 038090461..262e822c2 100755
--- a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py
+++ b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py
@@ -87,11 +87,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -117,12 +115,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -169,16 +165,15 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
         type=int,
         default=1,
-        help=(
-            "Maximum number of symbols per frame. "
-            "Use only when --method is greedy_search"
-        ),
+        help="Maximum number of symbols per frame. "
+        "Use only when --method is greedy_search",
     )
 
     return parser
@@ -199,9 +194,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -258,9 +254,13 @@ def main():
     feature_lens = [f.size(0) for f in features]
     feature_lens = torch.tensor(feature_lens, device=device)
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
-    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=features, x_lens=feature_lens
+    )
 
     num_waves = encoder_out.size(0)
     hyp_list = []
@@ -308,7 +308,9 @@ def main():
                     beam=params.beam_size,
                 )
             else:
-                raise ValueError(f"Unsupported decoding method: {params.method}")
+                raise ValueError(
+                    f"Unsupported decoding method: {params.method}"
+                )
             hyp_list.append(hyp)
 
     hyps = []
@@ -325,7 +327,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/transducer_stateless_modified/train.py b/egs/aishell/ASR/transducer_stateless_modified/train.py
index 5f116f2bd..d3ffccafa 100755
--- a/egs/aishell/ASR/transducer_stateless_modified/train.py
+++ b/egs/aishell/ASR/transducer_stateless_modified/train.py
@@ -142,7 +142,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -413,7 +414,9 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -526,7 +529,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -652,7 +657,9 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
+            tb_writer.add_scalar(
+                "train/learning_rate", cur_lr, params.batch_idx_train
+            )
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/aishell2/ASR/local/__init__.py b/egs/aishell2/ASR/local/__init__.py
old mode 100644
new mode 100755
diff --git a/egs/aishell2/ASR/local/compute_fbank_aishell2.py b/egs/aishell2/ASR/local/compute_fbank_aishell2.py
index ec0c584ca..d8d3622bd 100755
--- a/egs/aishell2/ASR/local/compute_fbank_aishell2.py
+++ b/egs/aishell2/ASR/local/compute_fbank_aishell2.py
@@ -83,7 +83,9 @@ def compute_fbank_aishell2(num_mel_bins: int = 80):
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+                    cut_set
+                    + cut_set.perturb_speed(0.9)
+                    + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -109,7 +111,9 @@ def get_args():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/__init__.py b/egs/aishell2/ASR/pruned_transducer_stateless5/__init__.py
old mode 100644
new mode 100755
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py
old mode 100644
new mode 100755
index e8966b554..b7a21f579
--- a/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py
@@ -76,12 +76,10 @@ class AiShell2AsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description=(
-                "These options are used for the preparation of "
-                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-                "effective batch sizes, sampling strategies, applied data "
-                "augmentations, etc."
-            ),
+            description="These options are used for the preparation of "
+            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+            "effective batch sizes, sampling strategies, applied data "
+            "augmentations, etc.",
         )
         group.add_argument(
             "--manifest-dir",
@@ -93,74 +91,59 @@ class AiShell2AsrDataModule:
             "--max-duration",
             type=int,
             default=200.0,
-            help=(
-                "Maximum pooled recordings duration (seconds) in a "
-                "single batch. You can reduce it if it causes CUDA OOM."
-            ),
+            help="Maximum pooled recordings duration (seconds) in a "
+            "single batch. You can reduce it if it causes CUDA OOM.",
         )
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, the batches will come from buckets of "
-                "similar duration (saves padding frames)."
-            ),
+            help="When enabled, the batches will come from buckets of "
+            "similar duration (saves padding frames).",
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=30,
-            help=(
-                "The number of buckets for the DynamicBucketingSampler"
-                "(you might want to increase it for larger datasets)."
-            ),
+            help="The number of buckets for the DynamicBucketingSampler"
+            "(you might want to increase it for larger datasets).",
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, utterances (cuts) will be concatenated "
-                "to minimize the amount of padding."
-            ),
+            help="When enabled, utterances (cuts) will be concatenated "
+            "to minimize the amount of padding.",
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help=(
-                "Determines the maximum duration of a concatenated cut "
-                "relative to the duration of the longest cut in a batch."
-            ),
+            help="Determines the maximum duration of a concatenated cut "
+            "relative to the duration of the longest cut in a batch.",
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help=(
-                "The amount of padding (in seconds) inserted between "
-                "concatenated cuts. This padding is filled with noise when "
-                "noise augmentation is used."
-            ),
+            help="The amount of padding (in seconds) inserted between "
+            "concatenated cuts. This padding is filled with noise when "
+            "noise augmentation is used.",
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, use on-the-fly cut mixing and feature "
-                "extraction. Will drop existing precomputed feature manifests "
-                "if available."
-            ),
+            help="When enabled, use on-the-fly cut mixing and feature "
+            "extraction. Will drop existing precomputed feature manifests "
+            "if available.",
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled (=default), the examples will be shuffled for each epoch."
-            ),
+            help="When enabled (=default), the examples will be "
+            "shuffled for each epoch.",
         )
         group.add_argument(
             "--drop-last",
@@ -172,18 +155,17 @@ class AiShell2AsrDataModule:
             "--return-cuts",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, each batch will have the "
-                "field: batch['supervisions']['cut'] with the cuts that "
-                "were used to construct it."
-            ),
+            help="When enabled, each batch will have the "
+            "field: batch['supervisions']['cut'] with the cuts that "
+            "were used to construct it.",
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that collect the batches.",
+            help="The number of training dataloader workers that "
+            "collect the batches.",
         )
 
         group.add_argument(
@@ -197,22 +179,18 @@ class AiShell2AsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help=(
-                "Used only when --enable-spec-aug is True. "
-                "It specifies the factor for time warping in SpecAugment. "
-                "Larger values mean more warping. "
-                "A value less than 1 means to disable time warp."
-            ),
+            help="Used only when --enable-spec-aug is True. "
+            "It specifies the factor for time warping in SpecAugment. "
+            "Larger values mean more warping. "
+            "A value less than 1 means to disable time warp.",
         )
 
         group.add_argument(
             "--enable-musan",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, select noise from MUSAN and mix it"
-                "with training dataset. "
-            ),
+            help="When enabled, select noise from MUSAN and mix it"
+            "with training dataset. ",
         )
 
         group.add_argument(
@@ -238,16 +216,20 @@ class AiShell2AsrDataModule:
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             logging.info("About to get Musan cuts")
-            cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+            cuts_musan = load_manifest(
+                self.args.manifest_dir / "musan_cuts.jsonl.gz"
+            )
             transforms.append(
-                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+                CutMix(
+                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
+                )
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                "Using cut concatenation with duration factor "
+                f"Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -262,7 +244,9 @@ class AiShell2AsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+            logging.info(
+                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
+            )
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -306,7 +290,9 @@ class AiShell2AsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -362,7 +348,9 @@ class AiShell2AsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -418,7 +406,9 @@ class AiShell2AsrDataModule:
     @lru_cache()
     def valid_cuts(self) -> CutSet:
         logging.info("About to gen cuts from aishell2_cuts_dev.jsonl.gz")
-        return load_manifest_lazy(self.args.manifest_dir / "aishell2_cuts_dev.jsonl.gz")
+        return load_manifest_lazy(
+            self.args.manifest_dir / "aishell2_cuts_dev.jsonl.gz"
+        )
 
     @lru_cache()
     def test_cuts(self) -> CutSet:
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py
index 64b64d1b1..915737f4a 100755
--- a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py
@@ -168,24 +168,20 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch' and '--iter'"
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
     )
 
     parser.add_argument(
         "--use-averaged-model",
         type=str2bool,
         default=True,
-        help=(
-            "Whether to load averaged model. Currently it only supports "
-            "using --epoch. If True, it would decode with the averaged model "
-            "over the epoch range from `epoch-avg` (excluded) to `epoch`."
-            "Actually only the models with epoch number of `epoch-avg` and "
-            "`epoch` are loaded for averaging. "
-        ),
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
     )
 
     parser.add_argument(
@@ -273,7 +269,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -351,7 +348,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -410,7 +409,10 @@ def decode_one_batch(
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -536,7 +538,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -569,7 +573,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -620,7 +625,9 @@ def main():
             if "LG" in params.decoding_method:
                 params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        params.suffix += (
+            f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        )
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@@ -654,12 +661,13 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg:
                 raise ValueError(
@@ -682,12 +690,13 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg + 1
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg + 1]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg + 1:
                 raise ValueError(
@@ -715,7 +724,7 @@ def main():
             filename_start = f"{params.exp_dir}/epoch-{start}.pt"
             filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
             logging.info(
-                "Calculating the averaged model over epoch range from "
+                f"Calculating the averaged model over epoch range from "
                 f"{start} (excluded) to {params.epoch}"
             )
             model.to(device)
@@ -740,7 +749,9 @@ def main():
             )
             decoding_graph.scores *= params.ngram_lm_scale
         else:
-            decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+            decoding_graph = k2.trivial_graph(
+                params.vocab_size - 1, device=device
+            )
     else:
         decoding_graph = None
 
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py
index 547ce2069..bc7bd71cb 100755
--- a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py
@@ -89,24 +89,20 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch' and '--iter'"
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
     )
 
     parser.add_argument(
         "--use-averaged-model",
         type=str2bool,
         default=False,
-        help=(
-            "Whether to load averaged model. Currently it only supports "
-            "using --epoch. If True, it would decode with the averaged model "
-            "over the epoch range from `epoch-avg` (excluded) to `epoch`."
-            "Actually only the models with epoch number of `epoch-avg` and "
-            "`epoch` are loaded for averaging. "
-        ),
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
     )
 
     parser.add_argument(
@@ -137,7 +133,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     add_model_arguments(parser)
@@ -170,12 +167,13 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg:
                 raise ValueError(
@@ -198,12 +196,13 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg + 1
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg + 1]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg + 1:
                 raise ValueError(
@@ -231,7 +230,7 @@ def main():
             filename_start = f"{params.exp_dir}/epoch-{start}.pt"
             filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
             logging.info(
-                "Calculating the averaged model over epoch range from "
+                f"Calculating the averaged model over epoch range from "
                 f"{start} (excluded) to {params.epoch}"
             )
             model.to(device)
@@ -267,7 +266,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py b/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py
index 4b16511e8..09de1bece 100755
--- a/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py
@@ -81,11 +81,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -111,12 +109,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -163,7 +159,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -194,9 +191,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -256,11 +254,15 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=features, x_lens=feature_lengths
+    )
 
     num_waves = encoder_out.size(0)
     hyps = []
@@ -332,7 +334,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py
index d37e7bdca..838a0497f 100755
--- a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py
@@ -92,7 +92,9 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+LRSchedulerType = Union[
+    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
+]
 
 
 def add_model_arguments(parser: argparse.ArgumentParser):
@@ -218,7 +220,8 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need to be changed.",
+        help="The initial learning rate.  This value should not need "
+        "to be changed.",
     )
 
     parser.add_argument(
@@ -241,45 +244,42 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help=(
-            "The prune range for rnnt loss, it means how many symbols(context)"
-            "we are using to compute the loss"
-        ),
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help=(
-            "The scale to smooth the loss with lm (output of prediction network) part."
-        ),
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)part.",
+        help="The scale to smooth the loss with am (output of encoder network)"
+        "part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help=(
-            "To get pruning ranges, we will calculate a simple version"
-            "loss(joiner is just addition), this simple loss also uses for"
-            "training (as a regularization item). We will scale the simple loss"
-            "with this parameter before adding to the final loss."
-        ),
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
     )
 
     parser.add_argument(
@@ -603,7 +603,11 @@ def compute_loss(
      warmup: a floating point value which increases throughout training;
         values >= 1.0 are fully warmed up and have all modules present.
     """
-    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+    device = (
+        model.device
+        if isinstance(model, DDP)
+        else next(model.parameters()).device
+    )
     feature = batch["inputs"]
     # at entry, feature is (N, T, C)
     assert feature.ndim == 3
@@ -632,16 +636,23 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+            0.0
+            if warmup < 1.0
+            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+        )
+        loss = (
+            params.simple_loss_scale * simple_loss
+            + pruned_loss_scale * pruned_loss
         )
-        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
 
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -760,7 +771,9 @@ def train_one_epoch(
             scaler.update()
             optimizer.zero_grad()
         except:  # noqa
-            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
+            display_and_save_batch(
+                batch, params=params, graph_compiler=graph_compiler
+            )
             raise
 
         if params.print_diagnostics and batch_idx == 5:
@@ -816,7 +829,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -924,7 +939,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2**22
+            2 ** 22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
@@ -1089,7 +1104,9 @@ def scan_pessimistic_batches_for_oom(
                     f"Failing criterion: {criterion} "
                     f"(={crit_values[criterion]}) ..."
                 )
-            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
+            display_and_save_batch(
+                batch, params=params, graph_compiler=graph_compiler
+            )
             raise
 
 
diff --git a/egs/aishell4/ASR/local/compute_fbank_aishell4.py b/egs/aishell4/ASR/local/compute_fbank_aishell4.py
index 400c406f0..3f50d9e3e 100755
--- a/egs/aishell4/ASR/local/compute_fbank_aishell4.py
+++ b/egs/aishell4/ASR/local/compute_fbank_aishell4.py
@@ -85,7 +85,9 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+                    cut_set
+                    + cut_set.perturb_speed(0.9)
+                    + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -118,7 +120,9 @@ def get_args():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/aishell4/ASR/local/prepare_char.py b/egs/aishell4/ASR/local/prepare_char.py
index 6b440dfb3..d9e47d17a 100755
--- a/egs/aishell4/ASR/local/prepare_char.py
+++ b/egs/aishell4/ASR/local/prepare_char.py
@@ -86,7 +86,9 @@ def lexicon_to_fst_no_sil(
         cur_state = loop_state
 
         word = word2id[word]
-        pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces]
+        pieces = [
+            token2id[i] if i in token2id else token2id[""] for i in pieces
+        ]
 
         for i in range(len(pieces) - 1):
             w = word if i == 0 else eps
@@ -140,7 +142,9 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
     return False
 
 
-def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
+def generate_lexicon(
+    token_sym_table: Dict[str, int], words: List[str]
+) -> Lexicon:
     """Generate a lexicon from a word list and token_sym_table.
 
     Args:
diff --git a/egs/aishell4/ASR/local/prepare_lang.py b/egs/aishell4/ASR/local/prepare_lang.py
index c8cf9b881..e5ae89ec4 100755
--- a/egs/aishell4/ASR/local/prepare_lang.py
+++ b/egs/aishell4/ASR/local/prepare_lang.py
@@ -317,7 +317,9 @@ def lexicon_to_fst(
 
 def get_args():
     parser = argparse.ArgumentParser()
-    parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
+    parser.add_argument(
+        "--lang-dir", type=str, help="The lang dir, data/lang_phone"
+    )
     return parser.parse_args()
 
 
diff --git a/egs/aishell4/ASR/local/test_prepare_lang.py b/egs/aishell4/ASR/local/test_prepare_lang.py
index 74e025ad7..d4cf62bba 100755
--- a/egs/aishell4/ASR/local/test_prepare_lang.py
+++ b/egs/aishell4/ASR/local/test_prepare_lang.py
@@ -88,7 +88,9 @@ def test_read_lexicon(filename: str):
     fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa.draw("L.pdf", title="L")
 
-    fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
+    fsa_disambig = lexicon_to_fst(
+        lexicon_disambig, phone2id=phone2id, word2id=word2id
+    )
     fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
     fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
diff --git a/egs/aishell4/ASR/local/text2token.py b/egs/aishell4/ASR/local/text2token.py
index 2be639b7a..71be2a613 100755
--- a/egs/aishell4/ASR/local/text2token.py
+++ b/egs/aishell4/ASR/local/text2token.py
@@ -50,15 +50,15 @@ def get_parser():
         "-n",
         default=1,
         type=int,
-        help=(
-            "number of characters to split, i.e.,                         aabb -> a a b"
-            " b with -n 1 and aa bb with -n 2"
-        ),
+        help="number of characters to split, i.e., \
+                        aabb -> a a b b with -n 1 and aa bb with -n 2",
     )
     parser.add_argument(
         "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
     )
-    parser.add_argument("--space", default="", type=str, help="space symbol")
+    parser.add_argument(
+        "--space", default="", type=str, help="space symbol"
+    )
     parser.add_argument(
         "--non-lang-syms",
         "-l",
@@ -66,7 +66,9 @@ def get_parser():
         type=str,
         help="list of non-linguistic symobles, e.g.,  etc.",
     )
-    parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
+    parser.add_argument(
+        "text", type=str, default=False, nargs="?", help="input text"
+    )
     parser.add_argument(
         "--trans_type",
         "-t",
@@ -106,7 +108,8 @@ def token2id(
             if token_type == "lazy_pinyin":
                 text = lazy_pinyin(chars_list)
                 sub_ids = [
-                    token_table[txt] if txt in token_table else oov_id for txt in text
+                    token_table[txt] if txt in token_table else oov_id
+                    for txt in text
                 ]
                 ids.append(sub_ids)
             else:  # token_type = "pinyin"
@@ -132,7 +135,9 @@ def main():
     if args.text:
         f = codecs.open(args.text, encoding="utf-8")
     else:
-        f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
+        f = codecs.getreader("utf-8")(
+            sys.stdin if is_python2 else sys.stdin.buffer
+        )
 
     sys.stdout = codecs.getwriter("utf-8")(
         sys.stdout if is_python2 else sys.stdout.buffer
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py
index 84c7f0443..7aa53ddda 100644
--- a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py
@@ -74,12 +74,10 @@ class Aishell4AsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description=(
-                "These options are used for the preparation of "
-                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-                "effective batch sizes, sampling strategies, applied data "
-                "augmentations, etc."
-            ),
+            description="These options are used for the preparation of "
+            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+            "effective batch sizes, sampling strategies, applied data "
+            "augmentations, etc.",
         )
 
         group.add_argument(
@@ -93,81 +91,66 @@ class Aishell4AsrDataModule:
             "--max-duration",
             type=int,
             default=200.0,
-            help=(
-                "Maximum pooled recordings duration (seconds) in a "
-                "single batch. You can reduce it if it causes CUDA OOM."
-            ),
+            help="Maximum pooled recordings duration (seconds) in a "
+            "single batch. You can reduce it if it causes CUDA OOM.",
         )
 
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, the batches will come from buckets of "
-                "similar duration (saves padding frames)."
-            ),
+            help="When enabled, the batches will come from buckets of "
+            "similar duration (saves padding frames).",
         )
 
         group.add_argument(
             "--num-buckets",
             type=int,
             default=300,
-            help=(
-                "The number of buckets for the DynamicBucketingSampler"
-                "(you might want to increase it for larger datasets)."
-            ),
+            help="The number of buckets for the DynamicBucketingSampler"
+            "(you might want to increase it for larger datasets).",
         )
 
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, utterances (cuts) will be concatenated "
-                "to minimize the amount of padding."
-            ),
+            help="When enabled, utterances (cuts) will be concatenated "
+            "to minimize the amount of padding.",
         )
 
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help=(
-                "Determines the maximum duration of a concatenated cut "
-                "relative to the duration of the longest cut in a batch."
-            ),
+            help="Determines the maximum duration of a concatenated cut "
+            "relative to the duration of the longest cut in a batch.",
         )
 
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help=(
-                "The amount of padding (in seconds) inserted between "
-                "concatenated cuts. This padding is filled with noise when "
-                "noise augmentation is used."
-            ),
+            help="The amount of padding (in seconds) inserted between "
+            "concatenated cuts. This padding is filled with noise when "
+            "noise augmentation is used.",
         )
 
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, use on-the-fly cut mixing and feature "
-                "extraction. Will drop existing precomputed feature manifests "
-                "if available."
-            ),
+            help="When enabled, use on-the-fly cut mixing and feature "
+            "extraction. Will drop existing precomputed feature manifests "
+            "if available.",
         )
 
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled (=default), the examples will be shuffled for each epoch."
-            ),
+            help="When enabled (=default), the examples will be "
+            "shuffled for each epoch.",
         )
 
         group.add_argument(
@@ -181,18 +164,17 @@ class Aishell4AsrDataModule:
             "--return-cuts",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, each batch will have the "
-                "field: batch['supervisions']['cut'] with the cuts that "
-                "were used to construct it."
-            ),
+            help="When enabled, each batch will have the "
+            "field: batch['supervisions']['cut'] with the cuts that "
+            "were used to construct it.",
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that collect the batches.",
+            help="The number of training dataloader workers that "
+            "collect the batches.",
         )
 
         group.add_argument(
@@ -206,22 +188,18 @@ class Aishell4AsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help=(
-                "Used only when --enable-spec-aug is True. "
-                "It specifies the factor for time warping in SpecAugment. "
-                "Larger values mean more warping. "
-                "A value less than 1 means to disable time warp."
-            ),
+            help="Used only when --enable-spec-aug is True. "
+            "It specifies the factor for time warping in SpecAugment. "
+            "Larger values mean more warping. "
+            "A value less than 1 means to disable time warp.",
         )
 
         group.add_argument(
             "--enable-musan",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, select noise from MUSAN and mix it"
-                "with training dataset. "
-            ),
+            help="When enabled, select noise from MUSAN and mix it"
+            "with training dataset. ",
         )
 
         group.add_argument(
@@ -244,20 +222,24 @@ class Aishell4AsrDataModule:
             The state dict for the training sampler.
         """
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+        cuts_musan = load_manifest(
+            self.args.manifest_dir / "musan_cuts.jsonl.gz"
+        )
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+                CutMix(
+                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
+                )
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                "Using cut concatenation with duration factor "
+                f"Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -272,7 +254,9 @@ class Aishell4AsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+            logging.info(
+                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
+            )
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -316,7 +300,9 @@ class Aishell4AsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -373,7 +359,9 @@ class Aishell4AsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 return_cuts=self.args.return_cuts,
             )
         else:
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py
index 616a88937..14e44c7d9 100755
--- a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py
@@ -117,24 +117,20 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch' and '--iter'"
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
     )
 
     parser.add_argument(
         "--use-averaged-model",
         type=str2bool,
         default=False,
-        help=(
-            "Whether to load averaged model. Currently it only supports "
-            "using --epoch. If True, it would decode with the averaged model "
-            "over the epoch range from `epoch-avg` (excluded) to `epoch`."
-            "Actually only the models with epoch number of `epoch-avg` and "
-            "`epoch` are loaded for averaging. "
-        ),
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
     )
 
     parser.add_argument(
@@ -205,7 +201,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -263,7 +260,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -278,7 +277,10 @@ def decode_one_batch(
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -324,7 +326,11 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -395,7 +401,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -428,7 +436,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -471,7 +480,9 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        params.suffix += (
+            f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        )
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@@ -499,12 +510,13 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg:
                 raise ValueError(
@@ -531,12 +543,13 @@ def main():
             )
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg + 1
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg + 1]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg + 1:
                 raise ValueError(
@@ -565,7 +578,7 @@ def main():
             filename_start = f"{params.exp_dir}/epoch-{start}.pt"
             filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
             logging.info(
-                "Calculating the averaged model over epoch range from "
+                f"Calculating the averaged model over epoch range from "
                 f"{start} (excluded) to {params.epoch}"
             )
             model.to(device)
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py
index 3c580ff7b..993341131 100755
--- a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py
@@ -89,24 +89,20 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch' and '--iter'"
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
     )
 
     parser.add_argument(
         "--use-averaged-model",
         type=str2bool,
         default=False,
-        help=(
-            "Whether to load averaged model. Currently it only supports "
-            "using --epoch. If True, it would decode with the averaged model "
-            "over the epoch range from `epoch-avg` (excluded) to `epoch`."
-            "Actually only the models with epoch number of `epoch-avg` and "
-            "`epoch` are loaded for averaging. "
-        ),
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
     )
 
     parser.add_argument(
@@ -140,7 +136,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     add_model_arguments(parser)
@@ -172,12 +169,13 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg:
                 raise ValueError(
@@ -204,12 +202,13 @@ def main():
             )
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg + 1
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg + 1]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg + 1:
                 raise ValueError(
@@ -238,7 +237,7 @@ def main():
             filename_start = f"{params.exp_dir}/epoch-{start}.pt"
             filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
             logging.info(
-                "Calculating the averaged model over epoch range from "
+                f"Calculating the averaged model over epoch range from "
                 f"{start} (excluded) to {params.epoch}"
             )
             model.to(device)
@@ -277,7 +276,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py b/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py
index 8151442af..1fa893637 100755
--- a/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py
@@ -94,11 +94,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -124,12 +122,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -176,7 +172,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -207,9 +204,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -268,11 +266,15 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=features, x_lens=feature_lengths
+    )
 
     num_waves = encoder_out.size(0)
     hyps = []
@@ -304,7 +306,10 @@ def main():
 
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -345,7 +350,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py
index aacd23ecd..0a48b9059 100755
--- a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py
@@ -85,7 +85,9 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+LRSchedulerType = Union[
+    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
+]
 
 
 def add_model_arguments(parser: argparse.ArgumentParser):
@@ -211,7 +213,8 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need to be changed.",
+        help="The initial learning rate.  This value should not need "
+        "to be changed.",
     )
 
     parser.add_argument(
@@ -234,45 +237,42 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help=(
-            "The prune range for rnnt loss, it means how many symbols(context)"
-            "we are using to compute the loss"
-        ),
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help=(
-            "The scale to smooth the loss with lm (output of prediction network) part."
-        ),
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)part.",
+        help="The scale to smooth the loss with am (output of encoder network)"
+        "part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help=(
-            "To get pruning ranges, we will calculate a simple version"
-            "loss(joiner is just addition), this simple loss also uses for"
-            "training (as a regularization item). We will scale the simple loss"
-            "with this parameter before adding to the final loss."
-        ),
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
     )
 
     parser.add_argument(
@@ -599,7 +599,11 @@ def compute_loss(
      warmup: a floating point value which increases throughout training;
         values >= 1.0 are fully warmed up and have all modules present.
     """
-    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+    device = (
+        model.device
+        if isinstance(model, DDP)
+        else next(model.parameters()).device
+    )
     feature = batch["inputs"]
     # at entry, feature is (N, T, C)
     assert feature.ndim == 3
@@ -629,15 +633,22 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+            0.0
+            if warmup < 1.0
+            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+        )
+        loss = (
+            params.simple_loss_scale * simple_loss
+            + pruned_loss_scale * pruned_loss
         )
-        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -816,7 +827,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -924,7 +937,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2**22
+            2 ** 22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
diff --git a/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py b/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py
index 96115a230..af926aa53 100755
--- a/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py
+++ b/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py
@@ -84,7 +84,9 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80):
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+                    cut_set
+                    + cut_set.perturb_speed(0.9)
+                    + cut_set.perturb_speed(1.1)
                 )
             cur_num_jobs = num_jobs if ex is None else 80
             cur_num_jobs = min(cur_num_jobs, len(cut_set))
@@ -119,7 +121,9 @@ def get_args():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/alimeeting/ASR/local/prepare_char.py b/egs/alimeeting/ASR/local/prepare_char.py
index 6b440dfb3..d9e47d17a 100755
--- a/egs/alimeeting/ASR/local/prepare_char.py
+++ b/egs/alimeeting/ASR/local/prepare_char.py
@@ -86,7 +86,9 @@ def lexicon_to_fst_no_sil(
         cur_state = loop_state
 
         word = word2id[word]
-        pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces]
+        pieces = [
+            token2id[i] if i in token2id else token2id[""] for i in pieces
+        ]
 
         for i in range(len(pieces) - 1):
             w = word if i == 0 else eps
@@ -140,7 +142,9 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
     return False
 
 
-def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
+def generate_lexicon(
+    token_sym_table: Dict[str, int], words: List[str]
+) -> Lexicon:
     """Generate a lexicon from a word list and token_sym_table.
 
     Args:
diff --git a/egs/alimeeting/ASR/local/prepare_lang.py b/egs/alimeeting/ASR/local/prepare_lang.py
index c8cf9b881..e5ae89ec4 100755
--- a/egs/alimeeting/ASR/local/prepare_lang.py
+++ b/egs/alimeeting/ASR/local/prepare_lang.py
@@ -317,7 +317,9 @@ def lexicon_to_fst(
 
 def get_args():
     parser = argparse.ArgumentParser()
-    parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
+    parser.add_argument(
+        "--lang-dir", type=str, help="The lang dir, data/lang_phone"
+    )
     return parser.parse_args()
 
 
diff --git a/egs/alimeeting/ASR/local/test_prepare_lang.py b/egs/alimeeting/ASR/local/test_prepare_lang.py
index 74e025ad7..d4cf62bba 100755
--- a/egs/alimeeting/ASR/local/test_prepare_lang.py
+++ b/egs/alimeeting/ASR/local/test_prepare_lang.py
@@ -88,7 +88,9 @@ def test_read_lexicon(filename: str):
     fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa.draw("L.pdf", title="L")
 
-    fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
+    fsa_disambig = lexicon_to_fst(
+        lexicon_disambig, phone2id=phone2id, word2id=word2id
+    )
     fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
     fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
diff --git a/egs/alimeeting/ASR/local/text2segments.py b/egs/alimeeting/ASR/local/text2segments.py
index 27b904fc8..7c1019aa8 100644
--- a/egs/alimeeting/ASR/local/text2segments.py
+++ b/egs/alimeeting/ASR/local/text2segments.py
@@ -30,8 +30,8 @@ with word segmenting:
 
 import argparse
 
-import jieba
 import paddle
+import jieba
 from tqdm import tqdm
 
 paddle.enable_static()
diff --git a/egs/alimeeting/ASR/local/text2token.py b/egs/alimeeting/ASR/local/text2token.py
index 2be639b7a..71be2a613 100755
--- a/egs/alimeeting/ASR/local/text2token.py
+++ b/egs/alimeeting/ASR/local/text2token.py
@@ -50,15 +50,15 @@ def get_parser():
         "-n",
         default=1,
         type=int,
-        help=(
-            "number of characters to split, i.e.,                         aabb -> a a b"
-            " b with -n 1 and aa bb with -n 2"
-        ),
+        help="number of characters to split, i.e., \
+                        aabb -> a a b b with -n 1 and aa bb with -n 2",
     )
     parser.add_argument(
         "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
     )
-    parser.add_argument("--space", default="", type=str, help="space symbol")
+    parser.add_argument(
+        "--space", default="", type=str, help="space symbol"
+    )
     parser.add_argument(
         "--non-lang-syms",
         "-l",
@@ -66,7 +66,9 @@ def get_parser():
         type=str,
         help="list of non-linguistic symobles, e.g.,  etc.",
     )
-    parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
+    parser.add_argument(
+        "text", type=str, default=False, nargs="?", help="input text"
+    )
     parser.add_argument(
         "--trans_type",
         "-t",
@@ -106,7 +108,8 @@ def token2id(
             if token_type == "lazy_pinyin":
                 text = lazy_pinyin(chars_list)
                 sub_ids = [
-                    token_table[txt] if txt in token_table else oov_id for txt in text
+                    token_table[txt] if txt in token_table else oov_id
+                    for txt in text
                 ]
                 ids.append(sub_ids)
             else:  # token_type = "pinyin"
@@ -132,7 +135,9 @@ def main():
     if args.text:
         f = codecs.open(args.text, encoding="utf-8")
     else:
-        f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
+        f = codecs.getreader("utf-8")(
+            sys.stdin if is_python2 else sys.stdin.buffer
+        )
 
     sys.stdout = codecs.getwriter("utf-8")(
         sys.stdout if is_python2 else sys.stdout.buffer
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py
index d0467a29e..bf6faad7a 100644
--- a/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -81,12 +81,10 @@ class AlimeetingAsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description=(
-                "These options are used for the preparation of "
-                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-                "effective batch sizes, sampling strategies, applied data "
-                "augmentations, etc."
-            ),
+            description="These options are used for the preparation of "
+            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+            "effective batch sizes, sampling strategies, applied data "
+            "augmentations, etc.",
         )
         group.add_argument(
             "--manifest-dir",
@@ -98,91 +96,75 @@ class AlimeetingAsrDataModule:
             "--max-duration",
             type=int,
             default=200.0,
-            help=(
-                "Maximum pooled recordings duration (seconds) in a "
-                "single batch. You can reduce it if it causes CUDA OOM."
-            ),
+            help="Maximum pooled recordings duration (seconds) in a "
+            "single batch. You can reduce it if it causes CUDA OOM.",
         )
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, the batches will come from buckets of "
-                "similar duration (saves padding frames)."
-            ),
+            help="When enabled, the batches will come from buckets of "
+            "similar duration (saves padding frames).",
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=300,
-            help=(
-                "The number of buckets for the DynamicBucketingSampler"
-                "(you might want to increase it for larger datasets)."
-            ),
+            help="The number of buckets for the DynamicBucketingSampler"
+            "(you might want to increase it for larger datasets).",
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, utterances (cuts) will be concatenated "
-                "to minimize the amount of padding."
-            ),
+            help="When enabled, utterances (cuts) will be concatenated "
+            "to minimize the amount of padding.",
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help=(
-                "Determines the maximum duration of a concatenated cut "
-                "relative to the duration of the longest cut in a batch."
-            ),
+            help="Determines the maximum duration of a concatenated cut "
+            "relative to the duration of the longest cut in a batch.",
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help=(
-                "The amount of padding (in seconds) inserted between "
-                "concatenated cuts. This padding is filled with noise when "
-                "noise augmentation is used."
-            ),
+            help="The amount of padding (in seconds) inserted between "
+            "concatenated cuts. This padding is filled with noise when "
+            "noise augmentation is used.",
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, use on-the-fly cut mixing and feature "
-                "extraction. Will drop existing precomputed feature manifests "
-                "if available."
-            ),
+            help="When enabled, use on-the-fly cut mixing and feature "
+            "extraction. Will drop existing precomputed feature manifests "
+            "if available.",
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled (=default), the examples will be shuffled for each epoch."
-            ),
+            help="When enabled (=default), the examples will be "
+            "shuffled for each epoch.",
         )
         group.add_argument(
             "--return-cuts",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, each batch will have the "
-                "field: batch['supervisions']['cut'] with the cuts that "
-                "were used to construct it."
-            ),
+            help="When enabled, each batch will have the "
+            "field: batch['supervisions']['cut'] with the cuts that "
+            "were used to construct it.",
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that collect the batches.",
+            help="The number of training dataloader workers that "
+            "collect the batches.",
         )
 
         group.add_argument(
@@ -196,22 +178,18 @@ class AlimeetingAsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help=(
-                "Used only when --enable-spec-aug is True. "
-                "It specifies the factor for time warping in SpecAugment. "
-                "Larger values mean more warping. "
-                "A value less than 1 means to disable time warp."
-            ),
+            help="Used only when --enable-spec-aug is True. "
+            "It specifies the factor for time warping in SpecAugment. "
+            "Larger values mean more warping. "
+            "A value less than 1 means to disable time warp.",
         )
 
         group.add_argument(
             "--enable-musan",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, select noise from MUSAN and mix it"
-                "with training dataset. "
-            ),
+            help="When enabled, select noise from MUSAN and mix it"
+            "with training dataset. ",
         )
 
     def train_dataloaders(
@@ -227,20 +205,24 @@ class AlimeetingAsrDataModule:
             The state dict for the training sampler.
         """
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+        cuts_musan = load_manifest(
+            self.args.manifest_dir / "musan_cuts.jsonl.gz"
+        )
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+                CutMix(
+                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
+                )
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                "Using cut concatenation with duration factor "
+                f"Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -255,7 +237,9 @@ class AlimeetingAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+            logging.info(
+                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
+            )
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -298,7 +282,9 @@ class AlimeetingAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -355,7 +341,9 @@ class AlimeetingAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 return_cuts=self.args.return_cuts,
             )
         else:
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py
index ffaca1021..6358fe970 100755
--- a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py
@@ -70,7 +70,11 @@ from beam_search import (
 from lhotse.cut import Cut
 from train import get_params, get_transducer_model
 
-from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
+from icefall.checkpoint import (
+    average_checkpoints,
+    find_checkpoints,
+    load_checkpoint,
+)
 from icefall.lexicon import Lexicon
 from icefall.utils import (
     AttributeDict,
@@ -89,30 +93,25 @@ def get_parser():
         "--epoch",
         type=int,
         default=28,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--batch",
         type=int,
         default=None,
-        help=(
-            "It specifies the batch checkpoint to use for decoding."
-            "Note: Epoch counts from 0."
-        ),
+        help="It specifies the batch checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -194,7 +193,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -249,7 +249,9 @@ def decode_one_batch(
 
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -264,7 +266,10 @@ def decode_one_batch(
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -310,7 +315,11 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -381,7 +390,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -414,7 +425,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -551,7 +563,8 @@ def main():
         )
 
     dev_shards = [
-        str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
+        str(path)
+        for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
     ]
     cuts_dev_webdataset = CutSet.from_webdataset(
         dev_shards,
@@ -561,7 +574,8 @@ def main():
     )
 
     test_shards = [
-        str(path) for path in sorted(glob.glob(os.path.join(test, "shared-*.tar")))
+        str(path)
+        for path in sorted(glob.glob(os.path.join(test, "shared-*.tar")))
     ]
     cuts_test_webdataset = CutSet.from_webdataset(
         test_shards,
@@ -574,7 +588,9 @@ def main():
         return 1.0 <= c.duration
 
     cuts_dev_webdataset = cuts_dev_webdataset.filter(remove_short_and_long_utt)
-    cuts_test_webdataset = cuts_test_webdataset.filter(remove_short_and_long_utt)
+    cuts_test_webdataset = cuts_test_webdataset.filter(
+        remove_short_and_long_utt
+    )
 
     dev_dl = alimeeting.valid_dataloaders(cuts_dev_webdataset)
     test_dl = alimeeting.test_dataloaders(cuts_test_webdataset)
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py
index 482e52d83..8beec1b8a 100644
--- a/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py
@@ -62,20 +62,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=28,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -106,7 +103,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     return parser
@@ -175,7 +173,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py
index afbf0960a..93b1e1f57 100644
--- a/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py
@@ -85,11 +85,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -114,12 +112,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -166,7 +162,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -196,9 +193,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -259,7 +257,9 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -284,7 +284,10 @@ def main():
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -336,7 +339,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py
index 158ea9c1b..81a0ede7f 100644
--- a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py
@@ -81,7 +81,9 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+LRSchedulerType = Union[
+    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
+]
 
 os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
 
@@ -185,45 +187,42 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help=(
-            "The prune range for rnnt loss, it means how many symbols(context)"
-            "we are using to compute the loss"
-        ),
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help=(
-            "The scale to smooth the loss with lm (output of prediction network) part."
-        ),
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)part.",
+        help="The scale to smooth the loss with am (output of encoder network)"
+        "part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help=(
-            "To get pruning ranges, we will calculate a simple version"
-            "loss(joiner is just addition), this simple loss also uses for"
-            "training (as a regularization item). We will scale the simple loss"
-            "with this parameter before adding to the final loss."
-        ),
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
     )
 
     parser.add_argument(
@@ -543,15 +542,22 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+            0.0
+            if warmup < 1.0
+            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+        )
+        loss = (
+            params.simple_loss_scale * simple_loss
+            + pruned_loss_scale * pruned_loss
         )
-        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -705,7 +711,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -805,7 +813,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2**22
+            2 ** 22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
diff --git a/egs/csj/ASR/.gitignore b/egs/csj/ASR/.gitignore
index cd0e20c4c..5d965832e 100644
--- a/egs/csj/ASR/.gitignore
+++ b/egs/csj/ASR/.gitignore
@@ -5,4 +5,4 @@ notify_tg.py
 finetune_*
 misc.ini
 .vscode/*
-offline/*
+offline/*
\ No newline at end of file
diff --git a/egs/csj/ASR/local/compute_fbank_csj.py b/egs/csj/ASR/local/compute_fbank_csj.py
index 036ce925f..994dedbdd 100644
--- a/egs/csj/ASR/local/compute_fbank_csj.py
+++ b/egs/csj/ASR/local/compute_fbank_csj.py
@@ -25,10 +25,15 @@ from random import Random
 from typing import List, Tuple
 
 import torch
-from lhotse import (  # fmt: off; See the following for why LilcomChunkyWriter is preferred; https://github.com/k2-fsa/icefall/pull/404; https://github.com/lhotse-speech/lhotse/pull/527; fmt: on
+from lhotse import (
     CutSet,
     Fbank,
     FbankConfig,
+    # fmt: off
+    # See the following for why LilcomChunkyWriter is preferred
+    # https://github.com/k2-fsa/icefall/pull/404
+    # https://github.com/lhotse-speech/lhotse/pull/527
+    # fmt: on
     LilcomChunkyWriter,
     RecordingSet,
     SupervisionSet,
@@ -76,13 +81,17 @@ def make_cutset_blueprints(
         cut_sets.append((f"eval{i}", cut_set))
 
     # Create train and valid cuts
-    logging.info("Loading, trimming, and shuffling the remaining core+noncore cuts.")
+    logging.info(
+        "Loading, trimming, and shuffling the remaining core+noncore cuts."
+    )
     recording_set = RecordingSet.from_file(
         manifest_dir / "csj_recordings_core.jsonl.gz"
     ) + RecordingSet.from_file(manifest_dir / "csj_recordings_noncore.jsonl.gz")
     supervision_set = SupervisionSet.from_file(
         manifest_dir / "csj_supervisions_core.jsonl.gz"
-    ) + SupervisionSet.from_file(manifest_dir / "csj_supervisions_noncore.jsonl.gz")
+    ) + SupervisionSet.from_file(
+        manifest_dir / "csj_supervisions_noncore.jsonl.gz"
+    )
 
     cut_set = CutSet.from_manifests(
         recordings=recording_set,
@@ -92,12 +101,15 @@ def make_cutset_blueprints(
     cut_set = cut_set.shuffle(Random(RNG_SEED))
 
     logging.info(
-        f"Creating valid and train cuts from core and noncore,split at {split}."
+        "Creating valid and train cuts from core and noncore,"
+        f"split at {split}."
     )
     valid_set = CutSet.from_cuts(islice(cut_set, 0, split))
 
     train_set = CutSet.from_cuts(islice(cut_set, split, None))
-    train_set = train_set + train_set.perturb_speed(0.9) + train_set.perturb_speed(1.1)
+    train_set = (
+        train_set + train_set.perturb_speed(0.9) + train_set.perturb_speed(1.1)
+    )
 
     cut_sets.extend([("valid", valid_set), ("train", train_set)])
 
@@ -110,9 +122,15 @@ def get_args():
         formatter_class=argparse.ArgumentDefaultsHelpFormatter,
     )
 
-    parser.add_argument("--manifest-dir", type=Path, help="Path to save manifests")
-    parser.add_argument("--fbank-dir", type=Path, help="Path to save fbank features")
-    parser.add_argument("--split", type=int, default=4000, help="Split at this index")
+    parser.add_argument(
+        "--manifest-dir", type=Path, help="Path to save manifests"
+    )
+    parser.add_argument(
+        "--fbank-dir", type=Path, help="Path to save fbank features"
+    )
+    parser.add_argument(
+        "--split", type=int, default=4000, help="Split at this index"
+    )
 
     return parser.parse_args()
 
@@ -123,7 +141,9 @@ def main():
     extractor = Fbank(FbankConfig(num_mel_bins=80))
     num_jobs = min(16, os.cpu_count())
 
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/csj/ASR/local/compute_fbank_musan.py b/egs/csj/ASR/local/compute_fbank_musan.py
index f60e62c85..44a33c4eb 100644
--- a/egs/csj/ASR/local/compute_fbank_musan.py
+++ b/egs/csj/ASR/local/compute_fbank_musan.py
@@ -26,6 +26,7 @@ from lhotse.recipes.utils import read_manifests_if_cached
 
 from icefall.utils import get_executor
 
+
 ARGPARSE_DESCRIPTION = """
 This file computes fbank features of the musan dataset.
 It looks for manifests in the directory data/manifests.
@@ -83,7 +84,9 @@ def compute_fbank_musan(manifest_dir: Path, fbank_dir: Path):
         # create chunks of Musan with duration 5 - 10 seconds
         musan_cuts = (
             CutSet.from_manifests(
-                recordings=combine(part["recordings"] for part in manifests.values())
+                recordings=combine(
+                    part["recordings"] for part in manifests.values()
+                )
             )
             .cut_into_windows(10.0)
             .filter(lambda c: c.duration > 5)
@@ -104,15 +107,21 @@ def get_args():
         formatter_class=argparse.ArgumentDefaultsHelpFormatter,
     )
 
-    parser.add_argument("--manifest-dir", type=Path, help="Path to save manifests")
-    parser.add_argument("--fbank-dir", type=Path, help="Path to save fbank features")
+    parser.add_argument(
+        "--manifest-dir", type=Path, help="Path to save manifests"
+    )
+    parser.add_argument(
+        "--fbank-dir", type=Path, help="Path to save fbank features"
+    )
 
     return parser.parse_args()
 
 
 if __name__ == "__main__":
     args = get_args()
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     compute_fbank_musan(args.manifest_dir, args.fbank_dir)
diff --git a/egs/csj/ASR/local/conf/disfluent.ini b/egs/csj/ASR/local/conf/disfluent.ini
index c987e72c5..eb70673de 100644
--- a/egs/csj/ASR/local/conf/disfluent.ini
+++ b/egs/csj/ASR/local/conf/disfluent.ini
@@ -1,17 +1,17 @@
 ; # This section is ignored if this file is not supplied as the first config file to
-; # lhotse prepare csj
+; # lhotse prepare csj  
 [SEGMENTS]
 ; # Allowed period of nonverbal noise. If exceeded, a new segment is created.
 gap = 0.5
 ; # Maximum length of segment (s).
 maxlen = 10
-; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently.
+; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently.  
 minlen = 0.02
-; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`.
-; # Pass an empty string to avoid adding any symbol. It was "" in kaldi.
-; # If you intend to use a multicharacter string for gap_sym, remember to register the
-; # multicharacter string as part of userdef-string in prepare_lang_char.py.
-gap_sym =
+; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. 
+; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. 
+; # If you intend to use a multicharacter string for gap_sym, remember to register the 
+; # multicharacter string as part of userdef-string in prepare_lang_char.py. 
+gap_sym = 
 
 [CONSTANTS]
 ; # Name of this mode
@@ -115,59 +115,59 @@ B^ = 0
 ; # 0 to remain, 1 to delete
 ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)'
 笑 = 0
-; # Example: 'コク(笑 サイ+(D オン))',
+; # Example: 'コク(笑 サイ+(D オン))', 
 笑^ = 0
 ; # 泣きながら発話
 ; # 0 to remain, 1 to delete
-; # Example: '(泣 ドンナニ)'
+; # Example: '(泣 ドンナニ)' 
 泣 = 0
 泣^ = 0
 ; # 咳をしながら発話
 ; # 0 to remain, 1 to delete
-; # Example: 'シャ(咳 リン) ノ'
+; # Example: 'シャ(咳 リン) ノ' 
 咳 = 0
 ; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)'
 咳^ = 0
 ; # ささやき声や独り言などの小さな声
 ; # 0 to remain, 1 to delete
-; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))'
+; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' 
 L = 0
 ; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト'
 L^ = 0
 
 [REPLACEMENTS]
 ; # ボーカルフライなどで母音が同定できない場合
- =
+ = 
 ; # 「うん/うーん/ふーん」の音の特定が困難な場合
- =
+ = 
 ; # 非語彙的な母音の引き延ばし
- =
+ = 
 ; # 非語彙的な子音の引き延ばし
- =
+ = 
 ; # 言語音と独立に講演者の笑いが生じている場合
-<笑> =
+<笑> = 
 ; # 言語音と独立に講演者の咳が生じている場合
-<咳> =
+<咳> = 
 ; # 言語音と独立に講演者の息が生じている場合
-<息> =
+<息> = 
 ; # 講演者の泣き声
-<泣> =
+<泣> = 
 ; # 聴衆(司会者なども含む)の発話
-<フロア発話> =
+<フロア発話> = 
 ; # 聴衆の笑い
-<フロア笑> =
+<フロア笑> = 
 ; # 聴衆の拍手
-<拍手> =
+<拍手> = 
 ; # 講演者が発表中に用いたデモンストレーションの音声
-<デモ> =
+<デモ> = 
 ; # 学会講演に発表時間を知らせるためにならすベルの音
-<ベル> =
+<ベル> = 
 ; # 転記単位全体が再度読み直された場合
-<朗読間違い> =
+<朗読間違い> = 
 ; # 上記以外の音で特に目立った音
-<雑音> =
+<雑音> = 
 ; # 0.2秒以上のポーズ
-

= +

= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -318,3 +318,4 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo + diff --git a/egs/csj/ASR/local/conf/fluent.ini b/egs/csj/ASR/local/conf/fluent.ini index f7f27f5bc..5d22f9eb8 100644 --- a/egs/csj/ASR/local/conf/fluent.ini +++ b/egs/csj/ASR/local/conf/fluent.ini @@ -1,17 +1,17 @@ ; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj +; # lhotse prepare csj [SEGMENTS] ; # Allowed period of nonverbal noise. If exceeded, a new segment is created. gap = 0.5 ; # Maximum length of segment (s). maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. +; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = +; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. +; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. +; # If you intend to use a multicharacter string for gap_sym, remember to register the +; # multicharacter string as part of userdef-string in prepare_lang_char.py. +gap_sym = [CONSTANTS] ; # Name of this mode @@ -115,59 +115,59 @@ B^ = 0 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', +; # Example: 'コク(笑 サイ+(D オン))', 笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' +; # Example: '(泣 ドンナニ)' 泣 = 0 泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' +; # Example: 'シャ(咳 リン) ノ' 咳 = 0 ; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' 咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 ; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' L^ = 0 [REPLACEMENTS] ; # ボーカルフライなどで母音が同定できない場合 - = + = ; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = + = ; # 非語彙的な母音の引き延ばし - = + = ; # 非語彙的な子音の引き延ばし - = + = ; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = +<笑> = ; # 言語音と独立に講演者の咳が生じている場合 -<咳> = +<咳> = ; # 言語音と独立に講演者の息が生じている場合 -<息> = +<息> = ; # 講演者の泣き声 -<泣> = +<泣> = ; # 聴衆(司会者なども含む)の発話 -<フロア発話> = +<フロア発話> = ; # 聴衆の笑い -<フロア笑> = +<フロア笑> = ; # 聴衆の拍手 -<拍手> = +<拍手> = ; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = +<デモ> = ; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = +<ベル> = ; # 転記単位全体が再度読み直された場合 -<朗読間違い> = +<朗読間違い> = ; # 上記以外の音で特に目立った音 -<雑音> = +<雑音> = ; # 0.2秒以上のポーズ -

= +

= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -318,3 +318,4 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo + diff --git a/egs/csj/ASR/local/conf/number.ini b/egs/csj/ASR/local/conf/number.ini index cf9038f62..2613c3409 100644 --- a/egs/csj/ASR/local/conf/number.ini +++ b/egs/csj/ASR/local/conf/number.ini @@ -1,17 +1,17 @@ ; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj +; # lhotse prepare csj [SEGMENTS] ; # Allowed period of nonverbal noise. If exceeded, a new segment is created. gap = 0.5 ; # Maximum length of segment (s). maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. +; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = +; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. +; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. +; # If you intend to use a multicharacter string for gap_sym, remember to register the +; # multicharacter string as part of userdef-string in prepare_lang_char.py. +gap_sym = [CONSTANTS] ; # Name of this mode @@ -115,59 +115,59 @@ B^ = 0 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', +; # Example: 'コク(笑 サイ+(D オン))', 笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' +; # Example: '(泣 ドンナニ)' 泣 = 0 泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' +; # Example: 'シャ(咳 リン) ノ' 咳 = 0 ; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' 咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 ; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' L^ = 0 [REPLACEMENTS] ; # ボーカルフライなどで母音が同定できない場合 - = + = ; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = + = ; # 非語彙的な母音の引き延ばし - = + = ; # 非語彙的な子音の引き延ばし - = + = ; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = +<笑> = ; # 言語音と独立に講演者の咳が生じている場合 -<咳> = +<咳> = ; # 言語音と独立に講演者の息が生じている場合 -<息> = +<息> = ; # 講演者の泣き声 -<泣> = +<泣> = ; # 聴衆(司会者なども含む)の発話 -<フロア発話> = +<フロア発話> = ; # 聴衆の笑い -<フロア笑> = +<フロア笑> = ; # 聴衆の拍手 -<拍手> = +<拍手> = ; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = +<デモ> = ; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = +<ベル> = ; # 転記単位全体が再度読み直された場合 -<朗読間違い> = +<朗読間違い> = ; # 上記以外の音で特に目立った音 -<雑音> = +<雑音> = ; # 0.2秒以上のポーズ -

= +

= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -318,3 +318,4 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo + diff --git a/egs/csj/ASR/local/conf/symbol.ini b/egs/csj/ASR/local/conf/symbol.ini index f9801284b..8ba451dd5 100644 --- a/egs/csj/ASR/local/conf/symbol.ini +++ b/egs/csj/ASR/local/conf/symbol.ini @@ -1,17 +1,17 @@ ; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj +; # lhotse prepare csj [SEGMENTS] ; # Allowed period of nonverbal noise. If exceeded, a new segment is created. gap = 0.5 ; # Maximum length of segment (s). maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. +; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = +; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. +; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. +; # If you intend to use a multicharacter string for gap_sym, remember to register the +; # multicharacter string as part of userdef-string in prepare_lang_char.py. +gap_sym = [CONSTANTS] ; # Name of this mode @@ -116,59 +116,59 @@ B^ = 0 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', +; # Example: 'コク(笑 サイ+(D オン))', 笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' +; # Example: '(泣 ドンナニ)' 泣 = 0 泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' +; # Example: 'シャ(咳 リン) ノ' 咳 = 0 ; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' 咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 ; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' L^ = 0 [REPLACEMENTS] ; # ボーカルフライなどで母音が同定できない場合 - = + = ; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = + = ; # 非語彙的な母音の引き延ばし - = + = ; # 非語彙的な子音の引き延ばし - = + = ; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = +<笑> = ; # 言語音と独立に講演者の咳が生じている場合 -<咳> = +<咳> = ; # 言語音と独立に講演者の息が生じている場合 -<息> = +<息> = ; # 講演者の泣き声 -<泣> = +<泣> = ; # 聴衆(司会者なども含む)の発話 -<フロア発話> = +<フロア発話> = ; # 聴衆の笑い -<フロア笑> = +<フロア笑> = ; # 聴衆の拍手 -<拍手> = +<拍手> = ; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = +<デモ> = ; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = +<ベル> = ; # 転記単位全体が再度読み直された場合 -<朗読間違い> = +<朗読間違い> = ; # 上記以外の音で特に目立った音 -<雑音> = +<雑音> = ; # 0.2秒以上のポーズ -

= +

= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -319,3 +319,4 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo + diff --git a/egs/csj/ASR/local/display_manifest_statistics.py b/egs/csj/ASR/local/display_manifest_statistics.py index c043cf853..c9de21073 100644 --- a/egs/csj/ASR/local/display_manifest_statistics.py +++ b/egs/csj/ASR/local/display_manifest_statistics.py @@ -37,7 +37,9 @@ def get_parser(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument("--manifest-dir", type=Path, help="Path to cutset manifests") + parser.add_argument( + "--manifest-dir", type=Path, help="Path to cutset manifests" + ) return parser.parse_args() diff --git a/egs/csj/ASR/local/prepare_lang_char.py b/egs/csj/ASR/local/prepare_lang_char.py index f0078421b..e4d996871 100644 --- a/egs/csj/ASR/local/prepare_lang_char.py +++ b/egs/csj/ASR/local/prepare_lang_char.py @@ -68,7 +68,8 @@ def get_args(): type=Path, default=None, help=( - "Name of lang dir. If not set, this will default to lang_char_{trans-mode}" + "Name of lang dir. " + "If not set, this will default to lang_char_{trans-mode}" ), ) @@ -86,7 +87,9 @@ def main(): args = get_args() logging.basicConfig( - format="%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s", + format=( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] " "%(message)s" + ), level=logging.INFO, ) @@ -108,7 +111,8 @@ def main(): words = set() logging.info( - f"Creating vocabulary from {args.train_cut.name} at {args.trans_mode} mode." + f"Creating vocabulary from {args.train_cut.name}" + f" at {args.trans_mode} mode." ) for cut in train_set: try: @@ -119,7 +123,8 @@ def main(): ) except KeyError: raise KeyError( - f"Could not find {args.trans_mode} in {cut.supervisions[0].custom}" + f"Could not find {args.trans_mode} in " + f"{cut.supervisions[0].custom}" ) for t in text.split(): if t in args.userdef_string: @@ -138,7 +143,9 @@ def main(): (args.lang_dir / "words_len").write_text(f"{len(words)}") - (args.lang_dir / "userdef_string").write_text("\n".join(args.userdef_string)) + (args.lang_dir / "userdef_string").write_text( + "\n".join(args.userdef_string) + ) (args.lang_dir / "trans_mode").write_text(args.trans_mode) logging.info("Done.") diff --git a/egs/csj/ASR/local/validate_manifest.py b/egs/csj/ASR/local/validate_manifest.py index 89448a49c..0c4c6c1ea 100644 --- a/egs/csj/ASR/local/validate_manifest.py +++ b/egs/csj/ASR/local/validate_manifest.py @@ -68,7 +68,8 @@ def validate_supervision_and_cut_time_bounds(c: Cut): if s.end > c.end: raise ValueError( - f"{c.id}: Supervision end time {s.end} is larger than cut end time {c.end}" + f"{c.id}: Supervision end time {s.end} is larger " + f"than cut end time {c.end}" ) @@ -88,7 +89,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py index c3e3e84bf..d78e26240 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py @@ -61,12 +61,10 @@ class GigaSpeechAsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description=( - "These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc." - ), + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", ) group.add_argument( "--manifest-dir", @@ -78,91 +76,75 @@ class GigaSpeechAsrDataModule: "--max-duration", type=int, default=200.0, - help=( - "Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM." - ), + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help=( - "When enabled, the batches will come from buckets of " - "similar duration (saves padding frames)." - ), + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", ) group.add_argument( "--num-buckets", type=int, default=30, - help=( - "The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets)." - ), + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help=( - "When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding." - ), + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", ) group.add_argument( "--duration-factor", type=float, default=1.0, - help=( - "Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch." - ), + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", ) group.add_argument( "--gap", type=float, default=1.0, - help=( - "The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used." - ), + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help=( - "When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available." - ), + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", ) group.add_argument( "--shuffle", type=str2bool, default=True, - help=( - "When enabled (=default), the examples will be shuffled for each epoch." - ), + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help=( - "When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it." - ), + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that collect the batches.", + help="The number of training dataloader workers that " + "collect the batches.", ) group.add_argument( @@ -176,22 +158,18 @@ class GigaSpeechAsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help=( - "Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp." - ), + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help=( - "When enabled, select noise from MUSAN and mix it " - "with training dataset. " - ), + help="When enabled, select noise from MUSAN and mix it " + "with training dataset. ", ) # GigaSpeech specific arguments @@ -205,25 +183,30 @@ class GigaSpeechAsrDataModule: "--small-dev", type=str2bool, default=False, - help="Should we use only 1000 utterances for dev (speeds up training)", + help="Should we use only 1000 utterances for dev " + "(speeds up training)", ) def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + cuts_musan = load_manifest( + self.args.manifest_dir / "musan_cuts.jsonl.gz" + ) transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix( + cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True + ) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - "Using cut concatenation with duration factor " + f"Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -238,7 +221,9 @@ class GigaSpeechAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + logging.info( + f"Time warp factor: {self.args.spec_aug_time_warp_factor}" + ) input_transforms.append( SpecAugment( time_warp_factor=self.args.spec_aug_time_warp_factor, @@ -271,7 +256,9 @@ class GigaSpeechAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -317,7 +304,9 @@ class GigaSpeechAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), return_cuts=self.args.return_cuts, ) else: @@ -373,7 +362,9 @@ class GigaSpeechAsrDataModule: @lru_cache() def dev_cuts(self) -> CutSet: logging.info("About to get dev cuts") - cuts_valid = load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") + cuts_valid = load_manifest_lazy( + self.args.manifest_dir / "cuts_DEV.jsonl.gz" + ) if self.args.small_dev: return cuts_valid.subset(first=1000) else: diff --git a/egs/gigaspeech/ASR/conformer_ctc/conformer.py b/egs/gigaspeech/ASR/conformer_ctc/conformer.py index 1153a814c..6fac07f93 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/conformer.py +++ b/egs/gigaspeech/ASR/conformer_ctc/conformer.py @@ -160,7 +160,9 @@ class ConformerEncoderLayer(nn.Module): use_conv_batchnorm: bool = False, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -180,14 +182,18 @@ class ConformerEncoderLayer(nn.Module): d_model, cnn_module_kernel, use_batchnorm=use_conv_batchnorm ) - self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm( + d_model + ) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm(d_model) # for the final output of the block + self.norm_final = nn.LayerNorm( + d_model + ) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -221,7 +227,9 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) + src = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(src) + ) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -340,7 +348,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -356,7 +366,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -626,9 +638,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -696,25 +708,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -751,7 +771,9 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -763,7 +785,9 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -797,9 +821,13 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads diff --git a/egs/gigaspeech/ASR/conformer_ctc/decode.py b/egs/gigaspeech/ASR/conformer_ctc/decode.py index b38ae9c8c..9c1418baa 100755 --- a/egs/gigaspeech/ASR/conformer_ctc/decode.py +++ b/egs/gigaspeech/ASR/conformer_ctc/decode.py @@ -62,19 +62,16 @@ def get_parser(): "--epoch", type=int, default=0, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=1, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -479,7 +476,9 @@ def decode_dataset( results[lm_scale].extend(this_batch) else: - assert len(results) > 0, "It should not decode to empty in the first batch!" + assert ( + len(results) > 0 + ), "It should not decode to empty in the first batch!" this_batch = [] hyp_words = [] for cut_id, ref_text in zip(cut_ids, texts): @@ -494,7 +493,9 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -527,7 +528,9 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + logging.info( + "Wrote detailed error stats to {}".format(errs_filename) + ) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -702,7 +705,9 @@ def main(): eos_id=eos_id, ) - save_results(params=params, test_set_name=test_set, results_dict=results_dict) + save_results( + params=params, test_set_name=test_set, results_dict=results_dict + ) logging.info("Done!") diff --git a/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py b/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py index 880aa76e2..ef53b77f8 100755 --- a/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py +++ b/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py @@ -73,7 +73,8 @@ def asr_text_post_processing(text: str) -> str: if __name__ == "__main__": parser = argparse.ArgumentParser( - description="This script evaluates GigaSpeech ASR result viaSCTK's tool sclite" + description="This script evaluates GigaSpeech ASR result via" + "SCTK's tool sclite" ) parser.add_argument( "ref", diff --git a/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py b/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py index 3b94f0c4b..cdc85ce9a 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py +++ b/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py @@ -78,10 +78,13 @@ class LabelSmoothingLoss(torch.nn.Module): ignored = target == self.ignore_index target[ignored] = 0 - true_dist = torch.nn.functional.one_hot(target, num_classes=num_classes).to(x) + true_dist = torch.nn.functional.one_hot( + target, num_classes=num_classes + ).to(x) true_dist = ( - true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes + true_dist * (1 - self.label_smoothing) + + self.label_smoothing / num_classes ) # Set the value of ignored indexes to 0 true_dist[ignored] = 0 diff --git a/egs/gigaspeech/ASR/conformer_ctc/subsampling.py b/egs/gigaspeech/ASR/conformer_ctc/subsampling.py index 8e0f73d05..542fb0364 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/subsampling.py +++ b/egs/gigaspeech/ASR/conformer_ctc/subsampling.py @@ -42,9 +42,13 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), + nn.Conv2d( + in_channels=1, out_channels=odim, kernel_size=3, stride=2 + ), nn.ReLU(), - nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), + nn.Conv2d( + in_channels=odim, out_channels=odim, kernel_size=3, stride=2 + ), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -128,13 +132,17 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) + torch.nn.MaxPool2d( + kernel_size=2, stride=2, padding=0, ceil_mode=True + ) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) + self.out = nn.Linear( + block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim + ) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/gigaspeech/ASR/conformer_ctc/train.py b/egs/gigaspeech/ASR/conformer_ctc/train.py index 4883d04d8..2965cde18 100755 --- a/egs/gigaspeech/ASR/conformer_ctc/train.py +++ b/egs/gigaspeech/ASR/conformer_ctc/train.py @@ -386,7 +386,9 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) + unsorted_token_ids = graph_compiler.texts_to_ids( + supervisions["text"] + ) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -519,7 +521,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -637,7 +641,9 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/gigaspeech/ASR/conformer_ctc/transformer.py b/egs/gigaspeech/ASR/conformer_ctc/transformer.py index 0566cfc81..00ca027a7 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/transformer.py +++ b/egs/gigaspeech/ASR/conformer_ctc/transformer.py @@ -151,7 +151,9 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) + self.decoder_output_layer = torch.nn.Linear( + d_model, self.decoder_num_class + ) self.decoder_criterion = LabelSmoothingLoss() else: @@ -179,13 +181,18 @@ class Transformer(nn.Module): memory_key_padding_mask for the decoder. Its shape is (N, T). It is None if `supervision` is None. """ - if isinstance(self.use_feat_batchnorm, bool) and self.use_feat_batchnorm: + if ( + isinstance(self.use_feat_batchnorm, bool) + and self.use_feat_batchnorm + ): x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) if isinstance(self.use_feat_batchnorm, float): x *= self.use_feat_batchnorm - encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) + encoder_memory, memory_key_padding_mask = self.run_encoder( + x, supervision + ) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -266,17 +273,23 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -337,17 +350,23 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -620,7 +639,9 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) + raise RuntimeError( + "activation should be relu/gelu, not {}".format(activation) + ) class PositionalEncoding(nn.Module): @@ -822,7 +843,9 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] + lengths = [ + 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) + ] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -843,7 +866,9 @@ def encoder_padding_mask( return mask -def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: +def decoder_padding_mask( + ys_pad: torch.Tensor, ignore_id: int = -1 +) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py index 07beeb1f0..8209ee3ec 100755 --- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py +++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py @@ -77,7 +77,9 @@ def compute_fbank_gigaspeech_dev_test(): def main(): - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_gigaspeech_dev_test() diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py index 0ee845ec8..6410249db 100755 --- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py +++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py @@ -47,10 +47,8 @@ def get_parser(): "--batch-duration", type=float, default=600.0, - help=( - "The maximum number of audio seconds in a batch." - "Determines batch size dynamically." - ), + help="The maximum number of audio seconds in a batch." + "Determines batch size dynamically.", ) parser.add_argument( @@ -136,7 +134,9 @@ def main(): date_time = now.strftime("%Y-%m-%d-%H-%M-%S") log_filename = "log-compute_fbank_gigaspeech_splits" - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) log_filename = f"{log_filename}-{date_time}" logging.basicConfig( diff --git a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py index 31abe7fff..48d10a157 100755 --- a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py +++ b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py @@ -98,13 +98,19 @@ def preprocess_giga_speech(): f"Speed perturb for {partition} with factors 0.9 and 1.1 " "(Perturbing may take 8 minutes and saving may take 20 minutes)" ) - cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + cut_set = ( + cut_set + + cut_set.perturb_speed(0.9) + + cut_set.perturb_speed(1.1) + ) logging.info(f"Saving to {raw_cuts_path}") cut_set.to_file(raw_cuts_path) def main(): - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) preprocess_giga_speech() diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index 9ae3f071e..c87686e1e 100644 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -73,12 +73,10 @@ class GigaSpeechAsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description=( - "These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc." - ), + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", ) group.add_argument( "--manifest-dir", @@ -90,91 +88,75 @@ class GigaSpeechAsrDataModule: "--max-duration", type=int, default=200.0, - help=( - "Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM." - ), + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help=( - "When enabled, the batches will come from buckets of " - "similar duration (saves padding frames)." - ), + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", ) group.add_argument( "--num-buckets", type=int, default=30, - help=( - "The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets)." - ), + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help=( - "When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding." - ), + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", ) group.add_argument( "--duration-factor", type=float, default=1.0, - help=( - "Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch." - ), + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", ) group.add_argument( "--gap", type=float, default=1.0, - help=( - "The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used." - ), + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help=( - "When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available." - ), + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", ) group.add_argument( "--shuffle", type=str2bool, default=True, - help=( - "When enabled (=default), the examples will be shuffled for each epoch." - ), + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help=( - "When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it." - ), + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that collect the batches.", + help="The number of training dataloader workers that " + "collect the batches.", ) group.add_argument( @@ -188,22 +170,18 @@ class GigaSpeechAsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help=( - "Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp." - ), + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help=( - "When enabled, select noise from MUSAN and mix it " - "with training dataset. " - ), + help="When enabled, select noise from MUSAN and mix it " + "with training dataset. ", ) # GigaSpeech specific arguments @@ -217,7 +195,8 @@ class GigaSpeechAsrDataModule: "--small-dev", type=str2bool, default=False, - help="Should we use only 1000 utterances for dev (speeds up training)", + help="Should we use only 1000 utterances for dev " + "(speeds up training)", ) def train_dataloaders( @@ -237,16 +216,20 @@ class GigaSpeechAsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + cuts_musan = load_manifest( + self.args.manifest_dir / "musan_cuts.jsonl.gz" + ) transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix( + cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True + ) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - "Using cut concatenation with duration factor " + f"Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -261,7 +244,9 @@ class GigaSpeechAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + logging.info( + f"Time warp factor: {self.args.spec_aug_time_warp_factor}" + ) # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -304,7 +289,9 @@ class GigaSpeechAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -360,7 +347,9 @@ class GigaSpeechAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), return_cuts=self.args.return_cuts, ) else: @@ -416,7 +405,9 @@ class GigaSpeechAsrDataModule: @lru_cache() def dev_cuts(self) -> CutSet: logging.info("About to get dev cuts") - cuts_valid = load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") + cuts_valid = load_manifest_lazy( + self.args.manifest_dir / "cuts_DEV.jsonl.gz" + ) if self.args.small_dev: return cuts_valid.subset(first=1000) else: diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py index 9f5d4711b..5849a3471 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py @@ -77,7 +77,11 @@ from beam_search import ( from gigaspeech_scoring import asr_text_post_processing from train import get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import ( AttributeDict, setup_logger, @@ -114,11 +118,9 @@ def get_parser(): "--avg", type=int, default=8, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -186,7 +188,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -255,7 +258,9 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] if params.decoding_method == "fast_beam_search": @@ -270,7 +275,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -316,7 +324,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -386,7 +398,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -420,7 +434,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -496,7 +511,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py index 17f8614dc..cff9c7377 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py @@ -51,7 +51,11 @@ import sentencepiece as spm import torch from train import get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import str2bool @@ -83,11 +87,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -118,7 +120,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) return parser @@ -157,7 +160,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -205,7 +209,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py index 4d1a2356d..83ae25561 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py @@ -77,7 +77,9 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def get_parser(): @@ -176,45 +178,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -554,16 +553,23 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -726,7 +732,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") diff --git a/egs/librispeech/ASR/conformer_ctc/ali.py b/egs/librispeech/ASR/conformer_ctc/ali.py index 0169d0f82..2828e309e 100755 --- a/egs/librispeech/ASR/conformer_ctc/ali.py +++ b/egs/librispeech/ASR/conformer_ctc/ali.py @@ -61,19 +61,16 @@ def get_parser(): "--epoch", type=int, default=34, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=20, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -234,7 +231,9 @@ def compute_alignments( labels_ali = get_alignments(best_path, kind="labels") aux_labels_ali = get_alignments(best_path, kind="aux_labels") assert len(labels_ali) == len(aux_labels_ali) == len(cut_list) - for cut, labels, aux_labels in zip(cut_list, labels_ali, aux_labels_ali): + for cut, labels, aux_labels in zip( + cut_list, labels_ali, aux_labels_ali + ): cut.labels_alignment = labels_writer.store_array( key=cut.id, value=np.asarray(labels, dtype=np.int32), @@ -259,7 +258,9 @@ def compute_alignments( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return CutSet.from_cuts(cuts) @@ -288,7 +289,9 @@ def main(): out_labels_ali_filename = out_dir / f"labels_{params.dataset}.h5" out_aux_labels_ali_filename = out_dir / f"aux_labels_{params.dataset}.h5" - out_manifest_filename = out_dir / f"librispeech_cuts_{params.dataset}.jsonl.gz" + out_manifest_filename = ( + out_dir / f"librispeech_cuts_{params.dataset}.jsonl.gz" + ) for f in ( out_labels_ali_filename, diff --git a/egs/librispeech/ASR/conformer_ctc/conformer.py b/egs/librispeech/ASR/conformer_ctc/conformer.py index 1153a814c..6fac07f93 100644 --- a/egs/librispeech/ASR/conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc/conformer.py @@ -160,7 +160,9 @@ class ConformerEncoderLayer(nn.Module): use_conv_batchnorm: bool = False, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -180,14 +182,18 @@ class ConformerEncoderLayer(nn.Module): d_model, cnn_module_kernel, use_batchnorm=use_conv_batchnorm ) - self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm( + d_model + ) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm(d_model) # for the final output of the block + self.norm_final = nn.LayerNorm( + d_model + ) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -221,7 +227,9 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) + src = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(src) + ) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -340,7 +348,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -356,7 +366,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -626,9 +638,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -696,25 +708,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -751,7 +771,9 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -763,7 +785,9 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -797,9 +821,13 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 66fdf82d9..3f3b1acda 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -64,19 +64,16 @@ def get_parser(): "--epoch", type=int, default=77, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=55, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -554,7 +551,9 @@ def decode_dataset( results[lm_scale].extend(this_batch) else: - assert len(results) > 0, "It should not decode to empty in the first batch!" + assert ( + len(results) > 0 + ), "It should not decode to empty in the first batch!" this_batch = [] hyp_words = [] for ref_text in texts: @@ -569,7 +568,9 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -601,7 +602,9 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + logging.info( + "Wrote detailed error stats to {}".format(errs_filename) + ) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -806,7 +809,9 @@ def main(): eos_id=eos_id, ) - save_results(params=params, test_set_name=test_set, results_dict=results_dict) + save_results( + params=params, test_set_name=test_set, results_dict=results_dict + ) logging.info("Done!") diff --git a/egs/librispeech/ASR/conformer_ctc/export.py b/egs/librispeech/ASR/conformer_ctc/export.py index bdb8a85e5..28c28df01 100755 --- a/egs/librispeech/ASR/conformer_ctc/export.py +++ b/egs/librispeech/ASR/conformer_ctc/export.py @@ -40,20 +40,17 @@ def get_parser(): "--epoch", type=int, default=34, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=20, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -160,7 +157,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conformer_ctc/label_smoothing.py b/egs/librispeech/ASR/conformer_ctc/label_smoothing.py index cb0d6e04d..1f2f3b137 100644 --- a/egs/librispeech/ASR/conformer_ctc/label_smoothing.py +++ b/egs/librispeech/ASR/conformer_ctc/label_smoothing.py @@ -82,10 +82,13 @@ class LabelSmoothingLoss(torch.nn.Module): # for why we don't use target[ignored] = 0 here target = torch.where(ignored, torch.zeros_like(target), target) - true_dist = torch.nn.functional.one_hot(target, num_classes=num_classes).to(x) + true_dist = torch.nn.functional.one_hot( + target, num_classes=num_classes + ).to(x) true_dist = ( - true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes + true_dist * (1 - self.label_smoothing) + + self.label_smoothing / num_classes ) # Set the value of ignored indexes to 0 diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index 8cabf1a53..a2c0a5486 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -48,11 +48,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -191,12 +189,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) return parser @@ -240,9 +236,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -303,7 +300,9 @@ def main(): logging.info("Decoding started") features = fbank(waves) - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) # Note: We don't use key padding mask for attention during decoding with torch.no_grad(): @@ -428,7 +427,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 8e0f73d05..542fb0364 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -42,9 +42,13 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), + nn.Conv2d( + in_channels=1, out_channels=odim, kernel_size=3, stride=2 + ), nn.ReLU(), - nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), + nn.Conv2d( + in_channels=odim, out_channels=odim, kernel_size=3, stride=2 + ), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -128,13 +132,17 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) + torch.nn.MaxPool2d( + kernel_size=2, stride=2, padding=0, ceil_mode=True + ) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) + self.out = nn.Linear( + block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim + ) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 1a1c2f4c5..6419f6816 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -393,7 +393,9 @@ def compute_loss( # Works with a phone lexicon decoding_graph = graph_compiler.compile(texts) else: - raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}") + raise ValueError( + f"Unsupported type of graph compiler: {type(graph_compiler)}" + ) dense_fsa_vec = k2.DenseFsaVec( nnet_output, @@ -420,7 +422,9 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) + unsorted_token_ids = graph_compiler.texts_to_ids( + supervisions["text"] + ) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -449,7 +453,9 @@ def compute_loss( info["utt_duration"] = supervisions["num_frames"].sum().item() # averaged padding proportion over utterances info["utt_pad_proportion"] = ( - ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)).sum().item() + ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)) + .sum() + .item() ) return loss, info @@ -562,7 +568,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -652,7 +660,7 @@ def run(rank, world_size, args): graph_compiler.eos_id = 1 else: raise ValueError( - "Unsupported type of lang dir (we expected it to have " + f"Unsupported type of lang dir (we expected it to have " f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}" ) @@ -725,7 +733,9 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc/transformer.py index 0566cfc81..00ca027a7 100644 --- a/egs/librispeech/ASR/conformer_ctc/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/transformer.py @@ -151,7 +151,9 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) + self.decoder_output_layer = torch.nn.Linear( + d_model, self.decoder_num_class + ) self.decoder_criterion = LabelSmoothingLoss() else: @@ -179,13 +181,18 @@ class Transformer(nn.Module): memory_key_padding_mask for the decoder. Its shape is (N, T). It is None if `supervision` is None. """ - if isinstance(self.use_feat_batchnorm, bool) and self.use_feat_batchnorm: + if ( + isinstance(self.use_feat_batchnorm, bool) + and self.use_feat_batchnorm + ): x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) if isinstance(self.use_feat_batchnorm, float): x *= self.use_feat_batchnorm - encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) + encoder_memory, memory_key_padding_mask = self.run_encoder( + x, supervision + ) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -266,17 +273,23 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -337,17 +350,23 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -620,7 +639,9 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) + raise RuntimeError( + "activation should be relu/gelu, not {}".format(activation) + ) class PositionalEncoding(nn.Module): @@ -822,7 +843,9 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] + lengths = [ + 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) + ] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -843,7 +866,9 @@ def encoder_padding_mask( return mask -def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: +def decoder_padding_mask( + ys_pad: torch.Tensor, ignore_id: int = -1 +) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/librispeech/ASR/conformer_ctc2/attention.py b/egs/librispeech/ASR/conformer_ctc2/attention.py index 356d3f21b..1375d7245 100644 --- a/egs/librispeech/ASR/conformer_ctc2/attention.py +++ b/egs/librispeech/ASR/conformer_ctc2/attention.py @@ -18,10 +18,11 @@ from typing import Optional, Tuple import torch import torch.nn as nn -from scaling import ScaledLinear from torch import Tensor from torch.nn.init import xavier_normal_ +from scaling import ScaledLinear + class MultiheadAttention(nn.Module): r"""Allows the model to jointly attend to information @@ -75,7 +76,9 @@ class MultiheadAttention(nn.Module): self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim - self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + self._qkv_same_embed_dim = ( + self.kdim == embed_dim and self.vdim == embed_dim + ) self.num_heads = num_heads self.dropout = dropout @@ -91,7 +94,9 @@ class MultiheadAttention(nn.Module): self.v_proj_weight = ScaledLinear(self.vdim, embed_dim, bias=bias) self.register_parameter("in_proj_weight", None) else: - self.in_proj_weight = ScaledLinear(embed_dim, 3 * embed_dim, bias=bias) + self.in_proj_weight = ScaledLinear( + embed_dim, 3 * embed_dim, bias=bias + ) self.register_parameter("q_proj_weight", None) self.register_parameter("k_proj_weight", None) self.register_parameter("v_proj_weight", None) @@ -102,8 +107,12 @@ class MultiheadAttention(nn.Module): self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=bias) if add_bias_kv: - self.bias_k = nn.Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) - self.bias_v = nn.Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) + self.bias_k = nn.Parameter( + torch.empty((1, 1, embed_dim), **factory_kwargs) + ) + self.bias_v = nn.Parameter( + torch.empty((1, 1, embed_dim), **factory_kwargs) + ) else: self.bias_k = self.bias_v = None diff --git a/egs/librispeech/ASR/conformer_ctc2/conformer.py b/egs/librispeech/ASR/conformer_ctc2/conformer.py index a6f1679ef..b906d2650 100644 --- a/egs/librispeech/ASR/conformer_ctc2/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc2/conformer.py @@ -29,8 +29,9 @@ from scaling import ( ScaledConv1d, ScaledLinear, ) -from subsampling import Conv2dSubsampling from torch import Tensor, nn +from subsampling import Conv2dSubsampling + from transformer import Supervisions, Transformer, encoder_padding_mask @@ -181,7 +182,9 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -353,7 +356,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -368,7 +373,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vecotr and `j` means the @@ -643,9 +650,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -714,25 +721,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -769,7 +784,9 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -777,9 +794,13 @@ class RelPositionMultiheadAttention(nn.Module): ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) + attn_output_weights = ( + matrix_ac + matrix_bd + ) # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -813,9 +834,13 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -838,7 +863,9 @@ class ConvolutionModule(nn.Module): """ - def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: + def __init__( + self, channels: int, kernel_size: int, bias: bool = True + ) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/conformer_ctc2/decode.py b/egs/librispeech/ASR/conformer_ctc2/decode.py index 934177b1f..97f2f2d39 100755 --- a/egs/librispeech/ASR/conformer_ctc2/decode.py +++ b/egs/librispeech/ASR/conformer_ctc2/decode.py @@ -90,11 +90,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -132,13 +130,11 @@ def get_parser(): "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -662,7 +658,9 @@ def decode_dataset( results[lm_scale].extend(this_batch) else: - assert len(results) > 0, "It should not decode to empty in the first batch!" + assert ( + len(results) > 0 + ), "It should not decode to empty in the first batch!" this_batch = [] hyp_words = [] for ref_text in texts: @@ -677,7 +675,9 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -709,7 +709,9 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + logging.info( + "Wrote detailed error stats to {}".format(errs_filename) + ) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -850,12 +852,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -878,12 +881,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -911,7 +915,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -981,7 +985,9 @@ def main(): eos_id=eos_id, ) - save_results(params=params, test_set_name=test_set, results_dict=results_dict) + save_results( + params=params, test_set_name=test_set, results_dict=results_dict + ) logging.info("Done!") diff --git a/egs/librispeech/ASR/conformer_ctc2/export.py b/egs/librispeech/ASR/conformer_ctc2/export.py index 0e1841d8d..584b3c3fc 100755 --- a/egs/librispeech/ASR/conformer_ctc2/export.py +++ b/egs/librispeech/ASR/conformer_ctc2/export.py @@ -47,7 +47,6 @@ import logging from pathlib import Path import torch -from conformer import Conformer from decode import get_params from icefall.checkpoint import ( @@ -56,8 +55,10 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon +from conformer import Conformer + from icefall.utils import str2bool +from icefall.lexicon import Lexicon def get_parser(): @@ -88,24 +89,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -180,12 +177,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -208,12 +206,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -241,7 +240,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -274,7 +273,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conformer_ctc2/train.py b/egs/librispeech/ASR/conformer_ctc2/train.py index 4d7137ad7..18fa3e69f 100755 --- a/egs/librispeech/ASR/conformer_ctc2/train.py +++ b/egs/librispeech/ASR/conformer_ctc2/train.py @@ -69,8 +69,8 @@ from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter -from icefall import diagnostics from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall import diagnostics from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import ( @@ -89,7 +89,9 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def get_parser(): @@ -496,7 +498,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -525,7 +531,9 @@ def compute_loss( # Works with a phone lexicon decoding_graph = graph_compiler.compile(texts) else: - raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}") + raise ValueError( + f"Unsupported type of graph compiler: {type(graph_compiler)}" + ) dense_fsa_vec = k2.DenseFsaVec( nnet_output, @@ -552,7 +560,9 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) + unsorted_token_ids = graph_compiler.texts_to_ids( + supervisions["text"] + ) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -570,7 +580,9 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) info["ctc_loss"] = ctc_loss.detach().cpu().item() if params.att_rate != 0.0: info["att_loss"] = att_loss.detach().cpu().item() @@ -708,7 +720,8 @@ def train_one_epoch( except RuntimeError as e: if "CUDA out of memory" in str(e): logging.error( - f"failing batch size:{batch_size} failing batch names {batch_name}" + f"failing batch size:{batch_size} " + f"failing batch names {batch_name}" ) raise @@ -763,9 +776,9 @@ def train_one_epoch( f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"lr: {cur_lr:.2e}" ) - if loss_info["ctc_loss"] == float("inf") or loss_info["att_loss"] == float( - "inf" - ): + if loss_info["ctc_loss"] == float("inf") or loss_info[ + "att_loss" + ] == float("inf"): logging.error( "Your loss contains inf, something goes wrong" f"failing batch names {batch_name}" @@ -778,7 +791,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -870,7 +885,7 @@ def run(rank, world_size, args): graph_compiler.eos_id = 1 else: raise ValueError( - "Unsupported type of lang dir (we expected it to have " + f"Unsupported type of lang dir (we expected it to have " f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}" ) diff --git a/egs/librispeech/ASR/conformer_ctc2/transformer.py b/egs/librispeech/ASR/conformer_ctc2/transformer.py index d3443dc94..3ef7edc23 100644 --- a/egs/librispeech/ASR/conformer_ctc2/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc2/transformer.py @@ -21,17 +21,19 @@ from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn -from attention import MultiheadAttention from label_smoothing import LabelSmoothingLoss +from subsampling import Conv2dSubsampling +from attention import MultiheadAttention +from torch.nn.utils.rnn import pad_sequence + from scaling import ( ActivationBalancer, BasicNorm, DoubleSwish, - ScaledEmbedding, ScaledLinear, + ScaledEmbedding, ) -from subsampling import Conv2dSubsampling -from torch.nn.utils.rnn import pad_sequence + # Note: TorchScript requires Dict/List/etc. to be fully typed. Supervisions = Dict[str, torch.Tensor] @@ -208,7 +210,9 @@ class Transformer(nn.Module): x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) mask = encoder_padding_mask(x.size(0), supervisions) mask = mask.to(x.device) if mask is not None else None - x = self.encoder(x, src_key_padding_mask=mask, warmup=warmup) # (T, N, C) + x = self.encoder( + x, src_key_padding_mask=mask, warmup=warmup + ) # (T, N, C) return x, mask @@ -257,17 +261,23 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -328,17 +338,23 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -943,7 +959,9 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] + lengths = [ + 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) + ] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -964,7 +982,9 @@ def encoder_padding_mask( return mask -def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: +def decoder_padding_mask( + ys_pad: torch.Tensor, ignore_id: int = -1 +) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/librispeech/ASR/conformer_mmi/conformer.py b/egs/librispeech/ASR/conformer_mmi/conformer.py index 4d9ddaea9..97c8d83a2 100644 --- a/egs/librispeech/ASR/conformer_mmi/conformer.py +++ b/egs/librispeech/ASR/conformer_mmi/conformer.py @@ -156,7 +156,9 @@ class ConformerEncoderLayer(nn.Module): normalize_before: bool = True, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -174,14 +176,18 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm( + d_model + ) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm(d_model) # for the final output of the block + self.norm_final = nn.LayerNorm( + d_model + ) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -215,7 +221,9 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) + src = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(src) + ) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -334,7 +342,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -350,7 +360,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -620,9 +632,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -690,25 +702,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -745,7 +765,9 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -757,7 +779,9 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -791,9 +815,13 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -816,7 +844,9 @@ class ConvolutionModule(nn.Module): """ - def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: + def __init__( + self, channels: int, kernel_size: int, bias: bool = True + ) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/conformer_mmi/decode.py b/egs/librispeech/ASR/conformer_mmi/decode.py index e8390ded9..fc9861489 100755 --- a/egs/librispeech/ASR/conformer_mmi/decode.py +++ b/egs/librispeech/ASR/conformer_mmi/decode.py @@ -60,19 +60,16 @@ def get_parser(): "--epoch", type=int, default=34, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=20, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -481,7 +478,9 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -513,7 +512,9 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + logging.info( + "Wrote detailed error stats to {}".format(errs_filename) + ) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -652,7 +653,9 @@ def main(): if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") + torch.save( + {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" + ) return model.to(device) @@ -684,7 +687,9 @@ def main(): eos_id=eos_id, ) - save_results(params=params, test_set_name=test_set, results_dict=results_dict) + save_results( + params=params, test_set_name=test_set, results_dict=results_dict + ) logging.info("Done!") diff --git a/egs/librispeech/ASR/conformer_mmi/subsampling.py b/egs/librispeech/ASR/conformer_mmi/subsampling.py index ad9415987..5c3e1222e 100644 --- a/egs/librispeech/ASR/conformer_mmi/subsampling.py +++ b/egs/librispeech/ASR/conformer_mmi/subsampling.py @@ -25,9 +25,13 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), + nn.Conv2d( + in_channels=1, out_channels=odim, kernel_size=3, stride=2 + ), nn.ReLU(), - nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), + nn.Conv2d( + in_channels=odim, out_channels=odim, kernel_size=3, stride=2 + ), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -111,13 +115,17 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) + torch.nn.MaxPool2d( + kernel_size=2, stride=2, padding=0, ceil_mode=True + ) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) + self.out = nn.Linear( + block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim + ) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/librispeech/ASR/conformer_mmi/test_subsampling.py b/egs/librispeech/ASR/conformer_mmi/test_subsampling.py index d0bb017dd..937845d77 100755 --- a/egs/librispeech/ASR/conformer_mmi/test_subsampling.py +++ b/egs/librispeech/ASR/conformer_mmi/test_subsampling.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 +from subsampling import Conv2dSubsampling +from subsampling import VggSubsampling import torch -from subsampling import Conv2dSubsampling, VggSubsampling def test_conv2d_subsampling(): diff --git a/egs/librispeech/ASR/conformer_mmi/test_transformer.py b/egs/librispeech/ASR/conformer_mmi/test_transformer.py index 25d18076d..08e680607 100644 --- a/egs/librispeech/ASR/conformer_mmi/test_transformer.py +++ b/egs/librispeech/ASR/conformer_mmi/test_transformer.py @@ -1,16 +1,17 @@ #!/usr/bin/env python3 import torch -from torch.nn.utils.rnn import pad_sequence from transformer import ( Transformer, - add_eos, - add_sos, - decoder_padding_mask, encoder_padding_mask, generate_square_subsequent_mask, + decoder_padding_mask, + add_sos, + add_eos, ) +from torch.nn.utils.rnn import pad_sequence + def test_encoder_padding_mask(): supervisions = { diff --git a/egs/librispeech/ASR/conformer_mmi/train-with-attention.py b/egs/librispeech/ASR/conformer_mmi/train-with-attention.py index f8c94cff9..011dadd73 100755 --- a/egs/librispeech/ASR/conformer_mmi/train-with-attention.py +++ b/egs/librispeech/ASR/conformer_mmi/train-with-attention.py @@ -36,14 +36,23 @@ from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter from transformer import Noam -from icefall.ali import convert_alignments_to_tensor, load_alignments, lookup_alignments +from icefall.ali import ( + convert_alignments_to_tensor, + load_alignments, + lookup_alignments, +) from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.lexicon import Lexicon from icefall.mmi import LFMMILoss from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler -from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + encode_supervisions, + setup_logger, + str2bool, +) def get_parser(): @@ -361,7 +370,10 @@ def compute_loss( nnet_output = nnet_output.clone() nnet_output[:, :min_len, :] += ali_scale * mask[:, :min_len, :] - if params.batch_idx_train > params.use_ali_until and params.beam_size < 8: + if ( + params.batch_idx_train > params.use_ali_until + and params.beam_size < 8 + ): # logging.info("Change beam size to 8") params.beam_size = 8 else: @@ -750,14 +762,19 @@ def run(rank, world_size, args): for epoch in range(params.start_epoch, params.num_epochs): train_dl.sampler.set_epoch(epoch) - if params.batch_idx_train >= params.use_ali_until and train_ali is not None: + if ( + params.batch_idx_train >= params.use_ali_until + and train_ali is not None + ): # Delete the alignments to save memory train_ali = None valid_ali = None cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/conformer_mmi/train.py b/egs/librispeech/ASR/conformer_mmi/train.py index 5cfb2bfc7..9a5bdcce2 100755 --- a/egs/librispeech/ASR/conformer_mmi/train.py +++ b/egs/librispeech/ASR/conformer_mmi/train.py @@ -36,14 +36,23 @@ from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter from transformer import Noam -from icefall.ali import convert_alignments_to_tensor, load_alignments, lookup_alignments +from icefall.ali import ( + convert_alignments_to_tensor, + load_alignments, + lookup_alignments, +) from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.lexicon import Lexicon from icefall.mmi import LFMMILoss from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler -from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + encode_supervisions, + setup_logger, + str2bool, +) def get_parser(): @@ -368,7 +377,10 @@ def compute_loss( nnet_output = nnet_output.clone() nnet_output[:, :min_len, :] += ali_scale * mask[:, :min_len, :] - if params.batch_idx_train > params.use_ali_until and params.beam_size < 8: + if ( + params.batch_idx_train > params.use_ali_until + and params.beam_size < 8 + ): logging.info("Change beam size to 8") params.beam_size = 8 else: @@ -758,14 +770,19 @@ def run(rank, world_size, args): for epoch in range(params.start_epoch, params.num_epochs): fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) - if params.batch_idx_train >= params.use_ali_until and train_ali is not None: + if ( + params.batch_idx_train >= params.use_ali_until + and train_ali is not None + ): # Delete the alignments to save memory train_ali = None valid_ali = None cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/conformer_mmi/transformer.py b/egs/librispeech/ASR/conformer_mmi/transformer.py index 2542d9abe..68a4ff65c 100644 --- a/egs/librispeech/ASR/conformer_mmi/transformer.py +++ b/egs/librispeech/ASR/conformer_mmi/transformer.py @@ -148,7 +148,9 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) + self.decoder_output_layer = torch.nn.Linear( + d_model, self.decoder_num_class + ) self.decoder_criterion = LabelSmoothingLoss(self.decoder_num_class) else: @@ -180,7 +182,9 @@ class Transformer(nn.Module): x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) - encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) + encoder_memory, memory_key_padding_mask = self.run_encoder( + x, supervision + ) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -270,7 +274,9 @@ class Transformer(nn.Module): ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -335,7 +341,9 @@ class Transformer(nn.Module): ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -608,7 +616,9 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) + raise RuntimeError( + "activation should be relu/gelu, not {}".format(activation) + ) class PositionalEncoding(nn.Module): @@ -877,7 +887,9 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] + lengths = [ + 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) + ] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -898,7 +910,9 @@ def encoder_padding_mask( return mask -def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: +def decoder_padding_mask( + ys_pad: torch.Tensor, ignore_id: int = -1 +) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py index a1c43f7f5..620d69a19 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py @@ -135,24 +135,20 @@ def get_parser(): "--avg", type=int, default=10, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -219,7 +215,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -287,7 +284,9 @@ def decode_one_batch( value=LOG_EPS, ) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] if params.decoding_method == "fast_beam_search": @@ -302,7 +301,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -348,7 +350,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -421,7 +427,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -454,7 +462,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -497,7 +506,9 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -529,12 +540,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -557,12 +569,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -590,7 +603,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py index 0639ba746..8ca7d5568 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py @@ -35,6 +35,7 @@ from scaling import ( from icefall.utils import make_pad_mask + LOG_EPSILON = math.log(1e-10) @@ -126,7 +127,9 @@ def stack_states( for si, s in enumerate(layer): attn_caches[li][si].append(s) if b == batch_size - 1: - attn_caches[li][si] = torch.stack(attn_caches[li][si], dim=1) + attn_caches[li][si] = torch.stack( + attn_caches[li][si], dim=1 + ) conv_caches = [] for layer in state_list[0][1]: @@ -265,7 +268,9 @@ class ConvolutionModule(nn.Module): intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length ) - first = torch.arange(self.chunk_length, self.chunk_length + self.cache_size) + first = torch.arange( + self.chunk_length, self.chunk_length + self.cache_size + ) indexes = intervals.unsqueeze(1) + first.unsqueeze(0) indexes = torch.cat( [indexes, torch.arange(U_ - self.cache_size, U_).unsqueeze(0)] @@ -279,7 +284,9 @@ class ConvolutionModule(nn.Module): # (num_chunks * B, cache_size + right_context_length, D) return pad_right_context.permute(0, 2, 1) - def _merge_right_context(self, right_context: torch.Tensor, B: int) -> torch.Tensor: + def _merge_right_context( + self, right_context: torch.Tensor, B: int + ) -> torch.Tensor: """ Args: right_context: @@ -330,8 +337,12 @@ class ConvolutionModule(nn.Module): right_context = x[:, :, :R] # (B, D, R) # make causal convolution - cache = torch.zeros(B, D, self.cache_size, device=x.device, dtype=x.dtype) - pad_utterance = torch.cat([cache, utterance], dim=2) # (B, D, cache + U) + cache = torch.zeros( + B, D, self.cache_size, device=x.device, dtype=x.dtype + ) + pad_utterance = torch.cat( + [cache, utterance], dim=2 + ) # (B, D, cache + U) # depth-wise conv on utterance utterance = self.depthwise_conv(pad_utterance) # (B, D, U) @@ -344,7 +355,9 @@ class ConvolutionModule(nn.Module): right_context = self.depthwise_conv( pad_right_context ) # (num_segs * B, D, right_context_length) - right_context = self._merge_right_context(right_context, B) # (B, D, R) + right_context = self._merge_right_context( + right_context, B + ) # (B, D, R) x = torch.cat([right_context, utterance], dim=2) # (B, D, R + U) x = self.deriv_balancer2(x) @@ -445,7 +458,8 @@ class EmformerAttention(nn.Module): if embed_dim % nhead != 0: raise ValueError( - f"embed_dim ({embed_dim}) is not a multiple ofnhead ({nhead})." + f"embed_dim ({embed_dim}) is not a multiple of" + f"nhead ({nhead})." ) self.embed_dim = embed_dim @@ -455,7 +469,9 @@ class EmformerAttention(nn.Module): self.head_dim = embed_dim // nhead self.dropout = dropout - self.emb_to_key_value = ScaledLinear(embed_dim, 2 * embed_dim, bias=True) + self.emb_to_key_value = ScaledLinear( + embed_dim, 2 * embed_dim, bias=True + ) self.emb_to_query = ScaledLinear(embed_dim, embed_dim, bias=True) self.out_proj = ScaledLinear( embed_dim, embed_dim, bias=True, initial_scale=0.25 @@ -497,7 +513,9 @@ class EmformerAttention(nn.Module): if padding_mask is not None: Q = attention_weights.size(1) B = attention_weights.size(0) // self.nhead - attention_weights_float = attention_weights_float.view(B, self.nhead, Q, -1) + attention_weights_float = attention_weights_float.view( + B, self.nhead, Q, -1 + ) attention_weights_float = attention_weights_float.masked_fill( padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), self.negative_inf, @@ -533,7 +551,9 @@ class EmformerAttention(nn.Module): scaling = float(self.head_dim) ** -0.5 # compute query with [right_context, utterance, summary]. - query = self.emb_to_query(torch.cat([right_context, utterance, summary])) + query = self.emb_to_query( + torch.cat([right_context, utterance, summary]) + ) # compute key and value with [memory, right_context, utterance]. key, value = self.emb_to_key_value( torch.cat([memory, right_context, utterance]) @@ -544,12 +564,16 @@ class EmformerAttention(nn.Module): # [memory, right context, left context, uttrance] # this is used in inference mode key = torch.cat([key[: M + R], left_context_key, key[M + R :]]) - value = torch.cat([value[: M + R], left_context_val, value[M + R :]]) + value = torch.cat( + [value[: M + R], left_context_val, value[M + R :]] + ) Q = query.size(0) # KV = key.size(0) reshaped_query, reshaped_key, reshaped_value = [ - tensor.contiguous().view(-1, B * self.nhead, self.head_dim).transpose(0, 1) + tensor.contiguous() + .view(-1, B * self.nhead, self.head_dim) + .transpose(0, 1) for tensor in [query, key, value] ] # (B * nhead, Q or KV, head_dim) attention_weights = torch.bmm( @@ -564,7 +588,9 @@ class EmformerAttention(nn.Module): # compute attention outputs attention = torch.bmm(attention_probs, reshaped_value) assert attention.shape == (B * self.nhead, Q, self.head_dim) - attention = attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) + attention = ( + attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) + ) # apply output projection outputs = self.out_proj(attention) @@ -646,7 +672,12 @@ class EmformerAttention(nn.Module): - output of right context and utterance, with shape (R + U, B, D). - memory output, with shape (M, B, D), where M = S - 1 or M = 0. """ - (output_right_context_utterance, output_memory, _, _,) = self._forward_impl( + ( + output_right_context_utterance, + output_memory, + _, + _, + ) = self._forward_impl( utterance, right_context, summary, @@ -916,9 +947,13 @@ class EmformerEncoderLayer(nn.Module): right_context = right_context_utterance[:R] if self.use_memory: - summary = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1) + summary = self.summary_op(utterance.permute(1, 2, 0)).permute( + 2, 0, 1 + ) else: - summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) + summary = torch.empty(0).to( + dtype=utterance.dtype, device=utterance.device + ) output_right_context_utterance, output_memory = self.attention( utterance=utterance, right_context=right_context, @@ -957,10 +992,14 @@ class EmformerEncoderLayer(nn.Module): left_context_val = attn_cache[2] if self.use_memory: - summary = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1) + summary = self.summary_op(utterance.permute(1, 2, 0)).permute( + 2, 0, 1 + ) summary = summary[:1] else: - summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) + summary = torch.empty(0).to( + dtype=utterance.dtype, device=utterance.device + ) ( output_right_context_utterance, output_memory, @@ -975,7 +1014,9 @@ class EmformerEncoderLayer(nn.Module): left_context_val=left_context_val, padding_mask=padding_mask, ) - attn_cache = self._update_attn_cache(next_key, next_val, memory, attn_cache) + attn_cache = self._update_attn_cache( + next_key, next_val, memory, attn_cache + ) return output_right_context_utterance, output_memory, attn_cache def forward( @@ -1110,7 +1151,11 @@ class EmformerEncoderLayer(nn.Module): src = src + self.dropout(self.feed_forward_macaron(src)) # emformer attention module - (src_att, output_memory, attn_cache,) = self._apply_attention_module_infer( + ( + src_att, + output_memory, + attn_cache, + ) = self._apply_attention_module_infer( src, R, memory, attn_cache, padding_mask=padding_mask ) src = src + self.dropout(src_att) @@ -1250,7 +1295,9 @@ class EmformerEncoder(nn.Module): def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: """Hard copy each chunk's right context and concat them.""" T = x.shape[0] - num_chunks = math.ceil((T - self.right_context_length) / self.chunk_length) + num_chunks = math.ceil( + (T - self.right_context_length) / self.chunk_length + ) # first (num_chunks - 1) right context block intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length @@ -1269,7 +1316,9 @@ class EmformerEncoder(nn.Module): right_context_blocks = x[indexes.reshape(-1)] return right_context_blocks - def _gen_attention_mask_col_widths(self, chunk_idx: int, U: int) -> List[int]: + def _gen_attention_mask_col_widths( + self, chunk_idx: int, U: int + ) -> List[int]: """Calculate column widths (key, value) in attention mask for the chunk_idx chunk.""" num_chunks = math.ceil(U / self.chunk_length) @@ -1430,7 +1479,9 @@ class EmformerEncoder(nn.Module): output_lengths = torch.clamp(lengths - self.right_context_length, min=0) attention_mask = self._gen_attention_mask(utterance) memory = ( - self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[:-1] + self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ + :-1 + ] if self.use_memory else torch.empty(0).to(dtype=x.dtype, device=x.device) ) @@ -1592,8 +1643,12 @@ class EmformerEncoder(nn.Module): attn_caches = [ [ torch.zeros(self.memory_size, self.d_model, device=device), - torch.zeros(self.left_context_length, self.d_model, device=device), - torch.zeros(self.left_context_length, self.d_model, device=device), + torch.zeros( + self.left_context_length, self.d_model, device=device + ), + torch.zeros( + self.left_context_length, self.d_model, device=device + ), ] for _ in range(self.num_encoder_layers) ] @@ -1638,11 +1693,17 @@ class Emformer(EncoderInterface): raise NotImplementedError( "chunk_length must be a mutiple of subsampling_factor." ) - if left_context_length != 0 and left_context_length % subsampling_factor != 0: + if ( + left_context_length != 0 + and left_context_length % subsampling_factor != 0 + ): raise NotImplementedError( "left_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) - if right_context_length != 0 and right_context_length % subsampling_factor != 0: + if ( + right_context_length != 0 + and right_context_length % subsampling_factor != 0 + ): raise NotImplementedError( "right_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) @@ -1705,7 +1766,9 @@ class Emformer(EncoderInterface): x_lens = (((x_lens - 1) >> 1) - 1) >> 1 assert x.size(0) == x_lens.max().item() - output, output_lengths = self.encoder(x, x_lens, warmup=warmup) # (T, N, C) + output, output_lengths = self.encoder( + x, x_lens, warmup=warmup + ) # (T, N, C) output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py index 59105e286..4930881ea 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py @@ -103,11 +103,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -138,20 +136,19 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) add_model_arguments(parser) @@ -184,12 +181,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -212,12 +210,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -245,7 +244,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -280,7 +279,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py index c211b215e..9494e1fc1 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py @@ -68,12 +68,14 @@ class Stream(object): elif params.decoding_method == "fast_beam_search": # feature_len is needed to get partial results. # The rnnt_decoding_stream for fast_beam_search. - self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( - decoding_graph + self.rnnt_decoding_stream: k2.RnntDecodingStream = ( + k2.RnntDecodingStream(decoding_graph) ) self.hyp: Optional[List[int]] = None else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) self.ground_truth: str = "" diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py index abe83732a..61dbe8658 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py @@ -113,9 +113,8 @@ def get_parser(): "--epoch", type=int, default=28, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( @@ -132,24 +131,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -216,7 +211,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -375,7 +371,9 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner(current_encoder_out, decoder_out, project_input=False) + logits = model.joiner( + current_encoder_out, decoder_out, project_input=False + ) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -392,7 +390,9 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -551,10 +551,14 @@ def decode_one_chunk( feature_list, batch_first=True, padding_value=LOG_EPSILON ).to(device) feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor(num_processed_frames_list, device=device) + num_processed_frames = torch.tensor( + num_processed_frames_list, device=device + ) # Make sure it has at least 1 frame after subsampling, first-and-last-frame cutting, and right context cutting # noqa - tail_length = 3 * params.subsampling_factor + params.right_context_length + 3 + tail_length = ( + 3 * params.subsampling_factor + params.right_context_length + 3 + ) if features.size(1) < tail_length: pad_length = tail_length - features.size(1) feature_lens += pad_length @@ -601,7 +605,9 @@ def decode_one_chunk( max_states=params.max_states, ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) # Update cached states of each stream state_list = unstack_states(states) @@ -776,7 +782,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -824,7 +831,9 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -858,12 +867,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -886,12 +896,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -919,7 +930,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py index a76417e5f..c07d8f76b 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py @@ -95,7 +95,9 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def add_model_arguments(parser: argparse.ArgumentParser): @@ -263,45 +265,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -637,7 +636,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -665,16 +668,23 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -861,7 +871,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -969,7 +981,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 2 ** 22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py index 9cb4a5afc..98b8290b5 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py @@ -135,24 +135,20 @@ def get_parser(): "--avg", type=int, default=10, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -219,7 +215,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -287,7 +284,9 @@ def decode_one_batch( value=LOG_EPS, ) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] if params.decoding_method == "fast_beam_search": @@ -302,7 +301,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -348,7 +350,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -421,7 +427,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -454,7 +462,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -497,7 +506,9 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -529,12 +540,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -557,12 +569,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -590,7 +603,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py index 09200f2e1..f16f5acc7 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py @@ -35,6 +35,7 @@ from scaling import ( from icefall.utils import make_pad_mask + LOG_EPSILON = math.log(1e-10) @@ -126,7 +127,9 @@ def stack_states( for si, s in enumerate(layer): attn_caches[li][si].append(s) if b == batch_size - 1: - attn_caches[li][si] = torch.stack(attn_caches[li][si], dim=1) + attn_caches[li][si] = torch.stack( + attn_caches[li][si], dim=1 + ) conv_caches = [] for layer in state_list[0][1]: @@ -265,7 +268,9 @@ class ConvolutionModule(nn.Module): intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length ) - first = torch.arange(self.chunk_length, self.chunk_length + self.cache_size) + first = torch.arange( + self.chunk_length, self.chunk_length + self.cache_size + ) indexes = intervals.unsqueeze(1) + first.unsqueeze(0) indexes = torch.cat( [indexes, torch.arange(U_ - self.cache_size, U_).unsqueeze(0)] @@ -279,7 +284,9 @@ class ConvolutionModule(nn.Module): # (num_chunks * B, cache_size + right_context_length, D) return pad_right_context.permute(0, 2, 1) - def _merge_right_context(self, right_context: torch.Tensor, B: int) -> torch.Tensor: + def _merge_right_context( + self, right_context: torch.Tensor, B: int + ) -> torch.Tensor: """ Args: right_context: @@ -330,8 +337,12 @@ class ConvolutionModule(nn.Module): right_context = x[:, :, :R] # (B, D, R) # make causal convolution - cache = torch.zeros(B, D, self.cache_size, device=x.device, dtype=x.dtype) - pad_utterance = torch.cat([cache, utterance], dim=2) # (B, D, cache + U) + cache = torch.zeros( + B, D, self.cache_size, device=x.device, dtype=x.dtype + ) + pad_utterance = torch.cat( + [cache, utterance], dim=2 + ) # (B, D, cache + U) # depth-wise conv on utterance utterance = self.depthwise_conv(pad_utterance) # (B, D, U) @@ -344,7 +355,9 @@ class ConvolutionModule(nn.Module): right_context = self.depthwise_conv( pad_right_context ) # (num_segs * B, D, right_context_length) - right_context = self._merge_right_context(right_context, B) # (B, D, R) + right_context = self._merge_right_context( + right_context, B + ) # (B, D, R) x = torch.cat([right_context, utterance], dim=2) # (B, D, R + U) x = self.deriv_balancer2(x) @@ -445,7 +458,8 @@ class EmformerAttention(nn.Module): if embed_dim % nhead != 0: raise ValueError( - f"embed_dim ({embed_dim}) is not a multiple ofnhead ({nhead})." + f"embed_dim ({embed_dim}) is not a multiple of" + f"nhead ({nhead})." ) self.embed_dim = embed_dim @@ -455,7 +469,9 @@ class EmformerAttention(nn.Module): self.head_dim = embed_dim // nhead self.dropout = dropout - self.emb_to_key_value = ScaledLinear(embed_dim, 2 * embed_dim, bias=True) + self.emb_to_key_value = ScaledLinear( + embed_dim, 2 * embed_dim, bias=True + ) self.emb_to_query = ScaledLinear(embed_dim, embed_dim, bias=True) self.out_proj = ScaledLinear( embed_dim, embed_dim, bias=True, initial_scale=0.25 @@ -497,7 +513,9 @@ class EmformerAttention(nn.Module): if padding_mask is not None: Q = attention_weights.size(1) B = attention_weights.size(0) // self.nhead - attention_weights_float = attention_weights_float.view(B, self.nhead, Q, -1) + attention_weights_float = attention_weights_float.view( + B, self.nhead, Q, -1 + ) attention_weights_float = attention_weights_float.masked_fill( padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), self.negative_inf, @@ -543,12 +561,16 @@ class EmformerAttention(nn.Module): # [memory, right context, left context, uttrance] # this is used in inference mode key = torch.cat([key[: M + R], left_context_key, key[M + R :]]) - value = torch.cat([value[: M + R], left_context_val, value[M + R :]]) + value = torch.cat( + [value[: M + R], left_context_val, value[M + R :]] + ) Q = query.size(0) # KV = key.size(0) reshaped_query, reshaped_key, reshaped_value = [ - tensor.contiguous().view(-1, B * self.nhead, self.head_dim).transpose(0, 1) + tensor.contiguous() + .view(-1, B * self.nhead, self.head_dim) + .transpose(0, 1) for tensor in [query, key, value] ] # (B * nhead, Q or KV, head_dim) attention_weights = torch.bmm( @@ -563,7 +585,9 @@ class EmformerAttention(nn.Module): # compute attention outputs attention = torch.bmm(attention_probs, reshaped_value) assert attention.shape == (B * self.nhead, Q, self.head_dim) - attention = attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) + attention = ( + attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) + ) # apply output projection output_right_context_utterance = self.out_proj(attention) @@ -881,11 +905,13 @@ class EmformerEncoderLayer(nn.Module): right_context = right_context_utterance[:R] if self.use_memory: - memory = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ - :-1, :, : - ] + memory = self.summary_op(utterance.permute(1, 2, 0)).permute( + 2, 0, 1 + )[:-1, :, :] else: - memory = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) + memory = torch.empty(0).to( + dtype=utterance.dtype, device=utterance.device + ) output_right_context_utterance = self.attention( utterance=utterance, right_context=right_context, @@ -922,12 +948,18 @@ class EmformerEncoderLayer(nn.Module): left_context_val = attn_cache[2] if self.use_memory: - memory = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ - :1, :, : - ] + memory = self.summary_op(utterance.permute(1, 2, 0)).permute( + 2, 0, 1 + )[:1, :, :] else: - memory = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) - (output_right_context_utterance, next_key, next_val,) = self.attention.infer( + memory = torch.empty(0).to( + dtype=utterance.dtype, device=utterance.device + ) + ( + output_right_context_utterance, + next_key, + next_val, + ) = self.attention.infer( utterance=utterance, right_context=right_context, memory=pre_memory, @@ -935,7 +967,9 @@ class EmformerEncoderLayer(nn.Module): left_context_val=left_context_val, padding_mask=padding_mask, ) - attn_cache = self._update_attn_cache(next_key, next_val, memory, attn_cache) + attn_cache = self._update_attn_cache( + next_key, next_val, memory, attn_cache + ) return output_right_context_utterance, attn_cache def forward( @@ -1192,7 +1226,9 @@ class EmformerEncoder(nn.Module): def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: """Hard copy each chunk's right context and concat them.""" T = x.shape[0] - num_chunks = math.ceil((T - self.right_context_length) / self.chunk_length) + num_chunks = math.ceil( + (T - self.right_context_length) / self.chunk_length + ) # first (num_chunks - 1) right context block intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length @@ -1211,7 +1247,9 @@ class EmformerEncoder(nn.Module): right_context_blocks = x[indexes.reshape(-1)] return right_context_blocks - def _gen_attention_mask_col_widths(self, chunk_idx: int, U: int) -> List[int]: + def _gen_attention_mask_col_widths( + self, chunk_idx: int, U: int + ) -> List[int]: """Calculate column widths (key, value) in attention mask for the chunk_idx chunk.""" num_chunks = math.ceil(U / self.chunk_length) @@ -1511,8 +1549,12 @@ class EmformerEncoder(nn.Module): attn_caches = [ [ torch.zeros(self.memory_size, self.d_model, device=device), - torch.zeros(self.left_context_length, self.d_model, device=device), - torch.zeros(self.left_context_length, self.d_model, device=device), + torch.zeros( + self.left_context_length, self.d_model, device=device + ), + torch.zeros( + self.left_context_length, self.d_model, device=device + ), ] for _ in range(self.num_encoder_layers) ] @@ -1557,11 +1599,17 @@ class Emformer(EncoderInterface): raise NotImplementedError( "chunk_length must be a mutiple of subsampling_factor." ) - if left_context_length != 0 and left_context_length % subsampling_factor != 0: + if ( + left_context_length != 0 + and left_context_length % subsampling_factor != 0 + ): raise NotImplementedError( "left_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) - if right_context_length != 0 and right_context_length % subsampling_factor != 0: + if ( + right_context_length != 0 + and right_context_length % subsampling_factor != 0 + ): raise NotImplementedError( "right_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) @@ -1624,7 +1672,9 @@ class Emformer(EncoderInterface): x_lens = (((x_lens - 1) >> 1) - 1) >> 1 assert x.size(0) == x_lens.max().item() - output, output_lengths = self.encoder(x, x_lens, warmup=warmup) # (T, N, C) + output, output_lengths = self.encoder( + x, x_lens, warmup=warmup + ) # (T, N, C) output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py index 4d05b367c..ab15e0241 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py @@ -103,11 +103,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -138,20 +136,19 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) add_model_arguments(parser) @@ -184,12 +181,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -212,12 +210,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -245,7 +244,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -280,7 +279,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py index 0486ac2eb..71150392d 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py @@ -113,9 +113,8 @@ def get_parser(): "--epoch", type=int, default=28, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( @@ -132,24 +131,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -216,7 +211,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -375,7 +371,9 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner(current_encoder_out, decoder_out, project_input=False) + logits = model.joiner( + current_encoder_out, decoder_out, project_input=False + ) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -392,7 +390,9 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -551,10 +551,14 @@ def decode_one_chunk( feature_list, batch_first=True, padding_value=LOG_EPSILON ).to(device) feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor(num_processed_frames_list, device=device) + num_processed_frames = torch.tensor( + num_processed_frames_list, device=device + ) # Make sure it has at least 1 frame after subsampling, first-and-last-frame cutting, and right context cutting # noqa - tail_length = 3 * params.subsampling_factor + params.right_context_length + 3 + tail_length = ( + 3 * params.subsampling_factor + params.right_context_length + 3 + ) if features.size(1) < tail_length: pad_length = tail_length - features.size(1) feature_lens += pad_length @@ -601,7 +605,9 @@ def decode_one_chunk( max_states=params.max_states, ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) # Update cached states of each stream state_list = unstack_states(states) @@ -776,7 +782,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -824,7 +831,9 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -858,12 +867,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -886,12 +896,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -919,7 +930,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py index 2c2593b56..2bbc45d78 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py @@ -95,7 +95,9 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def add_model_arguments(parser: argparse.ArgumentParser): @@ -263,45 +265,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -637,7 +636,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -665,16 +668,23 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -861,7 +871,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -969,7 +981,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 2 ** 22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/local/add_alignment_librispeech.py b/egs/librispeech/ASR/local/add_alignment_librispeech.py index cc34a72d8..fe6a26c51 100755 --- a/egs/librispeech/ASR/local/add_alignment_librispeech.py +++ b/egs/librispeech/ASR/local/add_alignment_librispeech.py @@ -157,7 +157,9 @@ def add_alignment( for ali_path in part_ali_dir.rglob("*.alignment.txt"): ali = parse_alignments(ali_path) alignments.update(ali) - logging.info(f"{part} has {len(alignments.keys())} cuts with alignments.") + logging.info( + f"{part} has {len(alignments.keys())} cuts with alignments." + ) # add alignment attribute and write out cuts_in = load_manifest_lazy(cuts_in_path) @@ -168,14 +170,18 @@ def add_alignment( if origin_id in alignments: ali = alignments[origin_id] else: - logging.info(f"Warning: {origin_id} does not have alignment.") + logging.info( + f"Warning: {origin_id} does not have alignment." + ) ali = [] subcut.alignment = {"word": ali} writer.write(cut, flush=True) def main(): - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) parser = get_parser() diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index df6c609bb..c628dfd53 100755 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -57,7 +57,7 @@ def get_args(): return parser.parse_args() -def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa: +def compile_HLG(lang_dir: str, lm: str="G_3_gram") -> k2.Fsa: """ Args: lang_dir: @@ -159,7 +159,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/compile_lg.py b/egs/librispeech/ASR/local/compile_lg.py index 19bf3bff4..45c4b7f5f 100755 --- a/egs/librispeech/ASR/local/compile_lg.py +++ b/egs/librispeech/ASR/local/compile_lg.py @@ -132,7 +132,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py index 97750f3ea..c0c7ef8c5 100644 --- a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py +++ b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py @@ -80,7 +80,9 @@ def compute_fbank_gigaspeech_dev_test(): def main(): - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_gigaspeech_dev_test() diff --git a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py index 37fce11f4..5587106e5 100644 --- a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py +++ b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py @@ -48,10 +48,8 @@ def get_parser(): "--batch-duration", type=float, default=600.0, - help=( - "The maximum number of audio seconds in a batch." - "Determines batch size dynamically." - ), + help="The maximum number of audio seconds in a batch." + "Determines batch size dynamically.", ) parser.add_argument( @@ -146,7 +144,9 @@ def main(): date_time = now.strftime("%Y-%m-%d-%H-%M-%S") log_filename = "log-compute_fbank_gigaspeech_splits" - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) log_filename = f"{log_filename}-{date_time}" logging.basicConfig( diff --git a/egs/librispeech/ASR/local/compute_fbank_librispeech.py b/egs/librispeech/ASR/local/compute_fbank_librispeech.py index 9f8503814..ce7d087f0 100755 --- a/egs/librispeech/ASR/local/compute_fbank_librispeech.py +++ b/egs/librispeech/ASR/local/compute_fbank_librispeech.py @@ -112,7 +112,9 @@ def compute_fbank_librispeech(bpe_model: Optional[str] = None): if "train" in partition: cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + cut_set + + cut_set.perturb_speed(0.9) + + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -126,7 +128,9 @@ def compute_fbank_librispeech(bpe_model: Optional[str] = None): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() diff --git a/egs/librispeech/ASR/local/compute_fbank_musan.py b/egs/librispeech/ASR/local/compute_fbank_musan.py index 4a4093ae4..056da29e5 100755 --- a/egs/librispeech/ASR/local/compute_fbank_musan.py +++ b/egs/librispeech/ASR/local/compute_fbank_musan.py @@ -83,7 +83,9 @@ def compute_fbank_musan(): # create chunks of Musan with duration 5 - 10 seconds musan_cuts = ( CutSet.from_manifests( - recordings=combine(part["recordings"] for part in manifests.values()) + recordings=combine( + part["recordings"] for part in manifests.values() + ) ) .cut_into_windows(10.0) .filter(lambda c: c.duration > 5) @@ -99,7 +101,9 @@ def compute_fbank_musan(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_musan() diff --git a/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py b/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py index f149b7871..133499c8b 100755 --- a/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py +++ b/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py @@ -46,19 +46,21 @@ def get_args(): parser.add_argument( "--transcript", type=str, - help=( - "The input transcript file." - "We assume that the transcript file consists of " - "lines. Each line consists of space separated words." - ), + help="The input transcript file." + "We assume that the transcript file consists of " + "lines. Each line consists of space separated words.", ) parser.add_argument("--lexicon", type=str, help="The input lexicon file.") - parser.add_argument("--oov", type=str, default="", help="The OOV word.") + parser.add_argument( + "--oov", type=str, default="", help="The OOV word." + ) return parser.parse_args() -def process_line(lexicon: Dict[str, List[str]], line: str, oov_token: str) -> None: +def process_line( + lexicon: Dict[str, List[str]], line: str, oov_token: str +) -> None: """ Args: lexicon: diff --git a/egs/librispeech/ASR/local/download_lm.py b/egs/librispeech/ASR/local/download_lm.py index 3518db524..030122aa7 100755 --- a/egs/librispeech/ASR/local/download_lm.py +++ b/egs/librispeech/ASR/local/download_lm.py @@ -87,7 +87,9 @@ def main(out_dir: str): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/filter_cuts.py b/egs/librispeech/ASR/local/filter_cuts.py index fbcc9e24a..dff98a954 100644 --- a/egs/librispeech/ASR/local/filter_cuts.py +++ b/egs/librispeech/ASR/local/filter_cuts.py @@ -79,7 +79,8 @@ def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor): total += 1 if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) removed += 1 return False @@ -124,7 +125,8 @@ def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor): ans = cut_set.filter(remove_short_and_long_utterances).to_eager() ratio = removed / total * 100 logging.info( - f"Removed {removed} cuts from {total} cuts. {ratio:.3f}% data is removed." + f"Removed {removed} cuts from {total} cuts. " + f"{ratio:.3f}% data is removed." ) return ans @@ -153,7 +155,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/generate_unique_lexicon.py b/egs/librispeech/ASR/local/generate_unique_lexicon.py index 3459c2f5a..566c0743d 100755 --- a/egs/librispeech/ASR/local/generate_unique_lexicon.py +++ b/egs/librispeech/ASR/local/generate_unique_lexicon.py @@ -91,7 +91,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/prepare_lang_bpe.py b/egs/librispeech/ASR/local/prepare_lang_bpe.py index e121aefa9..dec8a7442 100755 --- a/egs/librispeech/ASR/local/prepare_lang_bpe.py +++ b/egs/librispeech/ASR/local/prepare_lang_bpe.py @@ -150,7 +150,9 @@ def generate_lexicon( words_pieces_ids: List[List[int]] = sp.encode(words, out_type=int) # Now convert word piece IDs back to word piece strings. - words_pieces: List[List[str]] = [sp.id_to_piece(ids) for ids in words_pieces_ids] + words_pieces: List[List[str]] = [ + sp.id_to_piece(ids) for ids in words_pieces_ids + ] lexicon = [] for word, pieces in zip(words, words_pieces): diff --git a/egs/librispeech/ASR/local/prepare_lm_training_data.py b/egs/librispeech/ASR/local/prepare_lm_training_data.py index 70343fef7..5070341f1 100755 --- a/egs/librispeech/ASR/local/prepare_lm_training_data.py +++ b/egs/librispeech/ASR/local/prepare_lm_training_data.py @@ -137,7 +137,8 @@ def main(): for i in range(num_sentences): if step and i % step == 0: logging.info( - f"Processed number of lines: {i} ({i/num_sentences*100: .3f}%)" + f"Processed number of lines: {i} " + f"({i/num_sentences*100: .3f}%)" ) word_ids = sentences[i] @@ -153,14 +154,18 @@ def main(): sentence_lengths[i] = token_ids.numel() - output["sentence_lengths"] = torch.tensor(sentence_lengths, dtype=torch.int32) + output["sentence_lengths"] = torch.tensor( + sentence_lengths, dtype=torch.int32 + ) torch.save(output, args.lm_archive) logging.info(f"Saved to {args.lm_archive}") if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/preprocess_gigaspeech.py b/egs/librispeech/ASR/local/preprocess_gigaspeech.py index 8aa5e461d..077f23039 100644 --- a/egs/librispeech/ASR/local/preprocess_gigaspeech.py +++ b/egs/librispeech/ASR/local/preprocess_gigaspeech.py @@ -119,7 +119,9 @@ def preprocess_giga_speech(): def main(): - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) preprocess_giga_speech() diff --git a/egs/librispeech/ASR/local/test_prepare_lang.py b/egs/librispeech/ASR/local/test_prepare_lang.py index 74e025ad7..d4cf62bba 100755 --- a/egs/librispeech/ASR/local/test_prepare_lang.py +++ b/egs/librispeech/ASR/local/test_prepare_lang.py @@ -88,7 +88,9 @@ def test_read_lexicon(filename: str): fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.draw("L.pdf", title="L") - fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) + fsa_disambig = lexicon_to_fst( + lexicon_disambig, phone2id=phone2id, word2id=word2id + ) fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.draw("L_disambig.pdf", title="L_disambig") diff --git a/egs/librispeech/ASR/local/validate_manifest.py b/egs/librispeech/ASR/local/validate_manifest.py index 807aaf891..7c57d629a 100755 --- a/egs/librispeech/ASR/local/validate_manifest.py +++ b/egs/librispeech/ASR/local/validate_manifest.py @@ -64,7 +64,8 @@ def validate_supervision_and_cut_time_bounds(c: Cut): if s.end > c.end: raise ValueError( - f"{c.id}: Supervision end time {s.end} is larger than cut end time {c.end}" + f"{c.id}: Supervision end time {s.end} is larger " + f"than cut end time {c.end}" ) @@ -84,7 +85,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py old mode 100644 new mode 100755 index e69de29bb..27414d717 --- a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py @@ -0,0 +1,818 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./lstm_transducer_stateless/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./lstm_transducer_stateless/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./lstm_transducer_stateless/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./lstm_transducer_stateless/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./lstm_transducer_stateless/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./lstm_transducer_stateless/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./lstm_transducer_stateless/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./lstm_transducer_stateless/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./lstm_transducer_stateless/decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./lstm_transducer_stateless/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./lstm_transducer_stateless/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./lstm_transducer_stateless/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./lstm_transducer_stateless/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="lstm_transducer_stateless/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + # tail padding here to alleviate the tail deletion problem + num_tail_padded_frames = 35 + feature = torch.nn.functional.pad( + feature, + (0, 0, 0, num_tail_padded_frames), + mode="constant", + value=LOG_EPS, + ) + feature_lens += num_tail_padded_frames + + encoder_out, encoder_out_lens, _ = model.encoder( + x=feature, x_lens=feature_lens + ) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/export.py b/egs/librispeech/ASR/lstm_transducer_stateless/export.py old mode 100644 new mode 100755 index e69de29bb..13dac6009 --- a/egs/librispeech/ASR/lstm_transducer_stateless/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/export.py @@ -0,0 +1,388 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +(1) Export to torchscript model using torch.jit.trace() + +./lstm_transducer_stateless/export.py \ + --exp-dir ./lstm_transducer_stateless/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 35 \ + --avg 10 \ + --jit-trace 1 + +It will generate 3 files: `encoder_jit_trace.pt`, +`decoder_jit_trace.pt`, and `joiner_jit_trace.pt`. + +(2) Export `model.state_dict()` + +./lstm_transducer_stateless/export.py \ + --exp-dir ./lstm_transducer_stateless/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 35 \ + --avg 10 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +To use the generated file with `lstm_transducer_stateless/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./lstm_transducer_stateless/decode.py \ + --exp-dir ./lstm_transducer_stateless/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +https://huggingface.co/Zengwei/icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18 + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18 + # You will find the pre-trained model in icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18/exp +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +import torch.nn as nn +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless3/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--jit-trace", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.trace. + It will generate 3 files: + - encoder_jit_trace.pt + - decoder_jit_trace.pt + - joiner_jit_trace.pt + + Check ./jit_pretrained.py for how to use them. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def export_encoder_model_jit_trace( + encoder_model: nn.Module, + encoder_filename: str, +) -> None: + """Export the given encoder model with torch.jit.trace() + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported model. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + states = encoder_model.get_init_states() + + traced_model = torch.jit.trace(encoder_model, (x, x_lens, states)) + traced_model.save(encoder_filename) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_jit_trace( + decoder_model: nn.Module, + decoder_filename: str, +) -> None: + """Export the given decoder model with torch.jit.trace() + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The input decoder model + decoder_filename: + The filename to save the exported model. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_model = torch.jit.trace(decoder_model, (y, need_pad)) + traced_model.save(decoder_filename) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_jit_trace( + joiner_model: nn.Module, + joiner_filename: str, +) -> None: + """Export the given joiner model with torch.jit.trace() + + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + + Args: + joiner_model: + The input joiner model + joiner_filename: + The filename to save the exported model. + + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) + traced_model.save(joiner_filename) + logging.info(f"Saved to {joiner_filename}") + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + if params.jit_trace is True: + convert_scaled_to_non_scaled(model, inplace=True) + logging.info("Using torch.jit.trace()") + encoder_filename = params.exp_dir / "encoder_jit_trace.pt" + export_encoder_model_jit_trace(model.encoder, encoder_filename) + + decoder_filename = params.exp_dir / "decoder_jit_trace.pt" + export_decoder_model_jit_trace(model.decoder, decoder_filename) + + joiner_filename = params.exp_dir / "joiner_jit_trace.pt" + export_joiner_model_jit_trace(model.joiner, joiner_filename) + else: + logging.info("Not using torchscript") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py old mode 100644 new mode 100755 index e69de29bb..594c33e4f --- a/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py @@ -0,0 +1,322 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, either exported by `torch.jit.trace()` +or by `torch.jit.script()`, and uses them to decode waves. +You can use the following command to get the exported models: + +./lstm_transducer_stateless/export.py \ + --exp-dir ./lstm_transducer_stateless/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --jit-trace 1 + +Usage of this script: + +./lstm_transducer_stateless/jit_pretrained.py \ + --encoder-model-filename ./lstm_transducer_stateless/exp/encoder_jit_trace.pt \ + --decoder-model-filename ./lstm_transducer_stateless/exp/decoder_jit_trace.pt \ + --joiner-model-filename ./lstm_transducer_stateless/exp/joiner_jit_trace.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder torchscript model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder torchscript model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner torchscript model. ", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="Context size of the decoder model", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + decoder: torch.jit.ScriptModule, + joiner: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + context_size: int, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + decoder: + The decoder model. + joiner: + The joiner model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + context_size: + The context size of the decoder model. + Returns: + Return the decoded results for each utterance. + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + device = encoder_out.device + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = decoder( + decoder_input, + need_pad=torch.tensor([False]), + ).squeeze(1) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + current_encoder_out = current_encoder_out + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = joiner( + current_encoder_out, + decoder_out, + ) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = decoder( + decoder_input, + need_pad=torch.tensor([False]), + ) + decoder_out = decoder_out.squeeze(1) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + encoder = torch.jit.load(args.encoder_model_filename) + decoder = torch.jit.load(args.decoder_model_filename) + joiner = torch.jit.load(args.joiner_model_filename) + + encoder.eval() + decoder.eval() + joiner.eval() + + encoder.to(device) + decoder.to(device) + joiner.to(device) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + states = encoder.get_init_states(batch_size=features.size(0), device=device) + + encoder_out, encoder_out_lens, _ = encoder( + x=features, + x_lens=feature_lengths, + states=states, + ) + + hyps = greedy_search( + decoder=decoder, + joiner=joiner, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + context_size=args.context_size, + ) + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = sp.decode(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py index e69de29bb..c54a4c478 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py @@ -0,0 +1,871 @@ +# Copyright 2022 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import math +from typing import List, Optional, Tuple + +import torch +from encoder_interface import EncoderInterface +from scaling import ( + ActivationBalancer, + BasicNorm, + DoubleSwish, + ScaledConv2d, + ScaledLinear, + ScaledLSTM, +) +from torch import nn + +LOG_EPSILON = math.log(1e-10) + + +def unstack_states( + states: Tuple[torch.Tensor, torch.Tensor] +) -> List[Tuple[torch.Tensor, torch.Tensor]]: + """ + Unstack the lstm states corresponding to a batch of utterances into a list + of states, where the i-th entry is the state from the i-th utterance. + + Args: + states: + A tuple of 2 elements. + ``states[0]`` is the lstm hidden states, of a batch of utterance. + ``states[1]`` is the lstm cell states, of a batch of utterances. + + Returns: + A list of states. + ``states[i]`` is a tuple of 2 elememts of i-th utterance. + ``states[i][0]`` is the lstm hidden states of i-th utterance. + ``states[i][1]`` is the lstm cell states of i-th utterance. + """ + hidden_states, cell_states = states + + list_hidden_states = hidden_states.unbind(dim=1) + list_cell_states = cell_states.unbind(dim=1) + + ans = [ + (h.unsqueeze(1), c.unsqueeze(1)) + for (h, c) in zip(list_hidden_states, list_cell_states) + ] + return ans + + +def stack_states( + states_list: List[Tuple[torch.Tensor, torch.Tensor]] +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Stack list of lstm states corresponding to separate utterances into a single + lstm state so that it can be used as an input for lstm when those utterances + are formed into a batch. + + Args: + state_list: + Each element in state_list corresponds to the lstm state for a single + utterance. + ``states[i]`` is a tuple of 2 elememts of i-th utterance. + ``states[i][0]`` is the lstm hidden states of i-th utterance. + ``states[i][1]`` is the lstm cell states of i-th utterance. + + + Returns: + A new state corresponding to a batch of utterances. + It is a tuple of 2 elements. + ``states[0]`` is the lstm hidden states, of a batch of utterance. + ``states[1]`` is the lstm cell states, of a batch of utterances. + """ + hidden_states = torch.cat([s[0] for s in states_list], dim=1) + cell_states = torch.cat([s[1] for s in states_list], dim=1) + ans = (hidden_states, cell_states) + return ans + + +class RNN(EncoderInterface): + """ + Args: + num_features (int): + Number of input features. + subsampling_factor (int): + Subsampling factor of encoder (convolution layers before lstm layers) (default=4). # noqa + d_model (int): + Output dimension (default=512). + dim_feedforward (int): + Feedforward dimension (default=2048). + rnn_hidden_size (int): + Hidden dimension for lstm layers (default=1024). + num_encoder_layers (int): + Number of encoder layers (default=12). + dropout (float): + Dropout rate (default=0.1). + layer_dropout (float): + Dropout value for model-level warmup (default=0.075). + aux_layer_period (int): + Period of auxiliary layers used for random combiner during training. + If set to 0, will not use the random combiner (Default). + You can set a positive integer to use the random combiner, e.g., 3. + is_pnnx: + True to make this class exportable via PNNX. + """ + + def __init__( + self, + num_features: int, + subsampling_factor: int = 4, + d_model: int = 512, + dim_feedforward: int = 2048, + rnn_hidden_size: int = 1024, + num_encoder_layers: int = 12, + dropout: float = 0.1, + layer_dropout: float = 0.075, + aux_layer_period: int = 0, + is_pnnx: bool = False, + ) -> None: + super(RNN, self).__init__() + + self.num_features = num_features + self.subsampling_factor = subsampling_factor + if subsampling_factor != 4: + raise NotImplementedError("Support only 'subsampling_factor=4'.") + + # self.encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, T//subsampling_factor, d_model). + # That is, it does two things simultaneously: + # (1) subsampling: T -> T//subsampling_factor + # (2) embedding: num_features -> d_model + self.encoder_embed = Conv2dSubsampling( + num_features, + d_model, + is_pnnx=is_pnnx, + ) + + self.is_pnnx = is_pnnx + + self.num_encoder_layers = num_encoder_layers + self.d_model = d_model + self.rnn_hidden_size = rnn_hidden_size + + encoder_layer = RNNEncoderLayer( + d_model=d_model, + dim_feedforward=dim_feedforward, + rnn_hidden_size=rnn_hidden_size, + dropout=dropout, + layer_dropout=layer_dropout, + ) + self.encoder = RNNEncoder( + encoder_layer, + num_encoder_layers, + aux_layers=list( + range( + num_encoder_layers // 3, + num_encoder_layers - 1, + aux_layer_period, + ) + ) + if aux_layer_period > 0 + else None, + ) + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Args: + x: + The input tensor. Its shape is (N, T, C), where N is the batch size, + T is the sequence length, C is the feature dimension. + x_lens: + A tensor of shape (N,), containing the number of frames in `x` + before padding. + states: + A tuple of 2 tensors (optional). It is for streaming inference. + states[0] is the hidden states of all layers, + with shape of (num_layers, N, d_model); + states[1] is the cell states of all layers, + with shape of (num_layers, N, rnn_hidden_size). + warmup: + A floating point value that gradually increases from 0 throughout + training; when it is >= 1.0 we are "fully warmed up". It is used + to turn modules on sequentially. + + Returns: + A tuple of 3 tensors: + - embeddings: its shape is (N, T', d_model), where T' is the output + sequence lengths. + - lengths: a tensor of shape (batch_size,) containing the number of + frames in `embeddings` before padding. + - updated states, whose shape is the same as the input states. + """ + x = self.encoder_embed(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + # lengths = ((x_lens - 3) // 2 - 1) // 2 # issue an warning + # + # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 + if not self.is_pnnx: + lengths = (((x_lens - 3) >> 1) - 1) >> 1 + else: + lengths1 = torch.floor((x_lens - 3) / 2) + lengths = torch.floor((lengths1 - 1) / 2) + lengths = lengths.to(x_lens) + + if not torch.jit.is_tracing(): + assert x.size(0) == lengths.max().item() + + if states is None: + x = self.encoder(x, warmup=warmup)[0] + # torch.jit.trace requires returned types to be the same as annotated # noqa + new_states = (torch.empty(0), torch.empty(0)) + else: + assert not self.training + assert len(states) == 2 + if not torch.jit.is_tracing(): + # for hidden state + assert states[0].shape == ( + self.num_encoder_layers, + x.size(1), + self.d_model, + ) + # for cell state + assert states[1].shape == ( + self.num_encoder_layers, + x.size(1), + self.rnn_hidden_size, + ) + x, new_states = self.encoder(x, states) + + x = x.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + return x, lengths, new_states + + @torch.jit.export + def get_init_states( + self, batch_size: int = 1, device: torch.device = torch.device("cpu") + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Get model initial states.""" + # for rnn hidden states + hidden_states = torch.zeros( + (self.num_encoder_layers, batch_size, self.d_model), device=device + ) + cell_states = torch.zeros( + (self.num_encoder_layers, batch_size, self.rnn_hidden_size), + device=device, + ) + return (hidden_states, cell_states) + + +class RNNEncoderLayer(nn.Module): + """ + RNNEncoderLayer is made up of lstm and feedforward networks. + + Args: + d_model: + The number of expected features in the input (required). + dim_feedforward: + The dimension of feedforward network model (default=2048). + rnn_hidden_size: + The hidden dimension of rnn layer. + dropout: + The dropout value (default=0.1). + layer_dropout: + The dropout value for model-level warmup (default=0.075). + """ + + def __init__( + self, + d_model: int, + dim_feedforward: int, + rnn_hidden_size: int, + dropout: float = 0.1, + layer_dropout: float = 0.075, + ) -> None: + super(RNNEncoderLayer, self).__init__() + self.layer_dropout = layer_dropout + self.d_model = d_model + self.rnn_hidden_size = rnn_hidden_size + + assert rnn_hidden_size >= d_model, (rnn_hidden_size, d_model) + self.lstm = ScaledLSTM( + input_size=d_model, + hidden_size=rnn_hidden_size, + proj_size=d_model if rnn_hidden_size > d_model else 0, + num_layers=1, + dropout=0.0, + ) + self.feed_forward = nn.Sequential( + ScaledLinear(d_model, dim_feedforward), + ActivationBalancer(channel_dim=-1), + DoubleSwish(), + nn.Dropout(dropout), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ) + self.norm_final = BasicNorm(d_model) + + # try to ensure the output is close to zero-mean (or at least, zero-median). # noqa + self.balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 + ) + self.dropout = nn.Dropout(dropout) + + def forward( + self, + src: torch.Tensor, + states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Pass the input through the encoder layer. + + Args: + src: + The sequence to the encoder layer (required). + Its shape is (S, N, E), where S is the sequence length, + N is the batch size, and E is the feature number. + states: + A tuple of 2 tensors (optional). It is for streaming inference. + states[0] is the hidden states of all layers, + with shape of (1, N, d_model); + states[1] is the cell states of all layers, + with shape of (1, N, rnn_hidden_size). + warmup: + It controls selective bypass of of layers; if < 1.0, we will + bypass layers more frequently. + """ + src_orig = src + + warmup_scale = min(0.1 + warmup, 1.0) + # alpha = 1.0 means fully use this encoder layer, 0.0 would mean + # completely bypass it. + if self.training: + alpha = ( + warmup_scale + if torch.rand(()).item() <= (1.0 - self.layer_dropout) + else 0.1 + ) + else: + alpha = 1.0 + + # lstm module + if states is None: + src_lstm = self.lstm(src)[0] + # torch.jit.trace requires returned types be the same as annotated + new_states = (torch.empty(0), torch.empty(0)) + else: + assert not self.training + assert len(states) == 2 + if not torch.jit.is_tracing(): + # for hidden state + assert states[0].shape == (1, src.size(1), self.d_model) + # for cell state + assert states[1].shape == (1, src.size(1), self.rnn_hidden_size) + src_lstm, new_states = self.lstm(src, states) + src = self.dropout(src_lstm) + src + + # feed forward module + src = src + self.dropout(self.feed_forward(src)) + + src = self.norm_final(self.balancer(src)) + + if alpha != 1.0: + src = alpha * src + (1 - alpha) * src_orig + + return src, new_states + + +class RNNEncoder(nn.Module): + """ + RNNEncoder is a stack of N encoder layers. + + Args: + encoder_layer: + An instance of the RNNEncoderLayer() class (required). + num_layers: + The number of sub-encoder-layers in the encoder (required). + """ + + def __init__( + self, + encoder_layer: nn.Module, + num_layers: int, + aux_layers: Optional[List[int]] = None, + ) -> None: + super(RNNEncoder, self).__init__() + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + self.d_model = encoder_layer.d_model + self.rnn_hidden_size = encoder_layer.rnn_hidden_size + + self.aux_layers: List[int] = [] + self.combiner: Optional[nn.Module] = None + if aux_layers is not None: + assert len(set(aux_layers)) == len(aux_layers) + assert num_layers - 1 not in aux_layers + self.aux_layers = aux_layers + [num_layers - 1] + self.combiner = RandomCombine( + num_inputs=len(self.aux_layers), + final_weight=0.5, + pure_prob=0.333, + stddev=2.0, + ) + + def forward( + self, + src: torch.Tensor, + states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Pass the input through the encoder layer in turn. + + Args: + src: + The sequence to the encoder layer (required). + Its shape is (S, N, E), where S is the sequence length, + N is the batch size, and E is the feature number. + states: + A tuple of 2 tensors (optional). It is for streaming inference. + states[0] is the hidden states of all layers, + with shape of (num_layers, N, d_model); + states[1] is the cell states of all layers, + with shape of (num_layers, N, rnn_hidden_size). + warmup: + It controls selective bypass of of layers; if < 1.0, we will + bypass layers more frequently. + """ + if states is not None: + assert not self.training + assert len(states) == 2 + if not torch.jit.is_tracing(): + # for hidden state + assert states[0].shape == ( + self.num_layers, + src.size(1), + self.d_model, + ) + # for cell state + assert states[1].shape == ( + self.num_layers, + src.size(1), + self.rnn_hidden_size, + ) + + output = src + + outputs = [] + + new_hidden_states = [] + new_cell_states = [] + + for i, mod in enumerate(self.layers): + if states is None: + output = mod(output, warmup=warmup)[0] + else: + layer_state = ( + states[0][i : i + 1, :, :], # h: (1, N, d_model) + states[1][i : i + 1, :, :], # c: (1, N, rnn_hidden_size) + ) + output, (h, c) = mod(output, layer_state) + new_hidden_states.append(h) + new_cell_states.append(c) + + if self.combiner is not None and i in self.aux_layers: + outputs.append(output) + + if self.combiner is not None: + output = self.combiner(outputs) + + if states is None: + new_states = (torch.empty(0), torch.empty(0)) + else: + new_states = ( + torch.cat(new_hidden_states, dim=0), + torch.cat(new_cell_states, dim=0), + ) + + return output, new_states + + +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = ((T-3)//2-1)//2, which approximates T' == T//4 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + layer1_channels: int = 8, + layer2_channels: int = 32, + layer3_channels: int = 128, + is_pnnx: bool = False, + ) -> None: + """ + Args: + in_channels: + Number of channels in. The input shape is (N, T, in_channels). + Caution: It requires: T >= 9, in_channels >= 9. + out_channels + Output dim. The output shape is (N, ((T-3)//2-1)//2, out_channels) + layer1_channels: + Number of channels in layer1 + layer1_channels: + Number of channels in layer2 + is_pnnx: + True if we are converting the model to PNNX format. + False otherwise. + """ + assert in_channels >= 9 + super().__init__() + + self.conv = nn.Sequential( + ScaledConv2d( + in_channels=1, + out_channels=layer1_channels, + kernel_size=3, + padding=0, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ScaledConv2d( + in_channels=layer1_channels, + out_channels=layer2_channels, + kernel_size=3, + stride=2, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ScaledConv2d( + in_channels=layer2_channels, + out_channels=layer3_channels, + kernel_size=3, + stride=2, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ) + self.out = ScaledLinear( + layer3_channels * (((in_channels - 3) // 2 - 1) // 2), out_channels + ) + # set learn_eps=False because out_norm is preceded by `out`, and `out` + # itself has learned scale, so the extra degree of freedom is not + # needed. + self.out_norm = BasicNorm(out_channels, learn_eps=False) + # constrain median of output to be close to zero. + self.out_balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55 + ) + + # ncnn supports only batch size == 1 + self.is_pnnx = is_pnnx + self.conv_out_dim = self.out.weight.shape[1] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + + Returns: + Return a tensor of shape (N, ((T-3)//2-1)//2, odim) + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + x = self.conv(x) + + if torch.jit.is_tracing() and self.is_pnnx: + x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim) + x = self.out(x) + else: + # Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + + # Now x is of shape (N, ((T-3)//2-1))//2, odim) + x = self.out_norm(x) + x = self.out_balancer(x) + return x + + +class RandomCombine(nn.Module): + """ + This module combines a list of Tensors, all with the same shape, to + produce a single output of that same shape which, in training time, + is a random combination of all the inputs; but which in test time + will be just the last input. + + The idea is that the list of Tensors will be a list of outputs of multiple + conformer layers. This has a similar effect as iterated loss. (See: + DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER + NETWORKS). + """ + + def __init__( + self, + num_inputs: int, + final_weight: float = 0.5, + pure_prob: float = 0.5, + stddev: float = 2.0, + ) -> None: + """ + Args: + num_inputs: + The number of tensor inputs, which equals the number of layers' + outputs that are fed into this module. E.g. in an 18-layer neural + net if we output layers 16, 12, 18, num_inputs would be 3. + final_weight: + The amount of weight or probability we assign to the + final layer when randomly choosing layers or when choosing + continuous layer weights. + pure_prob: + The probability, on each frame, with which we choose + only a single layer to output (rather than an interpolation) + stddev: + A standard deviation that we add to log-probs for computing + randomized weights. + + The method of choosing which layers, or combinations of layers, to use, + is conceptually as follows:: + + With probability `pure_prob`:: + With probability `final_weight`: choose final layer, + Else: choose random non-final layer. + Else:: + Choose initial log-weights that correspond to assigning + weight `final_weight` to the final layer and equal + weights to other layers; then add Gaussian noise + with variance `stddev` to these log-weights, and normalize + to weights (note: the average weight assigned to the + final layer here will not be `final_weight` if stddev>0). + """ + super().__init__() + assert 0 <= pure_prob <= 1, pure_prob + assert 0 < final_weight < 1, final_weight + assert num_inputs >= 1 + + self.num_inputs = num_inputs + self.final_weight = final_weight + self.pure_prob = pure_prob + self.stddev = stddev + + self.final_log_weight = ( + torch.tensor( + (final_weight / (1 - final_weight)) * (self.num_inputs - 1) + ) + .log() + .item() + ) + + def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: + """Forward function. + Args: + inputs: + A list of Tensor, e.g. from various layers of a transformer. + All must be the same shape, of (*, num_channels) + Returns: + A Tensor of shape (*, num_channels). In test mode + this is just the final input. + """ + num_inputs = self.num_inputs + assert len(inputs) == num_inputs + if not self.training or torch.jit.is_scripting(): + return inputs[-1] + + # Shape of weights: (*, num_inputs) + num_channels = inputs[0].shape[-1] + num_frames = inputs[0].numel() // num_channels + + ndim = inputs[0].ndim + # stacked_inputs: (num_frames, num_channels, num_inputs) + stacked_inputs = torch.stack(inputs, dim=ndim).reshape( + (num_frames, num_channels, num_inputs) + ) + + # weights: (num_frames, num_inputs) + weights = self._get_random_weights( + inputs[0].dtype, inputs[0].device, num_frames + ) + + weights = weights.reshape(num_frames, num_inputs, 1) + # ans: (num_frames, num_channels, 1) + ans = torch.matmul(stacked_inputs, weights) + # ans: (*, num_channels) + + ans = ans.reshape(inputs[0].shape[:-1] + (num_channels,)) + + # The following if causes errors for torch script in torch 1.6.0 + # if __name__ == "__main__": + # # for testing only... + # print("Weights = ", weights.reshape(num_frames, num_inputs)) + return ans + + def _get_random_weights( + self, dtype: torch.dtype, device: torch.device, num_frames: int + ) -> torch.Tensor: + """Return a tensor of random weights, of shape + `(num_frames, self.num_inputs)`, + Args: + dtype: + The data-type desired for the answer, e.g. float, double. + device: + The device needed for the answer. + num_frames: + The number of sets of weights desired + Returns: + A tensor of shape (num_frames, self.num_inputs), such that + `ans.sum(dim=1)` is all ones. + """ + pure_prob = self.pure_prob + if pure_prob == 0.0: + return self._get_random_mixed_weights(dtype, device, num_frames) + elif pure_prob == 1.0: + return self._get_random_pure_weights(dtype, device, num_frames) + else: + p = self._get_random_pure_weights(dtype, device, num_frames) + m = self._get_random_mixed_weights(dtype, device, num_frames) + return torch.where( + torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m + ) + + def _get_random_pure_weights( + self, dtype: torch.dtype, device: torch.device, num_frames: int + ): + """Return a tensor of random one-hot weights, of shape + `(num_frames, self.num_inputs)`, + Args: + dtype: + The data-type desired for the answer, e.g. float, double. + device: + The device needed for the answer. + num_frames: + The number of sets of weights desired. + Returns: + A one-hot tensor of shape `(num_frames, self.num_inputs)`, with + exactly one weight equal to 1.0 on each frame. + """ + final_prob = self.final_weight + + # final contains self.num_inputs - 1 in all elements + final = torch.full((num_frames,), self.num_inputs - 1, device=device) + # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. # noqa + nonfinal = torch.randint( + self.num_inputs - 1, (num_frames,), device=device + ) + + indexes = torch.where( + torch.rand(num_frames, device=device) < final_prob, final, nonfinal + ) + ans = torch.nn.functional.one_hot( + indexes, num_classes=self.num_inputs + ).to(dtype=dtype) + return ans + + def _get_random_mixed_weights( + self, dtype: torch.dtype, device: torch.device, num_frames: int + ): + """Return a tensor of random one-hot weights, of shape + `(num_frames, self.num_inputs)`, + Args: + dtype: + The data-type desired for the answer, e.g. float, double. + device: + The device needed for the answer. + num_frames: + The number of sets of weights desired. + Returns: + A tensor of shape (num_frames, self.num_inputs), which elements + in [0..1] that sum to one over the second axis, i.e. + `ans.sum(dim=1)` is all ones. + """ + logprobs = ( + torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device) + * self.stddev # noqa + ) + logprobs[:, -1] += self.final_log_weight + return logprobs.softmax(dim=1) + + +def _test_random_combine(final_weight: float, pure_prob: float, stddev: float): + print( + f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}" # noqa + ) + num_inputs = 3 + num_channels = 50 + m = RandomCombine( + num_inputs=num_inputs, + final_weight=final_weight, + pure_prob=pure_prob, + stddev=stddev, + ) + + x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)] + + y = m(x) + assert y.shape == x[0].shape + assert torch.allclose(y, x[0]) # .. since actually all ones. + + +def _test_random_combine_main(): + _test_random_combine(0.999, 0, 0.0) + _test_random_combine(0.5, 0, 0.0) + _test_random_combine(0.999, 0, 0.0) + _test_random_combine(0.5, 0, 0.3) + _test_random_combine(0.5, 1, 0.3) + _test_random_combine(0.5, 0.5, 0.3) + + feature_dim = 50 + c = RNN(num_features=feature_dim, d_model=128) + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + f = c( + torch.randn(batch_size, seq_len, feature_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + f # to remove flake8 warnings + + +if __name__ == "__main__": + feature_dim = 80 + m = RNN( + num_features=feature_dim, + d_model=512, + rnn_hidden_size=1024, + dim_feedforward=2048, + num_encoder_layers=12, + ) + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + f = m( + torch.randn(batch_size, seq_len, feature_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + warmup=0.5, + ) + num_param = sum([p.numel() for p in m.parameters()]) + print(f"Number of model parameters: {num_param}") + + _test_random_combine_main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/model.py b/egs/librispeech/ASR/lstm_transducer_stateless/model.py index e69de29bb..d71132b4a 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/model.py @@ -0,0 +1,210 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Tuple + +import k2 +import torch +import torch.nn as nn +from encoder_interface import EncoderInterface +from scaling import ScaledLinear + +from icefall.utils import add_sos + + +class Transducer(nn.Module): + """It implements https://arxiv.org/pdf/1211.3711.pdf + "Sequence Transduction with Recurrent Neural Networks" + """ + + def __init__( + self, + encoder: EncoderInterface, + decoder: nn.Module, + joiner: nn.Module, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, + vocab_size: int, + ): + """ + Args: + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dm) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. + joiner: + It has two inputs with shapes: (N, T, encoder_dim) and + (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output + contains unnormalized probs, i.e., not processed by log-softmax. + """ + super().__init__() + assert isinstance(encoder, EncoderInterface), type(encoder) + assert hasattr(decoder, "blank_id") + + self.encoder = encoder + self.decoder = decoder + self.joiner = joiner + + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_speed=0.5 + ) + self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + warmup: float = 1.0, + reduction: str = "sum", + delay_penalty: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + warmup: + A value warmup >= 0 that determines which modules are active, values + warmup > 1 "are fully warmed up" and all modules will be active. + reduction: + "sum" to sum the losses over all utterances in the batch. + "none" to return the loss in a 1-D tensor for each utterance + in the batch. + delay_penalty: + A constant value used to penalize symbol delay, to encourage + streaming models to emit symbols earlier. + See https://github.com/k2-fsa/k2/issues/955 and + https://arxiv.org/pdf/2211.00490.pdf for more details. + Returns: + Return the transducer loss. + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert reduction in ("sum", "none"), reduction + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0 + + encoder_out, x_lens, _ = self.encoder(x, x_lens, warmup=warmup) + assert torch.all(x_lens > 0) + + # Now for the decoder, i.e., the prediction network + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + blank_id = self.decoder.blank_id + sos_y = add_sos(y, sos_id=blank_id) + + # sos_y_padded: [B, S + 1], start with SOS. + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + + # decoder_out: [B, S + 1, decoder_dim] + decoder_out = self.decoder(sos_y_padded) + + # Note: y does not start with SOS + # y_padded : [B, S] + y_padded = y.pad(mode="constant", padding_value=0) + + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) + boundary[:, 2] = y_lens + boundary[:, 3] = x_lens + + lm = self.simple_lm_proj(decoder_out) + am = self.simple_am_proj(encoder_out) + + with torch.cuda.amp.autocast(enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction=reduction, + delay_penalty=delay_penalty, + return_grad=True, + ) + + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, + ) + + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=self.joiner.encoder_proj(encoder_out), + lm=self.joiner.decoder_proj(decoder_out), + ranges=ranges, + ) + + # logits : [B, T, prune_range, vocab_size] + + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + logits = self.joiner(am_pruned, lm_pruned, project_input=False) + + with torch.cuda.amp.autocast(enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + delay_penalty=delay_penalty, + reduction=reduction, + ) + + return (simple_loss, pruned_loss) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py old mode 100644 new mode 100755 index e69de29bb..2a6e2adc6 --- a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py @@ -0,0 +1,352 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +(1) greedy search +./lstm_transducer_stateless/pretrained.py \ + --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) beam search +./lstm_transducer_stateless/pretrained.py \ + --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) modified beam search +./lstm_transducer_stateless/pretrained.py \ + --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(4) fast beam search +./lstm_transducer_stateless/pretrained.py \ + --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method fast_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +You can also use `./lstm_transducer_stateless/exp/epoch-xx.pt`. + +Note: ./lstm_transducer_stateless/exp/pretrained.pt is generated by +./lstm_transducer_stateless/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens, _ = model.encoder( + x=features, x_lens=feature_lengths + ) + + num_waves = encoder_out.size(0) + hyps = [] + msg = f"Using {params.method}" + if params.method == "beam_search": + msg += f" with beam size {params.beam_size}" + logging.info(msg) + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(sp.decode(hyp).split()) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/stream.py b/egs/librispeech/ASR/lstm_transducer_stateless/stream.py index e69de29bb..97d890c82 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/stream.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/stream.py @@ -0,0 +1,148 @@ +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple + +import k2 +import torch +from beam_search import Hypothesis, HypothesisList + +from icefall.utils import AttributeDict + + +class Stream(object): + def __init__( + self, + params: AttributeDict, + cut_id: str, + decoding_graph: Optional[k2.Fsa] = None, + device: torch.device = torch.device("cpu"), + LOG_EPS: float = math.log(1e-10), + ) -> None: + """ + Args: + params: + It's the return value of :func:`get_params`. + cut_id: + The cut id of the current stream. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + device: + The device to run this stream. + LOG_EPS: + A float value used for padding. + """ + self.LOG_EPS = LOG_EPS + self.cut_id = cut_id + + # Containing attention caches and convolution caches + self.states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + + # It uses different attributes for different decoding methods. + self.context_size = params.context_size + self.decoding_method = params.decoding_method + if params.decoding_method == "greedy_search": + self.hyp = [params.blank_id] * params.context_size + elif params.decoding_method == "modified_beam_search": + self.hyps = HypothesisList() + self.hyps.add( + Hypothesis( + ys=[params.blank_id] * params.context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + elif params.decoding_method == "fast_beam_search": + # feature_len is needed to get partial results. + # The rnnt_decoding_stream for fast_beam_search. + self.rnnt_decoding_stream: k2.RnntDecodingStream = ( + k2.RnntDecodingStream(decoding_graph) + ) + self.hyp: Optional[List[int]] = None + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + + self.ground_truth: str = "" + + self.feature: Optional[torch.Tensor] = None + # Make sure all feature frames can be used. + # We aim to obtain 1 frame after subsampling. + self.chunk_length = params.subsampling_factor + self.pad_length = 5 + self.num_frames = 0 + self.num_processed_frames = 0 + + # After all feature frames are processed, we set this flag to True + self._done = False + + def set_feature(self, feature: torch.Tensor) -> None: + assert feature.dim() == 2, feature.dim() + # tail padding here to alleviate the tail deletion problem + num_tail_padded_frames = 35 + self.num_frames = feature.size(0) + num_tail_padded_frames + self.feature = torch.nn.functional.pad( + feature, + (0, 0, 0, self.pad_length + num_tail_padded_frames), + mode="constant", + value=self.LOG_EPS, + ) + + def get_feature_chunk(self) -> torch.Tensor: + """Get a chunk of feature frames. + + Returns: + A tensor of shape (ret_length, feature_dim). + """ + update_length = min( + self.num_frames - self.num_processed_frames, self.chunk_length + ) + ret_length = update_length + self.pad_length + + ret_feature = self.feature[ + self.num_processed_frames : self.num_processed_frames + ret_length + ] + # Cut off used frames. + # self.feature = self.feature[update_length:] + + self.num_processed_frames += update_length + if self.num_processed_frames >= self.num_frames: + self._done = True + + return ret_feature + + @property + def id(self) -> str: + return self.cut_id + + @property + def done(self) -> bool: + """Return True if all feature frames are processed.""" + return self._done + + def decoding_result(self) -> List[int]: + """Obtain current decoding result.""" + if self.decoding_method == "greedy_search": + return self.hyp[self.context_size :] + elif self.decoding_method == "modified_beam_search": + best_hyp = self.hyps.get_most_probable(length_norm=True) + return best_hyp.ys[self.context_size :] + else: + assert self.decoding_method == "fast_beam_search" + return self.hyp diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py old mode 100644 new mode 100755 index e69de29bb..d6376bdc0 --- a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py @@ -0,0 +1,968 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./lstm_transducer_stateless/streaming_decode.py \ + --epoch 35 \ + --avg 10 \ + --exp-dir lstm_transducer_stateless/exp \ + --num-decode-streams 2000 \ + --num-encoder-layers 12 \ + --rnn-hidden-size 1024 \ + --decoding-method greedy_search \ + --use-averaged-model True + +(2) modified beam search +./lstm_transducer_stateless/streaming_decode.py \ + --epoch 35 \ + --avg 10 \ + --exp-dir lstm_transducer_stateless/exp \ + --num-decode-streams 2000 \ + --num-encoder-layers 12 \ + --rnn-hidden-size 1024 \ + --decoding-method modified_beam_search \ + --use-averaged-model True \ + --beam-size 4 + +(3) fast beam search +./lstm_transducer_stateless/streaming_decode.py \ + --epoch 35 \ + --avg 10 \ + --exp-dir lstm_transducer_stateless/exp \ + --num-decode-streams 2000 \ + --num-encoder-layers 12 \ + --rnn-hidden-size 1024 \ + --decoding-method fast_beam_search \ + --use-averaged-model True \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 +""" +import argparse +import logging +import warnings +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import Hypothesis, HypothesisList, get_hyps_shape +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from lstm import LOG_EPSILON, stack_states, unstack_states +from stream import Stream +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.decode import one_best_decoding +from icefall.utils import ( + AttributeDict, + get_texts, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=False, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="transducer_emformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--sampling-rate", + type=float, + default=16000, + help="Sample rate of the audio", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded in parallel", + ) + + add_model_arguments(parser) + + return parser + + +def greedy_search( + model: nn.Module, + encoder_out: torch.Tensor, + streams: List[Stream], +) -> None: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C), where N >= 1. + streams: + A list of Stream objects. + """ + assert len(streams) == encoder_out.size(0) + assert encoder_out.ndim == 3 + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = next(model.parameters()).device + T = encoder_out.size(1) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + decoder_input = torch.tensor( + [stream.hyp[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + # decoder_out is of shape (batch_size, 1, decoder_out_dim) + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + for t in range(T): + # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) + current_encoder_out = encoder_out[:, t : t + 1, :] # noqa + + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + # logits'shape (batch_size, vocab_size) + logits = logits.squeeze(1).squeeze(1) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + streams[i].hyp.append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = torch.tensor( + [stream.hyp[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=False, + ) + decoder_out = model.joiner.decoder_proj(decoder_out) + + +def modified_beam_search( + model: nn.Module, + encoder_out: torch.Tensor, + streams: List[Stream], + beam: int = 4, +): + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + + Args: + model: + The RNN-T model. + encoder_out: + A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of + the encoder model. + streams: + A list of stream objects. + beam: + Number of active paths during the beam search. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert len(streams) == encoder_out.size(0) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = next(model.parameters()).device + batch_size = len(streams) + T = encoder_out.size(1) + + B = [stream.hyps for stream in streams] + + encoder_out = model.joiner.encoder_proj(encoder_out) + + for t in range(T): + current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.stack( + [hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0 + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, decoder_out, project_input=False + ) + # logits is of shape (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + if new_token != blank_id: + new_ys.append(new_token) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B[i].add(new_hyp) + + for i in range(batch_size): + streams[i].hyps = B[i] + + +def fast_beam_search_one_best( + model: nn.Module, + streams: List[Stream], + encoder_out: torch.Tensor, + processed_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, +) -> None: + """It limits the maximum number of symbols per frame to 1. + + A lattice is first obtained using modified beam search, and then + the shortest path within the lattice is used as the final output. + + Args: + model: + An instance of `Transducer`. + streams: + A list of stream objects. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + processed_lens: + A tensor of shape (N,) containing the number of processed frames + in `encoder_out` before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + """ + assert encoder_out.ndim == 3 + + context_size = model.decoder.context_size + vocab_size = model.decoder.vocab_size + + B, T, C = encoder_out.shape + assert B == len(streams) + + config = k2.RnntDecodingConfig( + vocab_size=vocab_size, + decoder_history_len=context_size, + beam=beam, + max_contexts=max_contexts, + max_states=max_states, + ) + individual_streams = [] + for i in range(B): + individual_streams.append(streams[i].rnnt_decoding_stream) + decoding_streams = k2.RnntDecodingStreams(individual_streams, config) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + for t in range(T): + # shape is a RaggedShape of shape (B, context) + # contexts is a Tensor of shape (shape.NumElements(), context_size) + shape, contexts = decoding_streams.get_contexts() + # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 + contexts = contexts.to(torch.int64) + # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) + decoder_out = model.decoder(contexts, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + # current_encoder_out is of shape + # (shape.NumElements(), 1, joiner_dim) + # fmt: off + current_encoder_out = torch.index_select( + encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) + ) + # fmt: on + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + logits = logits.squeeze(1).squeeze(1) + log_probs = logits.log_softmax(dim=-1) + decoding_streams.advance(log_probs) + + decoding_streams.terminate_and_flush_to_streams() + + lattice = decoding_streams.format_output(processed_lens.tolist()) + + best_path = one_best_decoding(lattice) + hyps = get_texts(best_path) + + for i in range(B): + streams[i].hyp = hyps[i] + + +def decode_one_chunk( + model: nn.Module, + streams: List[Stream], + params: AttributeDict, + decoding_graph: Optional[k2.Fsa] = None, +) -> List[int]: + """ + Args: + model: + The Transducer model. + streams: + A list of Stream objects. + params: + It is returned by :func:`get_params`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search. + + Returns: + A list of indexes indicating the finished streams. + """ + device = next(model.parameters()).device + + feature_list = [] + feature_len_list = [] + state_list = [] + num_processed_frames_list = [] + + for stream in streams: + # We should first get `stream.num_processed_frames` + # before calling `stream.get_feature_chunk()` + # since `stream.num_processed_frames` would be updated + num_processed_frames_list.append(stream.num_processed_frames) + feature = stream.get_feature_chunk() + feature_len = feature.size(0) + feature_list.append(feature) + feature_len_list.append(feature_len) + state_list.append(stream.states) + + features = pad_sequence( + feature_list, batch_first=True, padding_value=LOG_EPSILON + ).to(device) + feature_lens = torch.tensor(feature_len_list, device=device) + num_processed_frames = torch.tensor( + num_processed_frames_list, device=device + ) + + # Make sure it has at least 1 frame after subsampling + tail_length = params.subsampling_factor + 5 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPSILON, + ) + + # Stack states of all streams + states = stack_states(state_list) + + encoder_out, encoder_out_lens, states = model.encoder( + x=features, + x_lens=feature_lens, + states=states, + ) + + if params.decoding_method == "greedy_search": + greedy_search( + model=model, + streams=streams, + encoder_out=encoder_out, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=streams, + encoder_out=encoder_out, + beam=params.beam_size, + ) + elif params.decoding_method == "fast_beam_search": + # feature_len is needed to get partial results. + # The rnnt_decoding_stream for fast_beam_search. + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + processed_lens = ( + num_processed_frames // params.subsampling_factor + + encoder_out_lens + ) + fast_beam_search_one_best( + model=model, + streams=streams, + encoder_out=encoder_out, + processed_lens=processed_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + + # Update cached states of each stream + state_list = unstack_states(states) + for i, s in enumerate(state_list): + streams[i].states = s + + finished_streams = [i for i, stream in enumerate(streams) if stream.done] + return finished_streams + + +def create_streaming_feature_extractor() -> Fbank: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + return Fbank(opts) + + +def decode_dataset( + cuts: CutSet, + model: nn.Module, + params: AttributeDict, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +): + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The Transducer model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search. + + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = next(model.parameters()).device + + log_interval = 300 + + fbank = create_streaming_feature_extractor() + + decode_results = [] + streams = [] + for num, cut in enumerate(cuts): + # Each utterance has a Stream. + stream = Stream( + params=params, + cut_id=cut.id, + decoding_graph=decoding_graph, + device=device, + LOG_EPS=LOG_EPSILON, + ) + + stream.states = model.encoder.get_init_states(device=device) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + # The trained model is using normalized samples + assert audio.max() <= 1, "Should be normalized to [-1, 1])" + + samples = torch.from_numpy(audio).squeeze(0) + feature = fbank(samples) + stream.set_feature(feature) + stream.ground_truth = cut.supervisions[0].text + + streams.append(stream) + + while len(streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + model=model, + streams=streams, + params=params, + decoding_graph=decoding_graph, + ) + + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + streams[i].id, + streams[i].ground_truth.split(), + sp.decode(streams[i].decoding_result()).split(), + ) + ) + del streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + while len(streams) > 0: + finished_streams = decode_one_chunk( + model=model, + streams=streams, + params=params, + decoding_graph=decoding_graph, + ) + + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + streams[i].id, + streams[i].ground_truth.split(), + sp.decode(streams[i].decoding_result()).split(), + ) + ) + del streams[i] + + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + else: + key = f"beam_size_{params.beam_size}" + + return {key: decode_results} + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + store_transcripts(filename=recog_path, texts=sorted(results)) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "fast_beam_search", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + elif "beam_search" in params.decoding_method: + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-streaming-decode") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + params.device = device + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.eval() + + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_sets = ["test-clean", "test-other"] + test_cuts = [test_clean_cuts, test_other_cuts] + + for test_set, test_cut in zip(test_sets, test_cuts): + results_dict = decode_dataset( + cuts=test_cut, + model=model, + params=params, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + torch.manual_seed(20220810) + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/train.py b/egs/librispeech/ASR/lstm_transducer_stateless/train.py old mode 100644 new mode 100755 index e69de29bb..d30fc260a --- a/egs/librispeech/ASR/lstm_transducer_stateless/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/train.py @@ -0,0 +1,1157 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./lstm_transducer_stateless/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir lstm_transducer_stateless/exp \ + --full-libri 1 \ + --max-duration 300 + +# For mix precision training: + +./lstm_transducer_stateless/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir lstm_transducer_stateless/exp \ + --full-libri 1 \ + --max-duration 550 +""" + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from lstm import RNN +from model import Transducer +from optim import Eden, Eve +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.utils import ( + AttributeDict, + MetricsTracker, + display_and_save_batch, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=int, + default=12, + help="Number of RNN encoder layers..", + ) + + parser.add_argument( + "--encoder-dim", + type=int, + default=512, + help="Encoder output dimesion.", + ) + + parser.add_argument( + "--rnn-hidden-size", + type=int, + default=1024, + help="Hidden dim for LSTM layers.", + ) + + parser.add_argument( + "--aux-layer-period", + type=int, + default=0, + help="""Peroid of auxiliary layers used for randomly combined during training. + If set to 0, will not use the random combiner (Default). + You can set a positive integer to use the random combiner, e.g., 3. + """, + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=35, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="lstm_transducer_stateless/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--initial-lr", + type=float, + default=0.003, + help="""The initial learning rate. This value should not need to be + changed.""", + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate decreases. + We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=10, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" + "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=20, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=100, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--delay-penalty", + type=float, + default=0.0, + help="""A constant value used to penalize symbol delay, + to encourage streaming models to emit symbols earlier. + See https://github.com/k2-fsa/k2/issues/955 and + https://arxiv.org/pdf/2211.00490.pdf for more details.""", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for conformer + "feature_dim": 80, + "subsampling_factor": 4, + "dim_feedforward": 2048, + # parameters for decoder + "decoder_dim": 512, + # parameters for joiner + "joiner_dim": 512, + # parameters for Noam + "model_warm_step": 3000, # arg given to model, not for lrate + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = RNN( + num_features=params.feature_dim, + subsampling_factor=params.subsampling_factor, + d_model=params.encoder_dim, + rnn_hidden_size=params.rnn_hidden_size, + dim_feedforward=params.dim_feedforward, + num_encoder_layers=params.num_encoder_layers, + aux_layer_period=params.aux_layer_period, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=params.encoder_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=params.encoder_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, + warmup: float = 1.0, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute RNN-T loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + warmup=warmup, + reduction="none", + delay_penalty=params.delay_penalty if warmup >= 2.0 else 0, + ) + simple_loss_is_finite = torch.isfinite(simple_loss) + pruned_loss_is_finite = torch.isfinite(pruned_loss) + is_finite = simple_loss_is_finite & pruned_loss_is_finite + if not torch.all(is_finite): + logging.info( + "Not all losses are finite!\n" + f"simple_loss: {simple_loss}\n" + f"pruned_loss: {pruned_loss}" + ) + display_and_save_batch(batch, params=params, sp=sp) + simple_loss = simple_loss[simple_loss_is_finite] + pruned_loss = pruned_loss[pruned_loss_is_finite] + + # If either all simple_loss or pruned_loss is inf or nan, + # we stop the training process by raising an exception + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): + raise ValueError( + "There are too many utterances in this batch " + "leading to inf or nan losses." + ) + + simple_loss = simple_loss.sum() + pruned_loss = pruned_loss.sum() + # after the main warmup step, we keep pruned_loss_scale small + # for the same amount of time (model_warm_step), to avoid + # overwhelming the simple_loss and causing it to diverge, + # in case it had not fully learned the alignment yet. + pruned_loss_scale = ( + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss + ) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # info["frames"] is an approximate number for two reasons: + # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 + # (2) If some utterances in the batch lead to inf/nan loss, they + # are filtered out. + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) + + # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa + info["utterances"] = feature.size(0) + # averaged input duration in frames over utterances + info["utt_duration"] = feature_lens.sum().item() + # averaged padding proportion over utterances + info["utt_pad_proportion"] = ( + ((feature.size(1) - feature_lens) / feature.size(1)).sum().item() + ) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup=(params.batch_idx_train / params.model_warm_step), + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 30: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}" + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + if params.full_libri is False: + params.valid_interval = 800 + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank]) + + optimizer = Eve(model.parameters(), lr=params.initial_lr) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + # # overwrite it + # scheduler.base_lrs = [params.initial_lr for _ in scheduler.base_lrs] + # print(scheduler.base_lrs) + + if params.print_diagnostics: + diagnostic = diagnostics.attach_diagnostics(model) + + librispeech = LibriSpeechAsrDataModule(args) + + train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./lstm.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 3) // 2 - 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + warmup=0.0 if params.start_epoch == 1 else 1.0, + ) + + scaler = GradScaler(enabled=params.use_fp16) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, + warmup: float, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup=warmup, + ) + loss.backward() + optimizer.step() + optimizer.zero_grad() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index f7e1b5a54..bad4e243e 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -185,24 +185,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -299,7 +295,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -477,7 +474,9 @@ def decode_one_batch( ) feature_lens += num_tail_padded_frames - encoder_out, encoder_out_lens, _ = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens, _ = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] @@ -536,7 +535,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -698,7 +700,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -731,7 +735,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -784,7 +789,9 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -819,12 +826,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -852,12 +860,13 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -886,7 +895,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -952,7 +961,9 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py index 0ad00cda3..190673638 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py @@ -146,24 +146,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -229,7 +225,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) add_model_arguments(parser) @@ -345,7 +342,9 @@ def export_encoder_model_onnx( x = torch.zeros(N, 9, 80, dtype=torch.float32) x_lens = torch.tensor([9], dtype=torch.int64) h = torch.rand(encoder_model.num_encoder_layers, N, encoder_model.d_model) - c = torch.rand(encoder_model.num_encoder_layers, N, encoder_model.rnn_hidden_size) + c = torch.rand( + encoder_model.num_encoder_layers, N, encoder_model.rnn_hidden_size + ) warmup = 1.0 torch.onnx.export( @@ -446,9 +445,13 @@ def export_joiner_model_onnx( - projected_decoder_out: a tensor of shape (N, joiner_dim) """ - encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") + encoder_proj_filename = str(joiner_filename).replace( + ".onnx", "_encoder_proj.onnx" + ) - decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") + decoder_proj_filename = str(joiner_filename).replace( + ".onnx", "_decoder_proj.onnx" + ) encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] @@ -547,12 +550,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -581,12 +585,13 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -615,7 +620,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -689,7 +694,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py index 5a8efd718..da184b76f 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py @@ -86,12 +86,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -126,9 +124,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -316,7 +315,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py index 4957d14b1..fadeb4ac2 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py @@ -84,7 +84,9 @@ class Transducer(nn.Module): self.decoder_giga = decoder_giga self.joiner_giga = joiner_giga - self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_speed=0.5 + ) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) if decoder_giga is not None: @@ -188,7 +190,9 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py index 3b471fa85..410de8d3d 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py @@ -156,7 +156,9 @@ class Model: assert ret == 0, ret encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() - encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to(torch.int32) + encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to( + torch.int32 + ) hx = torch.from_numpy(ncnn_out2.numpy()).clone() cx = torch.from_numpy(ncnn_out3.numpy()).clone() return encoder_out, encoder_out_lens, hx, cx @@ -198,9 +200,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -283,7 +286,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py index 7d931a286..bef0ad760 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py @@ -92,11 +92,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -121,12 +119,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -173,7 +169,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -204,9 +201,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -269,11 +267,15 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens, _ = model.encoder(x=features, x_lens=feature_lengths) + encoder_out, encoder_out_lens, _ = model.encoder( + x=features, x_lens=feature_lengths + ) num_waves = encoder_out.size(0) hyps = [] @@ -345,7 +347,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py index baff15ea6..e47a05a9e 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py @@ -144,7 +144,9 @@ class Model: assert ret == 0, ret encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() - encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to(torch.int32) + encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to( + torch.int32 + ) hx = torch.from_numpy(ncnn_out2.numpy()).clone() cx = torch.from_numpy(ncnn_out3.numpy()).clone() return encoder_out, encoder_out_lens, hx, cx @@ -186,9 +188,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -226,7 +229,9 @@ def greedy_search( if decoder_out is None: assert hyp is None, hyp hyp = [blank_id] * context_size - decoder_input = torch.tensor(hyp, dtype=torch.int32) # (1, context_size) + decoder_input = torch.tensor( + hyp, dtype=torch.int32 + ) # (1, context_size) decoder_out = model.run_decoder(decoder_input).squeeze(0) else: assert decoder_out.ndim == 1 @@ -305,7 +310,9 @@ def main(): frames.append(online_fbank.get_frame(num_processed_frames + i)) num_processed_frames += offset frames = torch.cat(frames, dim=0) - encoder_out, encoder_out_lens, hx, cx = model.run_encoder(frames, states) + encoder_out, encoder_out_lens, hx, cx = model.run_encoder( + frames, states + ) states = (hx, cx) hyp, decoder_out = greedy_search( model, encoder_out.squeeze(0), decoder_out, hyp @@ -321,7 +328,9 @@ def main(): frames.append(online_fbank.get_frame(num_processed_frames + i)) num_processed_frames += offset frames = torch.cat(frames, dim=0) - encoder_out, encoder_out_lens, hx, cx = model.run_encoder(frames, states) + encoder_out, encoder_out_lens, hx, cx = model.run_encoder( + frames, states + ) states = (hx, cx) hyp, decoder_out = greedy_search( model, encoder_out.squeeze(0), decoder_out, hyp @@ -334,7 +343,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py index b31fefa0a..232d3dd18 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py @@ -109,12 +109,10 @@ def get_args(): parser.add_argument( "sound_filename", type=str, - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -149,9 +147,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -200,7 +199,9 @@ class Model: sess_options=self.session_opts, ) - def run_encoder(self, x, h0, c0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def run_encoder( + self, x, h0, c0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: x: @@ -257,7 +258,9 @@ class Model: }, )[0] - return self.run_joiner_decoder_proj(torch.from_numpy(decoder_out).squeeze(1)) + return self.run_joiner_decoder_proj( + torch.from_numpy(decoder_out).squeeze(1) + ) def run_joiner( self, @@ -300,7 +303,11 @@ class Model: projected_encoder_out = self.joiner_encoder_proj.run( [self.joiner_encoder_proj.get_outputs()[0].name], - {self.joiner_encoder_proj.get_inputs()[0].name: encoder_out.numpy()}, + { + self.joiner_encoder_proj.get_inputs()[ + 0 + ].name: encoder_out.numpy() + }, )[0] return torch.from_numpy(projected_encoder_out) @@ -319,7 +326,11 @@ class Model: projected_decoder_out = self.joiner_decoder_proj.run( [self.joiner_decoder_proj.get_outputs()[0].name], - {self.joiner_decoder_proj.get_inputs()[0].name: decoder_out.numpy()}, + { + self.joiner_decoder_proj.get_inputs()[ + 0 + ].name: decoder_out.numpy() + }, )[0] return torch.from_numpy(projected_decoder_out) @@ -358,7 +369,9 @@ def greedy_search( if decoder_out is None: assert hyp is None, hyp hyp = [blank_id] * context_size - decoder_input = torch.tensor([hyp], dtype=torch.int64) # (1, context_size) + decoder_input = torch.tensor( + [hyp], dtype=torch.int64 + ) # (1, context_size) decoder_out = model.run_decoder(decoder_input) else: assert decoder_out.shape[0] == 1 @@ -461,7 +474,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py index 08a895a75..5eaaf321f 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py @@ -95,7 +95,9 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def add_model_arguments(parser: argparse.ArgumentParser): @@ -161,7 +163,8 @@ def get_parser(): "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. " + "Otherwise, use 100h subset.", ) parser.add_argument( @@ -235,45 +238,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -645,7 +645,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -688,7 +692,9 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -701,9 +707,14 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -714,7 +725,9 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -945,7 +958,9 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) libri_tot_loss.write_summary( tb_writer, "train/libri_tot_", params.batch_idx_train ) @@ -991,7 +1006,8 @@ def filter_short_and_long_utterances( # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False @@ -1139,7 +1155,9 @@ def run(rank, world_size, args): train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: - cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") + cuts_musan = load_manifest( + Path(args.manifest_dir) / "musan_cuts.jsonl.gz" + ) else: cuts_musan = None diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py index a8d5605fb..9eee19379 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py @@ -182,24 +182,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -294,7 +290,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -389,7 +386,9 @@ def decode_one_batch( ) feature_lens += num_tail_padded_frames - encoder_out, encoder_out_lens, _ = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens, _ = model.encoder( + x=feature, x_lens=feature_lens + ) if params.decoding_method == "fast_beam_search": res = fast_beam_search_one_best( @@ -442,7 +441,10 @@ def decode_one_batch( nbest_scale=params.nbest_scale, return_timestamps=True, ) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): res = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -520,7 +522,9 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]: +) -> Dict[ + str, List[Tuple[str, List[str], List[str], List[float], List[float]]] +]: """Decode dataset. Args: @@ -595,7 +599,9 @@ def decode_dataset( cut_ids, hyps, texts, timestamps_hyp, timestamps_ref ): ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words, time_ref, time_hyp)) + this_batch.append( + (cut_id, ref_words, hyp_words, time_ref, time_hyp) + ) results[name].extend(this_batch) @@ -604,7 +610,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -642,7 +650,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -669,7 +678,9 @@ def save_results( note = "" logging.info(s) - s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name) + s = "\nFor {}, symbol-delay of different settings are:\n".format( + test_set_name + ) note = "\tbest for {}".format(test_set_name) for key, val in test_set_delays: s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note) @@ -713,7 +724,9 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -745,12 +758,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -773,12 +787,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -806,7 +821,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -833,7 +848,9 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/export.py b/egs/librispeech/ASR/lstm_transducer_stateless3/export.py index 51238f768..212c7bad6 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/export.py @@ -122,24 +122,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -176,7 +172,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) add_model_arguments(parser) @@ -284,12 +281,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -312,12 +310,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -345,7 +344,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -381,7 +380,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py index 180ba8c72..a3443cf0a 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py @@ -85,12 +85,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -125,9 +123,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -315,7 +314,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py index 6e51b85e4..90bc351f4 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py @@ -661,7 +661,9 @@ class RandomCombine(nn.Module): self.stddev = stddev self.final_log_weight = ( - torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)) + torch.tensor( + (final_weight / (1 - final_weight)) * (self.num_inputs - 1) + ) .log() .item() ) @@ -758,14 +760,16 @@ class RandomCombine(nn.Module): # final contains self.num_inputs - 1 in all elements final = torch.full((num_frames,), self.num_inputs - 1, device=device) # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. # noqa - nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device) + nonfinal = torch.randint( + self.num_inputs - 1, (num_frames,), device=device + ) indexes = torch.where( torch.rand(num_frames, device=device) < final_prob, final, nonfinal ) - ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to( - dtype=dtype - ) + ans = torch.nn.functional.one_hot( + indexes, num_classes=self.num_inputs + ).to(dtype=dtype) return ans def _get_random_mixed_weights( diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py index 4f8049245..0e48fef04 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py @@ -89,11 +89,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -118,12 +116,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -170,7 +166,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -201,9 +198,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -266,11 +264,15 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens, _ = model.encoder(x=features, x_lens=feature_lengths) + encoder_out, encoder_out_lens, _ = model.encoder( + x=features, x_lens=feature_lengths + ) num_waves = encoder_out.size(0) hyps = [] @@ -342,7 +344,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py index 4e9063a40..cfa918ed5 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py @@ -101,9 +101,8 @@ def get_parser(): "--epoch", type=int, default=40, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( @@ -120,24 +119,20 @@ def get_parser(): "--avg", type=int, default=20, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -204,7 +199,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -363,7 +359,9 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner(current_encoder_out, decoder_out, project_input=False) + logits = model.joiner( + current_encoder_out, decoder_out, project_input=False + ) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -380,7 +378,9 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -539,7 +539,9 @@ def decode_one_chunk( feature_list, batch_first=True, padding_value=LOG_EPSILON ).to(device) feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor(num_processed_frames_list, device=device) + num_processed_frames = torch.tensor( + num_processed_frames_list, device=device + ) # Make sure it has at least 1 frame after subsampling tail_length = params.subsampling_factor + 5 @@ -581,7 +583,8 @@ def decode_one_chunk( with warnings.catch_warnings(): warnings.simplefilter("ignore") processed_lens = ( - num_processed_frames // params.subsampling_factor + encoder_out_lens + num_processed_frames // params.subsampling_factor + + encoder_out_lens ) fast_beam_search_one_best( model=model, @@ -593,7 +596,9 @@ def decode_one_chunk( max_states=params.max_states, ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) # Update cached states of each stream state_list = unstack_states(states) @@ -768,7 +773,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -810,7 +816,9 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -844,12 +852,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -872,12 +881,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -905,7 +915,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py index a1d19fb73..60a5a2be7 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py @@ -87,7 +87,9 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def add_model_arguments(parser: argparse.ArgumentParser): @@ -230,45 +232,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -607,7 +606,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -647,7 +650,9 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -660,9 +665,14 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -673,7 +683,9 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -840,7 +852,10 @@ def train_one_epoch( rank=rank, ) - if batch_idx % params.log_interval == 0 and not params.print_diagnostics: + if ( + batch_idx % params.log_interval == 0 + and not params.print_diagnostics + ): cur_lr = scheduler.get_last_lr()[0] logging.info( f"Epoch {params.cur_epoch}, " @@ -857,7 +872,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if ( batch_idx > 0 @@ -992,7 +1009,8 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py index fd2a5354a..8dd1459ca 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py +++ b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py @@ -74,18 +74,17 @@ class LibriSpeechAsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description=( - "These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc." - ), + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", ) group.add_argument( "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. " + "Otherwise, use 100h subset.", ) group.add_argument( "--manifest-dir", @@ -97,91 +96,75 @@ class LibriSpeechAsrDataModule: "--max-duration", type=int, default=200.0, - help=( - "Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM." - ), + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help=( - "When enabled, the batches will come from buckets of " - "similar duration (saves padding frames)." - ), + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", ) group.add_argument( "--num-buckets", type=int, default=30, - help=( - "The number of buckets for the BucketingSampler" - "(you might want to increase it for larger datasets)." - ), + help="The number of buckets for the BucketingSampler" + "(you might want to increase it for larger datasets).", ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help=( - "When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding." - ), + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", ) group.add_argument( "--duration-factor", type=float, default=1.0, - help=( - "Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch." - ), + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", ) group.add_argument( "--gap", type=float, default=1.0, - help=( - "The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used." - ), + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help=( - "When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available." - ), + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", ) group.add_argument( "--shuffle", type=str2bool, default=True, - help=( - "When enabled (=default), the examples will be shuffled for each epoch." - ), + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help=( - "When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it." - ), + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that collect the batches.", + help="The number of training dataloader workers that " + "collect the batches.", ) group.add_argument( @@ -195,22 +178,18 @@ class LibriSpeechAsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help=( - "Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp." - ), + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help=( - "When enabled, select noise from MUSAN and mix it" - "with training dataset. " - ), + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", ) def train_dataloaders( @@ -229,16 +208,20 @@ class LibriSpeechAsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.json.gz") + cuts_musan = load_manifest( + self.args.manifest_dir / "cuts_musan.json.gz" + ) transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix( + cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True + ) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - "Using cut concatenation with duration factor " + f"Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -253,7 +236,9 @@ class LibriSpeechAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + logging.info( + f"Time warp factor: {self.args.spec_aug_time_warp_factor}" + ) # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -296,7 +281,9 @@ class LibriSpeechAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -353,7 +340,9 @@ class LibriSpeechAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), return_cuts=self.args.return_cuts, ) else: @@ -400,17 +389,23 @@ class LibriSpeechAsrDataModule: @lru_cache() def train_clean_100_cuts(self) -> CutSet: logging.info("About to get train-clean-100 cuts") - return load_manifest(self.args.manifest_dir / "cuts_train-clean-100.json.gz") + return load_manifest( + self.args.manifest_dir / "cuts_train-clean-100.json.gz" + ) @lru_cache() def train_clean_360_cuts(self) -> CutSet: logging.info("About to get train-clean-360 cuts") - return load_manifest(self.args.manifest_dir / "cuts_train-clean-360.json.gz") + return load_manifest( + self.args.manifest_dir / "cuts_train-clean-360.json.gz" + ) @lru_cache() def train_other_500_cuts(self) -> CutSet: logging.info("About to get train-other-500 cuts") - return load_manifest(self.args.manifest_dir / "cuts_train-other-500.json.gz") + return load_manifest( + self.args.manifest_dir / "cuts_train-other-500.json.gz" + ) @lru_cache() def dev_clean_cuts(self) -> CutSet: diff --git a/egs/librispeech/ASR/pruned2_knowledge/beam_search.py b/egs/librispeech/ASR/pruned2_knowledge/beam_search.py index 785a8f097..2e9bf3e0b 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/beam_search.py +++ b/egs/librispeech/ASR/pruned2_knowledge/beam_search.py @@ -172,9 +172,9 @@ def greedy_search( y = logits.argmax().item() if y != blank_id: hyp.append(y) - decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( - 1, context_size - ) + decoder_input = torch.tensor( + [hyp[-context_size:]], device=device + ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.joiner.decoder_proj(decoder_out) @@ -302,7 +302,9 @@ class HypothesisList(object): key = hyp.key if key in self: old_hyp = self._data[key] # shallow copy - torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) + torch.logaddexp( + old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob + ) else: self._data[key] = hyp @@ -318,7 +320,9 @@ class HypothesisList(object): Return the hypothesis that has the largest `log_prob`. """ if length_norm: - return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) + return max( + self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) + ) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -492,7 +496,9 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) diff --git a/egs/librispeech/ASR/pruned2_knowledge/conformer.py b/egs/librispeech/ASR/pruned2_knowledge/conformer.py index 3b6d0549d..295a35204 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/conformer.py +++ b/egs/librispeech/ASR/pruned2_knowledge/conformer.py @@ -18,10 +18,10 @@ import math import warnings from typing import Optional, Tuple +from sampling import create_knowledge_base, KnowledgeBaseLookup import torch from encoder_interface import EncoderInterface -from sampling import KnowledgeBaseLookup, create_knowledge_base from scaling import ( ActivationBalancer, BasicNorm, @@ -73,9 +73,9 @@ class Conformer(EncoderInterface): if subsampling_factor != 4: raise NotImplementedError("Support only 'subsampling_factor=4'.") - self.knowledge_base = create_knowledge_base( - knowledge_M, knowledge_N, knowledge_D - ) + + self.knowledge_base = create_knowledge_base(knowledge_M, knowledge_N, + knowledge_D) # self.encoder_embed converts the input of shape (N, T, num_features) # to the shape (N, T//subsampling_factor, d_model). @@ -89,7 +89,7 @@ class Conformer(EncoderInterface): # Pass in a lambda that creates a new ConformerEncoderLayer with these # args. Don't use deepcopy because we need the knowledge_base # to be shared. - encoder_layer_fn = lambda: ConformerEncoderLayer( # noqa: E731 + encoder_layer_fn = lambda: ConformerEncoderLayer( self.knowledge_base, d_model, nhead, @@ -100,7 +100,7 @@ class Conformer(EncoderInterface): knowledge_M, knowledge_N, knowledge_D, - knowledge_K, + knowledge_K ) self.encoder = ConformerEncoder(encoder_layer_fn, num_encoder_layers) @@ -187,7 +187,9 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -207,14 +209,10 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.lookup = KnowledgeBaseLookup( - knowledge_M, - knowledge_N, - knowledge_D, - knowledge_K, - d_model, - knowledge_base, - ) + self.lookup = KnowledgeBaseLookup(knowledge_M, knowledge_N, + knowledge_D, knowledge_K, + d_model, + knowledge_base) self.norm_final = BasicNorm(d_model) @@ -313,7 +311,9 @@ class ConformerEncoder(nn.Module): def __init__(self, encoder_layer_fn, num_layers: int) -> None: super().__init__() - self.layers = nn.ModuleList([encoder_layer_fn() for i in range(num_layers)]) + self.layers = nn.ModuleList( + [encoder_layer_fn() for i in range(num_layers)] + ) self.num_layers = num_layers def forward( @@ -367,7 +367,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -382,7 +384,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vecotr and `j` means the @@ -657,9 +661,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -728,25 +732,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -783,7 +795,9 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -791,9 +805,13 @@ class RelPositionMultiheadAttention(nn.Module): ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) + attn_output_weights = ( + matrix_ac + matrix_bd + ) # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -827,9 +845,13 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -852,7 +874,9 @@ class ConvolutionModule(nn.Module): """ - def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: + def __init__( + self, channels: int, kernel_size: int, bias: bool = True + ) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/pruned2_knowledge/decode.py b/egs/librispeech/ASR/pruned2_knowledge/decode.py index 65da19f27..b4a9af55a 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/decode.py +++ b/egs/librispeech/ASR/pruned2_knowledge/decode.py @@ -76,7 +76,11 @@ from beam_search import ( ) from train import get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import ( AttributeDict, setup_logger, @@ -94,19 +98,16 @@ def get_parser(): "--epoch", type=int, default=28, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -185,7 +186,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -243,7 +245,9 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] if params.decoding_method == "fast_beam_search": @@ -258,7 +262,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -302,7 +309,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -374,7 +385,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -406,7 +419,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned2_knowledge/decoder.py b/egs/librispeech/ASR/pruned2_knowledge/decoder.py index 0b9c886c7..b6d94aaf1 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/decoder.py +++ b/egs/librispeech/ASR/pruned2_knowledge/decoder.py @@ -90,7 +90,9 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) + embedding_out = F.pad( + embedding_out, pad=(self.context_size - 1, 0) + ) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned2_knowledge/decoder2.py b/egs/librispeech/ASR/pruned2_knowledge/decoder2.py index 2ca76a30c..db51fb1cd 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/decoder2.py +++ b/egs/librispeech/ASR/pruned2_knowledge/decoder2.py @@ -14,13 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional - import torch import torch.nn as nn import torch.nn.functional as F -from subsampling import ScaledConv1d from torch import Tensor +from typing import Optional +from subsampling import ScaledConv1d class Decoder(nn.Module): @@ -91,7 +90,9 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) + embedding_out = F.pad( + embedding_out, pad=(self.context_size - 1, 0) + ) else: # During inference time, there is no need to do extra padding # as we only need one output @@ -101,6 +102,7 @@ class Decoder(nn.Module): return embedding_out + class ScaledEmbedding(nn.Module): r"""A simple lookup table that stores embeddings of a fixed dictionary and size. @@ -169,13 +171,8 @@ class ScaledEmbedding(nn.Module): [ 0.0000, 0.0000, 0.0000], [-0.1655, 0.9897, 0.0635]]]) """ - __constants__ = [ - "num_embeddings", - "embedding_dim", - "padding_idx", - "scale_grad_by_freq", - "sparse", - ] + __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', + 'scale_grad_by_freq', 'sparse'] num_embeddings: int embedding_dim: int @@ -184,41 +181,34 @@ class ScaledEmbedding(nn.Module): weight: Tensor sparse: bool - def __init__( - self, - num_embeddings: int, - embedding_dim: int, - padding_idx: Optional[int] = None, - scale_grad_by_freq: bool = False, - sparse: bool = False, - scale_speed: float = 5.0, - ) -> None: + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, + scale_grad_by_freq: bool = False, + sparse: bool = False, + scale_speed: float = 5.0) -> None: super(ScaledEmbedding, self).__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim if padding_idx is not None: if padding_idx > 0: - assert ( - padding_idx < self.num_embeddings - ), "Padding_idx must be within num_embeddings" + assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' elif padding_idx < 0: - assert ( - padding_idx >= -self.num_embeddings - ), "Padding_idx must be within num_embeddings" + assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' padding_idx = self.num_embeddings + padding_idx self.padding_idx = padding_idx self.scale_grad_by_freq = scale_grad_by_freq self.scale_speed = scale_speed - self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() + self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() self.sparse = sparse self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) self.reset_parameters() + + def reset_parameters(self) -> None: nn.init.normal_(self.weight, std=0.05) - nn.init.constant_(self.scale, torch.tensor(1.0 / 0.05).log() / self.scale_speed) + nn.init.constant_(self.scale, torch.tensor(1.0/0.05).log() / self.scale_speed) if self.padding_idx is not None: with torch.no_grad(): @@ -227,38 +217,22 @@ class ScaledEmbedding(nn.Module): def forward(self, input: Tensor) -> Tensor: scale = (self.scale * self.scale_speed).exp() if input.numel() < self.num_embeddings: - return ( - F.embedding( - input, - self.weight, - self.padding_idx, - None, - 2.0, # None, 2.0 relate to normalization - self.scale_grad_by_freq, - self.sparse, - ) - * scale - ) + return F.embedding( + input, self.weight, self.padding_idx, + None, 2.0, # None, 2.0 relate to normalization + self.scale_grad_by_freq, self.sparse) * scale else: return F.embedding( - input, - self.weight * scale, - self.padding_idx, - None, - 2.0, # None, 2.0 relates to normalization - self.scale_grad_by_freq, - self.sparse, - ) + input, self.weight * scale, self.padding_idx, + None, 2.0, # None, 2.0 relates to normalization + self.scale_grad_by_freq, self.sparse) def extra_repr(self) -> str: - s = ( - "{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}," - " scale={scale}" - ) + s = '{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}, scale={scale}' if self.padding_idx is not None: - s += ", padding_idx={padding_idx}" + s += ', padding_idx={padding_idx}' if self.scale_grad_by_freq is not False: - s += ", scale_grad_by_freq={scale_grad_by_freq}" + s += ', scale_grad_by_freq={scale_grad_by_freq}' if self.sparse is not False: - s += ", sparse=True" + s += ', sparse=True' return s.format(**self.__dict__) diff --git a/egs/librispeech/ASR/pruned2_knowledge/export.py b/egs/librispeech/ASR/pruned2_knowledge/export.py index 1af05d9c8..96d1a30fb 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/export.py +++ b/egs/librispeech/ASR/pruned2_knowledge/export.py @@ -64,20 +64,17 @@ def get_parser(): "--epoch", type=int, default=28, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -108,7 +105,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) return parser @@ -176,7 +174,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned2_knowledge/joiner.py b/egs/librispeech/ASR/pruned2_knowledge/joiner.py index 68c663b66..35f75ed2a 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/joiner.py +++ b/egs/librispeech/ASR/pruned2_knowledge/joiner.py @@ -56,7 +56,9 @@ class Joiner(nn.Module): assert encoder_out.shape[:-1] == decoder_out.shape[:-1] if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) + logit = self.encoder_proj(encoder_out) + self.decoder_proj( + decoder_out + ) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/pruned2_knowledge/model.py b/egs/librispeech/ASR/pruned2_knowledge/model.py index ca8c28af1..599bf2506 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/model.py +++ b/egs/librispeech/ASR/pruned2_knowledge/model.py @@ -63,7 +63,9 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_speed=0.5 + ) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) def forward( @@ -134,7 +136,9 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned2_knowledge/optim.py b/egs/librispeech/ASR/pruned2_knowledge/optim.py index 76cd4e11e..432bf8220 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/optim.py +++ b/egs/librispeech/ASR/pruned2_knowledge/optim.py @@ -72,11 +72,17 @@ class Eve(Optimizer): if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + raise ValueError( + "Invalid beta parameter at index 0: {}".format(betas[0]) + ) if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + raise ValueError( + "Invalid beta parameter at index 1: {}".format(betas[1]) + ) if not 0 <= weight_decay <= 0.1: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict( @@ -112,7 +118,9 @@ class Eve(Optimizer): # Perform optimization step grad = p.grad if grad.is_sparse: - raise RuntimeError("AdamW does not support sparse gradients") + raise RuntimeError( + "AdamW does not support sparse gradients" + ) state = self.state[p] @@ -139,7 +147,7 @@ class Eve(Optimizer): # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( + denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( group["eps"] ) @@ -150,7 +158,9 @@ class Eve(Optimizer): if p.numel() > 1: # avoid applying this weight-decay on "scaling factors" # (which are scalar). - is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) + is_above_target_rms = p.norm() > ( + target_rms * (p.numel() ** 0.5) + ) p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) @@ -166,14 +176,18 @@ class LRScheduler(object): def __init__(self, optimizer: Optimizer, verbose: bool = False): # Attach optimizer if not isinstance(optimizer, Optimizer): - raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) + raise TypeError( + "{} is not an Optimizer".format(type(optimizer).__name__) + ) self.optimizer = optimizer self.verbose = verbose for group in optimizer.param_groups: group.setdefault("initial_lr", group["lr"]) - self.base_lrs = [group["initial_lr"] for group in optimizer.param_groups] + self.base_lrs = [ + group["initial_lr"] for group in optimizer.param_groups + ] self.epoch = 0 self.batch = 0 @@ -281,9 +295,10 @@ class Eden(LRScheduler): def get_lr(self): factor = ( - (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 + (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 ) ** -0.25 * ( - ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 + ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) + ** -0.25 ) return [x * factor for x in self.base_lrs] diff --git a/egs/librispeech/ASR/pruned2_knowledge/sampling.py b/egs/librispeech/ASR/pruned2_knowledge/sampling.py index 8cc930927..7b05e2f00 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/sampling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/sampling.py @@ -3,29 +3,32 @@ # This was copied from /ceph-dan/torch-sampling/torch_sampling/sampling_ref.py, # its git history is there. -import random import timeit -from typing import Optional, Tuple - import torch +from torch import Tensor +from torch import nn +from torch.cuda.amp import GradScaler, custom_fwd, custom_bwd +from typing import Tuple, Optional from scaling import ScaledLinear -from torch import Tensor, nn -from torch.cuda.amp import GradScaler, custom_bwd, custom_fwd +import random from torch_scheduled_sampling import sample_combined # The main exports of this file are the module KnowledgeBaseLookup and the # function create_knowledge_base. + + + + def create_knowledge_base(M: int, N: int, D: int) -> nn.Parameter: std = 0.1 - a = (3**0.5) * std # this sqrt(3) thing is intended to get variance of - # 0.1 from uniform distribution - ans = nn.Parameter(torch.ones(M**N, D)) + a = (3 ** 0.5) * std # this sqrt(3) thing is intended to get variance of + # 0.1 from uniform distribution + ans = nn.Parameter(torch.ones(M ** N, D)) nn.init.uniform_(ans, -a, a) return ans - def join_indexes(indexes: Tensor, M: int) -> Tensor: """ Combines N-tuples of indexes into single indexes that can be used for @@ -44,9 +47,9 @@ def join_indexes(indexes: Tensor, M: int) -> Tensor: # Note, we don't use this, we -def weighted_matrix_lookup( - weights: Tensor, indexes: Tensor, knowledge_base: Tensor -) -> Tensor: +def weighted_matrix_lookup(weights: Tensor, + indexes: Tensor, + knowledge_base: Tensor) -> Tensor: """ Weighted combination of specified rows of a matrix. weights: Tensor of shape (*, K), can contain any value but probably in [0..1]. @@ -62,9 +65,9 @@ def weighted_matrix_lookup( # simpler but less memory-efficient implementation lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) D = knowledge_base.shape[-1] - weights = weights.unsqueeze(-2) # (*, 1, K) - lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) - ans = torch.matmul(weights, lookup) # ans: (*, 1, D) + weights = weights.unsqueeze(-2) # (*, 1, K) + lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) + ans = torch.matmul(weights, lookup) # ans: (*, 1, D) ans = ans.squeeze(-2) assert list(ans.shape) == list(weights.shape[:-2]) + [D] return ans @@ -73,9 +76,7 @@ def weighted_matrix_lookup( class WeightedMatrixLookupFunction(torch.autograd.Function): @staticmethod @custom_fwd - def forward( - ctx, weights: Tensor, indexes: Tensor, knowledge_base: Tensor - ) -> Tensor: + def forward(ctx, weights: Tensor, indexes: Tensor, knowledge_base: Tensor) -> Tensor: """ Weighted combination of specified rows of a matrix. weights: Tensor of shape (*, K), can contain any value but probably in [0..1]. @@ -87,16 +88,15 @@ class WeightedMatrixLookupFunction(torch.autograd.Function): """ if random.random() < 0.001: print("dtype[1] = ", weights.dtype) - ctx.save_for_backward( - weights.detach(), indexes.detach(), knowledge_base.detach() - ) + ctx.save_for_backward(weights.detach(), indexes.detach(), + knowledge_base.detach()) with torch.no_grad(): lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) D = knowledge_base.shape[-1] - weights = weights.unsqueeze(-2) # (*, 1, K) - lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) - ans = torch.matmul(weights, lookup) # ans: (*, 1, D) - ans = ans.squeeze(-2) # (*, D) + weights = weights.unsqueeze(-2) # (*, 1, K) + lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) + ans = torch.matmul(weights, lookup) # ans: (*, 1, D) + ans = ans.squeeze(-2) #(*, D) return ans @staticmethod @@ -107,7 +107,7 @@ class WeightedMatrixLookupFunction(torch.autograd.Function): knowledge_base.requires_grad = True dtype = ans_grad.dtype ans_grad = ans_grad.to(weights.dtype) - assert weights.requires_grad is False + assert weights.requires_grad == False D = knowledge_base.shape[-1] with torch.enable_grad(): # we'll use torch's autograd to differentiate this operation, which @@ -115,19 +115,16 @@ class WeightedMatrixLookupFunction(torch.autograd.Function): # We don't save `lookup` because it's large, that is the reason # we override Torch autograd. lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) - lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) - weights = weights.unsqueeze(-1) # (*, K, 1) + lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) + weights = weights.unsqueeze(-1) # (*, K, 1) # forward pass: was: ## ans = torch.matmul(weights, lookup) ## ans: (*, 1, D) ## ans = ans.squeeze(-2) # ans, ans_grad: (*, D) - weights_grad = torch.matmul( - lookup, ans_grad.unsqueeze(-1) # (*, K, D) - ) # (*, D, 1) - weights_grad = weights_grad.squeeze(-1) # (*, K, 1) -> (*, K) - lookup_grad = weights * ans_grad.unsqueeze( - -2 - ) # (*, K, 1) * (*, 1, D) = (*, K, D) + weights_grad = torch.matmul(lookup, # (*, K, D) + ans_grad.unsqueeze(-1)) # (*, D, 1) + weights_grad = weights_grad.squeeze(-1) # (*, K, 1) -> (*, K) + lookup_grad = weights * ans_grad.unsqueeze(-2) # (*, K, 1) * (*, 1, D) = (*, K, D) lookup.backward(gradient=lookup_grad) return weights_grad.to(dtype), None, knowledge_base.grad.to(dtype) @@ -149,7 +146,6 @@ class PenalizeNegentropyFunction(torch.autograd.Function): Returns: logprobs """ - @staticmethod def forward(ctx, logprobs: Tensor, alpha: float): ctx.save_for_backward(logprobs.detach()) @@ -158,23 +154,18 @@ class PenalizeNegentropyFunction(torch.autograd.Function): @staticmethod def backward(ctx, logprobs_grad: Tensor) -> Tuple[Tensor, None]: - (logprobs,) = ctx.saved_tensors + logprobs, = ctx.saved_tensors with torch.enable_grad(): logprobs.requires_grad = True # `negentropy` is the negative entropy of the average distribution. # distributions. It will be <= 0. - l = logprobs.reshape(-1, logprobs.shape[-1]) # noqa: E741 + l = logprobs.reshape(-1, logprobs.shape[-1]) scale = ctx.alpha * l.shape[0] avg_dist = l.exp().mean(dim=0) negentropy = (avg_dist * (avg_dist + 1.0e-20).log()).sum() if random.random() < 0.0005: negentropy_individual = (l * l.exp()).sum(dim=-1).mean() - print( - "Negentropy[individual,combined] = ", - negentropy_individual.item(), - ", ", - negentropy.item(), - ) + print("Negentropy[individual,combined] = ", negentropy_individual.item(), ", ", negentropy.item()) loss = negentropy * scale loss.backward() return logprobs_grad + logprobs.grad, None @@ -192,23 +183,18 @@ class KnowledgeBaseLookup(nn.Module): embedding_dim: the dimension to project from and to, e.g. the d_model of the conformer. """ - - def __init__( - self, - M: int, - N: int, - D: int, - K: int, - embedding_dim: int, - knowledge_base: nn.Parameter, - negentropy_penalty: float = 0.001, - ): + def __init__(self, M: int, N: int, D: int, + K: int, embedding_dim: int, + knowledge_base: nn.Parameter, + negentropy_penalty: float = 0.001): super(KnowledgeBaseLookup, self).__init__() self.knowledge_base = knowledge_base # shared! - self.in_proj = ScaledLinear(embedding_dim, M * N, initial_scale=1.0) + self.in_proj = ScaledLinear(embedding_dim, M * N, + initial_scale=1.0) # initial_scale = 4.0 because the knowlege_base activations are # quite small -- if we use our optimizer they'll have stddev <= 0.1. - self.out_proj = ScaledLinear(D, embedding_dim, initial_scale=4.0) + self.out_proj = ScaledLinear(D, embedding_dim, + initial_scale = 4.0) self.M = M self.N = N self.K = K @@ -224,14 +210,14 @@ class KnowledgeBaseLookup(nn.Module): # TODO: later we can try multiplying by a projection of x or something like that. """ - x = self.in_proj(x) # now (*, M*N) - x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M) - x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M) + x = self.in_proj(x) # now (*, M*N) + x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M) + x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M) x = PenalizeNegentropyFunction.apply(x, self.negentropy_penalty) _, indexes, weights = sample_combined(x, self.K, input_is_log=True) - x = weighted_matrix_lookup(weights, indexes, self.knowledge_base) # now (*, D) - x = self.out_proj(x) # now (*, self.embedding_dim) + x = weighted_matrix_lookup(weights, indexes, self.knowledge_base) # now (*, D) + x = self.out_proj(x) # now (*, self.embedding_dim) return x @@ -251,44 +237,38 @@ def _test_knowledge_base_lookup(): x.requires_grad = True y = m(x) assert y.shape == x.shape - y.sum().backward() # make sure backward doesn't crash.. + y.sum().backward() # make sure backward doesn't crash.. print("y = ", y) print("x.grad = ", x.grad) print("knowlege_base.grad norm = ", knowledge_base.grad.norm()) dtype = torch.float32 - device = torch.device("cuda") - train_pairs = [ - ( - torch.randn(B, T, E, device=device, dtype=dtype), - torch.randn(B, T, E, device=device, dtype=dtype), - ) - for _ in range(10) - ] + device = torch.device('cuda') + train_pairs = [ (torch.randn(B, T, E, device=device, dtype=dtype), torch.randn(B, T, E, device=device, dtype=dtype)) for _ in range(10) ] from optim import Eve - optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04) m = m.to(device).to(dtype) + start = timeit.default_timer() - # Epoch 0, batch 0, loss 1.0109944343566895 - # Epoch 10, batch 0, loss 1.0146660804748535 - # Epoch 20, batch 0, loss 1.0119813680648804 - # Epoch 30, batch 0, loss 1.0105408430099487 - # Epoch 40, batch 0, loss 1.0077732801437378 - # Epoch 50, batch 0, loss 1.0050103664398193 - # Epoch 60, batch 0, loss 1.0033129453659058 - # Epoch 70, batch 0, loss 1.0014232397079468 - # Epoch 80, batch 0, loss 0.9977912306785583 - # Epoch 90, batch 0, loss 0.8274348974227905 - # Epoch 100, batch 0, loss 0.3368612825870514 - # Epoch 110, batch 0, loss 0.11323091387748718 - # Time taken: 17.591704960912466 +# Epoch 0, batch 0, loss 1.0109944343566895 +# Epoch 10, batch 0, loss 1.0146660804748535 +# Epoch 20, batch 0, loss 1.0119813680648804 +# Epoch 30, batch 0, loss 1.0105408430099487 +# Epoch 40, batch 0, loss 1.0077732801437378 +# Epoch 50, batch 0, loss 1.0050103664398193 +# Epoch 60, batch 0, loss 1.0033129453659058 +# Epoch 70, batch 0, loss 1.0014232397079468 +# Epoch 80, batch 0, loss 0.9977912306785583 +# Epoch 90, batch 0, loss 0.8274348974227905 +# Epoch 100, batch 0, loss 0.3368612825870514 +# Epoch 110, batch 0, loss 0.11323091387748718 +# Time taken: 17.591704960912466 for epoch in range(150): - for n, (x, y) in enumerate(train_pairs): + for n, (x,y) in enumerate(train_pairs): y_out = m(x) - loss = ((y_out - y) ** 2).mean() * 100.0 + loss = ((y_out - y)**2).mean() * 100.0 if n % 10 == 0 and epoch % 10 == 0: print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") loss.backward() @@ -296,8 +276,7 @@ def _test_knowledge_base_lookup(): optimizer.zero_grad() stop = timeit.default_timer() - print("Time taken: ", stop - start) - + print('Time taken: ', stop - start) def _test_knowledge_base_lookup_autocast(): K = 16 @@ -315,21 +294,14 @@ def _test_knowledge_base_lookup_autocast(): x.requires_grad = True y = m(x) assert y.shape == x.shape - y.sum().backward() # make sure backward doesn't crash.. + y.sum().backward() # make sure backward doesn't crash.. print("y = ", y) print("x.grad = ", x.grad) print("knowlege_base.grad norm = ", knowledge_base.grad.norm()) - device = torch.device("cuda") - train_pairs = [ - ( - torch.randn(B, T, E, device=device), - torch.randn(B, T, E, device=device), - ) - for _ in range(10) - ] + device = torch.device('cuda') + train_pairs = [ (torch.randn(B, T, E, device=device), torch.randn(B, T, E, device=device)) for _ in range(10) ] from optim import Eve - optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04) m = m.to(device) @@ -337,11 +309,12 @@ def _test_knowledge_base_lookup_autocast(): start = timeit.default_timer() + for epoch in range(150): - for n, (x, y) in enumerate(train_pairs): + for n, (x,y) in enumerate(train_pairs): y_out = m(x) with torch.cuda.amp.autocast(enabled=True): - loss = ((y_out - y) ** 2).mean() * 100.0 + loss = ((y_out - y)**2).mean() * 100.0 if n % 10 == 0 and epoch % 10 == 0: print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") scaler.scale(loss).backward() @@ -350,9 +323,10 @@ def _test_knowledge_base_lookup_autocast(): optimizer.zero_grad() stop = timeit.default_timer() - print("Time taken: ", stop - start) + print('Time taken: ', stop - start) -if __name__ == "__main__": + +if __name__ == '__main__': _test_knowledge_base_lookup() _test_knowledge_base_lookup_autocast() diff --git a/egs/librispeech/ASR/pruned2_knowledge/scaling.py b/egs/librispeech/ASR/pruned2_knowledge/scaling.py index 527c735eb..f726c2583 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/scaling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/scaling.py @@ -18,11 +18,11 @@ import collections from itertools import repeat from typing import Optional, Tuple +from torch.cuda.amp import custom_fwd, custom_bwd import torch import torch.nn as nn from torch import Tensor -from torch.cuda.amp import custom_bwd, custom_fwd def _ntuple(n): @@ -79,7 +79,9 @@ class ActivationBalancerFunction(torch.autograd.Function): below_threshold = mean_abs < min_abs above_threshold = mean_abs > max_abs - ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) + ctx.save_for_backward( + factor, xgt0, below_threshold, above_threshold + ) ctx.max_factor = max_factor ctx.sum_dims = sum_dims return x @@ -147,7 +149,8 @@ class BasicNorm(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels scales = ( - torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps.exp() + torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + + self.eps.exp() ) ** -0.5 return x * scales @@ -179,7 +182,11 @@ class ScaledLinear(nn.Linear): """ def __init__( - self, *args, initial_scale: float = 1.0, initial_speed: float = 1.0, **kwargs + self, + *args, + initial_scale: float = 1.0, + initial_speed: float = 1.0, + **kwargs ): super(ScaledLinear, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() @@ -195,12 +202,12 @@ class ScaledLinear(nn.Linear): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3**0.5) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -211,13 +218,19 @@ class ScaledLinear(nn.Linear): return None if self.bias is None else self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear(input, self.get_weight(), self.get_bias()) + return torch.nn.functional.linear( + input, self.get_weight(), self.get_bias() + ) class ScaledConv1d(nn.Conv1d): # See docs for ScaledLinear def __init__( - self, *args, initial_scale: float = 1.0, initial_speed: float = 1.0, **kwargs + self, + *args, + initial_scale: float = 1.0, + initial_speed: float = 1.0, + **kwargs ): super(ScaledConv1d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() @@ -232,12 +245,12 @@ class ScaledConv1d(nn.Conv1d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3**0.5) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -277,7 +290,11 @@ class ScaledConv1d(nn.Conv1d): class ScaledConv2d(nn.Conv2d): # See docs for ScaledLinear def __init__( - self, *args, initial_scale: float = 1.0, initial_speed: float = 1.0, **kwargs + self, + *args, + initial_scale: float = 1.0, + initial_speed: float = 1.0, + **kwargs ): super(ScaledConv2d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() @@ -292,12 +309,12 @@ class ScaledConv2d(nn.Conv2d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3**0.5) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -636,7 +653,9 @@ def _test_activation_balancer_sign(): def _test_activation_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( + -1 + ) x = x.detach() x.requires_grad = True m = ActivationBalancer( @@ -666,8 +685,8 @@ def _test_basic_norm(): y = m(x) assert y.shape == x.shape - x_rms = (x**2).mean().sqrt() - y_rms = (y**2).mean().sqrt() + x_rms = (x ** 2).mean().sqrt() + y_rms = (y ** 2).mean().sqrt() print("x rms = ", x_rms) print("y rms = ", y_rms) assert y_rms < x_rms diff --git a/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py b/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py index 3f21133a0..6293e081a 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py +++ b/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py @@ -15,23 +15,21 @@ # limitations under the License. -from typing import Optional, Tuple - import torch import torch.nn as nn from torch import Tensor +from typing import Tuple, Optional -def _activation_balancer_loss( - mean_pos: Tensor, - mean_neg: Tensor, - min_positive: float, # e.g. 0.05 - max_positive: float, # e.g. 0.95 - max_factor: float, # e.g. 0.01 - min_abs: float, # e.g. 0.2 - max_abs: float, # e.g. 100.0 - eps: float = 1.0e-10, -): + +def _activation_balancer_loss(mean_pos: Tensor, + mean_neg: Tensor, + min_positive: float, # e.g. 0.05 + max_positive: float, # e.g. 0.95 + max_factor: float, # e.g. 0.01 + min_abs: float, # e.g. 0.2 + max_abs: float, # e.g. 100.0 + eps: float = 1.0e-10): """ Returns a loss-function for the ActivationBalancer module. This loss function is not exposed to the user but is used internally, and eventually @@ -52,32 +50,28 @@ def _activation_balancer_loss( """ loss_parts = [] - x_mean = mean_pos - mean_neg - x_mean_abs = (mean_pos + mean_neg + eps).detach() - x_rel_mean = x_mean / x_mean_abs + x_mean = mean_positive - mean_negative + x_mean_abs = (mean_positive + mean_negative + eps).detach() + x_rel_mean= x_mean / x_mean_abs if min_positive != 0.0: # e.g. x_mean_floor = -0.95 + 0.05 = -0.9 - x_rel_mean_floor = -(1 - min_positive) + min_positive - min_positive_loss = (x_rel_mean_floor - x_rel_mean).relu().sum() * ( - 1.0 / (2 * min_positive) - ) + x_rel_mean_floor = (-(1-min_positive) + min_positive) + min_positive_loss = (x_rel_mean_floor - x_rel_mean).relu().sum() * (1.0 / (2*min_positive)) # this part of the loss would be 1.0 * num_channels if all these constraints were # 100% violated. loss_parts.append(min_positive_loss) if max_positive != 1.0: # e.g. x_mean_floor = -0.05 + 0.95 = 0.8 - x_rel_mean_ceil = -(1.0 - max_positive) + max_positive - max_positive_loss = (x_rel_mean - x_rel_mean_ceil).relu().sum() * ( - 1.0 / (1 - x_rel_mean_ceil) - ) + x_rel_mean_ceil = - (1.0-max_positive) + max_positive + max_positive_loss = (x_rel_mean - x_rel_mean_ceil).relu().sum() * (1.0 / (1 - x_rel_mean_ceil)) # this part of the loss would be 1.0 * num_channels if all these constraints were # 100% violated. loss_parts.append(max_positive_loss) if min_abs != 0.0: - min_abs_loss = (min_abs - x_mean_abs).relu().sum() / min_abs + min_abs_loss = min_abs - x_mean_abs).relu().sum() / min_abs # this part of the loss would be 1.0 * num_channels if all these constraints were # 100% violated. loss_parts.append(min_abs_loss) @@ -88,53 +82,43 @@ def _activation_balancer_loss( # 100% violated. loss_parts.append(max_abs_loss) + # the min_positive and 1 - max_positive are "ballast" added to the denom = mean_pos + mean_neg + (min_positive + (1 - max_positive)) - # num + num if min_positive != 0.0: - pass + + class ActivationBalancerFunction(torch.autograd.Function): @staticmethod - def forward( - ctx, - x: Tensor, - channel_dim: int, - min_positive: float, # e.g. 0.05 - max_positive: float, # e.g. 0.95 - max_factor: float, # e.g. 0.01 - min_abs: float, # e.g. 0.2 - max_abs: float, # e.g. 100.0 + def forward(ctx, x: Tensor, + channel_dim: int, + min_positive: float, # e.g. 0.05 + max_positive: float, # e.g. 0.95 + max_factor: float, # e.g. 0.01 + min_abs: float, # e.g. 0.2 + max_abs: float, # e.g. 100.0 ) -> Tensor: if x.requires_grad: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] xgt0 = x > 0 - proportion_positive = torch.mean( - xgt0.to(x.dtype), dim=sum_dims, keepdim=True - ) - factor1 = ( - (min_positive - proportion_positive).relu() - * (max_factor / min_positive) - if min_positive != 0.0 - else 0.0 - ) - factor2 = ( - (proportion_positive - max_positive).relu() - * (max_factor / (max_positive - 1.0)) - if max_positive != 1.0 - else 0.0 - ) + proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True) + factor1 = ((min_positive - proportion_positive).relu() * (max_factor / min_positive) + if min_positive != 0.0 else 0.0) + factor2 = ((proportion_positive - max_positive).relu() * (max_factor / (max_positive - 1.0)) + if max_positive != 1.0 else 0.0) factor = factor1 + factor2 if isinstance(factor, float): factor = torch.zeros_like(proportion_positive) mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True) - below_threshold = mean_abs < min_abs - above_threshold = mean_abs > max_abs + below_threshold = (mean_abs < min_abs) + above_threshold = (mean_abs > max_abs) ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) ctx.max_factor = max_factor @@ -142,16 +126,11 @@ class ActivationBalancerFunction(torch.autograd.Function): return x @staticmethod - def backward( - ctx, x_grad: Tensor - ) -> Tuple[Tensor, None, None, None, None, None, None]: + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None, None]: factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors dtype = x_grad.dtype - scale_factor = ( - (below_threshold.to(dtype) - above_threshold.to(dtype)) - * (xgt0.to(dtype) - 0.5) - * (ctx.max_factor * 2.0) - ) + scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) * + (xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0)) neg_delta_grad = x_grad.abs() * (factor + scale_factor) return x_grad - neg_delta_grad, None, None, None, None, None, None @@ -184,30 +163,29 @@ class BasicNorm(torch.nn.Module): learn_eps: if true, we learn epsilon; if false, we keep it at the initial value. """ - - def __init__( - self, - num_channels: int, - channel_dim: int = -1, # CAUTION: see documentation. - eps: float = 0.25, - learn_eps: bool = True, - ) -> None: + def __init__(self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + eps: float = 0.25, + learn_eps: bool = True) -> None: super(BasicNorm, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim if learn_eps: self.eps = nn.Parameter(torch.tensor(eps).log().detach()) else: - self.register_buffer("eps", torch.tensor(eps).log().detach()) + self.register_buffer('eps', torch.tensor(eps).log().detach()) + def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels - scales = ( - torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps.exp() - ) ** -0.5 + scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) + + self.eps.exp()) ** -0.5 return x * scales + + class ScaledLinear(nn.Linear): """ A modified version of nn.Linear where the parameters are scaled before @@ -229,26 +207,27 @@ class ScaledLinear(nn.Linear): inherited from nn.Linear. For modules with small fan-in, this may be larger than optimal. """ - - def __init__(self, *args, initial_scale: float = 1.0, **kwargs): + def __init__(self, *args, + initial_scale: float = 1.0, + **kwargs): super(ScaledLinear, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: - self.register_parameter("bias_scale", None) + self.register_parameter('bias_scale', None) self._reset_parameters() # Overrides the reset_parameters in nn.Linear def _reset_parameters(self): std = 0.01 - a = (3**0.5) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() if self.bias is not None: @@ -258,67 +237,56 @@ class ScaledLinear(nn.Linear): return self.weight * self.weight_scale.exp() def get_bias(self): - return None if self.bias is None else self.bias * self.bias_scale.exp() + return (None if self.bias is None else + self.bias * self.bias_scale.exp()) def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear(input, self.get_weight(), self.get_bias()) + return torch.nn.functional.linear(input, self.get_weight(), + self.get_bias()) class ScaledConv1d(nn.Conv1d): - def __init__(self, *args, initial_scale=1.0, **kwargs): + def __init__(self, *args, + initial_scale=1.0, **kwargs): super(ScaledConv1d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: - self.register_parameter("bias_scale", None) + self.register_parameter('bias_scale', None) self._reset_parameters() # Overrides the reset_parameters in base class def _reset_parameters(self): std = 0.01 - a = (3**0.5) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() if self.bias is not None: self.bias_scale += torch.tensor(scale / std).log() + def get_weight(self): return self.weight * self.weight_scale.exp() def get_bias(self): - return None if self.bias is None else self.bias * self.bias_scale.exp() + return (None if self.bias is None else + self.bias * self.bias_scale.exp()) def forward(self, input: Tensor) -> Tensor: F = torch.nn.functional - if self.padding_mode != "zeros": - return F.conv1d( - F.pad( - input, - self._reversed_padding_repeated_twice, - mode=self.padding_mode, - ), - self.get_weight(), - self.get_bias(), - self.stride, - _single(0), # noqa: F821 - self.dilation, - self.groups, - ) - return F.conv1d( - input, - self.get_weight(), - self.get_bias(), - self.stride, - self.padding, - self.dilation, - self.groups, - ) + if self.padding_mode != 'zeros': + return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), + self.get_weight(), self.get_bias(), self.stride, + _single(0), self.dilation, self.groups) + return F.conv1d(input, self.get_weight(), self.get_bias(), self.stride, + self.padding, self.dilation, self.groups) + class ScaledConv2d(nn.Conv2d): @@ -329,58 +297,45 @@ class ScaledConv2d(nn.Conv2d): if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: - self.register_parameter("bias_scale", None) + self.register_parameter('bias_scale', None) self._reset_parameters() # Overrides the reset_parameters in base class def _reset_parameters(self): std = 0.01 - a = (3**0.5) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() if self.bias is not None: self.bias_scale += torch.tensor(scale / std).log() + def get_weight(self): return self.weight * self.weight_scale.exp() def get_bias(self): - return None if self.bias is None else self.bias * self.bias_scale.exp() + return (None if self.bias is None else + self.bias * self.bias_scale.exp()) def _conv_forward(self, input, weight): F = torch.nn.functional - if self.padding_mode != "zeros": - return F.conv2d( - F.pad( - input, - self._reversed_padding_repeated_twice, - mode=self.padding_mode, - ), - weight, - self.get_bias(), - self.stride, - _pair(0), # noqa: F821 - self.dilation, - self.groups, - ) - return F.conv2d( - input, - weight, - self.get_bias(), - self.stride, - self.padding, - self.dilation, - self.groups, - ) + if self.padding_mode != 'zeros': + return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), + weight, self.get_bias(), self.stride, + _pair(0), self.dilation, self.groups) + return F.conv2d(input, weight, self.get_bias(), self.stride, + self.padding, self.dilation, self.groups) def forward(self, input: Tensor) -> Tensor: return self._conv_forward(input, self.get_weight()) + + class ActivationBalancer(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to encourage, for @@ -409,16 +364,12 @@ class ActivationBalancer(torch.nn.Module): we allow, before we start to modify the derivatives to prevent this. """ - - def __init__( - self, - channel_dim: int, - min_positive: float = 0.05, - max_positive: float = 0.95, - max_factor: float = 0.01, - min_abs: float = 0.2, - max_abs: float = 100.0, - ): + def __init__(self, channel_dim: int, + min_positive: float = 0.05, + max_positive: float = 0.95, + max_factor: float = 0.01, + min_abs: float = 0.2, + max_abs: float = 100.0): super(ActivationBalancer, self).__init__() self.channel_dim = channel_dim self.min_positive = min_positive @@ -428,15 +379,10 @@ class ActivationBalancer(torch.nn.Module): self.max_abs = max_abs def forward(self, x: Tensor) -> Tensor: - return ActivationBalancerFunction.apply( - x, - self.channel_dim, - self.min_positive, - self.max_positive, - self.max_factor, - self.min_abs, - self.max_abs, - ) + return ActivationBalancerFunction.apply(x, self.channel_dim, + self.min_positive, self.max_positive, + self.max_factor, self.min_abs, + self.max_abs) class DoubleSwishFunction(torch.autograd.Function): @@ -454,7 +400,6 @@ class DoubleSwishFunction(torch.autograd.Function): = double_swish(x) * (1-s(x)) + s(x) ... so we just need to remember s(x) but not x itself. """ - @staticmethod def forward(ctx, x: Tensor) -> Tensor: x = x.detach() @@ -466,17 +411,18 @@ class DoubleSwishFunction(torch.autograd.Function): @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: s, y = ctx.saved_tensors - return (y * (1 - s) + s) * y_grad - + return (y * (1-s) + s) * y_grad class DoubleSwish(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: """Return double-swish activation function which is an approximation to Swish(Swish(x)), - that we approximate closely with x * sigmoid(x-1). + that we approximate closely with x * sigmoid(x-1). """ return DoubleSwishFunction.apply(x) + + class ScaledEmbedding(nn.Module): r"""A simple lookup table that stores embeddings of a fixed dictionary and size. @@ -545,13 +491,8 @@ class ScaledEmbedding(nn.Module): [ 0.0000, 0.0000, 0.0000], [-0.1655, 0.9897, 0.0635]]]) """ - __constants__ = [ - "num_embeddings", - "embedding_dim", - "padding_idx", - "scale_grad_by_freq", - "sparse", - ] + __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', + 'scale_grad_by_freq', 'sparse'] num_embeddings: int embedding_dim: int @@ -560,40 +501,33 @@ class ScaledEmbedding(nn.Module): weight: Tensor sparse: bool - def __init__( - self, - num_embeddings: int, - embedding_dim: int, - padding_idx: Optional[int] = None, - scale_grad_by_freq: bool = False, - sparse: bool = False, - ) -> None: + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, + scale_grad_by_freq: bool = False, + sparse: bool = False) -> None: super(ScaledEmbedding, self).__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim if padding_idx is not None: if padding_idx > 0: - assert ( - padding_idx < self.num_embeddings - ), "Padding_idx must be within num_embeddings" + assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' elif padding_idx < 0: - assert ( - padding_idx >= -self.num_embeddings - ), "Padding_idx must be within num_embeddings" + assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' padding_idx = self.num_embeddings + padding_idx self.padding_idx = padding_idx self.scale_grad_by_freq = scale_grad_by_freq - self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() + self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() self.sparse = sparse self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) self.reset_parameters() + + def reset_parameters(self) -> None: std = 0.01 nn.init.normal_(self.weight, std=std) - nn.init.constant_(self.scale, torch.tensor(1.0 / std).log()) + nn.init.constant_(self.scale, torch.tensor(1.0/std).log()) if self.padding_idx is not None: with torch.no_grad(): @@ -603,37 +537,24 @@ class ScaledEmbedding(nn.Module): F = torch.nn.functional scale = self.scale.exp() if input.numel() < self.num_embeddings: - return ( - F.embedding( - input, - self.weight, - self.padding_idx, - None, - 2.0, # None, 2.0 relate to normalization - self.scale_grad_by_freq, - self.sparse, - ) - * scale - ) + return F.embedding( + input, self.weight, self.padding_idx, + None, 2.0, # None, 2.0 relate to normalization + self.scale_grad_by_freq, self.sparse) * scale else: return F.embedding( - input, - self.weight * scale, - self.padding_idx, - None, - 2.0, # None, 2.0 relates to normalization - self.scale_grad_by_freq, - self.sparse, - ) + input, self.weight * scale, self.padding_idx, + None, 2.0, # None, 2.0 relates to normalization + self.scale_grad_by_freq, self.sparse) def extra_repr(self) -> str: - s = "{num_embeddings}, {embedding_dim}, scale={scale}" + s = '{num_embeddings}, {embedding_dim}, scale={scale}' if self.padding_idx is not None: - s += ", padding_idx={padding_idx}" + s += ', padding_idx={padding_idx}' if self.scale_grad_by_freq is not False: - s += ", scale_grad_by_freq={scale_grad_by_freq}" + s += ', scale_grad_by_freq={scale_grad_by_freq}' if self.sparse is not False: - s += ", sparse=True" + s += ', sparse=True' return s.format(**self.__dict__) @@ -644,13 +565,8 @@ def _test_activation_balancer_sign(): x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) x = x.detach() x.requires_grad = True - m = ActivationBalancer( - channel_dim=0, - min_positive=0.05, - max_positive=0.95, - max_factor=0.2, - min_abs=0.0, - ) + m = ActivationBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95, + max_factor=0.2, min_abs=0.0) y_grad = torch.sign(torch.randn(probs.numel(), N)) @@ -660,22 +576,17 @@ def _test_activation_balancer_sign(): print("_test_activation_balancer_sign: y grad = ", y_grad) print("_test_activation_balancer_sign: x grad = ", x.grad) - def _test_activation_balancer_magnitude(): channel_dim = 0 magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True - m = ActivationBalancer( - channel_dim=0, - min_positive=0.0, - max_positive=1.0, - max_factor=0.2, - min_abs=0.2, - max_abs=0.8, - ) + m = ActivationBalancer(channel_dim=0, + min_positive=0.0, max_positive=1.0, + max_factor=0.2, + min_abs=0.2, max_abs=0.8) y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) @@ -710,7 +621,7 @@ def _test_double_swish_deriv(): torch.autograd.gradcheck(m, x) -if __name__ == "__main__": +if __name__ == '__main__': _test_activation_balancer_sign() _test_activation_balancer_magnitude() _test_basic_norm() diff --git a/egs/librispeech/ASR/pruned2_knowledge/train.py b/egs/librispeech/ASR/pruned2_knowledge/train.py index a60d15c3b..2f6840166 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/train.py +++ b/egs/librispeech/ASR/pruned2_knowledge/train.py @@ -78,7 +78,9 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def get_parser(): @@ -177,45 +179,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -555,16 +554,23 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -727,7 +733,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -827,7 +835,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 2 ** 22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py index 1df1650f3..2d5724d30 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py @@ -123,24 +123,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -208,7 +204,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -275,7 +272,9 @@ def decode_one_batch( value=LOG_EPS, ) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] if params.decoding_method == "fast_beam_search": @@ -290,7 +289,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -336,7 +338,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -409,7 +415,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -442,7 +450,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -485,7 +494,9 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -517,12 +528,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -545,12 +557,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -578,7 +591,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py index 008f40fb1..318cd5094 100644 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py @@ -272,9 +272,13 @@ class Emformer(EncoderInterface): # Caution: We assume the subsampling factor is 4! x_lens = (((x_lens - 1) >> 1) - 1) >> 1 - emformer_out, emformer_out_lens, states = self.model.infer(x, x_lens, states) + emformer_out, emformer_out_lens, states = self.model.infer( + x, x_lens, states + ) - if x.size(1) != (self.model.segment_length + self.model.right_context_length): + if x.size(1) != ( + self.model.segment_length + self.model.right_context_length + ): raise ValueError( "Incorrect input shape." f"{x.size(1)} vs {self.model.segment_length} + " diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py index 81afb523d..2375f5001 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py @@ -89,24 +89,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -137,7 +133,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) add_model_arguments(parser) @@ -173,12 +170,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -201,12 +199,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -234,7 +233,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -274,7 +273,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py index ed6848879..2f019bcdb 100644 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py @@ -122,7 +122,9 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py index 6b30d3be8..fed814f19 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py @@ -209,45 +209,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -569,7 +566,11 @@ def compute_loss( function enables autograd during computation; when it is False, it disables autograd. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -598,7 +599,9 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -779,7 +782,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -903,7 +908,8 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py index 830b37cfb..7af9cc3d7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py @@ -509,9 +509,9 @@ def greedy_search( y = logits.argmax().item() if y not in (blank_id, unk_id): hyp.append(y) - decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( - 1, context_size - ) + decoder_input = torch.tensor( + [hyp[-context_size:]], device=device + ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -670,7 +670,9 @@ class HypothesisList(object): if use_max: old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob) else: - torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) + torch.logaddexp( + old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob + ) else: self._data[key] = hyp @@ -686,7 +688,9 @@ class HypothesisList(object): Return the hypothesis that has the largest `log_prob`. """ if length_norm: - return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) + return max( + self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) + ) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -888,7 +892,9 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1082,7 +1088,9 @@ def beam_search( t = 0 B = HypothesisList() - B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0), use_max=use_max) + B.add( + Hypothesis(ys=[blank_id] * context_size, log_prob=0.0), use_max=use_max + ) max_sym_per_utt = 20000 @@ -1122,7 +1130,9 @@ def beam_search( cached_key += f"-t-{t}" if cached_key not in joint_cache: - logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1)) + logits = model.joiner( + current_encoder_out, decoder_out.unsqueeze(1) + ) # TODO(fangjun): Scale the blank posterior diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 03ad45f49..7b6338948 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -128,7 +128,11 @@ from beam_search import ( ) from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -167,11 +171,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -267,7 +269,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -380,7 +383,9 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] if ( @@ -445,7 +450,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -576,7 +584,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -609,7 +619,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -667,7 +678,9 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -705,7 +718,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -743,7 +757,9 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py index e522943c0..386248554 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py @@ -75,7 +75,9 @@ class DecodeStream(object): # encoder.streaming_forward self.done_frames: int = 0 - self.pad_length = (params.right_context + 2) * params.subsampling_factor + 3 + self.pad_length = ( + params.right_context + 2 + ) * params.subsampling_factor + 3 if params.decoding_method == "greedy_search": self.hyp = [params.blank_id] * params.context_size @@ -89,11 +91,13 @@ class DecodeStream(object): ) elif params.decoding_method == "fast_beam_search": # The rnnt_decoding_stream for fast_beam_search. - self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( - decoding_graph + self.rnnt_decoding_stream: k2.RnntDecodingStream = ( + k2.RnntDecodingStream(decoding_graph) ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) @property def done(self) -> bool: @@ -122,10 +126,13 @@ class DecodeStream(object): """Consume chunk_size frames of features""" chunk_length = chunk_size + self.pad_length - ret_length = min(self.num_frames - self.num_processed_frames, chunk_length) + ret_length = min( + self.num_frames - self.num_processed_frames, chunk_length + ) ret_features = self.features[ - self.num_processed_frames : self.num_processed_frames + ret_length # noqa + self.num_processed_frames : self.num_processed_frames # noqa + + ret_length ] self.num_processed_frames += chunk_size diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py index 72593173c..f4355e8a0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py @@ -92,7 +92,9 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) + embedding_out = F.pad( + embedding_out, pad=(self.context_size - 1, 0) + ) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/export.py b/egs/librispeech/ASR/pruned_transducer_stateless/export.py index 64708e524..b5a151878 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/export.py @@ -64,20 +64,17 @@ def get_parser(): "--epoch", type=int, default=28, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -108,7 +105,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -194,7 +192,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/model.py b/egs/librispeech/ASR/pruned_transducer_stateless/model.py index 2cca7fa27..73b651b3f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/model.py @@ -130,7 +130,9 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py index a42b63b9c..eb95827af 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py @@ -91,11 +91,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -120,12 +118,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -172,7 +168,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -224,9 +221,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -294,7 +292,9 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -381,7 +381,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py index 9e09200a1..dcf6dc42f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py @@ -166,10 +166,14 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths) + topk_log_probs, topk_indexes = ragged_log_probs[i].topk( + num_active_paths + ) with warnings.catch_warnings(): warnings.simplefilter("ignore") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py index a50b4d4f0..d2cae4f9f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py @@ -51,7 +51,11 @@ from streaming_beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import ( AttributeDict, setup_logger, @@ -90,11 +94,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -160,7 +162,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -266,7 +269,9 @@ def decode_one_chunk( ) if params.decoding_method == "greedy_search": - greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + greedy_search( + model=model, encoder_out=encoder_out, streams=decode_streams + ) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -286,7 +291,9 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -342,7 +349,9 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state(params.left_context, device=device) + initial_states = model.encoder.get_init_state( + params.left_context, device=device + ) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -413,7 +422,9 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) return {key: decode_results} @@ -449,7 +460,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -521,7 +533,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index dd0331a60..399b11a29 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -203,45 +203,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -565,7 +562,9 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -585,7 +584,9 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -776,7 +777,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -894,7 +897,8 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False @@ -952,7 +956,9 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 5e9428b60..b7c2010f7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -580,9 +580,9 @@ def greedy_search( if y not in (blank_id, unk_id): hyp.append(y) timestamp.append(t) - decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( - 1, context_size - ) + decoder_input = torch.tensor( + [hyp[-context_size:]], device=device + ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.joiner.decoder_proj(decoder_out) @@ -775,7 +775,9 @@ class HypothesisList(object): key = hyp.key if key in self: old_hyp = self._data[key] # shallow copy - torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) + torch.logaddexp( + old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob + ) else: self._data[key] = hyp @@ -791,7 +793,9 @@ class HypothesisList(object): Return the hypothesis that has the largest `log_prob`. """ if length_norm: - return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) + return max( + self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) + ) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -986,7 +990,9 @@ def modified_beam_search( logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + log_probs = (logits / temperature).log_softmax( + dim=-1 + ) # (num_hyps, vocab_size) log_probs.add_(ys_log_probs) @@ -998,7 +1004,9 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1668,7 +1676,9 @@ def fast_beam_search_with_nbest_rnn_rescoring( for rnn_scale in rnn_lm_scale_list: key = f"ngram_lm_scale_{n_scale}_rnn_lm_scale_{rnn_scale}" tot_scores = ( - am_scores.values + n_scale * ngram_lm_scores + rnn_scale * rnn_lm_scores + am_scores.values + + n_scale * ngram_lm_scores + + rnn_scale * rnn_lm_scores ) ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) max_indexes = ragged_tot_scores.argmax() @@ -1794,7 +1804,9 @@ def modified_beam_search_ngram_rescoring( logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + log_probs = (logits / temperature).log_softmax( + dim=-1 + ) # (num_hyps, vocab_size) log_probs.add_(ys_log_probs) vocab_size = log_probs.size(-1) @@ -1804,7 +1816,9 @@ def modified_beam_search_ngram_rescoring( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1827,7 +1841,9 @@ def modified_beam_search_ngram_rescoring( state_cost = hyp.state_cost # We only keep AM scores in new_hyp.log_prob - new_log_prob = topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale + new_log_prob = ( + topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale + ) new_hyp = Hypothesis( ys=new_ys, log_prob=new_log_prob, state_cost=state_cost @@ -1979,7 +1995,9 @@ def modified_beam_search_rnnlm_shallow_fusion( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) """ for all hyps with a non-blank new token, score this token. It is a little confusing here because this for-loop @@ -2014,7 +2032,10 @@ def modified_beam_search_rnnlm_shallow_fusion( # forward RNNLM to get new states and scores if len(token_list) != 0: tokens_to_score = ( - torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) + torch.tensor(token_list) + .to(torch.int64) + .to(device) + .reshape(-1, 1) ) hs = torch.cat(hs, dim=1).to(device) @@ -2046,7 +2067,9 @@ def modified_beam_search_rnnlm_shallow_fusion( ys.append(new_token) new_timestamp.append(t) - hyp_log_prob += lm_score[new_token] * lm_scale # add the lm score + hyp_log_prob += ( + lm_score[new_token] * lm_scale + ) # add the lm score lm_score = scores[count] state = ( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 34ff0d7e2..bc273d33b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -214,7 +214,10 @@ class Conformer(EncoderInterface): NOTE: the returned tensors are on the given device. """ - if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context: + if ( + len(self._init_state) == 2 + and self._init_state[0].size(1) == left_context + ): # Note: It is OK to share the init state as it is # not going to be modified by the model return self._init_state @@ -436,7 +439,9 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -454,7 +459,9 @@ class ConformerEncoderLayer(nn.Module): ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) - self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal) + self.conv_module = ConvolutionModule( + d_model, cnn_module_kernel, causal=causal + ) self.norm_final = BasicNorm(d_model) @@ -520,7 +527,9 @@ class ConformerEncoderLayer(nn.Module): src = src + self.dropout(src_att) # convolution module - conv, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask) + conv, _ = self.conv_module( + src, src_key_padding_mask=src_key_padding_mask + ) src = src + self.dropout(conv) # feed forward module @@ -776,7 +785,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() if is_jit_tracing(): @@ -800,7 +811,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x_size_1 * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -1114,9 +1127,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -1185,25 +1198,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -1243,15 +1264,23 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) # compute matrix b and matrix d - matrix_bd = torch.matmul(q_with_bias_v, p) # (batch, head, time1, 2*time1-1) + matrix_bd = torch.matmul( + q_with_bias_v, p + ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd, left_context) - attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) + attn_output_weights = ( + matrix_ac + matrix_bd + ) # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) if not is_jit_tracing(): assert list(attn_output_weights.size()) == [ @@ -1293,17 +1322,21 @@ class RelPositionMultiheadAttention(nn.Module): ): if attn_mask.size(0) != 1: attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len) - combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) - else: - # attn_mask.shape == (1, tgt_len, src_len) - combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( + combined_mask = attn_mask | key_padding_mask.unsqueeze( 1 ).unsqueeze(2) + else: + # attn_mask.shape == (1, tgt_len, src_len) + combined_mask = attn_mask.unsqueeze( + 0 + ) | key_padding_mask.unsqueeze(1).unsqueeze(2) attn_output_weights = attn_output_weights.view( bsz, num_heads, tgt_len, src_len ) - attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) + attn_output_weights = attn_output_weights.masked_fill( + combined_mask, 0.0 + ) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, src_len ) @@ -1322,9 +1355,13 @@ class RelPositionMultiheadAttention(nn.Module): ] attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -1461,12 +1498,16 @@ class ConvolutionModule(nn.Module): # manualy padding self.lorder zeros to the left x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) else: - assert not self.training, "Cache should be None in training time" + assert ( + not self.training + ), "Cache should be None in training time" assert cache.size(0) == self.lorder x = torch.cat([cache.permute(1, 2, 0), x], dim=2) if right_context > 0: cache = x.permute(2, 0, 1)[ - -(self.lorder + right_context) : (-right_context), # noqa + -(self.lorder + right_context) : ( # noqa + -right_context + ), ..., ] else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 32cd53be3..979a0e02e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -132,7 +132,11 @@ from beam_search import ( ) from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -173,11 +177,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -273,7 +275,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -394,7 +397,9 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] @@ -460,7 +465,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -506,7 +514,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } elif "fast_beam_search" in params.decoding_method: key = f"beam_{params.beam}_" @@ -596,7 +608,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -629,7 +643,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -685,7 +700,9 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -723,7 +740,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -761,7 +779,9 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index b59928103..ba91302ce 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -107,11 +107,15 @@ class Decoder(nn.Module): # This is for exporting to PNNX via ONNX embedding_out = self.embedding(y) else: - embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) + embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze( + -1 + ) if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad: - embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) + embedding_out = F.pad( + embedding_out, pad=(self.context_size - 1, 0) + ) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py index 90367bd03..f1a8ea589 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py @@ -51,7 +51,11 @@ import sentencepiece as spm import torch from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import str2bool @@ -83,11 +87,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -118,7 +120,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -170,7 +173,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -218,7 +222,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py index 1954f4724..6a9d08033 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -60,7 +60,9 @@ class Joiner(nn.Module): assert encoder_out.shape == decoder_out.shape if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) + logit = self.encoder_proj(encoder_out) + self.decoder_proj( + decoder_out + ) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 272d06c37..417c391d9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -66,7 +66,9 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_speed=0.5 + ) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) def forward( @@ -150,7 +152,9 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index 2d7f557ad..041a81f45 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -72,11 +72,17 @@ class Eve(Optimizer): if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + raise ValueError( + "Invalid beta parameter at index 0: {}".format(betas[0]) + ) if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + raise ValueError( + "Invalid beta parameter at index 1: {}".format(betas[1]) + ) if not 0 <= weight_decay <= 0.1: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict( @@ -112,7 +118,9 @@ class Eve(Optimizer): # Perform optimization step grad = p.grad if grad.is_sparse: - raise RuntimeError("AdamW does not support sparse gradients") + raise RuntimeError( + "AdamW does not support sparse gradients" + ) state = self.state[p] @@ -139,7 +147,7 @@ class Eve(Optimizer): # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( + denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( group["eps"] ) @@ -150,7 +158,9 @@ class Eve(Optimizer): if p.numel() > 1: # avoid applying this weight-decay on "scaling factors" # (which are scalar). - is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) + is_above_target_rms = p.norm() > ( + target_rms * (p.numel() ** 0.5) + ) p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) @@ -170,14 +180,18 @@ class LRScheduler(object): def __init__(self, optimizer: Optimizer, verbose: bool = False): # Attach optimizer if not isinstance(optimizer, Optimizer): - raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) + raise TypeError( + "{} is not an Optimizer".format(type(optimizer).__name__) + ) self.optimizer = optimizer self.verbose = verbose for group in optimizer.param_groups: group.setdefault("initial_lr", group["lr"]) - self.base_lrs = [group["initial_lr"] for group in optimizer.param_groups] + self.base_lrs = [ + group["initial_lr"] for group in optimizer.param_groups + ] self.epoch = 0 self.batch = 0 @@ -285,9 +299,10 @@ class Eden(LRScheduler): def get_lr(self): factor = ( - (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 + (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 ) ** -0.25 * ( - ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 + ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) + ** -0.25 ) return [x * factor for x in self.base_lrs] diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py index 58de6875f..f52cb22ab 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py @@ -91,11 +91,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -120,12 +118,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -172,7 +168,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -225,9 +222,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -295,7 +293,9 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -382,7 +382,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index f671e97b1..8c572a9ef 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -89,7 +89,9 @@ class ActivationBalancerFunction(torch.autograd.Function): below_threshold = mean_abs < min_abs above_threshold = mean_abs > max_abs - ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) + ctx.save_for_backward( + factor, xgt0, below_threshold, above_threshold + ) ctx.max_factor = max_factor ctx.sum_dims = sum_dims return x @@ -135,7 +137,7 @@ class GradientFilterFunction(torch.autograd.Function): eps = 1.0e-20 dim = ctx.batch_dim norm_dims = [d for d in range(x_grad.ndim) if d != dim] - norm_of_batch = (x_grad**2).mean(dim=norm_dims, keepdim=True).sqrt() + norm_of_batch = (x_grad ** 2).mean(dim=norm_dims, keepdim=True).sqrt() median_norm = norm_of_batch.median() cutoff = median_norm * ctx.threshold @@ -227,7 +229,8 @@ class BasicNorm(torch.nn.Module): if not is_jit_tracing(): assert x.shape[self.channel_dim] == self.num_channels scales = ( - torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps.exp() + torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + + self.eps.exp() ) ** -0.5 return x * scales @@ -279,12 +282,12 @@ class ScaledLinear(nn.Linear): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3**0.5) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -298,7 +301,9 @@ class ScaledLinear(nn.Linear): return self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear(input, self.get_weight(), self.get_bias()) + return torch.nn.functional.linear( + input, self.get_weight(), self.get_bias() + ) class ScaledConv1d(nn.Conv1d): @@ -326,12 +331,12 @@ class ScaledConv1d(nn.Conv1d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3**0.5) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -395,12 +400,12 @@ class ScaledConv2d(nn.Conv2d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3**0.5) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -471,7 +476,9 @@ class ScaledLSTM(nn.LSTM): setattr(self, scale_name, param) self._scales.append(param) - self.grad_filter = GradientFilter(batch_dim=1, threshold=grad_norm_threshold) + self.grad_filter = GradientFilter( + batch_dim=1, threshold=grad_norm_threshold + ) self._reset_parameters( initial_speed @@ -479,8 +486,8 @@ class ScaledLSTM(nn.LSTM): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3**0.5) * std - scale = self.hidden_size**-0.5 + a = (3 ** 0.5) * std + scale = self.hidden_size ** -0.5 v = scale / std for idx, name in enumerate(self._flat_weights_names): if "weight" in name: @@ -552,11 +559,15 @@ class ScaledLSTM(nn.LSTM): """Get scaled weights, and resets their data pointer.""" flat_weights = [] for idx in range(len(self._flat_weights_names)): - flat_weights.append(self._flat_weights[idx] * self._scales[idx].exp()) + flat_weights.append( + self._flat_weights[idx] * self._scales[idx].exp() + ) self._flatten_parameters(flat_weights) return flat_weights - def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None): + def forward( + self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None + ): # This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa # The change for calling `_VF.lstm()` is: # self._flat_weights -> self._get_flat_weights() @@ -904,7 +915,9 @@ def _test_activation_balancer_sign(): def _test_activation_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( + -1 + ) x = x.detach() x.requires_grad = True m = ActivationBalancer( @@ -934,8 +947,8 @@ def _test_basic_norm(): y = m(x) assert y.shape == x.shape - x_rms = (x**2).mean().sqrt() - y_rms = (y**2).mean().sqrt() + x_rms = (x ** 2).mean().sqrt() + y_rms = (y ** 2).mean().sqrt() print("x rms = ", x_rms) print("y rms = ", y_rms) assert y_rms < x_rms @@ -988,18 +1001,17 @@ def _test_grad_filter(): ) print( - "_test_grad_filter: for gradient norms, the first element > median *" - " threshold ", # noqa + "_test_grad_filter: for gradient norms, the first element > median * threshold ", # noqa i % 2 == 1, ) print( "_test_grad_filter: x_out_grad norm = ", - (x_out_grad**2).mean(dim=(0, 2)).sqrt(), + (x_out_grad ** 2).mean(dim=(0, 2)).sqrt(), ) print( "_test_grad_filter: x.grad norm = ", - (x.grad**2).mean(dim=(0, 2)).sqrt(), + (x.grad ** 2).mean(dim=(0, 2)).sqrt(), ) print("_test_grad_filter: w_out_grad = ", w_out_grad) print("_test_grad_filter: w.grad = ", w.grad) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py index e6e0fb1c8..9bcd2f9f9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py @@ -153,7 +153,9 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner(current_encoder_out, decoder_out, project_input=False) + logits = model.joiner( + current_encoder_out, decoder_out, project_input=False + ) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -170,10 +172,14 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths) + topk_log_probs, topk_indexes = ragged_log_probs[i].topk( + num_active_paths + ) with warnings.catch_warnings(): warnings.simplefilter("ignore") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py index 0139863a1..d76a03946 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py @@ -51,7 +51,11 @@ from streaming_beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import ( AttributeDict, setup_logger, @@ -90,11 +94,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -160,7 +162,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -268,7 +271,9 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + greedy_search( + model=model, encoder_out=encoder_out, streams=decode_streams + ) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -288,7 +293,9 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -344,7 +351,9 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state(params.left_context, device=device) + initial_states = model.encoder.get_init_state( + params.left_context, device=device + ) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -416,7 +425,9 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) return {key: decode_results} @@ -451,7 +462,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -524,7 +536,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 623bdd51a..1947834bf 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -96,7 +96,9 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def add_model_arguments(parser: argparse.ArgumentParser): @@ -208,7 +210,8 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need to be changed.", + help="The initial learning rate. This value should not need to " + "be changed.", ) parser.add_argument( @@ -231,45 +234,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -634,7 +634,9 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -647,9 +649,14 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -660,7 +667,9 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -828,7 +837,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -952,7 +963,8 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py index 5e81aef07..1df7f9ee5 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py @@ -27,7 +27,10 @@ from lhotse.dataset import ( K2SpeechRecognitionDataset, SpecAugment, ) -from lhotse.dataset.input_strategies import OnTheFlyFeatures, PrecomputedFeatures +from lhotse.dataset.input_strategies import ( + OnTheFlyFeatures, + PrecomputedFeatures, +) from torch.utils.data import DataLoader from icefall.utils import str2bool @@ -41,69 +44,59 @@ class AsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description=( - "These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc." - ), + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", ) group.add_argument( "--max-duration", type=int, default=200.0, - help=( - "Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM." - ), + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help=( - "When enabled, the batches will come from buckets of " - "similar duration (saves padding frames)." - ), + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", ) group.add_argument( "--num-buckets", type=int, default=30, - help=( - "The number of buckets for the DynamicBucketingSampler. " - "(you might want to increase it for larger datasets)." - ), + help="The number of buckets for the DynamicBucketingSampler. " + "(you might want to increase it for larger datasets).", ) group.add_argument( "--shuffle", type=str2bool, default=True, - help=( - "When enabled (=default), the examples will be shuffled for each epoch." - ), + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help=( - "When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it." - ), + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that collect the batches.", + help="The number of training dataloader workers that " + "collect the batches.", ) group.add_argument( @@ -124,22 +117,18 @@ class AsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help=( - "Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp." - ), + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help=( - "When enabled, select noise from MUSAN and mix it" - "with training dataset. " - ), + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", ) group.add_argument( @@ -153,11 +142,9 @@ class AsrDataModule: "--on-the-fly-feats", type=str2bool, default=False, - help=( - "When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available. Used only in dev/test CutSet" - ), + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available. Used only in dev/test CutSet", ) def train_dataloaders( @@ -180,7 +167,9 @@ class AsrDataModule: if cuts_musan is not None: logging.info("Enable MUSAN") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix( + cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True + ) ) else: logging.info("Disable MUSAN") @@ -189,7 +178,9 @@ class AsrDataModule: if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + logging.info( + f"Time warp factor: {self.args.spec_aug_time_warp_factor}" + ) input_transforms.append( SpecAugment( time_warp_factor=self.args.spec_aug_time_warp_factor, @@ -259,7 +250,9 @@ class AsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), return_cuts=self.args.return_cuts, ) else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py index 66c8e30ba..5784a78ba 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py @@ -79,7 +79,11 @@ from gigaspeech import GigaSpeech from gigaspeech_scoring import asr_text_post_processing from train import get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import ( AttributeDict, setup_logger, @@ -116,11 +120,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -190,7 +192,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -277,7 +280,9 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] if params.decoding_method == "fast_beam_search": @@ -307,7 +312,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -351,11 +359,21 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } elif params.decoding_method == "fast_beam_search_nbest_oracle": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}_num_paths_{params.num_paths}_nbest_scale_{params.nbest_scale}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}_" + f"num_paths_{params.num_paths}_" + f"nbest_scale_{params.nbest_scale}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -428,7 +446,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -461,7 +481,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -511,7 +532,9 @@ def main(): params.suffix += f"-num-paths-{params.num_paths}" params.suffix += f"-nbest-scale-{params.nbest_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -544,7 +567,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index d90497e26..8025d6be1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -120,7 +120,11 @@ from beam_search import ( from librispeech import LibriSpeech from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.lexicon import Lexicon from icefall.rnn_lm.model import RnnLmModel from icefall.utils import ( @@ -163,11 +167,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -263,7 +265,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -475,7 +478,9 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] @@ -545,7 +550,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -638,11 +646,21 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}temperature_{params.temperature}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + f"temperature_{params.temperature}" + ): hyps } elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}temperature_{params.temperature}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + f"temperature_{params.temperature}" + ): hyps } elif params.decoding_method in [ "fast_beam_search_with_nbest_rescoring", @@ -672,7 +690,12 @@ def decode_one_batch( key += f"_ngram_lm_scale_{params.ngram_lm_scale}" return {key: hyps} else: - return {f"beam_size_{params.beam_size}_temperature_{params.temperature}": hyps} + return { + ( + f"beam_size_{params.beam_size}_" + f"temperature_{params.temperature}" + ): hyps + } def decode_dataset( @@ -756,7 +779,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -789,7 +814,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -913,7 +939,9 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) params.suffix += f"-temperature-{params.temperature}" else: params.suffix += f"-context-{params.context_size}" @@ -953,7 +981,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -1003,10 +1032,15 @@ def main(): word_table=word_table, device=device, ) - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) logging.info(f"G properties_str: {G.properties_str}") rnn_lm_model = None - if params.decoding_method == "fast_beam_search_with_nbest_rnn_rescoring": + if ( + params.decoding_method + == "fast_beam_search_with_nbest_rnn_rescoring" + ): rnn_lm_model = RnnLmModel( vocab_size=params.vocab_size, embedding_dim=params.rnn_lm_embedding_dim, @@ -1031,7 +1065,9 @@ def main(): rnn_lm_model.eval() else: word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) rnn_lm_model = None else: decoding_graph = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py index dcf65e937..47217ba05 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py @@ -128,7 +128,11 @@ import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import str2bool @@ -160,11 +164,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -233,7 +235,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -506,9 +509,13 @@ def export_joiner_model_onnx( - projected_decoder_out: a tensor of shape (N, joiner_dim) """ - encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") + encoder_proj_filename = str(joiner_filename).replace( + ".onnx", "_encoder_proj.onnx" + ) - decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") + decoder_proj_filename = str(joiner_filename).replace( + ".onnx", "_decoder_proj.onnx" + ) encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] @@ -609,7 +616,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -707,7 +715,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py index 598434f54..36f32c6b3 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py @@ -52,14 +52,18 @@ class GigaSpeech: ) pattern = re.compile(r"gigaspeech_cuts_XL.([0-9]+).jsonl.gz") - idx_filenames = [(int(pattern.search(f).group(1)), f) for f in filenames] + idx_filenames = [ + (int(pattern.search(f).group(1)), f) for f in filenames + ] idx_filenames = sorted(idx_filenames, key=lambda x: x[0]) sorted_filenames = [f[1] for f in idx_filenames] logging.info(f"Loading {len(sorted_filenames)} splits") - return lhotse.combine(lhotse.load_manifest_lazy(p) for p in sorted_filenames) + return lhotse.combine( + lhotse.load_manifest_lazy(p) for p in sorted_filenames + ) def train_L_cuts(self) -> CutSet: f = self.manifest_dir / "gigaspeech_cuts_L_raw.jsonl.gz" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py index 108915389..162f8c7db 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py @@ -104,12 +104,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -144,9 +142,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -331,7 +330,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py index d45f6dadc..7852f84e9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py @@ -84,7 +84,9 @@ class Transducer(nn.Module): self.decoder_giga = decoder_giga self.joiner_giga = joiner_giga - self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_speed=0.5 + ) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) if decoder_giga is not None: @@ -188,7 +190,9 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) boundary[:, 2] = y_lens boundary[:, 3] = encoder_out_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py index 163d737e3..d03d1d7ef 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py @@ -203,7 +203,9 @@ def test_joiner( ) # Now test encoder_proj - joiner_encoder_proj_inputs = {encoder_proj_input_name: encoder_out.numpy()} + joiner_encoder_proj_inputs = { + encoder_proj_input_name: encoder_out.numpy() + } joiner_encoder_proj_out = joiner_encoder_proj_session.run( [encoder_proj_output_name], joiner_encoder_proj_inputs )[0] @@ -212,10 +214,16 @@ def test_joiner( torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out) assert torch.allclose( joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5 - ), ((joiner_encoder_proj_out - torch_joiner_encoder_proj_out).abs().max()) + ), ( + (joiner_encoder_proj_out - torch_joiner_encoder_proj_out) + .abs() + .max() + ) # Now test decoder_proj - joiner_decoder_proj_inputs = {decoder_proj_input_name: decoder_out.numpy()} + joiner_decoder_proj_inputs = { + decoder_proj_input_name: decoder_out.numpy() + } joiner_decoder_proj_out = joiner_decoder_proj_session.run( [decoder_proj_output_name], joiner_decoder_proj_inputs )[0] @@ -224,7 +232,11 @@ def test_joiner( torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out) assert torch.allclose( joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5 - ), ((joiner_decoder_proj_out - torch_joiner_decoder_proj_out).abs().max()) + ), ( + (joiner_decoder_proj_out - torch_joiner_decoder_proj_out) + .abs() + .max() + ) @torch.no_grad() @@ -276,7 +288,9 @@ def main(): if __name__ == "__main__": torch.manual_seed(20220727) - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py index 11597aa49..ea5d4e674 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py @@ -102,12 +102,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -142,9 +140,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -192,7 +191,11 @@ def greedy_search( projected_encoder_out = joiner_encoder_proj.run( [joiner_encoder_proj.get_outputs()[0].name], - {joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()}, + { + joiner_encoder_proj.get_inputs()[ + 0 + ].name: packed_encoder_out.data.numpy() + }, )[0] blank_id = 0 # hard-code to 0 @@ -379,7 +382,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py index 849d6cf4e..19b636a23 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py @@ -100,11 +100,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -129,12 +127,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -181,7 +177,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -234,9 +231,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -304,7 +302,9 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -391,7 +391,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py index 85d87f8f2..1e6022b57 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py @@ -234,7 +234,9 @@ def scaled_lstm_to_lstm(scaled_lstm: ScaledLSTM) -> nn.LSTM: assert lstm._flat_weights_names == scaled_lstm._flat_weights_names for idx in range(len(scaled_lstm._flat_weights_names)): - scaled_weight = scaled_lstm._flat_weights[idx] * scaled_lstm._scales[idx].exp() + scaled_weight = ( + scaled_lstm._flat_weights[idx] * scaled_lstm._scales[idx].exp() + ) lstm._flat_weights[idx].data.copy_(scaled_weight) return lstm @@ -249,10 +251,12 @@ def get_submodule(model, target): mod: torch.nn.Module = model for item in atoms: if not hasattr(mod, item): - raise AttributeError(mod._get_name() + " has no attribute `" + item + "`") + raise AttributeError( + mod._get_name() + " has no " "attribute `" + item + "`" + ) mod = getattr(mod, item) if not isinstance(mod, torch.nn.Module): - raise AttributeError("`" + item + "` is not an nn.Module") + raise AttributeError("`" + item + "` is not " "an nn.Module") return mod diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py index 41a712498..10bb44e00 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py @@ -52,7 +52,11 @@ from streaming_beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import ( AttributeDict, setup_logger, @@ -91,11 +95,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -161,7 +163,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -269,7 +272,9 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + greedy_search( + model=model, encoder_out=encoder_out, streams=decode_streams + ) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -289,7 +294,9 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -345,7 +352,9 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state(params.left_context, device=device) + initial_states = model.encoder.get_init_state( + params.left_context, device=device + ) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -417,7 +426,9 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) return {key: decode_results} @@ -450,7 +461,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -523,7 +535,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py index 598fcf344..66ffbd3ec 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py @@ -90,7 +90,9 @@ def test_conv2d_subsampling(): onnx_y = torch.from_numpy(onnx_y) torch_y = jit_model(x) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() + assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( + (onnx_y - torch_y).abs().max() + ) os.remove(filename) @@ -145,7 +147,9 @@ def test_rel_pos(): onnx_pos_emb = torch.from_numpy(onnx_pos_emb) torch_y, torch_pos_emb = jit_model(x) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() + assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( + (onnx_y - torch_y).abs().max() + ) assert torch.allclose(onnx_pos_emb, torch_pos_emb, atol=1e-05), ( (onnx_pos_emb - torch_pos_emb).abs().max() @@ -193,7 +197,9 @@ def test_conformer_encoder_layer(): encoder_layer.eval() encoder_layer = convert_scaled_to_non_scaled(encoder_layer, inplace=True) - jit_model = torch.jit.trace(encoder_layer, (x, pos_emb, src_key_padding_mask)) + jit_model = torch.jit.trace( + encoder_layer, (x, pos_emb, src_key_padding_mask) + ) torch.onnx.export( encoder_layer, @@ -230,7 +236,9 @@ def test_conformer_encoder_layer(): onnx_y = torch.from_numpy(onnx_y) torch_y = jit_model(x, pos_emb, src_key_padding_mask) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() + assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( + (onnx_y - torch_y).abs().max() + ) print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) @@ -314,7 +322,9 @@ def test_conformer_encoder(): onnx_y = torch.from_numpy(onnx_y) torch_y = jit_model(x, pos_emb, src_key_padding_mask) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() + assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( + (onnx_y - torch_y).abs().max() + ) print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) @@ -369,7 +379,9 @@ def test_conformer(): onnx_y_lens = torch.from_numpy(onnx_y_lens) torch_y, torch_y_lens = jit_model(x, x_lens) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() + assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( + (onnx_y - torch_y).abs().max() + ) assert torch.allclose(onnx_y_lens, torch_y_lens, atol=1e-05), ( (onnx_y_lens - torch_y_lens).abs().max() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index 6724343dd..44e96644a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -92,7 +92,9 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def add_model_arguments(parser: argparse.ArgumentParser): @@ -161,7 +163,8 @@ def get_parser(): "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. " + "Otherwise, use 100h subset.", ) parser.add_argument( @@ -211,7 +214,8 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need to be changed.", + help="The initial learning rate. This value should not need " + "to be changed.", ) parser.add_argument( @@ -234,45 +238,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -671,7 +672,9 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -684,9 +687,14 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -697,7 +705,9 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -909,7 +919,9 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) libri_tot_loss.write_summary( tb_writer, "train/libri_tot_", params.batch_idx_train ) @@ -955,7 +967,8 @@ def filter_short_and_long_utterances( # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False @@ -1096,7 +1109,9 @@ def run(rank, world_size, args): train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: - cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") + cuts_musan = load_manifest( + Path(args.manifest_dir) / "musan_cuts.jsonl.gz" + ) else: cuts_musan = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 69cfcd298..4f043e5a6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -197,24 +197,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -310,7 +306,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -430,7 +427,9 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) if ( params.decoding_method == "fast_beam_search" @@ -486,7 +485,10 @@ def decode_one_batch( nbest_scale=params.nbest_scale, return_timestamps=True, ) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): res = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -564,7 +566,9 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]: +) -> Dict[ + str, List[Tuple[str, List[str], List[str], List[float], List[float]]] +]: """Decode dataset. Args: @@ -639,7 +643,9 @@ def decode_dataset( cut_ids, hyps, texts, timestamps_hyp, timestamps_ref ): ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words, time_ref, time_hyp)) + this_batch.append( + (cut_id, ref_words, hyp_words, time_ref, time_hyp) + ) results[name].extend(this_batch) @@ -648,7 +654,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -686,7 +694,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -713,7 +722,9 @@ def save_results( note = "" logging.info(s) - s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name) + s = "\nFor {}, symbol-delay of different settings are:\n".format( + test_set_name + ) note = "\tbest for {}".format(test_set_name) for key, val in test_set_delays: s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note) @@ -762,7 +773,9 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -799,12 +812,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -827,12 +841,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -860,7 +875,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -887,7 +902,9 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py index bd5801a78..ce7518ceb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py @@ -89,24 +89,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -137,7 +133,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -186,12 +183,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -214,12 +212,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -247,7 +246,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -283,7 +282,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py index a28e52c78..7af9ea9b8 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py @@ -96,24 +96,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -179,7 +175,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -287,7 +284,9 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + greedy_search( + model=model, encoder_out=encoder_out, streams=decode_streams + ) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -307,7 +306,9 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -363,7 +364,9 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state(params.left_context, device=device) + initial_states = model.encoder.get_init_state( + params.left_context, device=device + ) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -435,7 +438,9 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) return {key: decode_results} @@ -468,7 +473,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -541,12 +547,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -569,12 +576,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -602,7 +610,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index 76785a845..cf32e565b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -101,7 +101,9 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def add_model_arguments(parser: argparse.ArgumentParser): @@ -237,45 +239,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -622,7 +621,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -662,7 +665,9 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -675,9 +680,14 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -688,7 +698,9 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -867,7 +879,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -999,7 +1013,8 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index 8499651d7..427b06294 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -214,7 +214,10 @@ class Conformer(EncoderInterface): (num_encoder_layers, cnn_module_kernel - 1, encoder_dim). NOTE: the returned tensors are on the given device. """ - if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context: + if ( + len(self._init_state) == 2 + and self._init_state[0].size(1) == left_context + ): # Note: It is OK to share the init state as it is # not going to be modified by the model return self._init_state @@ -436,7 +439,9 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -454,7 +459,9 @@ class ConformerEncoderLayer(nn.Module): ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) - self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal) + self.conv_module = ConvolutionModule( + d_model, cnn_module_kernel, causal=causal + ) self.norm_final = BasicNorm(d_model) @@ -520,7 +527,9 @@ class ConformerEncoderLayer(nn.Module): src = src + self.dropout(src_att) # convolution module - conv, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask) + conv, _ = self.conv_module( + src, src_key_padding_mask=src_key_padding_mask + ) src = src + self.dropout(conv) # feed forward module @@ -793,7 +802,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -809,7 +820,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x_size_1 * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -835,7 +848,9 @@ class RelPositionalEncoding(torch.nn.Module): pe = torch.cat([pe_positive, pe_negative], dim=1) self.pe = pe.to(device=x.device, dtype=x.dtype) - def forward(self, x: torch.Tensor, left_context: int = 0) -> Tuple[Tensor, Tensor]: + def forward( + self, x: torch.Tensor, left_context: int = 0 + ) -> Tuple[Tensor, Tensor]: """Add positional encoding. Args: @@ -1103,9 +1118,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -1174,25 +1189,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -1230,15 +1253,23 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) # compute matrix b and matrix d - matrix_bd = torch.matmul(q_with_bias_v, p) # (batch, head, time1, 2*time1-1) + matrix_bd = torch.matmul( + q_with_bias_v, p + ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd, left_context) - attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) + attn_output_weights = ( + matrix_ac + matrix_bd + ) # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -1279,17 +1310,21 @@ class RelPositionMultiheadAttention(nn.Module): ): if attn_mask.size(0) != 1: attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len) - combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) - else: - # attn_mask.shape == (1, tgt_len, src_len) - combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( + combined_mask = attn_mask | key_padding_mask.unsqueeze( 1 ).unsqueeze(2) + else: + # attn_mask.shape == (1, tgt_len, src_len) + combined_mask = attn_mask.unsqueeze( + 0 + ) | key_padding_mask.unsqueeze(1).unsqueeze(2) attn_output_weights = attn_output_weights.view( bsz, num_heads, tgt_len, src_len ) - attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) + attn_output_weights = attn_output_weights.masked_fill( + combined_mask, 0.0 + ) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, src_len ) @@ -1301,9 +1336,13 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -1442,12 +1481,16 @@ class ConvolutionModule(nn.Module): # manualy padding self.lorder zeros to the left x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) else: - assert not self.training, "Cache should be None in training time" + assert ( + not self.training + ), "Cache should be None in training time" assert cache.size(0) == self.lorder x = torch.cat([cache.permute(1, 2, 0), x], dim=2) if right_context > 0: cache = x.permute(2, 0, 1)[ - -(self.lorder + right_context) : (-right_context), # noqa + -(self.lorder + right_context) : ( # noqa + -right_context + ), ..., ] else: @@ -1623,7 +1666,9 @@ class RandomCombine(nn.Module): self.stddev = stddev self.final_log_weight = ( - torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)) + torch.tensor( + (final_weight / (1 - final_weight)) * (self.num_inputs - 1) + ) .log() .item() ) @@ -1720,14 +1765,16 @@ class RandomCombine(nn.Module): # final contains self.num_inputs - 1 in all elements final = torch.full((num_frames,), self.num_inputs - 1, device=device) # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. - nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device) + nonfinal = torch.randint( + self.num_inputs - 1, (num_frames,), device=device + ) indexes = torch.where( torch.rand(num_frames, device=device) < final_prob, final, nonfinal ) - ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to( - dtype=dtype - ) + ans = torch.nn.functional.one_hot( + indexes, num_classes=self.num_inputs + ).to(dtype=dtype) return ans def _get_random_mixed_weights( @@ -1757,8 +1804,7 @@ class RandomCombine(nn.Module): def _test_random_combine(final_weight: float, pure_prob: float, stddev: float): print( - f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}," - f" stddev={stddev}" + f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}" ) num_inputs = 3 num_channels = 50 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index f462cc42f..22bcdd88e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -179,24 +179,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -307,7 +303,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -480,7 +477,9 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] @@ -546,7 +545,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -694,7 +696,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -727,7 +731,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -782,7 +787,9 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -821,12 +828,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -849,12 +857,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -882,7 +891,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -928,7 +937,9 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py index a739c17bc..b2e5b430e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py @@ -89,24 +89,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -137,7 +133,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -184,12 +181,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -212,12 +210,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -245,7 +244,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -281,7 +280,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py index e2da0da4c..1e100fcbd 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py @@ -89,11 +89,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -118,12 +116,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -170,7 +166,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -201,9 +198,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -266,11 +264,15 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lengths + ) num_waves = encoder_out.size(0) hyps = [] @@ -342,7 +344,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py index 59a0e8fa2..6fee9483e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py @@ -96,24 +96,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -179,7 +175,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -287,7 +284,9 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + greedy_search( + model=model, encoder_out=encoder_out, streams=decode_streams + ) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -307,7 +306,9 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -363,7 +364,9 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state(params.left_context, device=device) + initial_states = model.encoder.get_init_state( + params.left_context, device=device + ) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -435,7 +438,9 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) return {key: decode_results} @@ -468,7 +473,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -541,12 +547,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -569,12 +576,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -602,7 +610,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index 75696d61b..179d9372e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -89,7 +89,9 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def add_model_arguments(parser: argparse.ArgumentParser): @@ -246,7 +248,8 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need to be changed.", + help="The initial learning rate. This value should not need " + "to be changed.", ) parser.add_argument( @@ -269,45 +272,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -645,7 +645,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -686,7 +690,9 @@ def compute_loss( # If the batch contains more than 10 utterances AND # if either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -699,9 +705,14 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -712,7 +723,9 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -895,7 +908,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -1008,7 +1023,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 2 ** 22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1030,7 +1045,8 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py index 40ad61fd4..53788b3f7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py @@ -90,7 +90,10 @@ class Conformer(EncoderInterface): output_layers = [] if middle_output_layer is not None: - assert middle_output_layer >= 0 and middle_output_layer < num_encoder_layers + assert ( + middle_output_layer >= 0 + and middle_output_layer < num_encoder_layers + ) output_layers.append(middle_output_layer) # The last layer is always needed. @@ -175,7 +178,9 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -357,7 +362,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -372,7 +379,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -647,9 +656,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -718,25 +727,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -773,7 +790,9 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -781,9 +800,13 @@ class RelPositionMultiheadAttention(nn.Module): ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) + attn_output_weights = ( + matrix_ac + matrix_bd + ) # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -817,9 +840,13 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -842,7 +869,9 @@ class ConvolutionModule(nn.Module): """ - def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: + def __init__( + self, channels: int, kernel_size: int, bias: bool = True + ) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py index 600aa9b39..74df04006 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py @@ -120,24 +120,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -212,7 +208,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -270,7 +267,9 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - layer_results, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + layer_results, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) encoder_out = layer_results[-1] hyps = [] @@ -286,7 +285,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -332,7 +334,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -405,7 +411,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -438,7 +446,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -481,7 +490,9 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -513,12 +524,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -541,12 +553,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -574,7 +587,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/export.py b/egs/librispeech/ASR/pruned_transducer_stateless6/export.py index 17f8614dc..cff9c7377 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/export.py @@ -51,7 +51,11 @@ import sentencepiece as spm import torch from train import get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import str2bool @@ -83,11 +87,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -118,7 +120,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) return parser @@ -157,7 +160,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -205,7 +209,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py b/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py index 86cf34877..21409287c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py @@ -21,10 +21,9 @@ import os from pathlib import Path import torch +from vq_utils import CodebookIndexExtractor from asr_datamodule import LibriSpeechAsrDataModule from hubert_xlarge import HubertXlargeFineTuned -from vq_utils import CodebookIndexExtractor - from icefall.utils import AttributeDict, str2bool diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py index b8440f90a..49b557814 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py @@ -23,6 +23,7 @@ from pathlib import Path from typing import Dict, List, Tuple import torch + from asr_datamodule import LibriSpeechAsrDataModule from hubert_xlarge import HubertXlargeFineTuned @@ -98,7 +99,9 @@ def decode_dataset( if batch_idx % 20 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -121,7 +124,9 @@ def save_results( ) test_set_wers[key] = wer - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + logging.info( + "Wrote detailed error stats to {}".format(errs_filename) + ) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.res_dir / f"wer-summary-{test_set_name}.txt" @@ -150,7 +155,9 @@ def main(): # reset some parameters needed by hubert. params.update(HubertXlargeFineTuned.get_params()) - params.res_dir = params.exp_dir / f"ctc_greedy_search-{params.teacher_model_id}" + params.res_dir = ( + params.exp_dir / f"ctc_greedy_search-{params.teacher_model_id}" + ) setup_logger(f"{params.res_dir}/log/log-ctc_greedy_search") logging.info("Decoding started") @@ -183,7 +190,9 @@ def main(): params=params, ) - save_results(params=params, test_set_name=test_set, results_dict=results_dict) + save_results( + params=params, test_set_name=test_set, results_dict=results_dict + ) logging.info("Done!") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py index 4f9417c9f..55ce7b00d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py @@ -22,7 +22,11 @@ from pathlib import Path from typing import Dict, List, Tuple import torch -from fairseq import checkpoint_utils, tasks, utils +from fairseq import ( + checkpoint_utils, + tasks, + utils, +) from fairseq.data.data_utils import post_process from omegaconf import OmegaConf @@ -47,7 +51,9 @@ def _load_hubert_model(params: AttributeDict): "data": str(params.hubert_model_dir), } ) - model_path = Path(params.hubert_model_dir) / (params.teacher_model_id + ".pt") + model_path = Path(params.hubert_model_dir) / ( + params.teacher_model_id + ".pt" + ) task = tasks.setup_task(cfg_task) processor = task.target_dictionary models, saved_cfg = checkpoint_utils.load_model_ensemble( @@ -145,7 +151,9 @@ class HubertXlargeFineTuned: supervisions = batch["supervisions"] num_samples = supervisions["num_samples"] B, T = features.shape - padding_mask = torch.arange(0, T).expand(B, T) > num_samples.reshape([-1, 1]) + padding_mask = torch.arange(0, T).expand(B, T) > num_samples.reshape( + [-1, 1] + ) padding_mask = padding_mask.to(self.params.device) features = features.to(self.params.device) @@ -155,7 +163,9 @@ class HubertXlargeFineTuned: features = features.transpose(1, 2) features = self.w2v_model.layer_norm(features) - padding_mask = self.w2v_model.forward_padding_mask(features, padding_mask) + padding_mask = self.w2v_model.forward_padding_mask( + features, padding_mask + ) if self.w2v_model.post_extract_proj is not None: features = self.w2v_model.post_extract_proj(features) @@ -202,7 +212,9 @@ class HubertXlargeFineTuned: toks = encoder_out.argmax(dim=-1) blank = 0 toks = [tok.unique_consecutive() for tok in toks] - hyps = [self.processor.string(tok[tok != blank].int().cpu()) for tok in toks] + hyps = [ + self.processor.string(tok[tok != blank].int().cpu()) for tok in toks + ] hyps = [post_process(hyp, "letter") for hyp in hyps] return hyps diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py index daadb70c9..7716d19cf 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py @@ -69,7 +69,9 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_speed=0.5 + ) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) from icefall import is_module_available @@ -178,7 +180,9 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) boundary[:, 2] = y_lens boundary[:, 3] = x_lens @@ -233,7 +237,9 @@ class Transducer(nn.Module): return (simple_loss, pruned_loss, codebook_loss) @staticmethod - def concat_successive_codebook_indexes(middle_layer_output, codebook_indexes): + def concat_successive_codebook_indexes( + middle_layer_output, codebook_indexes + ): # Output rate of hubert is 50 frames per second, # while that of current encoder is 25. # Following code handling two issues: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py index be54ff0ce..f717d85fb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py @@ -101,7 +101,9 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def get_parser(): @@ -201,45 +203,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -570,7 +569,9 @@ def save_checkpoint( def extract_codebook_indexes(batch): cuts = batch["supervisions"]["cut"] # -100 is identical to ignore_value in CE loss computation. - cuts_pre_mixed = [c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts] + cuts_pre_mixed = [ + c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts + ] codebook_indexes, codebook_indexes_lens = collate_custom_field( cuts_pre_mixed, "codebook_indexes", pad_value=-100 ) @@ -603,7 +604,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -650,7 +655,9 @@ def compute_loss( # If the batch contains more than 10 utterances AND # if either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -663,9 +670,14 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss if is_training and params.enable_distillation: assert codebook_loss is not None loss += params.codebook_loss_scale * codebook_loss @@ -678,7 +690,9 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -859,7 +873,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -991,7 +1007,8 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py index 40f97f662..47cf2b14b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py @@ -68,7 +68,9 @@ class CodebookIndexExtractor: def init_dirs(self): # vq_dir is the root dir for quantization, containing: # training data, trained quantizer, and extracted codebook indexes - self.vq_dir = self.params.exp_dir / f"vq/{self.params.teacher_model_id}/" + self.vq_dir = ( + self.params.exp_dir / f"vq/{self.params.teacher_model_id}/" + ) self.vq_dir.mkdir(parents=True, exist_ok=True) # manifest_dir contains: @@ -206,7 +208,9 @@ class CodebookIndexExtractor: start = cur_offset % (data.shape[0] + 1 - B) end = start + B cur_offset += B - yield data[start:end, :].to(self.params.device).to(dtype=torch.float) + yield data[start:end, :].to(self.params.device).to( + dtype=torch.float + ) for x in minibatch_generator(train, repeat=True): trainer.step(x) @@ -223,11 +227,10 @@ class CodebookIndexExtractor: """ for subset in self.params.subsets: logging.info(f"About to split {subset}.") - ori_manifest = f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" - split_cmd = ( - "lhotse split" - f" {self.params.world_size} {ori_manifest} {self.manifest_dir}" + ori_manifest = ( + f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" ) + split_cmd = f"lhotse split {self.params.world_size} {ori_manifest} {self.manifest_dir}" os.system(f"{split_cmd}") def join_manifests(self): @@ -237,13 +240,16 @@ class CodebookIndexExtractor: logging.info("Start to join manifest files.") for subset in self.params.subsets: vq_manifest_path = ( - self.dst_manifest_dir / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" + self.dst_manifest_dir + / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" ) ori_manifest_path = ( - self.ori_manifest_dir / f"librispeech_cuts_train-{subset}.jsonl.gz" + self.ori_manifest_dir + / f"librispeech_cuts_train-{subset}.jsonl.gz" ) dst_vq_manifest_path = ( - self.dst_manifest_dir / f"librispeech_cuts_train-{subset}.jsonl.gz" + self.dst_manifest_dir + / f"librispeech_cuts_train-{subset}.jsonl.gz" ) cuts_vq = load_manifest(vq_manifest_path) cuts_ori = load_manifest(ori_manifest_path) @@ -263,7 +269,8 @@ class CodebookIndexExtractor: for subset in self.params.subsets: vq_manifests = f"{self.manifest_dir}/with_codebook_indexes-librispeech-cuts_train-{subset}*.jsonl.gz" dst_vq_manifest = ( - self.dst_manifest_dir / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" + self.dst_manifest_dir + / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" ) if 1 == self.params.world_size: merge_cmd = f"cp {vq_manifests} {dst_vq_manifest}" @@ -323,7 +330,9 @@ class CodebookIndexExtractor: def load_ori_dl(self, subset): if self.params.world_size == 1: - ori_manifest_path = f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" + ori_manifest_path = ( + f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" + ) else: ori_manifest_path = ( self.manifest_dir diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py index fa8144935..06c5863f1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py @@ -164,24 +164,20 @@ def get_parser(): "--avg", type=int, default=9, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -276,7 +272,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -396,7 +393,9 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] @@ -455,7 +454,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -586,7 +588,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -619,7 +623,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -674,7 +679,9 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -711,12 +718,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -739,12 +747,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -772,7 +781,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -799,7 +808,9 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py index 5f90e6375..712dc8ce1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py @@ -69,7 +69,7 @@ class Decoder(nn.Module): out_channels=decoder_dim, kernel_size=context_size, padding=0, - groups=decoder_dim // 4, # group size == 4 + groups=decoder_dim//4, # group size == 4 bias=False, ) @@ -91,7 +91,9 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) + embedding_out = F.pad( + embedding_out, pad=(self.context_size - 1, 0) + ) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py index 43ac658e5..5744ea3ea 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py @@ -129,24 +129,20 @@ def get_parser(): "--avg", type=int, default=9, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -180,7 +176,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) add_model_arguments(parser) @@ -218,12 +215,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -246,12 +244,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -279,7 +278,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -317,7 +316,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py index c94a34d58..e2405d5ef 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py @@ -69,12 +69,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) return parser @@ -95,9 +93,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -268,7 +267,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py index 3ddac2cf2..7d8de5afe 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py @@ -56,7 +56,9 @@ class Joiner(nn.Module): assert encoder_out.shape[:-1] == decoder_out.shape[:-1] if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) + logit = self.encoder_proj(encoder_out) + self.decoder_proj( + decoder_out + ) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index 0e59b0f2f..53cde6c6f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -15,15 +15,14 @@ # limitations under the License. -import random - import k2 import torch import torch.nn as nn +import random from encoder_interface import EncoderInterface -from scaling import penalize_abs_values_gt from icefall.utils import add_sos +from scaling import penalize_abs_values_gt class Transducer(nn.Module): @@ -66,8 +65,7 @@ class Transducer(nn.Module): self.joiner = joiner self.simple_am_proj = nn.Linear( - encoder_dim, - vocab_size, + encoder_dim, vocab_size, ) self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size) @@ -135,16 +133,18 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) boundary[:, 2] = y_lens boundary[:, 3] = x_lens lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - # if self.training and random.random() < 0.25: + #if self.training and random.random() < 0.25: # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) - # if self.training and random.random() < 0.25: + #if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) with torch.cuda.amp.autocast(enabled=False): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 460ac2c3e..bb8b0a0e3 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -14,17 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib -import logging -import random from collections import defaultdict -from typing import List, Optional, Tuple, Union - -import torch +from typing import List, Optional, Union, Tuple, List from lhotse.utils import fix_random_seed +import torch from scaling import ActivationBalancer +import random from torch import Tensor from torch.optim import Optimizer +import logging +import contextlib + class BatchedOptimizer(Optimizer): @@ -37,10 +37,11 @@ class BatchedOptimizer(Optimizer): Args: params: """ - def __init__(self, params, defaults): super(BatchedOptimizer, self).__init__(params, defaults) + + @contextlib.contextmanager def batched_params(self, param_group): """ @@ -72,9 +73,7 @@ class BatchedOptimizer(Optimizer): group: a parameter group, which is a list of parameters; should be one of self.groups. """ - batches = defaultdict( - list - ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter + batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter for p in param_group: key = (str(p.dtype), *p.shape) @@ -83,7 +82,7 @@ class BatchedOptimizer(Optimizer): stacked_params_dict = dict() # turn batches into a list, in deterministic order. - batches = [batches[key] for key in sorted(batches.keys())] + batches = [ batches[key] for key in sorted(batches.keys()) ] # pairs will contain pairs of (stacked_param, state), one for each batch # in `batches`. pairs = [] @@ -95,78 +94,77 @@ class BatchedOptimizer(Optimizer): # group. class Optimizer will take care of saving/loading state. state = self.state[p] p_stacked = torch.stack(batch) - grad = torch.stack( - [torch.zeros_like(p) if p.grad is None else p.grad for p in batch] - ) + grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch ]) p_stacked.grad = grad stacked_params_dict[key] = p_stacked pairs.append((p_stacked, state)) - yield pairs # <-- calling code will do the actual optimization here! + yield pairs # <-- calling code will do the actual optimization here! for ((stacked_params, _state), batch) in zip(pairs, batches): for i, p in enumerate(batch): # batch is list of Parameter p.copy_(stacked_params[i]) + class ScaledAdam(BatchedOptimizer): """ - Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update - proportional to the norm of that parameter; and also learn the scale of the parameter, - in log space, subject to upper and lower limits (as if we had factored each parameter as - param = underlying_param * log_scale.exp()) + Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update + proportional to the norm of that parameter; and also learn the scale of the parameter, + in log space, subject to upper and lower limits (as if we had factored each parameter as + param = underlying_param * log_scale.exp()) - Args: - params: The parameters or param_groups to optimize (like other Optimizer subclasses) - lr: The learning rate. We will typically use a learning rate schedule that starts - at 0.03 and decreases over time, i.e. much higher than other common - optimizers. - clipping_scale: (e.g. 2.0) - A scale for gradient-clipping: if specified, the normalized gradients - over the whole model will be clipped to have 2-norm equal to - `clipping_scale` times the median 2-norm over the most recent period - of `clipping_update_period` minibatches. By "normalized gradients", - we mean after multiplying by the rms parameter value for this tensor - [for non-scalars]; this is appropriate because our update is scaled - by this quantity. - betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. - Must satisfy 0 < beta <= beta2 < 1. - scalar_lr_scale: A scaling factor on the learning rate, that we use to update the - scale of each parameter tensor and scalar parameters of the mode.. - If each parameter were decomposed - as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale - would be a the scaling factor on the learning rate of p_scale. - eps: A general-purpose epsilon to prevent division by zero - param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be >= this value) - param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be <= this value) - scalar_max: Maximum absolute value for scalar parameters (applicable if your - model has any parameters with numel() == 1). - size_update_period: The periodicity, in steps, with which we update the size (scale) - of the parameter tensor. This is provided to save a little time - in the update. - clipping_update_period: if clipping_scale is specified, this is the period + Args: + params: The parameters or param_groups to optimize (like other Optimizer subclasses) + lr: The learning rate. We will typically use a learning rate schedule that starts + at 0.03 and decreases over time, i.e. much higher than other common + optimizers. + clipping_scale: (e.g. 2.0) + A scale for gradient-clipping: if specified, the normalized gradients + over the whole model will be clipped to have 2-norm equal to + `clipping_scale` times the median 2-norm over the most recent period + of `clipping_update_period` minibatches. By "normalized gradients", + we mean after multiplying by the rms parameter value for this tensor + [for non-scalars]; this is appropriate because our update is scaled + by this quantity. + betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. + Must satisfy 0 < beta <= beta2 < 1. + scalar_lr_scale: A scaling factor on the learning rate, that we use to update the + scale of each parameter tensor and scalar parameters of the mode.. + If each parameter were decomposed + as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale + would be a the scaling factor on the learning rate of p_scale. + eps: A general-purpose epsilon to prevent division by zero + param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be >= this value) + param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be <= this value) + scalar_max: Maximum absolute value for scalar parameters (applicable if your + model has any parameters with numel() == 1). + size_update_period: The periodicity, in steps, with which we update the size (scale) + of the parameter tensor. This is provided to save a little time + in the update. + clipping_update_period: if clipping_scale is specified, this is the period """ - def __init__( - self, - params, - lr=3e-02, - clipping_scale=None, - betas=(0.9, 0.98), - scalar_lr_scale=0.1, - eps=1.0e-08, - param_min_rms=1.0e-05, - param_max_rms=3.0, - scalar_max=10.0, - size_update_period=4, - clipping_update_period=100, + self, + params, + lr=3e-02, + clipping_scale=None, + betas=(0.9, 0.98), + scalar_lr_scale=0.1, + eps=1.0e-08, + param_min_rms=1.0e-05, + param_max_rms=3.0, + scalar_max=10.0, + size_update_period=4, + clipping_update_period=100, ): + defaults = dict( lr=lr, clipping_scale=clipping_scale, @@ -185,6 +183,7 @@ class ScaledAdam(BatchedOptimizer): def __setstate__(self, state): super(ScaledAdam, self).__setstate__(state) + @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. @@ -207,9 +206,7 @@ class ScaledAdam(BatchedOptimizer): # a regular parameter, and will have a .grad, but the 1st dim corresponds to # a stacking dim, it is not a real dim. - if ( - len(batches[0][1]) == 0 - ): # if len(first state) == 0: not yet initialized + if len(batches[0][1]) == 0: # if len(first state) == 0: not yet initialized clipping_scale = 1 else: clipping_scale = self._get_clipping_scale(group, batches) @@ -228,9 +225,13 @@ class ScaledAdam(BatchedOptimizer): self._step_one_batch(group, p, state, clipping_scale) + return loss - def _init_state(self, group: dict, p: Tensor, state: dict): + def _init_state(self, + group: dict, + p: Tensor, + state: dict): """ Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p is actually the batch dimension, corresponding to batched-together @@ -246,7 +247,7 @@ class ScaledAdam(BatchedOptimizer): state["step"] = 0 - kwargs = {"device": p.device, "dtype": p.dtype} + kwargs = {'device':p.device, 'dtype':p.dtype} # 'delta' implements conventional momentum. There are # several different kinds of update going on, so rather than @@ -254,30 +255,36 @@ class ScaledAdam(BatchedOptimizer): # parameter-change "delta", which combines all forms of # update. this is equivalent to how it's done in Adam, # except for the first few steps. - state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format) + state["delta"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) batch_size = p.shape[0] numel = p.numel() // batch_size numel = p.numel() + if numel > 1: # "param_rms" just periodically records the scalar root-mean-square value of # the parameter tensor. # it has a shape like (batch_size, 1, 1, 1, 1) - param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() + param_rms = (p**2).mean(dim=list(range(1, p.ndim)), + keepdim=True).sqrt() state["param_rms"] = param_rms state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) - state["scale_grads"] = torch.zeros( - size_update_period, *param_rms.shape, **kwargs - ) + state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape, + **kwargs) + # exp_avg_sq is the weighted sum of scaled gradients. as in Adam. - state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) - def _get_clipping_scale( - self, group: dict, pairs: List[Tuple[Tensor, dict]] - ) -> float: + def _get_clipping_scale(self, + group: dict, + pairs: List[Tuple[Tensor, dict]]) -> float: """ Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients by this amount before applying the rest of the update. @@ -307,67 +314,57 @@ class ScaledAdam(BatchedOptimizer): if p.numel() == p.shape[0]: # a batch of scalars tot_sumsq += (grad**2).sum() # sum() to change shape [1] to [] else: - tot_sumsq += ((grad * state["param_rms"]) ** 2).sum() + tot_sumsq += ((grad * state["param_rms"])**2).sum() tot_norm = tot_sumsq.sqrt() - if "model_norms" not in first_state: - first_state["model_norms"] = torch.zeros( - clipping_update_period, device=p.device - ) + if not "model_norms" in first_state: + first_state["model_norms"] = torch.zeros(clipping_update_period, + device=p.device) first_state["model_norms"][step % clipping_update_period] = tot_norm if step % clipping_update_period == 0: # Print some stats. # We don't reach here if step == 0 because we would have returned # above. - sorted_norms = first_state["model_norms"].sort()[0].to("cpu") + sorted_norms = first_state["model_norms"].sort()[0].to('cpu') quartiles = [] for n in range(0, 5): - index = min( - clipping_update_period - 1, - (clipping_update_period // 4) * n, - ) + index = min(clipping_update_period - 1, + (clipping_update_period // 4) * n) quartiles.append(sorted_norms[index].item()) median = quartiles[2] threshold = clipping_scale * median first_state["model_norm_threshold"] = threshold - percent_clipped = ( - first_state["num_clipped"] * 100.0 / clipping_update_period - if "num_clipped" in first_state - else 0.0 - ) + percent_clipped = (first_state["num_clipped"] * 100.0 / clipping_update_period + if "num_clipped" in first_state else 0.0) first_state["num_clipped"] = 0 - quartiles = " ".join(["%.3e" % x for x in quartiles]) - logging.info( - f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " - f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}" - ) + quartiles = ' '.join([ '%.3e' % x for x in quartiles ]) + logging.info(f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " + f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}") if step < clipping_update_period: return 1.0 # We have not yet estimated a norm to clip to. else: try: model_norm_threshold = first_state["model_norm_threshold"] - except KeyError: - logging.info( - "Warning: model_norm_threshold not in state: possibly " - "you changed config when restarting, adding clipping_scale option?" - ) + except: + logging.info("Warning: model_norm_threshold not in state: possibly " + "you changed config when restarting, adding clipping_scale option?") return 1.0 - ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item()) + ans = min(1.0,(model_norm_threshold / (tot_norm + 1.0e-20)).item()) if ans < 1.0: first_state["num_clipped"] += 1 if ans < 0.1: - logging.warn( - f"Scaling gradients by {ans}," - f" model_norm_threshold={model_norm_threshold}" - ) + logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}") return ans - def _step_one_batch( - self, group: dict, p: Tensor, state: dict, clipping_scale: float - ): + + def _step_one_batch(self, + group: dict, + p: Tensor, + state: dict, + clipping_scale: float): """ Do the step for one parameter, which is actually going to be a batch of `real` parameters, with dim 0 as the batch dim. @@ -394,18 +391,17 @@ class ScaledAdam(BatchedOptimizer): # Update the size/scale of p, and set param_rms scale_grads = state["scale_grads"] scale_grads[step % size_update_period] = (p * grad).sum( - dim=list(range(1, p.ndim)), keepdim=True - ) + dim=list(range(1, p.ndim)), keepdim=True) if step % size_update_period == size_update_period - 1: param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..) - param_rms.copy_( - (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() - ) + param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), + keepdim=True).sqrt()) if step > 0: # self._size_update() learns the overall scale on the # parameter, by shrinking or expanding it. self._size_update(group, scale_grads, p, state) + if numel == 1: # For parameters with 1 element we just use regular Adam. # Updates delta. @@ -415,21 +411,24 @@ class ScaledAdam(BatchedOptimizer): state["step"] = step + 1 - def _size_update( - self, group: dict, scale_grads: Tensor, p: Tensor, state: dict - ) -> None: - """ - Called only where p.numel() > 1, this updates the scale of the parameter. - If we imagine: p = underlying_param * scale.exp(), and we are doing - gradient descent on underlying param and on scale, this function does the update - on `scale`. - Args: - group: dict to look up configuration values - scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing - grads w.r.t. the scales. - p: The parameter to update - state: The state-dict of p + def _size_update(self, + group: dict, + scale_grads: Tensor, + p: Tensor, + state: dict) -> None: + """ + Called only where p.numel() > 1, this updates the scale of the parameter. + If we imagine: p = underlying_param * scale.exp(), and we are doing + gradient descent on underlying param and on scale, this function does the update + on `scale`. + + Args: + group: dict to look up configuration values + scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing + grads w.r.t. the scales. + p: The parameter to update + state: The state-dict of p """ param_rms = state["param_rms"] @@ -444,28 +443,25 @@ class ScaledAdam(BatchedOptimizer): size_update_period = scale_grads.shape[0] # correct beta2 for the size update period: we will have # faster decay at this level. - beta2_corr = beta2**size_update_period + beta2_corr = beta2 ** size_update_period scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..) scale_exp_avg_sq.mul_(beta2_corr).add_( - (scale_grads**2).mean(dim=0), # mean over dim `size_update_period` - alpha=1 - beta2_corr, - ) # shape is (batch_size, 1, 1, ...) + (scale_grads ** 2).mean(dim=0), # mean over dim `size_update_period` + alpha=1-beta2_corr) # shape is (batch_size, 1, 1, ...) # The 1st time we reach here is when size_step == 1. size_step = (step + 1) // size_update_period - bias_correction2 = 1 - beta2_corr**size_step + bias_correction2 = 1 - beta2_corr ** size_step # we don't bother with bias_correction1; this will help prevent divergence # at the start of training. denom = scale_exp_avg_sq.sqrt() + eps - scale_step = ( - -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom - ) + scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum(dim=0) / denom - is_too_small = param_rms < param_min_rms - is_too_large = param_rms > param_max_rms + is_too_small = (param_rms < param_min_rms) + is_too_large = (param_rms > param_max_rms) # when the param gets too small, just don't shrink it any further. scale_step.masked_fill_(is_too_small, 0.0) @@ -473,9 +469,13 @@ class ScaledAdam(BatchedOptimizer): scale_step.masked_fill_(is_too_large, -size_lr * size_update_period) delta = state["delta"] # the factor of (1-beta1) relates to momentum. - delta.add_(p * scale_step, alpha=(1 - beta1)) + delta.add_(p * scale_step, alpha=(1-beta1)) - def _step(self, group: dict, p: Tensor, state: dict): + + def _step(self, + group: dict, + p: Tensor, + state: dict): """ This function does the core update of self.step(), in the case where the members of the batch have more than 1 element. @@ -496,7 +496,8 @@ class ScaledAdam(BatchedOptimizer): step = state["step"] exp_avg_sq = state["exp_avg_sq"] - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, + value=(1-beta2)) this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0) bias_correction2 = 1 - beta2 ** (this_step + 1) @@ -508,13 +509,17 @@ class ScaledAdam(BatchedOptimizer): denom += eps grad = grad / denom - alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms) + alpha = -lr * (1-beta1) * state["param_rms"].clamp(min=param_min_rms) delta = state["delta"] delta.add_(grad * alpha) p.add_(delta) - def _step_scalar(self, group: dict, p: Tensor, state: dict): + + def _step_scalar(self, + group: dict, + p: Tensor, + state: dict): """ A simplified form of the core update for scalar tensors, where we cannot get a good estimate of the parameter rms. @@ -526,7 +531,8 @@ class ScaledAdam(BatchedOptimizer): grad = p.grad exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, + value=1-beta2) # bias_correction2 is like in Adam. Don't bother with bias_correction1; # slower update at the start will help stability anyway. @@ -534,11 +540,12 @@ class ScaledAdam(BatchedOptimizer): denom = (exp_avg_sq / bias_correction2).sqrt() + eps delta = state["delta"] - delta.add_(grad / denom, alpha=-lr * (1 - beta1)) + delta.add_(grad / denom, alpha=-lr*(1-beta1)) p.clamp_(min=-scalar_max, max=scalar_max) p.add_(delta) + class LRScheduler(object): """ Base-class for learning rate schedulers where the learning-rate depends on both the @@ -548,14 +555,18 @@ class LRScheduler(object): def __init__(self, optimizer: Optimizer, verbose: bool = False): # Attach optimizer if not isinstance(optimizer, Optimizer): - raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) + raise TypeError( + "{} is not an Optimizer".format(type(optimizer).__name__) + ) self.optimizer = optimizer self.verbose = verbose for group in optimizer.param_groups: group.setdefault("base_lr", group["lr"]) - self.base_lrs = [group["base_lr"] for group in optimizer.param_groups] + self.base_lrs = [ + group["base_lr"] for group in optimizer.param_groups + ] self.epoch = 0 self.batch = 0 @@ -669,15 +680,13 @@ class Eden(LRScheduler): def get_lr(self): factor = ( - (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 + (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 ) ** -0.25 * ( - ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 - ) - warmup_factor = ( - 1.0 - if self.batch >= self.warmup_batches - else 0.5 + 0.5 * (self.batch / self.warmup_batches) + ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) + ** -0.25 ) + warmup_factor = (1.0 if self.batch >= self.warmup_batches + else 0.5 + 0.5 * (self.batch / self.warmup_batches)) return [x * factor * warmup_factor for x in self.base_lrs] @@ -736,14 +745,13 @@ class Eve(Optimizer): parameters, if they fall below this we will stop applying weight decay. - .. _Adam: A Method for Stochastic Optimization: + .. _Adam\: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 .. _Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101 .. _On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ """ - def __init__( self, params, @@ -758,11 +766,17 @@ class Eve(Optimizer): if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + raise ValueError( + "Invalid beta parameter at index 0: {}".format(betas[0]) + ) if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + raise ValueError( + "Invalid beta parameter at index 1: {}".format(betas[1]) + ) if not 0 <= weight_decay <= 0.1: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict( @@ -798,7 +812,9 @@ class Eve(Optimizer): # Perform optimization step grad = p.grad if grad.is_sparse: - raise RuntimeError("AdamW does not support sparse gradients") + raise RuntimeError( + "AdamW does not support sparse gradients" + ) state = self.state[p] @@ -825,7 +841,7 @@ class Eve(Optimizer): # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( + denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( group["eps"] ) @@ -836,31 +852,30 @@ class Eve(Optimizer): if p.numel() > 1: # avoid applying this weight-decay on "scaling factors" # (which are scalar). - is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) + is_above_target_rms = p.norm() > ( + target_rms * (p.numel() ** 0.5) + ) p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) if random.random() < 0.0005: - step = (exp_avg / denom) * step_size - logging.info( - f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}" - ) + step = (exp_avg/denom) * step_size + logging.info(f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}") + return loss def _test_scaled_adam(hidden_dim: int): import timeit - from scaling import ScaledLinear - E = 100 B = 4 T = 2 logging.info("in test_eve_cain") - # device = torch.device('cuda') - device = torch.device("cpu") + #device = torch.device('cuda') + device = torch.device('cpu') dtype = torch.float32 fix_random_seed(42) @@ -874,93 +889,79 @@ def _test_scaled_adam(hidden_dim: int): fix_random_seed(42) Linear = torch.nn.Linear if iter == 0 else ScaledLinear - m = torch.nn.Sequential( - Linear(E, hidden_dim), - torch.nn.PReLU(), - Linear(hidden_dim, hidden_dim), - torch.nn.PReLU(), - Linear(hidden_dim, E), - ).to(device) + m = torch.nn.Sequential(Linear(E, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, E), + ).to(device) - train_pairs = [ - ( - 100.0 - * torch.randn(B, T, E, device=device, dtype=dtype) - * input_magnitudes, - torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes, - ) - for _ in range(20) - ] + train_pairs = [ (100.0 * torch.randn(B, T, E, device=device, dtype=dtype) * input_magnitudes, + torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes) for _ in range(20) ] - if iter == 0: - optim = Eve(m.parameters(), lr=0.003) - elif iter == 1: - optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) + if iter == 0: optim = Eve(m.parameters(), lr=0.003) + elif iter == 1: optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) + start = timeit.default_timer() avg_loss = 0.0 for epoch in range(180): scheduler.step_epoch() - # if epoch == 100 and iter in [2,3]: + #if epoch == 100 and iter in [2,3]: # optim.reset_speedup() # check it doesn't crash. - # if epoch == 130: + #if epoch == 130: # opts = diagnostics.TensorDiagnosticOptions( # 2 ** 22 # ) # allow 4 megabytes per sub-module # diagnostic = diagnostics.attach_diagnostics(m, opts) - for n, (x, y) in enumerate(train_pairs): + + for n, (x,y) in enumerate(train_pairs): y_out = m(x) - loss = ((y_out - y) ** 2).mean() * 100.0 + loss = ((y_out - y)**2).mean() * 100.0 if epoch == 0 and n == 0: avg_loss = loss.item() else: avg_loss = 0.98 * avg_loss + 0.02 * loss.item() if n == 0 and epoch % 5 == 0: - # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() - # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() - # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() - # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() - # scale1 = '%.2e' % (m[0].weight_scale.exp().item()) - # scale1b = '%.2e' % (m[0].bias_scale.exp().item()) - # scale2 = '%.2e' % (m[2].weight_scale.exp().item()) - # scale2b = '%.2e' % (m[2].bias_scale.exp().item()) + #norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() + #norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() + #norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() + #norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() + #scale1 = '%.2e' % (m[0].weight_scale.exp().item()) + #scale1b = '%.2e' % (m[0].bias_scale.exp().item()) + #scale2 = '%.2e' % (m[2].weight_scale.exp().item()) + #scale2b = '%.2e' % (m[2].bias_scale.exp().item()) lr = scheduler.get_last_lr()[0] - logging.info( - f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss" - f" {avg_loss:.4g}, lr={lr:.4e}" - ) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} + logging.info(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}") #, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} loss.log().backward() optim.step() optim.zero_grad() scheduler.step_batch() - # diagnostic.print_diagnostics() + #diagnostic.print_diagnostics() stop = timeit.default_timer() logging.info(f"Iter={iter}, Time taken: {stop - start}") logging.info(f"last lr = {scheduler.get_last_lr()}") - # logging.info("state dict = ", scheduler.state_dict()) - # logging.info("optim state_dict = ", optim.state_dict()) + #logging.info("state dict = ", scheduler.state_dict()) + #logging.info("optim state_dict = ", optim.state_dict()) logging.info(f"input_magnitudes = {input_magnitudes}") logging.info(f"output_magnitudes = {output_magnitudes}") + if __name__ == "__main__": torch.set_num_threads(1) torch.set_num_interop_threads(1) logging.getLogger().setLevel(logging.INFO) import subprocess - - s = subprocess.check_output( - "git status -uno .; git log -1; git diff HEAD .", shell=True - ) + s = subprocess.check_output("git status -uno .; git log -1; git diff HEAD .", shell=True) logging.info(s) import sys - if len(sys.argv) > 1: hidden_dim = int(sys.argv[1]) else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py index 8b4d88871..7fe1e681a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py @@ -100,11 +100,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -129,12 +127,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -181,7 +177,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -212,9 +209,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -277,11 +275,15 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lengths + ) num_waves = encoder_out.size(0) hyps = [] @@ -353,7 +355,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 4040065e1..50cedba56 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -16,12 +16,12 @@ import collections -import logging -import random -from functools import reduce from itertools import repeat from typing import Optional, Tuple, Union +from functools import reduce +import logging +import random import torch import torch.nn as nn import torch.nn.functional as F @@ -32,24 +32,27 @@ from torch.nn import Embedding as ScaledEmbedding class ActivationBalancerFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - scale_factor: Tensor, - sign_factor: Optional[Tensor], - channel_dim: int, + ctx, + x: Tensor, + scale_factor: Tensor, + sign_factor: Optional[Tensor], + channel_dim: int, ) -> Tensor: if channel_dim < 0: channel_dim += x.ndim ctx.channel_dim = channel_dim - xgt0 = x > 0 + xgt0 = (x > 0) if sign_factor is None: ctx.save_for_backward(xgt0, scale_factor) else: ctx.save_for_backward(xgt0, scale_factor, sign_factor) return x + @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: + def backward( + ctx, x_grad: Tensor + ) -> Tuple[Tensor, None, None, None]: if len(ctx.saved_tensors) == 3: xgt0, scale_factor, sign_factor = ctx.saved_tensors for _ in range(ctx.channel_dim, x_grad.ndim - 1): @@ -62,22 +65,14 @@ class ActivationBalancerFunction(torch.autograd.Function): scale_factor = scale_factor.unsqueeze(-1) factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5) neg_delta_grad = x_grad.abs() * factor - return ( - x_grad - neg_delta_grad, - None, - None, - None, - ) + return x_grad - neg_delta_grad, None, None, None, - -def _compute_scale_factor( - x: Tensor, - channel_dim: int, - min_abs: float, - max_abs: float, - gain_factor: float, - max_factor: float, -) -> Tensor: +def _compute_scale_factor(x: Tensor, + channel_dim: int, + min_abs: float, + max_abs: float, + gain_factor: float, + max_factor: float) -> Tensor: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] @@ -88,76 +83,71 @@ def _compute_scale_factor( else: # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if # x_abs)_mean , min_abs. - below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp( - min=0, max=max_factor - ) + below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(min=0, max=max_factor) - above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp( - min=0, max=max_factor - ) + above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(min=0, max=max_factor) return below_threshold - above_threshold - -def _compute_sign_factor( - x: Tensor, - channel_dim: int, - min_positive: float, - max_positive: float, - gain_factor: float, - max_factor: float, -) -> Tensor: +def _compute_sign_factor(x: Tensor, + channel_dim: int, + min_positive: float, + max_positive: float, + gain_factor: float, + max_factor: float) -> Tensor: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] - proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims) + proportion_positive = torch.mean((x > 0).to(torch.float32), + dim=sum_dims) if min_positive == 0.0: factor1 = 0.0 else: # 0 if proportion_positive >= min_positive, else can be # as large as max_factor. - factor1 = ( - (min_positive - proportion_positive) * (gain_factor / min_positive) - ).clamp_(min=0, max=max_factor) + factor1 = ((min_positive - proportion_positive) * + (gain_factor / min_positive)).clamp_(min=0, max=max_factor) if max_positive == 1.0: factor2 = 0.0 else: # 0 if self.proportion_positive <= max_positive, else can be # as large as -max_factor. - factor2 = ( - (proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive)) - ).clamp_(min=0, max=max_factor) + factor2 = ((proportion_positive - max_positive) * + (gain_factor / (1.0 - max_positive))).clamp_(min=0, max=max_factor) sign_factor = factor1 - factor2 # require min_positive != 0 or max_positive != 1: assert not isinstance(sign_factor, float) return sign_factor + class ActivationScaleBalancerFunction(torch.autograd.Function): """ This object is used in class ActivationBalancer when the user specified min_positive=0, max_positive=1, so there are no constraints on the signs of the activations and only the absolute value has a constraint. """ - @staticmethod def forward( - ctx, - x: Tensor, - sign_factor: Tensor, - scale_factor: Tensor, - channel_dim: int, + ctx, + x: Tensor, + sign_factor: Tensor, + scale_factor: Tensor, + channel_dim: int, ) -> Tensor: if channel_dim < 0: channel_dim += x.ndim ctx.channel_dim = channel_dim - xgt0 = x > 0 + xgt0 = (x > 0) ctx.save_for_backward(xgt0, sign_factor, scale_factor) return x + @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: + def backward( + ctx, x_grad: Tensor + ) -> Tuple[Tensor, None, None, None]: xgt0, sign_factor, scale_factor = ctx.saved_tensors for _ in range(ctx.channel_dim, x_grad.ndim - 1): sign_factor = sign_factor.unsqueeze(-1) @@ -165,24 +155,18 @@ class ActivationScaleBalancerFunction(torch.autograd.Function): factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) neg_delta_grad = x_grad.abs() * factor - return ( - x_grad - neg_delta_grad, - None, - None, - None, - ) + return x_grad - neg_delta_grad, None, None, None, class RandomClampFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - min: Optional[float], - max: Optional[float], - prob: float, - reflect: float, - ) -> Tensor: + ctx, + x: Tensor, + min: Optional[float], + max: Optional[float], + prob: float, + reflect: float) -> Tensor: x_clamped = torch.clamp(x, min=min, max=max) mask = torch.rand_like(x) < prob ans = torch.where(mask, x_clamped, x) @@ -195,32 +179,30 @@ class RandomClampFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None, None]: - (is_same,) = ctx.saved_tensors + is_same, = ctx.saved_tensors x_grad = ans_grad * is_same.to(ans_grad.dtype) reflect = ctx.reflect - if reflect != 0.0: + if reflect != 0.0: x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect) return x_grad, None, None, None, None - -def random_clamp( - x: Tensor, - min: Optional[float] = None, - max: Optional[float] = None, - prob: float = 0.5, - reflect: float = 0.0, -): +def random_clamp(x: Tensor, + min: Optional[float] = None, + max: Optional[float] = None, + prob: float = 0.5, + reflect: float = 0.0): return RandomClampFunction.apply(x, min, max, prob, reflect) -def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor: +def random_cast_to_half(x: Tensor, + min_abs: float = 5.0e-06) -> Tensor: """ A randomized way of casting a floating point value to half precision. """ if x.dtype == torch.float16: return x x_abs = x.abs() - is_too_small = x_abs < min_abs + is_too_small = (x_abs < min_abs) # for elements where is_too_small is true, random_val will contain +-min_abs with # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, # for those elements]. @@ -233,7 +215,6 @@ class RandomGradFunction(torch.autograd.Function): Does nothing in forward pass; in backward pass, gets rid of very small grads using randomized approach that preserves expectations (intended to reduce roundoff). """ - @staticmethod def forward(ctx, x: Tensor, min_abs: float) -> Tensor: ctx.min_abs = min_abs @@ -242,37 +223,35 @@ class RandomGradFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]: if ans_grad.dtype == torch.float16: - return ( - random_cast_to_half(ans_grad.to(torch.float32), min_abs=ctx.min_abs), - None, - ) + return random_cast_to_half(ans_grad.to(torch.float32), + min_abs=ctx.min_abs), None else: return ans_grad, None - class RandomGrad(torch.nn.Module): """ Gets rid of very small gradients using an expectation-preserving method, intended to increase accuracy of training when using amp (automatic mixed precision) """ - - def __init__(self, min_abs: float = 5.0e-06): + def __init__(self, + min_abs: float = 5.0e-06): super(RandomGrad, self).__init__() self.min_abs = min_abs - def forward(self, x: Tensor): + def forward(self, + x: Tensor): if torch.jit.is_scripting() or not self.training: return x else: return RandomGradFunction.apply(x, self.min_abs) + class SoftmaxFunction(torch.autograd.Function): """ Tries to handle half-precision derivatives in a randomized way that should be more accurate for training than the default behavior. """ - @staticmethod def forward(ctx, x: Tensor, dim: int): ans = x.softmax(dim=dim) @@ -288,7 +267,7 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): - (ans,) = ctx.saved_tensors + ans, = ctx.saved_tensors with torch.cuda.amp.autocast(enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) @@ -297,7 +276,9 @@ class SoftmaxFunction(torch.autograd.Function): return x_grad, None -def softmax(x: Tensor, dim: int): + +def softmax(x: Tensor, + dim: int): if torch.jit.is_scripting(): return x.softmax(dim) @@ -307,18 +288,20 @@ def softmax(x: Tensor, dim: int): class MaxEigLimiterFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - coeffs: Tensor, - direction: Tensor, - channel_dim: int, - grad_scale: float, - ) -> Tensor: + ctx, + x: Tensor, + coeffs: Tensor, + direction: Tensor, + channel_dim: int, + grad_scale: float) -> Tensor: ctx.channel_dim = channel_dim ctx.grad_scale = grad_scale - ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach()) + ctx.save_for_backward(x.detach(), + coeffs.detach(), + direction.detach()) return x + @staticmethod def backward(ctx, x_grad, *args): with torch.enable_grad(): @@ -328,20 +311,15 @@ class MaxEigLimiterFunction(torch.autograd.Function): x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) new_direction.requires_grad = False x = x - x.mean(dim=0) - x_var = (x**2).mean() + x_var = (x ** 2).mean() x_residual = x - coeffs * new_direction - x_residual_var = (x_residual**2).mean() + x_residual_var = (x_residual ** 2).mean() # `variance_proportion` is the proportion of the variance accounted for # by the top eigen-direction. This is to be minimized. variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) variance_proportion.backward() x_orig_grad = x_orig.grad - x_extra_grad = ( - x_orig.grad - * ctx.grad_scale - * x_grad.norm() - / (x_orig_grad.norm() + 1.0e-20) - ) + x_extra_grad = x_orig.grad * ctx.grad_scale * x_grad.norm() / (x_orig_grad.norm() + 1.0e-20) return x_grad + x_extra_grad.detach(), None, None, None, None @@ -407,12 +385,15 @@ class BasicNorm(torch.nn.Module): # region if it happens to exit it. eps = eps.clamp(min=self.eps_min, max=self.eps_max) scales = ( - torch.mean(x**2, dim=self.channel_dim, keepdim=True) + eps.exp() + torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp() ) ** -0.5 return x * scales -def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: + +def ScaledLinear(*args, + initial_scale: float = 1.0, + **kwargs ) -> nn.Linear: """ Behaves like a constructor of a modified version of nn.Linear that gives an easy way to set the default initial parameter scale. @@ -431,11 +412,16 @@ def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) + torch.nn.init.uniform_(ans.bias, + -0.1 * initial_scale, + 0.1 * initial_scale) return ans -def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d: + +def ScaledConv1d(*args, + initial_scale: float = 1.0, + **kwargs ) -> nn.Conv1d: """ Behaves like a constructor of a modified version of nn.Conv1d that gives an easy way to set the default initial parameter scale. @@ -454,10 +440,13 @@ def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d: with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) + torch.nn.init.uniform_(ans.bias, + -0.1 * initial_scale, + 0.1 * initial_scale) return ans + class ActivationBalancer(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to encourage, for @@ -497,19 +486,18 @@ class ActivationBalancer(torch.nn.Module): from doing it at the same time. Early in training we may use higher probabilities than this; it will decay to this value. """ - def __init__( - self, - num_channels: int, - channel_dim: int, - min_positive: float = 0.05, - max_positive: float = 0.95, - max_factor: float = 0.04, - sign_gain_factor: float = 0.01, - scale_gain_factor: float = 0.02, - min_abs: float = 0.2, - max_abs: float = 100.0, - min_prob: float = 0.1, + self, + num_channels: int, + channel_dim: int, + min_positive: float = 0.05, + max_positive: float = 0.95, + max_factor: float = 0.04, + sign_gain_factor: float = 0.01, + scale_gain_factor: float = 0.02, + min_abs: float = 0.2, + max_abs: float = 100.0, + min_prob: float = 0.1, ): super(ActivationBalancer, self).__init__() self.num_channels = num_channels @@ -527,7 +515,9 @@ class ActivationBalancer(torch.nn.Module): # We occasionally sync this to a tensor called `count`, that exists to # make sure it is synced to disk when we load and save the model. self.cpu_count = 0 - self.register_buffer("count", torch.tensor(0, dtype=torch.int64)) + self.register_buffer('count', torch.tensor(0, dtype=torch.int64)) + + def forward(self, x: Tensor) -> Tensor: if torch.jit.is_scripting() or not x.requires_grad: @@ -545,35 +535,26 @@ class ActivationBalancer(torch.nn.Module): # the prob of doing some work exponentially decreases from 0.5 till it hits # a floor at min_prob (==0.1, by default) - prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0))) + prob = max(self.min_prob, 0.5 ** (1 + (count/4000.0))) if random.random() < prob: sign_gain_factor = 0.5 if self.min_positive != 0.0 or self.max_positive != 1.0: - sign_factor = _compute_sign_factor( - x, - self.channel_dim, - self.min_positive, - self.max_positive, - gain_factor=self.sign_gain_factor / prob, - max_factor=self.max_factor, - ) + sign_factor = _compute_sign_factor(x, self.channel_dim, + self.min_positive, self.max_positive, + gain_factor=self.sign_gain_factor / prob, + max_factor=self.max_factor) else: sign_factor = None - scale_factor = _compute_scale_factor( - x, - self.channel_dim, - min_abs=self.min_abs, - max_abs=self.max_abs, - gain_factor=self.scale_gain_factor / prob, - max_factor=self.max_factor, - ) + + scale_factor = _compute_scale_factor(x, self.channel_dim, + min_abs=self.min_abs, + max_abs=self.max_abs, + gain_factor=self.scale_gain_factor / prob, + max_factor=self.max_factor) return ActivationBalancerFunction.apply( - x, - scale_factor, - sign_factor, - self.channel_dim, + x, scale_factor, sign_factor, self.channel_dim, ) else: return _no_op(x) @@ -613,12 +594,13 @@ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. else: (batch, dim, dim) = x.shape x = x.reshape(batch, dim * dim) - x = x[:, :: dim + 1] + x = x[:, ::dim+1] assert x.shape == (batch, dim) return x -def _whitening_metric(x: Tensor, num_groups: int): +def _whitening_metric(x: Tensor, + num_groups: int): """ Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of of the centered feature covariance are the same within each group's covariance matrix @@ -648,21 +630,19 @@ def _whitening_metric(x: Tensor, num_groups: int): # the following expression is what we'd get if we took the matrix product # of each covariance and measured the mean of its trace, i.e. # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). - x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group) + x_covarsq_mean_diag = (x_covar ** 2).sum() / (num_groups * channels_per_group) # this metric will be >= 1.0; the larger it is, the less 'white' the data was. - metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20) + metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20) return metric class WhiteningPenaltyFunction(torch.autograd.Function): @staticmethod - def forward( - ctx, - x: Tensor, - num_groups: int, - whitening_limit: float, - grad_scale: float, - ) -> Tensor: + def forward(ctx, + x: Tensor, + num_groups: int, + whitening_limit: float, + grad_scale: float) -> Tensor: ctx.save_for_backward(x) ctx.num_groups = num_groups ctx.whitening_limit = whitening_limit @@ -670,8 +650,9 @@ class WhiteningPenaltyFunction(torch.autograd.Function): return x @staticmethod - def backward(ctx, x_grad: Tensor): - (x_orig,) = ctx.saved_tensors + def backward(ctx, + x_grad: Tensor): + x_orig, = ctx.saved_tensors with torch.enable_grad(): with torch.cuda.amp.autocast(enabled=False): x_detached = x_orig.to(torch.float32).detach() @@ -680,29 +661,25 @@ class WhiteningPenaltyFunction(torch.autograd.Function): metric = _whitening_metric(x_detached, ctx.num_groups) if random.random() < 0.005 or __name__ == "__main__": - logging.info( - f"Whitening: num_groups={ctx.num_groups}," - f" num_channels={x_orig.shape[-1]}," - f" metric={metric.item():.2f} vs. limit={ctx.whitening_limit}" - ) + logging.info(f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, " + f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}") (metric - ctx.whitening_limit).relu().backward() penalty_grad = x_detached.grad - scale = ctx.grad_scale * ( - x_grad.to(torch.float32).norm() / (penalty_grad.norm() + 1.0e-20) - ) + scale = ctx.grad_scale * (x_grad.to(torch.float32).norm() / + (penalty_grad.norm() + 1.0e-20)) penalty_grad = penalty_grad * scale return x_grad + penalty_grad.to(x_grad.dtype), None, None, None + class Whiten(nn.Module): def __init__( - self, - num_groups: int, - whitening_limit: float, - prob: Union[float, Tuple[float, float]], - grad_scale: float, - ): + self, + num_groups: int, + whitening_limit: float, + prob: Union[float, Tuple[float,float]], + grad_scale: float): """ Args: num_groups: the number of groups to divide the channel dim into before @@ -737,7 +714,8 @@ class Whiten(nn.Module): self.grad_scale = grad_scale - def forward(self, x: Tensor) -> Tensor: + def forward(self, + x: Tensor) -> Tensor: """ In the forward pass, this function just returns the input unmodified. In the backward pass, it will modify the gradients to ensure that the @@ -757,21 +735,19 @@ class Whiten(nn.Module): if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0: return _no_op(x) else: - if hasattr(self, "min_prob") and random.random() < 0.25: + if hasattr(self, 'min_prob') and random.random() < 0.25: # occasionally switch between min_prob and max_prob, based on whether # we are above or below the threshold. - if ( - _whitening_metric(x.to(torch.float32), self.num_groups) - > self.whitening_limit - ): + if _whitening_metric(x.to(torch.float32), self.num_groups) > self.whitening_limit: # there would be a change to the grad. self.prob = self.max_prob else: self.prob = self.min_prob - return WhiteningPenaltyFunction.apply( - x, self.num_groups, self.whitening_limit, self.grad_scale - ) + return WhiteningPenaltyFunction.apply(x, + self.num_groups, + self.whitening_limit, + self.grad_scale) class WithLoss(torch.autograd.Function): @@ -779,14 +755,11 @@ class WithLoss(torch.autograd.Function): def forward(ctx, x: Tensor, y: Tensor): ctx.y_shape = y.shape return x - @staticmethod def backward(ctx, ans_grad: Tensor): - return ans_grad, torch.ones( - ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device - ) - - + return ans_grad, torch.ones(ctx.y_shape, + dtype=ans_grad.dtype, + device=ans_grad.device) def with_loss(x, y): if torch.jit.is_scripting(): return x @@ -795,7 +768,7 @@ def with_loss(x, y): def _no_op(x: Tensor) -> Tensor: - if torch.jit.is_scripting(): + if (torch.jit.is_scripting()): return x else: # a no-op function that will have a node in the autograd graph, @@ -810,7 +783,6 @@ class Identity(torch.nn.Module): def forward(self, x): return _no_op(x) - class MaxEig(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to discourage @@ -831,14 +803,13 @@ class MaxEig(torch.nn.Module): scale: determines the scale with which we modify the gradients, relative to the existing / unmodified gradients """ - def __init__( - self, - num_channels: int, - channel_dim: int, - max_var_per_eig: float = 0.2, - min_prob: float = 0.01, - scale: float = 0.01, + self, + num_channels: int, + channel_dim: int, + max_var_per_eig: float = 0.2, + min_prob: float = 0.01, + scale: float = 0.01, ): super(MaxEig, self).__init__() self.num_channels = num_channels @@ -854,7 +825,7 @@ class MaxEig(torch.nn.Module): # random parameters unchanged for comparison direction = torch.arange(num_channels).to(torch.float) direction = direction / direction.norm() - self.register_buffer("max_eig_direction", direction) + self.register_buffer('max_eig_direction', direction) self.min_prob = min_prob # cur_prob is the current probability we'll use to apply the ActivationBalancer. @@ -862,12 +833,12 @@ class MaxEig(torch.nn.Module): # active. self.cur_prob = 1.0 + + def forward(self, x: Tensor) -> Tensor: - if ( - torch.jit.is_scripting() - or self.max_var_per_eig <= 0 - or random.random() > self.cur_prob - ): + if (torch.jit.is_scripting() or + self.max_var_per_eig <= 0 or + random.random() > self.cur_prob): return _no_op(x) with torch.cuda.amp.autocast(enabled=False): @@ -877,9 +848,7 @@ class MaxEig(torch.nn.Module): with torch.no_grad(): x = x.transpose(self.channel_dim, -1).reshape(-1, self.num_channels) x = x - x.mean(dim=0) - new_direction, coeffs = self._find_direction_coeffs( - x, self.max_eig_direction - ) + new_direction, coeffs = self._find_direction_coeffs(x, self.max_eig_direction) x_var = (x**2).mean() x_residual = x - coeffs * new_direction x_residual_var = (x_residual**2).mean() @@ -892,10 +861,7 @@ class MaxEig(torch.nn.Module): self._set_direction(0.1 * self.max_eig_direction + new_direction) if random.random() < 0.01 or __name__ == "__main__": - logging.info( - f"variance_proportion = {variance_proportion.item()}," - f" shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}" - ) + logging.info(f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}") if variance_proportion >= self.max_var_per_eig: # The constraint is active. Note, we should quite rarely @@ -903,16 +869,17 @@ class MaxEig(torch.nn.Module): # starting to diverge, should this constraint be active. cur_prob = self.cur_prob self.cur_prob = 1.0 # next time, do the update with probability 1.0. - return MaxEigLimiterFunction.apply( - orig_x, coeffs, new_direction, self.channel_dim, self.scale - ) + return MaxEigLimiterFunction.apply(orig_x, coeffs, new_direction, + self.channel_dim, self.scale) else: # let self.cur_prob exponentially approach self.min_prob, as # long as the constraint is inactive. self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob return orig_x - def _set_direction(self, direction: Tensor): + + def _set_direction(self, + direction: Tensor): """ Sets self.max_eig_direction to a normalized version of `direction` """ @@ -922,39 +889,40 @@ class MaxEig(torch.nn.Module): if direction_sum - direction_sum == 0: # no inf/nan self.max_eig_direction[:] = direction else: - logging.info( - f"Warning: sum of direction in MaxEig is {direction_sum}, " - "num_channels={self.num_channels}, channel_dim={self.channel_dim}" - ) + logging.info(f"Warning: sum of direction in MaxEig is {direction_sum}, " + "num_channels={self.num_channels}, channel_dim={self.channel_dim}") - def _find_direction_coeffs( - self, x: Tensor, prev_direction: Tensor - ) -> Tuple[Tensor, Tensor, Tensor]: - """ - Figure out (an approximation to) the proportion of the variance of a set of - feature vectors that can be attributed to the top eigen-direction. - Args: - x: a Tensor of shape (num_frames, num_channels), with num_frames > 1. - prev_direction: a Tensor of shape (num_channels,), that is our previous estimate - of the top eigen-direction, or a random direction if this is the first - iteration. Does not have to be normalized, but should be nonzero. - Returns: (cur_direction, coeffs), where: - cur_direction: a Tensor of shape (num_channels,) that is the current - estimate of the top eigen-direction. - coeffs: a Tensor of shape (num_frames, 1) that minimizes, or - approximately minimizes, (x - coeffs * cur_direction).norm() + def _find_direction_coeffs(self, + x: Tensor, + prev_direction: Tensor) -> Tuple[Tensor, Tensor, Tensor]: """ + Figure out (an approximation to) the proportion of the variance of a set of + feature vectors that can be attributed to the top eigen-direction. + Args: + x: a Tensor of shape (num_frames, num_channels), with num_frames > 1. + prev_direction: a Tensor of shape (num_channels,), that is our previous estimate + of the top eigen-direction, or a random direction if this is the first + iteration. Does not have to be normalized, but should be nonzero. + + Returns: (cur_direction, coeffs), where: + cur_direction: a Tensor of shape (num_channels,) that is the current + estimate of the top eigen-direction. + coeffs: a Tensor of shape (num_frames, 1) that minimizes, or + approximately minimizes, (x - coeffs * cur_direction).norm() + """ (num_frames, num_channels) = x.shape assert num_channels > 1 and num_frames > 1 assert prev_direction.shape == (num_channels,) # `coeffs` are the coefficients of `prev_direction` in x. # actually represent the coeffs up to a constant positive factor. coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10 - cur_direction = (x * coeffs).sum(dim=0) / ((coeffs**2).sum() + 1.0e-20) + cur_direction = (x * coeffs).sum(dim=0) / ((coeffs ** 2).sum() + 1.0e-20) return cur_direction, coeffs + + class DoubleSwishFunction(torch.autograd.Function): """ double_swish(x) = x * torch.sigmoid(x-1) @@ -982,7 +950,7 @@ class DoubleSwishFunction(torch.autograd.Function): y = x * s if requires_grad: - deriv = y * (1 - s) + s + deriv = (y * (1 - s) + s) # notes on derivative of x * sigmoid(x - 1): # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund @@ -991,9 +959,7 @@ class DoubleSwishFunction(torch.autograd.Function): # floors), should be expectation-preserving. floor = -0.043637 ceil = 1.2 - d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like( - deriv - ) + d_scaled = ((deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv)) if __name__ == "__main__": # for self-testing only. assert d_scaled.min() >= 0.0 @@ -1006,12 +972,12 @@ class DoubleSwishFunction(torch.autograd.Function): @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: - (d,) = ctx.saved_tensors + d, = ctx.saved_tensors # the same constants as used in forward pass. floor = -0.043637 ceil = 1.2 - d = d * ((ceil - floor) / 255.0) + floor - return y_grad * d + d = (d * ((ceil - floor) / 255.0) + floor) + return (y_grad * d) class DoubleSwish(torch.nn.Module): @@ -1024,6 +990,7 @@ class DoubleSwish(torch.nn.Module): return DoubleSwishFunction.apply(x) + def _test_max_eig(): for proportion in [0.1, 0.5, 10.0]: logging.info(f"proportion = {proportion}") @@ -1035,9 +1002,11 @@ def _test_max_eig(): x.requires_grad = True num_channels = 128 - m = MaxEig( - num_channels, 1, 0.5, scale=0.1 # channel_dim # max_var_per_eig - ) # grad_scale + m = MaxEig(num_channels, + 1, # channel_dim + 0.5, # max_var_per_eig + scale=0.1) # grad_scale + for _ in range(4): y = m(x) @@ -1062,9 +1031,11 @@ def _test_whiten(): x.requires_grad = True num_channels = 128 - m = Whiten( - 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit, - ) # grad_scale + m = Whiten(1, # num_groups + 5.0, # whitening_limit, + prob=1.0, + grad_scale=0.1) # grad_scale + for _ in range(4): y = m(x) @@ -1078,6 +1049,7 @@ def _test_whiten(): assert not torch.allclose(x.grad, y_grad) + def _test_activation_balancer_sign(): probs = torch.arange(0, 1, 0.01) N = 1000 @@ -1105,7 +1077,9 @@ def _test_activation_balancer_sign(): def _test_activation_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( + -1 + ) x = x.detach() x.requires_grad = True m = ActivationBalancer( @@ -1137,8 +1111,8 @@ def _test_basic_norm(): y = m(x) assert y.shape == x.shape - x_rms = (x**2).mean().sqrt() - y_rms = (y**2).mean().sqrt() + x_rms = (x ** 2).mean().sqrt() + y_rms = (y ** 2).mean().sqrt() print("x rms = ", x_rms) print("y rms = ", y_rms) assert y_rms < x_rms @@ -1150,27 +1124,30 @@ def _test_double_swish_deriv(): x.requires_grad = True m = DoubleSwish() - tol = (1.2 - (-0.043637)) / 255.0 + tol = ((1.2-(-0.043637))/255.0) torch.autograd.gradcheck(m, x, atol=tol) + # for self-test. x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 x.requires_grad = True y = m(x) + def _test_softmax(): a = torch.randn(2, 10, dtype=torch.float64) b = a.clone() a.requires_grad = True b.requires_grad = True - a.softmax(dim=1)[:, 0].sum().backward() + a.softmax(dim=1)[:,0].sum().backward() print("a grad = ", a.grad) - softmax(b, dim=1)[:, 0].sum().backward() + softmax(b, dim=1)[:,0].sum().backward() print("b grad = ", b.grad) assert torch.allclose(a.grad, b.grad) + if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py index 46e775285..8d357b15f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py @@ -26,7 +26,11 @@ from typing import List import torch import torch.nn as nn -from scaling import ActivationBalancer, BasicNorm, Whiten +from scaling import ( + ActivationBalancer, + BasicNorm, + Whiten, +) class NonScaledNorm(nn.Module): @@ -71,10 +75,12 @@ def get_submodule(model, target): mod: torch.nn.Module = model for item in atoms: if not hasattr(mod, item): - raise AttributeError(mod._get_name() + " has no attribute `" + item + "`") + raise AttributeError( + mod._get_name() + " has no " "attribute `" + item + "`" + ) mod = getattr(mod, item) if not isinstance(mod, torch.nn.Module): - raise AttributeError("`" + item + "` is not an nn.Module") + raise AttributeError("`" + item + "` is not " "an nn.Module") return mod diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 7f9526104..3f27736b3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -84,7 +84,9 @@ from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: @@ -122,10 +124,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-dims", type=str, default="384,384,384,384,384", - help=( - "Embedding dimension in the 2 blocks of zipformer encoder layers, comma" - " separated" - ), + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", ) parser.add_argument( @@ -140,11 +139,9 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-unmasked-dims", type=str, default="256,256,256,256,256", - help=( - "Unmasked dimensions in the encoders, relates to augmentation during" - " training. Must be <= each of encoder_dims. Empirically, less than 256" - " seems to make performance worse." - ), + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", ) parser.add_argument( @@ -272,45 +269,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -652,7 +646,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -699,7 +697,9 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -870,7 +870,9 @@ def train_one_epoch( # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + if cur_grad_scale < 1.0 or ( + cur_grad_scale < 8.0 and batch_idx % 400 == 0 + ): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") @@ -888,7 +890,11 @@ def train_one_epoch( f"batch {batch_idx}, loss[{loss_info}], " f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + + ( + f"grad_scale: {scaler._scale.item()}" + if params.use_fp16 + else "" + ) ) if tb_writer is not None: @@ -899,7 +905,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if params.use_fp16: tb_writer.add_scalar( "train/grad_scale", @@ -907,7 +915,10 @@ def train_one_epoch( params.batch_idx_train, ) - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + if ( + batch_idx % params.valid_interval == 0 + and not params.print_diagnostics + ): logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, @@ -919,8 +930,7 @@ def train_one_epoch( model.train() logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info( - "Maximum memory allocated so far is" - f" {torch.cuda.max_memory_allocated()//1000000}MB" + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" ) if tb_writer is not None: valid_info.write_summary( @@ -999,7 +1009,9 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0) + optimizer = ScaledAdam( + model.parameters(), lr=params.base_lr, clipping_scale=2.0 + ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) @@ -1017,7 +1029,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 2 ** 22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1042,7 +1054,8 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False @@ -1216,8 +1229,7 @@ def scan_pessimistic_batches_for_oom( display_and_save_batch(batch, params=params, sp=sp) raise logging.info( - "Maximum memory allocated so far is" - f" {torch.cuda.max_memory_allocated()//1000000}MB" + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index fcd9858cd..023dec97d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -16,35 +16,32 @@ # limitations under the License. import copy -import itertools -import logging import math -import random import warnings +import itertools from typing import List, Optional, Tuple, Union - +import logging import torch +import random from encoder_interface import EncoderInterface -from scaling import ( - ScaledLinear, # not as in other dirs.. just scales down initial parameter values. -) from scaling import ( ActivationBalancer, BasicNorm, - DoubleSwish, - Identity, MaxEig, + DoubleSwish, ScaledConv1d, + ScaledLinear, # not as in other dirs.. just scales down initial parameter values. Whiten, + Identity, _diag, - penalize_abs_values_gt, random_clamp, + penalize_abs_values_gt, softmax, ) from torch import Tensor, nn -from icefall.dist import get_rank from icefall.utils import make_pad_mask +from icefall.dist import get_rank class Zipformer(EncoderInterface): @@ -92,7 +89,7 @@ class Zipformer(EncoderInterface): self.batch_count = 0 self.warmup_end = warmup_batches - for u, d in zip(encoder_unmasked_dims, encoder_dims): + for u,d in zip(encoder_unmasked_dims, encoder_dims): assert u <= d, (u, d) # self.encoder_embed converts the input of shape (N, T, num_features) @@ -100,9 +97,9 @@ class Zipformer(EncoderInterface): # That is, it does two things simultaneously: # (1) subsampling: T -> (T - 7)//2 # (2) embedding: num_features -> encoder_dims - self.encoder_embed = Conv2dSubsampling( - num_features, encoder_dims[0], dropout=dropout - ) + self.encoder_embed = Conv2dSubsampling(num_features, encoder_dims[0], + dropout=dropout) + # each one will be ZipformerEncoder or DownsampledZipformerEncoder encoders = [] @@ -126,13 +123,13 @@ class Zipformer(EncoderInterface): num_encoder_layers[i], dropout, warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), - warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), + warmup_end=warmup_batches * (i + 2) / (num_encoders + 1) ) if zipformer_downsampling_factors[i] != 1: encoder = DownsampledZipformerEncoder( encoder, - input_dim=encoder_dims[i - 1] if i > 0 else encoder_dims[0], + input_dim=encoder_dims[i-1] if i > 0 else encoder_dims[0], output_dim=encoder_dims[i], downsample=zipformer_downsampling_factors[i], ) @@ -142,11 +139,10 @@ class Zipformer(EncoderInterface): # initializes self.skip_layers and self.skip_modules self._init_skip_modules() - self.downsample_output = AttentionDownsample( - encoder_dims[-1], - encoder_dims[-1], - downsample=output_downsampling_factor, - ) + self.downsample_output = AttentionDownsample(encoder_dims[-1], + encoder_dims[-1], + downsample=output_downsampling_factor) + def _get_layer_skip_dropout_prob(self): if not self.training: @@ -170,33 +166,27 @@ class Zipformer(EncoderInterface): skip_modules = [] z = self.zipformer_downsampling_factors for i in range(len(z)): - if i <= 1 or z[i - 1] <= z[i]: + if i <= 1 or z[i-1] <= z[i]: skip_layers.append(None) skip_modules.append(SimpleCombinerIdentity()) else: # TEMP - for j in range(i - 2, -1, -1): + for j in range(i-2, -1, -1): if z[j] <= z[i] or j == 0: # TEMP logging statement. - logging.info( - f"At encoder stack {i}, which has" - f" downsampling_factor={z[i]}, we will combine the outputs" - f" of layers {j} and {i-1}, with" - f" downsampling_factors={z[j]} and {z[i-1]}." - ) + logging.info(f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will " + f"combine the outputs of layers {j} and {i-1}, with downsampling_factors={z[j]} and {z[i-1]}.") skip_layers.append(j) - skip_modules.append( - SimpleCombiner( - self.encoder_dims[j], - self.encoder_dims[i - 1], - min_weight=(0.0, 0.25), - ) - ) + skip_modules.append(SimpleCombiner(self.encoder_dims[j], + self.encoder_dims[i-1], + min_weight=(0.0,0.25))) break self.skip_layers = skip_layers self.skip_modules = nn.ModuleList(skip_modules) - def get_feature_masks(self, x: torch.Tensor) -> List[float]: + def get_feature_masks( + self, + x: torch.Tensor) -> List[float]: # Note: The actual return type is Union[List[float], List[Tensor]], # but to make torch.jit.script() work, we use List[float] """ @@ -216,56 +206,46 @@ class Zipformer(EncoderInterface): """ num_encoders = len(self.encoder_dims) if torch.jit.is_scripting() or not self.training: - return [1.0] * num_encoders + return [ 1.0 ] * num_encoders (num_frames0, batch_size, _encoder_dims0) = x.shape - assert self.encoder_dims[0] == _encoder_dims0, ( - self.encoder_dims, - _encoder_dims0, - ) + + assert self.encoder_dims[0] == _encoder_dims0, (self.encoder_dims, _encoder_dims0) max_downsampling_factor = max(self.zipformer_downsampling_factors) - num_frames_max = num_frames0 + max_downsampling_factor - 1 + num_frames_max = (num_frames0 + max_downsampling_factor - 1) + feature_mask_dropout_prob = 0.15 # frame_mask_max shape: (num_frames_max, batch_size, 1) - frame_mask_max = ( - torch.rand(num_frames_max, batch_size, 1, device=x.device) - > feature_mask_dropout_prob - ).to(x.dtype) + frame_mask_max = (torch.rand(num_frames_max, batch_size, 1, + device=x.device) > + feature_mask_dropout_prob).to(x.dtype) feature_masks = [] for i in range(num_encoders): ds = self.zipformer_downsampling_factors[i] - upsample_factor = max_downsampling_factor // ds + upsample_factor = (max_downsampling_factor // ds) - frame_mask = ( - frame_mask_max.unsqueeze(1) - .expand(num_frames_max, upsample_factor, batch_size, 1) - .reshape(num_frames_max * upsample_factor, batch_size, 1) - ) + frame_mask = (frame_mask_max.unsqueeze(1).expand(num_frames_max, upsample_factor, + batch_size, 1) + .reshape(num_frames_max * upsample_factor, batch_size, 1)) num_frames = (num_frames0 + ds - 1) // ds frame_mask = frame_mask[:num_frames] - feature_mask = torch.ones( - num_frames, - batch_size, - self.encoder_dims[i], - dtype=x.dtype, - device=x.device, - ) + feature_mask = torch.ones(num_frames, batch_size, self.encoder_dims[i], + dtype=x.dtype, device=x.device) u = self.encoder_unmasked_dims[i] feature_mask[:, :, u:] *= frame_mask feature_masks.append(feature_mask) return feature_masks + def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, + self, x: torch.Tensor, x_lens: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -285,19 +265,13 @@ class Zipformer(EncoderInterface): x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) lengths = (x_lens - 7) >> 1 - assert x.size(0) == lengths.max().item(), ( - x.shape, - lengths, - lengths.max(), - ) + assert x.size(0) == lengths.max().item(), (x.shape, lengths, lengths.max()) mask = make_pad_mask(lengths) outputs = [] feature_masks = self.get_feature_masks(x) - for i, (module, skip_module) in enumerate( - zip(self.encoders, self.skip_modules) - ): + for i, (module, skip_module) in enumerate(zip(self.encoders, self.skip_modules)): ds = self.zipformer_downsampling_factors[i] k = self.skip_layers[i] if isinstance(k, int): @@ -306,11 +280,9 @@ class Zipformer(EncoderInterface): x = skip_module(outputs[k], x) elif (not self.training) or random.random() > layer_skip_dropout_prob: x = skip_module(outputs[k], x) - x = module( - x, - feature_mask=feature_masks[i], - src_key_padding_mask=None if mask is None else mask[..., ::ds], - ) + x = module(x, + feature_mask=feature_masks[i], + src_key_padding_mask=None if mask is None else mask[...,::ds]) outputs.append(x) x = self.downsample_output(x) @@ -340,16 +312,15 @@ class ZipformerEncoderLayer(nn.Module): >>> pos_emb = torch.rand(32, 19, 512) >>> out = encoder_layer(src, pos_emb) """ - def __init__( - self, - d_model: int, - attention_dim: int, - nhead: int, - feedforward_dim: int = 2048, - dropout: float = 0.1, - cnn_module_kernel: int = 31, - pos_dim: int = 4, + self, + d_model: int, + attention_dim: int, + nhead: int, + feedforward_dim: int = 2048, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + pos_dim: int = 4, ) -> None: super(ZipformerEncoderLayer, self).__init__() @@ -359,24 +330,29 @@ class ZipformerEncoderLayer(nn.Module): self.batch_count = 0 self.self_attn = RelPositionMultiheadAttention( - d_model, - attention_dim, - nhead, - pos_dim, - dropout=0.0, + d_model, attention_dim, nhead, pos_dim, dropout=0.0, ) self.pooling = PoolingModule(d_model) - self.feed_forward1 = FeedforwardModule(d_model, feedforward_dim, dropout) + self.feed_forward1 = FeedforwardModule(d_model, + feedforward_dim, + dropout) - self.feed_forward2 = FeedforwardModule(d_model, feedforward_dim, dropout) + self.feed_forward2 = FeedforwardModule(d_model, + feedforward_dim, + dropout) - self.feed_forward3 = FeedforwardModule(d_model, feedforward_dim, dropout) + self.feed_forward3 = FeedforwardModule(d_model, + feedforward_dim, + dropout) - self.conv_module1 = ConvolutionModule(d_model, cnn_module_kernel) - self.conv_module2 = ConvolutionModule(d_model, cnn_module_kernel) + self.conv_module1 = ConvolutionModule(d_model, + cnn_module_kernel) + + self.conv_module2 = ConvolutionModule(d_model, + cnn_module_kernel) self.norm_final = BasicNorm(d_model) @@ -384,18 +360,14 @@ class ZipformerEncoderLayer(nn.Module): # try to ensure the output is close to zero-mean (or at least, zero-median). self.balancer = ActivationBalancer( - d_model, - channel_dim=-1, - min_positive=0.45, - max_positive=0.55, + d_model, channel_dim=-1, + min_positive=0.45, max_positive=0.55, max_abs=6.0, ) - self.whiten = Whiten( - num_groups=1, - whitening_limit=5.0, - prob=(0.025, 0.25), - grad_scale=0.01, - ) + self.whiten = Whiten(num_groups=1, + whitening_limit=5.0, + prob=(0.025, 0.25), + grad_scale=0.01) def get_bypass_scale(self): if torch.jit.is_scripting() or not self.training: @@ -410,9 +382,8 @@ class ZipformerEncoderLayer(nn.Module): if self.batch_count > warmup_period: clamp_min = final_clamp_min else: - clamp_min = initial_clamp_min - (self.batch_count / warmup_period) * ( - initial_clamp_min - final_clamp_min - ) + clamp_min = (initial_clamp_min - + (self.batch_count / warmup_period) * (initial_clamp_min - final_clamp_min)) return self.bypass_scale.clamp(min=clamp_min, max=1.0) def get_dynamic_dropout_rate(self): @@ -427,9 +398,8 @@ class ZipformerEncoderLayer(nn.Module): if self.batch_count > warmup_period: return final_dropout_rate else: - return initial_dropout_rate - ( - initial_dropout_rate * final_dropout_rate - ) * (self.batch_count / warmup_period) + return (initial_dropout_rate - + (initial_dropout_rate * final_dropout_rate) * (self.batch_count / warmup_period)) def forward( self, @@ -538,14 +508,13 @@ class ZipformerEncoder(nn.Module): >>> src = torch.rand(10, 32, 512) >>> out = zipformer_encoder(src) """ - def __init__( - self, - encoder_layer: nn.Module, - num_layers: int, - dropout: float, - warmup_begin: float, - warmup_end: float, + self, + encoder_layer: nn.Module, + num_layers: int, + dropout: float, + warmup_begin: float, + warmup_end: float ) -> None: super().__init__() # will be written to, see set_batch_count() Note: in inference time this @@ -559,7 +528,8 @@ class ZipformerEncoder(nn.Module): # so that we can keep this consistent across worker tasks (for efficiency). self.module_seed = torch.randint(0, 1000, ()).item() - self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model, dropout) + self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model, + dropout) self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] @@ -568,13 +538,15 @@ class ZipformerEncoder(nn.Module): assert 0 <= warmup_begin <= warmup_end, (warmup_begin, warmup_end) - delta = (1.0 / num_layers) * (warmup_end - warmup_begin) + + delta = (1. / num_layers) * (warmup_end - warmup_begin) cur_begin = warmup_begin for i in range(num_layers): self.layers[i].warmup_begin = cur_begin cur_begin += delta self.layers[i].warmup_end = cur_begin + def get_layers_to_drop(self, rnd_seed: int): ans = set() if not self.training: @@ -607,14 +579,12 @@ class ZipformerEncoder(nn.Module): # linearly interpolate t = (batch_count - layer_warmup_begin) / layer_warmup_end assert 0.0 <= t < 1.001, t - return initial_layerdrop_prob + t * ( - final_layerdrop_prob - initial_layerdrop_prob - ) + return initial_layerdrop_prob + t * (final_layerdrop_prob - initial_layerdrop_prob) shared_rng = random.Random(batch_count + self.module_seed) independent_rng = random.Random(rnd_seed) - layerdrop_probs = [get_layerdrop_prob(i) for i in range(num_layers)] + layerdrop_probs = [ get_layerdrop_prob(i) for i in range(num_layers) ] tot = sum(layerdrop_probs) # Instead of drawing the samples independently, we first randomly decide # how many layers to drop out, using the same random number generator between @@ -634,13 +604,11 @@ class ZipformerEncoder(nn.Module): if len(ans) == num_to_drop: break if shared_rng.random() < 0.005 or __name__ == "__main__": - logging.info( - f"warmup_begin={self.warmup_begin:.1f}," - f" warmup_end={self.warmup_end:.1f}, batch_count={batch_count:.1f}," - f" num_to_drop={num_to_drop}, layers_to_drop={ans}" - ) + logging.info(f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, " + f"batch_count={batch_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}") return ans + def forward( self, src: Tensor, @@ -671,6 +639,7 @@ class ZipformerEncoder(nn.Module): pos_emb = self.encoder_pos(src) output = src + if torch.jit.is_scripting(): layers_to_drop = [] else: @@ -701,31 +670,28 @@ class DownsampledZipformerEncoder(nn.Module): after convolutional downsampling, and then upsampled again at the output, and combined with the origin input, so that the output has the same shape as the input. """ - - def __init__( - self, - encoder: nn.Module, - input_dim: int, - output_dim: int, - downsample: int, - ): + def __init__(self, + encoder: nn.Module, + input_dim: int, + output_dim: int, + downsample: int): super(DownsampledZipformerEncoder, self).__init__() self.downsample_factor = downsample self.downsample = AttentionDownsample(input_dim, output_dim, downsample) self.encoder = encoder self.upsample = SimpleUpsample(output_dim, downsample) - self.out_combiner = SimpleCombiner( - input_dim, output_dim, min_weight=(0.0, 0.25) - ) + self.out_combiner = SimpleCombiner(input_dim, + output_dim, + min_weight=(0.0, 0.25)) - def forward( - self, - src: Tensor, - # Note: the type of feature_mask should be Unino[float, Tensor], - # but to make torch.jit.script() happ, we use float here - feature_mask: float = 1.0, - mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, + + def forward(self, + src: Tensor, + # Note: the type of feature_mask should be Unino[float, Tensor], + # but to make torch.jit.script() happ, we use float here + feature_mask: float = 1.0, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: r"""Downsample, go through encoder, upsample. @@ -752,43 +718,42 @@ class DownsampledZipformerEncoder(nn.Module): src = self.downsample(src) ds = self.downsample_factor if mask is not None: - mask = mask[::ds, ::ds] + mask = mask[::ds,::ds] src = self.encoder( - src, - feature_mask=feature_mask, - mask=mask, - src_key_padding_mask=mask, + src, feature_mask=feature_mask, mask=mask, src_key_padding_mask=mask, ) src = self.upsample(src) # remove any extra frames that are not a multiple of downsample_factor - src = src[: src_orig.shape[0]] + src = src[:src_orig.shape[0]] return self.out_combiner(src_orig, src) - class AttentionDownsample(torch.nn.Module): """ Does downsampling with attention, by weighted sum, and a projection.. """ - - def __init__(self, in_channels: int, out_channels: int, downsample: int): + def __init__(self, + in_channels: int, + out_channels: int, + downsample: int): """ Require out_channels > in_channels. """ super(AttentionDownsample, self).__init__() - self.query = nn.Parameter(torch.randn(in_channels) * (in_channels**-0.5)) + self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5)) # fill in the extra dimensions with a projection of the input if out_channels > in_channels: - self.extra_proj = nn.Linear( - in_channels * downsample, out_channels - in_channels, bias=False - ) + self.extra_proj = nn.Linear(in_channels * downsample, + out_channels - in_channels, + bias=False) else: self.extra_proj = None self.downsample = downsample - def forward(self, src: Tensor) -> Tensor: + def forward(self, + src: Tensor) -> Tensor: """ x: (seq_len, batch_size, in_channels) Returns a tensor of shape @@ -802,14 +767,16 @@ class AttentionDownsample(torch.nn.Module): if seq_len != d_seq_len * ds: # right-pad src, repeating the last element. pad = d_seq_len * ds - seq_len - src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) + src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2]) src = torch.cat((src, src_extra), dim=0) assert src.shape[0] == d_seq_len * ds, (src.shape[0], d_seq_len, ds) src = src.reshape(d_seq_len, ds, batch_size, in_channels) scores = (src * self.query).sum(dim=-1, keepdim=True) - scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) + scores = penalize_abs_values_gt(scores, + limit=10.0, + penalty=1.0e-04) weights = scores.softmax(dim=1) @@ -828,12 +795,14 @@ class SimpleUpsample(torch.nn.Module): A very simple form of upsampling that mostly just repeats the input, but also adds a position-specific bias. """ - - def __init__(self, num_channels: int, upsample: int): + def __init__(self, + num_channels: int, + upsample: int): super(SimpleUpsample, self).__init__() self.bias = nn.Parameter(torch.randn(upsample, num_channels) * 0.01) - def forward(self, src: Tensor) -> Tensor: + def forward(self, + src: Tensor) -> Tensor: """ x: (seq_len, batch_size, num_channels) Returns a tensor of shape @@ -846,7 +815,6 @@ class SimpleUpsample(torch.nn.Module): src = src.reshape(seq_len * upsample, batch_size, num_channels) return src - class SimpleCombinerIdentity(nn.Module): def __init__(self, *args, **kwargs): super().__init__() @@ -854,7 +822,6 @@ class SimpleCombinerIdentity(nn.Module): def forward(self, src1: Tensor, src2: Tensor) -> Tensor: return src1 - class SimpleCombiner(torch.nn.Module): """ A very simple way of combining 2 vectors of 2 different dims, via a @@ -864,14 +831,18 @@ class SimpleCombiner(torch.nn.Module): dim2: the dimension of the second input, e.g. 384. The output will have the same dimension as dim2. """ - - def __init__(self, dim1: int, dim2: int, min_weight: Tuple[float] = (0.0, 0.0)): + def __init__(self, + dim1: int, + dim2: int, + min_weight: Tuple[float] = (0., 0.)): super(SimpleCombiner, self).__init__() assert dim2 >= dim1, (dim2, dim1) self.weight1 = nn.Parameter(torch.zeros(())) self.min_weight = min_weight - def forward(self, src1: Tensor, src2: Tensor) -> Tensor: + def forward(self, + src1: Tensor, + src2: Tensor) -> Tensor: """ src1: (*, dim1) src2: (*, dim2) @@ -882,14 +853,10 @@ class SimpleCombiner(torch.nn.Module): weight1 = self.weight1 if not torch.jit.is_scripting(): - if ( - self.training - and random.random() < 0.25 - and self.min_weight != (0.0, 0.0) - ): - weight1 = weight1.clamp( - min=self.min_weight[0], max=1.0 - self.min_weight[1] - ) + if self.training and random.random() < 0.25 and self.min_weight != (0., 0.): + weight1 = weight1.clamp(min=self.min_weight[0], + max=1.0-self.min_weight[1]) + src1 = src1 * weight1 src2 = src2 * (1.0 - weight1) @@ -902,9 +869,12 @@ class SimpleCombiner(torch.nn.Module): else: src1 = src1[:src2_dim] + return src1 + src2 + + class RelPositionalEncoding(torch.nn.Module): """Relative positional encoding module. @@ -918,7 +888,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct a PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -933,7 +905,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(0) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vecotr and `j` means the @@ -981,6 +955,7 @@ class RelPositionalEncoding(torch.nn.Module): return self.dropout(pos_emb) + class RelPositionMultiheadAttention(nn.Module): r"""Multi-Head Attention layer with relative position encoding @@ -1017,46 +992,34 @@ class RelPositionMultiheadAttention(nn.Module): self.head_dim = attention_dim // num_heads self.pos_dim = pos_dim assert self.head_dim % 2 == 0, self.head_dim - assert self.head_dim * num_heads == attention_dim, ( - self.head_dim, - num_heads, - attention_dim, - ) + assert ( + self.head_dim * num_heads == attention_dim + ), (self.head_dim, num_heads, attention_dim) # the initial_scale is supposed to take over the "scaling" factor of # head_dim ** -0.5, dividing it between the query and key. - in_proj_dim = ( - 2 * attention_dim - + attention_dim // 2 # query, key - + pos_dim * num_heads # value - ) # positional encoding query + in_proj_dim = (2 * attention_dim + # query, key + attention_dim // 2 + # value + pos_dim * num_heads) # positional encoding query - self.in_proj = ScaledLinear( - embed_dim, - in_proj_dim, - bias=True, - initial_scale=self.head_dim**-0.25, - ) + self.in_proj = ScaledLinear(embed_dim, in_proj_dim, bias=True, + initial_scale=self.head_dim**-0.25) # self.whiten_values is applied on the values in forward(); # it just copies the keys but prevents low-rank distribution by modifying grads. - self.whiten_values = Whiten( - num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025, - ) - self.whiten_keys = Whiten( - num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025, - ) + self.whiten_values = Whiten(num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025) + self.whiten_keys = Whiten(num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025) + # linear transformation for positional encoding. - self.linear_pos = ScaledLinear( - embed_dim, num_heads * pos_dim, bias=False, initial_scale=0.05 - ) + self.linear_pos = ScaledLinear(embed_dim, num_heads * pos_dim, bias=False, + initial_scale=0.05) # the following are for diagnosics only, see --print-diagnostics option. # they only copy their inputs. @@ -1068,16 +1031,14 @@ class RelPositionMultiheadAttention(nn.Module): ) self.in_proj2 = nn.Linear(embed_dim, attention_dim // 2, bias=False) - self.out_proj2 = ScaledLinear( - attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 - ) + self.out_proj2 = ScaledLinear(attention_dim // 2, embed_dim, bias=True, + initial_scale=0.05) # self.whiten_values2 is applied on the values in forward2() - self.whiten_values2 = Whiten( - num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025, - ) + self.whiten_values2 = Whiten(num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025) + def forward( self, @@ -1137,6 +1098,7 @@ class RelPositionMultiheadAttention(nn.Module): ) return x, weights + def multi_head_attention_forward( self, x_proj: Tensor, @@ -1194,24 +1156,26 @@ class RelPositionMultiheadAttention(nn.Module): head_dim = attention_dim // num_heads pos_dim = self.pos_dim # positional-encoding dim per head - assert head_dim * num_heads == attention_dim, ( - f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}," - f" {attention_dim}" - ) + assert ( + head_dim * num_heads == attention_dim + ), f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}, {attention_dim}" + # self-attention - q = x_proj[..., 0:attention_dim] - k = x_proj[..., attention_dim : 2 * attention_dim] + q = x_proj[...,0:attention_dim] + k = x_proj[...,attention_dim:2*attention_dim] value_dim = attention_dim // 2 - v = x_proj[..., 2 * attention_dim : 2 * attention_dim + value_dim] + v = x_proj[...,2*attention_dim:2*attention_dim+value_dim] # p is the position-encoding query, its dimension is num_heads*pos_dim.. - p = x_proj[..., 2 * attention_dim + value_dim :] + p = x_proj[...,2*attention_dim+value_dim:] + k = self.whiten_keys(k) # does nothing in the forward pass. v = self.whiten_values(v) # does nothing in the forward pass. q = self.copy_query(q) # for diagnostics only, does nothing. p = self.copy_pos_query(p) # for diagnostics only, does nothing. + if attn_mask is not None: assert ( attn_mask.dtype == torch.float32 @@ -1231,25 +1195,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, seq_len, seq_len]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, seq_len, seq_len, ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -1258,6 +1230,7 @@ class RelPositionMultiheadAttention(nn.Module): k = k.reshape(seq_len, bsz, num_heads, head_dim) v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) + if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz, "{} == {}".format( key_padding_mask.size(0), bsz @@ -1266,10 +1239,13 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask.size(1), seq_len ) + + q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim) p = p.permute(1, 2, 0, 3) # (batch, head, time1, pos_dim) k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) + seq_len2 = 2 * seq_len - 1 pos = pos.reshape(1, seq_len2, num_heads, pos_dim).permute(0, 2, 3, 1) # pos shape now: (batch, head, pos_dim, seq_len2) @@ -1280,16 +1256,13 @@ class RelPositionMultiheadAttention(nn.Module): # the following .as_strided() expression converts the last axis of pos_weights from relative # to absolute position. I don't know whether I might have got the time-offsets backwards or # not, but let this code define which way round it is supposed to be. - pos_weights = pos_weights.as_strided( - (bsz, num_heads, seq_len, seq_len), - ( - pos_weights.stride(0), - pos_weights.stride(1), - pos_weights.stride(2) - pos_weights.stride(3), - pos_weights.stride(3), - ), - storage_offset=pos_weights.stride(3) * (seq_len - 1), - ) + pos_weights = pos_weights.as_strided((bsz, num_heads, seq_len, seq_len), + (pos_weights.stride(0), + pos_weights.stride(1), + pos_weights.stride(2)-pos_weights.stride(3), + pos_weights.stride(3)), + storage_offset=pos_weights.stride(3) * (seq_len - 1)) + # caution: they are really scores at this point. attn_output_weights = torch.matmul(q, k) + pos_weights @@ -1302,9 +1275,10 @@ class RelPositionMultiheadAttention(nn.Module): # this mechanism instead of, say, a limit on entropy, because once the entropy # gets very small gradients through the softmax can become very small, and # some mechanisms like that become ineffective. - attn_output_weights = penalize_abs_values_gt( - attn_output_weights, limit=25.0, penalty=1.0e-04 - ) + attn_output_weights = penalize_abs_values_gt(attn_output_weights, + limit=25.0, + penalty=1.0e-04) + # attn_output_weights: (batch, head, time1, time2) attn_output_weights = attn_output_weights.view( @@ -1346,20 +1320,20 @@ class RelPositionMultiheadAttention(nn.Module): ) attn_output = torch.bmm(attn_output_weights, v) - assert list(attn_output.size()) == [ - bsz * num_heads, - seq_len, - head_dim // 2, - ] + assert list(attn_output.size()) == [bsz * num_heads, seq_len, + head_dim // 2] attn_output = ( attn_output.transpose(0, 1) .contiguous() .view(seq_len, bsz, attention_dim // 2) ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias + ) return attn_output, attn_output_weights + def forward2( self, x: Tensor, @@ -1398,7 +1372,11 @@ class RelPositionMultiheadAttention(nn.Module): # returned value is of shape (seq_len, bsz, embed_dim), like x. return self.out_proj2(attn_output) - def _print_attn_stats(self, attn_weights: Tensor, attn_output: Tensor): + + def _print_attn_stats( + self, + attn_weights: Tensor, + attn_output: Tensor): # attn_weights: (batch_size * num_heads, seq_len, seq_len) # attn_output: (bsz * num_heads, seq_len, head_dim) (n, seq_len, head_dim) = attn_output.shape @@ -1409,50 +1387,39 @@ class RelPositionMultiheadAttention(nn.Module): with torch.cuda.amp.autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_output = attn_output.to(torch.float32) - attn_weights_entropy = ( - -((attn_weights + 1.0e-20).log() * attn_weights) - .sum(dim=-1) - .reshape(bsz, num_heads, seq_len) - .mean(dim=(0, 2)) - ) + attn_weights_entropy = -((attn_weights + 1.0e-20).log() * attn_weights).sum( + dim=-1).reshape(bsz, num_heads, seq_len).mean(dim=(0,2)) attn_output = attn_output.reshape(bsz, num_heads, seq_len, head_dim) - attn_output = attn_output.permute(1, 0, 2, 3).reshape( - num_heads, bsz * seq_len, head_dim - ) + attn_output = attn_output.permute(1, 0, 2, 3).reshape(num_heads, bsz * seq_len, head_dim) attn_output_mean = attn_output.mean(dim=1, keepdim=True) attn_output = attn_output - attn_output_mean - attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / ( - bsz * seq_len - ) + attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / (bsz * seq_len) # attn_covar: (num_heads, head_dim, head_dim) - # eigs, _ = torch.symeig(attn_covar) - # logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") + #eigs, _ = torch.symeig(attn_covar) + #logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") attn_covar = _diag(attn_covar).mean(dim=1) # (num_heads,) embed_dim = self.in_proj2.weight.shape[1] - in_proj_covar = ( - self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2 - ).mean(dim=(1, 2)) - out_proj_covar = ( - self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2 - ).mean(dim=(0, 2)) - logging.info( - f"attn_weights_entropy = {attn_weights_entropy}," - f" covar={attn_covar}, in_proj_covar={in_proj_covar}," - f" out_proj_covar={out_proj_covar}" - ) + in_proj_covar = (self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2).mean(dim=(1,2)) + out_proj_covar = (self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2).mean(dim=(0,2)) + logging.info(f"attn_weights_entropy = {attn_weights_entropy}, covar={attn_covar}, in_proj_covar={in_proj_covar}, out_proj_covar={out_proj_covar}") + + class PoolingModule(nn.Module): """ Averages the input over the time dimension and project with a square matrix. """ - - def __init__(self, d_model: int): + def __init__(self, + d_model: int): super().__init__() - self.proj = ScaledLinear(d_model, d_model, initial_scale=0.1, bias=False) + self.proj = ScaledLinear(d_model, d_model, + initial_scale=0.1, bias=False) - def forward(self, x: Tensor, key_padding_mask: Optional[Tensor] = None): + def forward(self, + x: Tensor, + key_padding_mask: Optional[Tensor] = None): """ Args: x: a Tensor of shape (T, N, C) @@ -1463,7 +1430,7 @@ class PoolingModule(nn.Module): """ if key_padding_mask is not None: pooling_mask = key_padding_mask.logical_not().to(x.dtype) # (N, T) - pooling_mask = pooling_mask / pooling_mask.sum(dim=1, keepdim=True) + pooling_mask = (pooling_mask / pooling_mask.sum(dim=1, keepdim=True)) pooling_mask = pooling_mask.transpose(0, 1).contiguous().unsqueeze(-1) # now pooling_mask: (T, N, 1) x = (x * pooling_mask).sum(dim=0, keepdim=True) @@ -1477,19 +1444,24 @@ class PoolingModule(nn.Module): class FeedforwardModule(nn.Module): - """Feedforward module in Zipformer model.""" - - def __init__(self, d_model: int, feedforward_dim: int, dropout: float): + """Feedforward module in Zipformer model. + """ + def __init__(self, + d_model: int, + feedforward_dim: int, + dropout: float): super(FeedforwardModule, self).__init__() self.in_proj = nn.Linear(d_model, feedforward_dim) - self.balancer = ActivationBalancer( - feedforward_dim, channel_dim=-1, max_abs=10.0, min_prob=0.25 - ) + self.balancer = ActivationBalancer(feedforward_dim, + channel_dim=-1, max_abs=10.0, + min_prob=0.25) self.activation = DoubleSwish() self.dropout = nn.Dropout(dropout) - self.out_proj = ScaledLinear(feedforward_dim, d_model, initial_scale=0.01) + self.out_proj = ScaledLinear(feedforward_dim, d_model, + initial_scale=0.01) - def forward(self, x: Tensor): + def forward(self, + x: Tensor): x = self.in_proj(x) x = self.balancer(x) x = self.activation(x) @@ -1509,7 +1481,9 @@ class ConvolutionModule(nn.Module): """ - def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: + def __init__( + self, channels: int, kernel_size: int, bias: bool = True + ) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding @@ -1539,10 +1513,7 @@ class ConvolutionModule(nn.Module): # the correct range. self.deriv_balancer1 = ActivationBalancer( 2 * channels, - channel_dim=1, - max_abs=10.0, - min_positive=0.05, - max_positive=1.0, + channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0 ) self.depthwise_conv = nn.Conv1d( @@ -1556,10 +1527,8 @@ class ConvolutionModule(nn.Module): ) self.deriv_balancer2 = ActivationBalancer( - channels, - channel_dim=1, - min_positive=0.05, - max_positive=1.0, + channels, channel_dim=1, + min_positive=0.05, max_positive=1.0, max_abs=20.0, ) @@ -1575,10 +1544,9 @@ class ConvolutionModule(nn.Module): initial_scale=0.05, ) - def forward( - self, - x: Tensor, - src_key_padding_mask: Optional[Tensor] = None, + def forward(self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: """Compute convolution module. @@ -1658,7 +1626,8 @@ class Conv2dSubsampling(nn.Module): kernel_size=3, padding=(0, 1), # (time, freq) ), - ActivationBalancer(layer1_channels, channel_dim=1), + ActivationBalancer(layer1_channels, + channel_dim=1), DoubleSwish(), nn.Conv2d( in_channels=layer1_channels, @@ -1667,21 +1636,24 @@ class Conv2dSubsampling(nn.Module): stride=2, padding=0, ), - ActivationBalancer(layer2_channels, channel_dim=1), + ActivationBalancer(layer2_channels, + channel_dim=1), DoubleSwish(), nn.Conv2d( in_channels=layer2_channels, out_channels=layer3_channels, kernel_size=3, - stride=(1, 2), # (time, freq) + stride=(1, 2), # (time, freq) ), - ActivationBalancer(layer3_channels, channel_dim=1), + ActivationBalancer(layer3_channels, + channel_dim=1), DoubleSwish(), ) out_height = (((in_channels - 1) // 2) - 1) // 2 self.out = ScaledLinear(out_height * layer3_channels, out_channels) self.dropout = nn.Dropout(dropout) + def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. @@ -1702,7 +1674,6 @@ class Conv2dSubsampling(nn.Module): x = self.dropout(x) return x - class AttentionCombine(nn.Module): """ This module combines a list of Tensors, all with the same shape, to @@ -1746,12 +1717,15 @@ class AttentionCombine(nn.Module): self.random_prob = random_prob self.single_prob = single_prob - self.weight = torch.nn.Parameter(torch.zeros(num_channels, num_inputs)) + self.weight = torch.nn.Parameter(torch.zeros(num_channels, + num_inputs)) self.bias = torch.nn.Parameter(torch.zeros(num_inputs)) assert 0 <= random_prob <= 1, random_prob assert 0 <= single_prob <= 1, single_prob + + def forward(self, inputs: List[Tensor]) -> Tensor: """Forward function. Args: @@ -1782,35 +1756,28 @@ class AttentionCombine(nn.Module): if self.training: # random masking.. - mask_start = torch.randint( - low=1, - high=int(num_inputs / self.random_prob), - size=(num_frames,), - device=scores.device, - ).unsqueeze(1) + mask_start = torch.randint(low=1, high=int(num_inputs / self.random_prob), + size=(num_frames,), device=scores.device).unsqueeze(1) # mask will have rows like: [ False, False, False, True, True, .. ] - arange = ( - torch.arange(num_inputs, device=scores.device) - .unsqueeze(0) - .expand(num_frames, num_inputs) - ) + arange = torch.arange(num_inputs, device=scores.device).unsqueeze(0).expand( + num_frames, num_inputs) mask = arange >= mask_start - apply_single_prob = torch.logical_and( - torch.rand(size=(num_frames, 1), device=scores.device) - < self.single_prob, - mask_start < num_inputs, - ) - single_prob_mask = torch.logical_and( - apply_single_prob, arange < mask_start - 1 - ) + apply_single_prob = torch.logical_and(torch.rand(size=(num_frames, 1), + device=scores.device) < self.single_prob, + mask_start < num_inputs) + single_prob_mask = torch.logical_and(apply_single_prob, + arange < mask_start - 1) - mask = torch.logical_or(mask, single_prob_mask) + mask = torch.logical_or(mask, + single_prob_mask) - scores = scores.masked_fill(mask, float("-inf")) + scores = scores.masked_fill(mask, float('-inf')) if self.training and random.random() < 0.1: - scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) + scores = penalize_abs_values_gt(scores, + limit=10.0, + penalty=1.0e-04) weights = scores.softmax(dim=1) @@ -1825,6 +1792,7 @@ class AttentionCombine(nn.Module): return ans + def _test_random_combine(): print("_test_random_combine()") num_inputs = 3 @@ -1833,8 +1801,8 @@ def _test_random_combine(): num_channels=num_channels, num_inputs=num_inputs, random_prob=0.5, - single_prob=0.0, - ) + single_prob=0.0) + x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)] @@ -1851,10 +1819,7 @@ def _test_zipformer_main(): # Just make sure the forward pass runs. c = Zipformer( - num_features=feature_dim, - encoder_dims=(64, 96), - encoder_unmasked_dims=(48, 64), - nhead=(4, 4), + num_features=feature_dim, encoder_dims=(64,96), encoder_unmasked_dims=(48,64), nhead=(4,4) ) batch_size = 5 seq_len = 20 @@ -1872,18 +1837,19 @@ def _test_zipformer_main(): ) f # to remove flake8 warnings - def _test_conv2d_subsampling(): num_features = 80 encoder_dims = 384 dropout = 0.1 - encoder_embed = Conv2dSubsampling(num_features, encoder_dims, dropout=dropout) + encoder_embed = Conv2dSubsampling(num_features, encoder_dims, + dropout=dropout) for i in range(20, 40): x = torch.rand(2, i, num_features) y = encoder_embed(x) assert (x.shape[1] - 7) // 2 == y.shape[1], (x.shape[1], y.shape[1]) + if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py index 822f8e44b..9d7335e77 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py @@ -165,24 +165,20 @@ def get_parser(): "--avg", type=int, default=9, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -277,7 +273,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -397,7 +394,9 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] @@ -456,7 +455,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -587,7 +589,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -620,7 +624,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -675,7 +680,9 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -712,12 +719,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -745,12 +753,13 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -779,7 +788,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -807,7 +816,9 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/export.py b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py index 43eb0c1bc..49f469e29 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py @@ -129,24 +129,20 @@ def get_parser(): "--avg", type=int, default=9, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -180,7 +176,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) add_model_arguments(parser) @@ -220,12 +217,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -254,12 +252,13 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -288,7 +287,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -327,7 +326,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py index ed920dc03..e79a3a3aa 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py @@ -69,12 +69,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) return parser @@ -95,9 +93,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -268,7 +267,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/model.py b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py index 39a360796..497b89136 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py @@ -160,7 +160,9 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py index 716136812..373a48fc1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py @@ -100,11 +100,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -129,12 +127,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -181,7 +177,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -212,9 +209,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -277,11 +275,15 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lengths + ) num_waves = encoder_out.size(0) hyps = [] @@ -353,7 +355,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py index 381a86a67..2603bb854 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py @@ -92,7 +92,9 @@ from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: @@ -130,10 +132,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-dims", type=str, default="384,384,384,384,384", - help=( - "Embedding dimension in the 2 blocks of zipformer encoder layers, comma" - " separated" - ), + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", ) parser.add_argument( @@ -148,11 +147,9 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-unmasked-dims", type=str, default="256,256,256,256,256", - help=( - "Unmasked dimensions in the encoders, relates to augmentation during" - " training. Must be <= each of encoder_dims. Empirically, less than 256" - " seems to make performance worse." - ), + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", ) parser.add_argument( @@ -217,7 +214,8 @@ def get_parser(): "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. " + "Otherwise, use 100h subset.", ) parser.add_argument( @@ -287,45 +285,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -696,7 +691,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -745,7 +744,9 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -951,7 +952,9 @@ def train_one_epoch( # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + if cur_grad_scale < 1.0 or ( + cur_grad_scale < 8.0 and batch_idx % 400 == 0 + ): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") @@ -972,7 +975,11 @@ def train_one_epoch( f"giga_tot_loss[{giga_tot_loss}], " f"batch size: {batch_size}, " f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + + ( + f"grad_scale: {scaler._scale.item()}" + if params.use_fp16 + else "" + ) ) if tb_writer is not None: @@ -985,8 +992,12 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) libri_tot_loss.write_summary( tb_writer, "train/libri_tot_", params.batch_idx_train ) @@ -1000,7 +1011,10 @@ def train_one_epoch( params.batch_idx_train, ) - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + if ( + batch_idx % params.valid_interval == 0 + and not params.print_diagnostics + ): logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, @@ -1012,8 +1026,7 @@ def train_one_epoch( model.train() logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info( - "Maximum memory allocated so far is" - f" {torch.cuda.max_memory_allocated()//1000000}MB" + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" ) if tb_writer is not None: valid_info.write_summary( @@ -1041,7 +1054,8 @@ def filter_short_and_long_utterances( # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False @@ -1138,7 +1152,9 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0) + optimizer = ScaledAdam( + model.parameters(), lr=params.base_lr, clipping_scale=2.0 + ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) @@ -1156,7 +1172,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 2 ** 22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1191,7 +1207,9 @@ def run(rank, world_size, args): train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: - cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") + cuts_musan = load_manifest( + Path(args.manifest_dir) / "musan_cuts.jsonl.gz" + ) else: cuts_musan = None @@ -1346,8 +1364,7 @@ def scan_pessimistic_batches_for_oom( display_and_save_batch(batch, params=params, sp=sp) raise logging.info( - "Maximum memory allocated so far is" - f" {torch.cuda.max_memory_allocated()//1000000}MB" + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" ) diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/README.md b/egs/librispeech/ASR/streaming_conformer_ctc/README.md index 53f383c99..01be7090b 100644 --- a/egs/librispeech/ASR/streaming_conformer_ctc/README.md +++ b/egs/librispeech/ASR/streaming_conformer_ctc/README.md @@ -1,20 +1,20 @@ ## Train and Decode -Commands of data preparation/train/decode steps are almost the same with +Commands of data preparation/train/decode steps are almost the same with ../conformer_ctc experiment except some options. Please read the code and understand following new added options before running this experiment: For data preparation: - + Nothing new. For streaming_conformer_ctc/train.py: - + --dynamic-chunk-training --short-chunk-proportion For streaming_conformer_ctc/streaming_decode.py: - + --chunk-size --tailing-num-frames --simulate-streaming @@ -57,10 +57,10 @@ And check md5sum values again. Finally, following files will be downloaded:

-streaming_models/
-|-- lang_bpe
-|   |-- L.pt
-|   |-- Linv.pt
+streaming_models/  
+|-- lang_bpe  
+|   |-- L.pt  
+|   |-- Linv.pt  
 |   |-- bpe.model
 |   |-- tokens.txt
 |   `-- words.txt
diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py
index 4f7427c1f..ff4c91446 100644
--- a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py
+++ b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py
@@ -309,26 +309,36 @@ class Conformer(Transformer):
 
                 # start chunk_by_chunk decoding
                 offset = 0
-                for cur in range(0, num_frames - embed_left_context + 1, stride):
+                for cur in range(
+                    0, num_frames - embed_left_context + 1, stride
+                ):
                     end = min(cur + decoding_window, num_frames)
                     cur_feature = feature[:, cur:end, :]
                     cur_feature = self.encoder_embed(cur_feature)
-                    cur_embed, cur_pos_emb = self.encoder_pos(cur_feature, offset)
-                    cur_embed = cur_embed.permute(1, 0, 2)  # (B, T, F) -> (T, B, F)
+                    cur_embed, cur_pos_emb = self.encoder_pos(
+                        cur_feature, offset
+                    )
+                    cur_embed = cur_embed.permute(
+                        1, 0, 2
+                    )  # (B, T, F) -> (T, B, F)
 
                     cur_T = cur_feature.size(1)
                     if cur == 0:
                         # for first chunk extract the central pos embedding
-                        pos_emb_central = cur_pos_emb[0, (chunk_size - 1), :].view(
-                            1, 1, -1
-                        )
+                        pos_emb_central = cur_pos_emb[
+                            0, (chunk_size - 1), :
+                        ].view(1, 1, -1)
                         cur_T -= 1
                     pos_emb_positive.append(cur_pos_emb[0, :cur_T].flip(0))
                     pos_emb_negative.append(cur_pos_emb[0, -cur_T:])
                     assert pos_emb_positive[-1].size(0) == cur_T
 
-                    pos_emb_pos = torch.cat(pos_emb_positive, dim=0).unsqueeze(0)
-                    pos_emb_neg = torch.cat(pos_emb_negative, dim=0).unsqueeze(0)
+                    pos_emb_pos = torch.cat(pos_emb_positive, dim=0).unsqueeze(
+                        0
+                    )
+                    pos_emb_neg = torch.cat(pos_emb_negative, dim=0).unsqueeze(
+                        0
+                    )
                     cur_pos_emb = torch.cat(
                         [pos_emb_pos.flip(1), pos_emb_central, pos_emb_neg],
                         dim=1,
@@ -403,7 +413,9 @@ class ConformerEncoderLayer(nn.Module):
         causal: bool = False,
     ) -> None:
         super(ConformerEncoderLayer, self).__init__()
-        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
+        self.self_attn = RelPositionMultiheadAttention(
+            d_model, nhead, dropout=0.0
+        )
 
         self.feed_forward = nn.Sequential(
             nn.Linear(d_model, dim_feedforward),
@@ -419,16 +431,22 @@ class ConformerEncoderLayer(nn.Module):
             nn.Linear(dim_feedforward, d_model),
         )
 
-        self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal)
+        self.conv_module = ConvolutionModule(
+            d_model, cnn_module_kernel, causal=causal
+        )
 
-        self.norm_ff_macaron = nn.LayerNorm(d_model)  # for the macaron style FNN module
+        self.norm_ff_macaron = nn.LayerNorm(
+            d_model
+        )  # for the macaron style FNN module
         self.norm_ff = nn.LayerNorm(d_model)  # for the FNN module
         self.norm_mha = nn.LayerNorm(d_model)  # for the MHA module
 
         self.ff_scale = 0.5
 
         self.norm_conv = nn.LayerNorm(d_model)  # for the CNN module
-        self.norm_final = nn.LayerNorm(d_model)  # for the final output of the block
+        self.norm_final = nn.LayerNorm(
+            d_model
+        )  # for the final output of the block
 
         self.dropout = nn.Dropout(dropout)
 
@@ -462,7 +480,9 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
+        src = residual + self.ff_scale * self.dropout(
+            self.feed_forward_macaron(src)
+        )
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -534,7 +554,9 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
+        src = residual + self.ff_scale * self.dropout(
+            self.feed_forward_macaron(src)
+        )
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -714,7 +736,9 @@ class RelPositionalEncoding(torch.nn.Module):
 
     """
 
-    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
+    def __init__(
+        self, d_model: int, dropout_rate: float, max_len: int = 5000
+    ) -> None:
         """Construct an PositionalEncoding object."""
         super(RelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -731,7 +755,9 @@ class RelPositionalEncoding(torch.nn.Module):
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(1) >= x_size_1 * 2 - 1:
                 # Note: TorchScript doesn't implement operator== for torch.Device
-                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
+                    x.device
+                ):
                     self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                 return
         # Suppose `i` means to the position of query vector and `j` means the
@@ -757,7 +783,9 @@ class RelPositionalEncoding(torch.nn.Module):
         pe = torch.cat([pe_positive, pe_negative], dim=1)
         self.pe = pe.to(device=x.device, dtype=x.dtype)
 
-    def forward(self, x: torch.Tensor, offset: int = 0) -> Tuple[Tensor, Tensor]:
+    def forward(
+        self, x: torch.Tensor, offset: int = 0
+    ) -> Tuple[Tensor, Tensor]:
         """Add positional encoding.
 
         Args:
@@ -785,7 +813,9 @@ class RelPositionalEncoding(torch.nn.Module):
             pos_emb = torch.cat(
                 [
                     pos_emb[:, : (x_T - 1)],
-                    self.pe[0, self.pe.size(1) // 2].view(1, 1, self.pe.size(-1)),
+                    self.pe[0, self.pe.size(1) // 2].view(
+                        1, 1, self.pe.size(-1)
+                    ),
                     pos_emb[:, -(x_T - 1) :],  # noqa: E203
                 ],
                 dim=1,
@@ -1020,9 +1050,9 @@ class RelPositionMultiheadAttention(nn.Module):
 
         if torch.equal(query, key) and torch.equal(key, value):
             # self-attention
-            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
-                3, dim=-1
-            )
+            q, k, v = nn.functional.linear(
+                query, in_proj_weight, in_proj_bias
+            ).chunk(3, dim=-1)
 
         elif torch.equal(key, value):
             # encoder-decoder attention
@@ -1090,25 +1120,33 @@ class RelPositionMultiheadAttention(nn.Module):
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0)
                 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
-                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
+                    raise RuntimeError(
+                        "The size of the 2D attn_mask is not correct."
+                    )
             elif attn_mask.dim() == 3:
                 if list(attn_mask.size()) != [
                     bsz * num_heads,
                     query.size(0),
                     key.size(0),
                 ]:
-                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
+                    raise RuntimeError(
+                        "The size of the 3D attn_mask is not correct."
+                    )
             else:
                 raise RuntimeError(
-                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
+                    "attn_mask's dimension {} is not supported".format(
+                        attn_mask.dim()
+                    )
                 )
             # attn_mask's dim is 3 now.
 
         # convert ByteTensor key_padding_mask to bool
-        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
+        if (
+            key_padding_mask is not None
+            and key_padding_mask.dtype == torch.uint8
+        ):
             warnings.warn(
-                "Byte tensor for key_padding_mask is deprecated. Use bool tensor"
-                " instead."
+                "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
             )
             key_padding_mask = key_padding_mask.to(torch.bool)
 
@@ -1147,16 +1185,24 @@ class RelPositionMultiheadAttention(nn.Module):
         # first compute matrix a and matrix c
         # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
         k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
-        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(
+            q_with_bias_u, k
+        )  # (batch, head, time1, time2)
 
         # compute matrix b and matrix d
-        matrix_bd = torch.matmul(q_with_bias_v, p)  # (batch, head, time1, 2*time1-1)
-        matrix_bd = self.rel_shift(matrix_bd, offset=offset)  # [B, head, time1, time2]
+        matrix_bd = torch.matmul(
+            q_with_bias_v, p
+        )  # (batch, head, time1, 2*time1-1)
+        matrix_bd = self.rel_shift(
+            matrix_bd, offset=offset
+        )  # [B, head, time1, time2]
         attn_output_weights = (
             matrix_ac + matrix_bd
         ) * scaling  # (batch, head, time1, time2)
 
-        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
+        attn_output_weights = attn_output_weights.view(
+            bsz * num_heads, tgt_len, -1
+        )
 
         assert list(attn_output_weights.size()) == [
             bsz * num_heads,
@@ -1190,9 +1236,13 @@ class RelPositionMultiheadAttention(nn.Module):
         attn_output = torch.bmm(attn_output_weights, v)
         assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
         attn_output = (
-            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+            attn_output.transpose(0, 1)
+            .contiguous()
+            .view(tgt_len, bsz, embed_dim)
+        )
+        attn_output = nn.functional.linear(
+            attn_output, out_proj_weight, out_proj_bias
         )
-        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
 
         if need_weights:
             # average attention weights over heads
diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py b/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py
index 5a8149aad..a74c51836 100755
--- a/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py
+++ b/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py
@@ -28,7 +28,6 @@ import torch
 import torch.nn as nn
 from asr_datamodule import LibriSpeechAsrDataModule
 from conformer import Conformer
-
 from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
 from icefall.checkpoint import average_checkpoints, load_checkpoint
 from icefall.lexicon import Lexicon
@@ -63,36 +62,32 @@ def get_parser():
         "--epoch",
         type=int,
         default=34,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=20,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
         "--chunk-size",
         type=int,
         default=8,
-        help=(
-            "Frames of right context"
-            "-1 for whole right context, i.e. non-streaming decoding"
-        ),
+        help="Frames of right context"
+        "-1 for whole right context, i.e. non-streaming decoding",
     )
 
     parser.add_argument(
         "--tailing-num-frames",
         type=int,
         default=20,
-        help="tailing dummy frames padded to the right,only used during decoding",
+        help="tailing dummy frames padded to the right,"
+        "only used during decoding",
     )
 
     parser.add_argument(
@@ -144,7 +139,8 @@ def get_parser():
         "--avg-models",
         type=str,
         default=None,
-        help="Manually select models to average, seperated by comma;e.g. 60,62,63,72",
+        help="Manually select models to average, seperated by comma;"
+        "e.g. 60,62,63,72",
     )
 
     return parser
@@ -252,9 +248,13 @@ def decode_one_batch(
     maxlen = nnet_output.size(1)
     topk_prob, topk_index = nnet_output.topk(1, dim=2)  # (B, maxlen, 1)
     topk_index = topk_index.view(batch_size, maxlen)  # (B, maxlen)
-    topk_index = topk_index.masked_fill_(memory_key_padding_mask, 0)  # (B, maxlen)
+    topk_index = topk_index.masked_fill_(
+        memory_key_padding_mask, 0
+    )  # (B, maxlen)
     token_ids = [token_id.tolist() for token_id in topk_index]
-    token_ids = [remove_duplicates_and_blank(token_id) for token_id in token_ids]
+    token_ids = [
+        remove_duplicates_and_blank(token_id) for token_id in token_ids
+    ]
     hyps = bpe_model.decode(token_ids)
     hyps = [s.split() for s in hyps]
     return {key: hyps}
@@ -337,7 +337,9 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
 
     return results
 
@@ -355,18 +357,15 @@ def save_results(
     test_set_wers = dict()
     if params.avg_models is not None:
         avg_models = params.avg_models.replace(",", "_")
-        result_file_prefix = (
-            f"epoch-avg-{avg_models}-chunksize        "
-            f" -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-"
-        )
+        result_file_prefix = f"epoch-avg-{avg_models}-chunksize \
+        -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-"
     else:
-        result_file_prefix = (
-            f"epoch-{params.epoch}-avg-{params.avg}-chunksize        "
-            f" -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-"
-        )
+        result_file_prefix = f"epoch-{params.epoch}-avg-{params.avg}-chunksize \
+        -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-"
     for key, results in results_dict.items():
         recog_path = (
-            params.exp_dir / f"{result_file_prefix}recogs-{test_set_name}-{key}.txt"
+            params.exp_dir
+            / f"{result_file_prefix}recogs-{test_set_name}-{key}.txt"
         )
         store_transcripts(filename=recog_path, texts=results)
         if enable_log:
@@ -375,7 +374,8 @@ def save_results(
         # The following prints out WERs, per-word error statistics and aligned
         # ref/hyp pairs.
         errs_filename = (
-            params.exp_dir / f"{result_file_prefix}-errs-{test_set_name}-{key}.txt"
+            params.exp_dir
+            / f"{result_file_prefix}-errs-{test_set_name}-{key}.txt"
         )
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
@@ -384,7 +384,9 @@ def save_results(
             test_set_wers[key] = wer
 
         if enable_log:
-            logging.info("Wrote detailed error stats to {}".format(errs_filename))
+            logging.info(
+                "Wrote detailed error stats to {}".format(errs_filename)
+            )
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
@@ -472,7 +474,9 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
+        torch.save(
+            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
+        )
         return
 
     model.to(device)
@@ -503,7 +507,9 @@ def main():
             simulate_streaming=params.simulate_streaming,
         )
 
-        save_results(params=params, test_set_name=test_set, results_dict=results_dict)
+        save_results(
+            params=params, test_set_name=test_set, results_dict=results_dict
+        )
 
     logging.info("Done!")
 
diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/train.py b/egs/librispeech/ASR/streaming_conformer_ctc/train.py
index 553b7d092..e41b7ea78 100755
--- a/egs/librispeech/ASR/streaming_conformer_ctc/train.py
+++ b/egs/librispeech/ASR/streaming_conformer_ctc/train.py
@@ -405,7 +405,9 @@ def compute_loss(
             #
             # See https://github.com/k2-fsa/icefall/issues/97
             # for more details
-            unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"])
+            unsorted_token_ids = graph_compiler.texts_to_ids(
+                supervisions["text"]
+            )
             att_loss = mmodel.decoder_forward(
                 encoder_memory,
                 memory_mask,
@@ -434,7 +436,9 @@ def compute_loss(
     info["utt_duration"] = supervisions["num_frames"].sum().item()
     # averaged padding proportion over utterances
     info["utt_pad_proportion"] = (
-        ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)).sum().item()
+        ((feature.size(1) - supervisions["num_frames"]) / feature.size(1))
+        .sum()
+        .item()
     )
 
     return loss, info
@@ -547,7 +551,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -662,7 +668,9 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
+            tb_writer.add_scalar(
+                "train/learning_rate", cur_lr, params.batch_idx_train
+            )
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py b/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py
index 0c87fdf1b..bc78e4a41 100644
--- a/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py
+++ b/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py
@@ -149,7 +149,9 @@ class Transformer(nn.Module):
                 norm=decoder_norm,
             )
 
-            self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class)
+            self.decoder_output_layer = torch.nn.Linear(
+                d_model, self.decoder_num_class
+            )
 
             self.decoder_criterion = LabelSmoothingLoss()
         else:
@@ -284,17 +286,23 @@ class Transformer(nn.Module):
         """
         ys_in = add_sos(token_ids, sos_id=sos_id)
         ys_in = [torch.tensor(y) for y in ys_in]
-        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
+        ys_in_pad = pad_sequence(
+            ys_in, batch_first=True, padding_value=float(eos_id)
+        )
 
         ys_out = add_eos(token_ids, eos_id=eos_id)
         ys_out = [torch.tensor(y) for y in ys_out]
-        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
+        ys_out_pad = pad_sequence(
+            ys_out, batch_first=True, padding_value=float(-1)
+        )
 
         device = memory.device
         ys_in_pad = ys_in_pad.to(device)
         ys_out_pad = ys_out_pad.to(device)
 
-        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
+            device
+        )
 
         tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
         # TODO: Use length information to create the decoder padding mask
@@ -355,17 +363,23 @@ class Transformer(nn.Module):
 
         ys_in = add_sos(token_ids, sos_id=sos_id)
         ys_in = [torch.tensor(y) for y in ys_in]
-        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
+        ys_in_pad = pad_sequence(
+            ys_in, batch_first=True, padding_value=float(eos_id)
+        )
 
         ys_out = add_eos(token_ids, eos_id=eos_id)
         ys_out = [torch.tensor(y) for y in ys_out]
-        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
+        ys_out_pad = pad_sequence(
+            ys_out, batch_first=True, padding_value=float(-1)
+        )
 
         device = memory.device
         ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
         ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
 
-        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
+            device
+        )
 
         tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
         # TODO: Use length information to create the decoder padding mask
@@ -638,7 +652,9 @@ def _get_activation_fn(activation: str):
     elif activation == "gelu":
         return nn.functional.gelu
 
-    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
+    raise RuntimeError(
+        "activation should be relu/gelu, not {}".format(activation)
+    )
 
 
 class PositionalEncoding(nn.Module):
@@ -840,7 +856,9 @@ def encoder_padding_mask(
         1,
     ).to(torch.int32)
 
-    lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)]
+    lengths = [
+        0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
+    ]
     for idx in range(supervision_segments.size(0)):
         # Note: TorchScript doesn't allow to unpack tensors as tuples
         sequence_idx = supervision_segments[idx, 0].item()
@@ -861,7 +879,9 @@ def encoder_padding_mask(
     return mask
 
 
-def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor:
+def decoder_padding_mask(
+    ys_pad: torch.Tensor, ignore_id: int = -1
+) -> torch.Tensor:
     """Generate a length mask for input.
 
     The masked position are filled with True,
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
index 63afd6be2..355ccc99a 100644
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -77,18 +77,17 @@ class LibriSpeechAsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description=(
-                "These options are used for the preparation of "
-                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-                "effective batch sizes, sampling strategies, applied data "
-                "augmentations, etc."
-            ),
+            description="These options are used for the preparation of "
+            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+            "effective batch sizes, sampling strategies, applied data "
+            "augmentations, etc.",
         )
         group.add_argument(
             "--full-libri",
             type=str2bool,
             default=True,
-            help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.",
+            help="When enabled, use 960h LibriSpeech. "
+            "Otherwise, use 100h subset.",
         )
         group.add_argument(
             "--manifest-dir",
@@ -100,74 +99,59 @@ class LibriSpeechAsrDataModule:
             "--max-duration",
             type=int,
             default=200.0,
-            help=(
-                "Maximum pooled recordings duration (seconds) in a "
-                "single batch. You can reduce it if it causes CUDA OOM."
-            ),
+            help="Maximum pooled recordings duration (seconds) in a "
+            "single batch. You can reduce it if it causes CUDA OOM.",
         )
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, the batches will come from buckets of "
-                "similar duration (saves padding frames)."
-            ),
+            help="When enabled, the batches will come from buckets of "
+            "similar duration (saves padding frames).",
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=30,
-            help=(
-                "The number of buckets for the DynamicBucketingSampler"
-                "(you might want to increase it for larger datasets)."
-            ),
+            help="The number of buckets for the DynamicBucketingSampler"
+            "(you might want to increase it for larger datasets).",
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, utterances (cuts) will be concatenated "
-                "to minimize the amount of padding."
-            ),
+            help="When enabled, utterances (cuts) will be concatenated "
+            "to minimize the amount of padding.",
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help=(
-                "Determines the maximum duration of a concatenated cut "
-                "relative to the duration of the longest cut in a batch."
-            ),
+            help="Determines the maximum duration of a concatenated cut "
+            "relative to the duration of the longest cut in a batch.",
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help=(
-                "The amount of padding (in seconds) inserted between "
-                "concatenated cuts. This padding is filled with noise when "
-                "noise augmentation is used."
-            ),
+            help="The amount of padding (in seconds) inserted between "
+            "concatenated cuts. This padding is filled with noise when "
+            "noise augmentation is used.",
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, use on-the-fly cut mixing and feature "
-                "extraction. Will drop existing precomputed feature manifests "
-                "if available."
-            ),
+            help="When enabled, use on-the-fly cut mixing and feature "
+            "extraction. Will drop existing precomputed feature manifests "
+            "if available.",
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled (=default), the examples will be shuffled for each epoch."
-            ),
+            help="When enabled (=default), the examples will be "
+            "shuffled for each epoch.",
         )
         group.add_argument(
             "--drop-last",
@@ -179,18 +163,17 @@ class LibriSpeechAsrDataModule:
             "--return-cuts",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, each batch will have the "
-                "field: batch['supervisions']['cut'] with the cuts that "
-                "were used to construct it."
-            ),
+            help="When enabled, each batch will have the "
+            "field: batch['supervisions']['cut'] with the cuts that "
+            "were used to construct it.",
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that collect the batches.",
+            help="The number of training dataloader workers that "
+            "collect the batches.",
         )
 
         group.add_argument(
@@ -204,22 +187,18 @@ class LibriSpeechAsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help=(
-                "Used only when --enable-spec-aug is True. "
-                "It specifies the factor for time warping in SpecAugment. "
-                "Larger values mean more warping. "
-                "A value less than 1 means to disable time warp."
-            ),
+            help="Used only when --enable-spec-aug is True. "
+            "It specifies the factor for time warping in SpecAugment. "
+            "Larger values mean more warping. "
+            "A value less than 1 means to disable time warp.",
         )
 
         group.add_argument(
             "--enable-musan",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, select noise from MUSAN and mix it"
-                "with training dataset. "
-            ),
+            help="When enabled, select noise from MUSAN and mix it"
+            "with training dataset. ",
         )
 
         group.add_argument(
@@ -245,16 +224,20 @@ class LibriSpeechAsrDataModule:
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             logging.info("About to get Musan cuts")
-            cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+            cuts_musan = load_manifest(
+                self.args.manifest_dir / "musan_cuts.jsonl.gz"
+            )
             transforms.append(
-                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+                CutMix(
+                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
+                )
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                "Using cut concatenation with duration factor "
+                f"Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -269,7 +252,9 @@ class LibriSpeechAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+            logging.info(
+                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
+            )
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -313,7 +298,9 @@ class LibriSpeechAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -369,7 +356,9 @@ class LibriSpeechAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 return_cuts=self.args.return_cuts,
             )
         else:
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
index 94ba0a4dc..7d0cd0bf3 100755
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
@@ -57,19 +57,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=19,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=5,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
     parser.add_argument(
         "--method",
@@ -339,7 +336,9 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -401,7 +400,9 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
+    HLG = k2.Fsa.from_dict(
+        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
+    )
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -466,7 +467,9 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
+        torch.save(
+            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
+        )
         return
 
     model.to(device)
@@ -495,7 +498,9 @@ def main():
             G=G,
         )
 
-        save_results(params=params, test_set_name=test_set, results_dict=results_dict)
+        save_results(
+            params=params, test_set_name=test_set, results_dict=results_dict
+        )
 
     logging.info("Done!")
 
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/model.py b/egs/librispeech/ASR/tdnn_lstm_ctc/model.py
index 1731e1ebe..5e04c11b4 100644
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/model.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/model.py
@@ -66,7 +66,10 @@ class TdnnLstm(nn.Module):
             nn.BatchNorm1d(num_features=500, affine=False),
         )
         self.lstms = nn.ModuleList(
-            [nn.LSTM(input_size=500, hidden_size=500, num_layers=1) for _ in range(5)]
+            [
+                nn.LSTM(input_size=500, hidden_size=500, num_layers=1)
+                for _ in range(5)
+            ]
         )
         self.lstm_bnorms = nn.ModuleList(
             [nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)]
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
index 722e8f003..2baeb6bba 100755
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
@@ -29,7 +29,11 @@ import torchaudio
 from model import TdnnLstm
 from torch.nn.utils.rnn import pad_sequence
 
-from icefall.decode import get_lattice, one_best_decoding, rescore_with_whole_lattice
+from icefall.decode import (
+    get_lattice,
+    one_best_decoding,
+    rescore_with_whole_lattice,
+)
 from icefall.utils import AttributeDict, get_texts
 
 
@@ -42,11 +46,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -56,7 +58,9 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
+    parser.add_argument(
+        "--HLG", type=str, required=True, help="Path to HLG.pt."
+    )
 
     parser.add_argument(
         "--method",
@@ -99,12 +103,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     return parser
@@ -142,9 +144,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -212,7 +215,9 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
     features = features.permute(0, 2, 1)  # now features is (N, C, T)
 
     with torch.no_grad():
@@ -264,7 +269,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py
index 071ac792b..6b37d5c23 100755
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py
@@ -355,7 +355,9 @@ def compute_loss(
     info["utt_duration"] = supervisions["num_frames"].sum().item()
     # averaged padding proportion over utterances
     info["utt_pad_proportion"] = (
-        ((feature.size(2) - supervisions["num_frames"]) / feature.size(2)).sum().item()
+        ((feature.size(2) - supervisions["num_frames"]) / feature.size(2))
+        .sum()
+        .item()
     )
 
     return loss, info
@@ -468,7 +470,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/egs/librispeech/ASR/transducer/beam_search.py b/egs/librispeech/ASR/transducer/beam_search.py
index b45b6a9d8..11032f31a 100644
--- a/egs/librispeech/ASR/transducer/beam_search.py
+++ b/egs/librispeech/ASR/transducer/beam_search.py
@@ -38,7 +38,9 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
     blank_id = model.decoder.blank_id
     device = model.device
 
-    sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(1, 1)
+    sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(
+        1, 1
+    )
     decoder_out, (h, c) = model.decoder(sos)
     T = encoder_out.size(1)
     t = 0
@@ -121,7 +123,9 @@ def beam_search(
     max_u = 20000  # terminate after this number of steps
     u = 0
 
-    cache: Dict[str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = {}
+    cache: Dict[
+        str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
+    ] = {}
 
     while t < T and u < max_u:
         # fmt: off
@@ -153,9 +157,9 @@ def beam_search(
             cached_key = "_".join(map(str, y_star.ys))
 
             if cached_key not in cache:
-                decoder_input = torch.tensor([y_star.ys[-1]], device=device).reshape(
-                    1, 1
-                )
+                decoder_input = torch.tensor(
+                    [y_star.ys[-1]], device=device
+                ).reshape(1, 1)
 
                 decoder_out, decoder_state = model.decoder(
                     decoder_input,
diff --git a/egs/librispeech/ASR/transducer/decode.py b/egs/librispeech/ASR/transducer/decode.py
index f30332cea..5f233df87 100755
--- a/egs/librispeech/ASR/transducer/decode.py
+++ b/egs/librispeech/ASR/transducer/decode.py
@@ -71,19 +71,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=34,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=11,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -231,7 +228,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
     hyps = []
     batch_size = encoder_out.size(0)
 
@@ -246,7 +245,9 @@ def decode_one_batch(
                 model=model, encoder_out=encoder_out_i, beam=params.beam_size
             )
         else:
-            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
+            raise ValueError(
+                f"Unsupported decoding method: {params.decoding_method}"
+            )
         hyps.append(sp.decode(hyp).split())
 
     if params.decoding_method == "greedy_search":
@@ -317,7 +318,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -350,7 +353,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
diff --git a/egs/librispeech/ASR/transducer/export.py b/egs/librispeech/ASR/transducer/export.py
index 4d9f937f5..5a5db30c4 100755
--- a/egs/librispeech/ASR/transducer/export.py
+++ b/egs/librispeech/ASR/transducer/export.py
@@ -67,20 +67,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=34,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=11,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -241,7 +238,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer/pretrained.py b/egs/librispeech/ASR/transducer/pretrained.py
index 7aadfbcd1..1db2df648 100755
--- a/egs/librispeech/ASR/transducer/pretrained.py
+++ b/egs/librispeech/ASR/transducer/pretrained.py
@@ -60,11 +60,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -89,12 +87,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -192,9 +188,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -252,7 +249,9 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -288,7 +287,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer/rnn.py b/egs/librispeech/ASR/transducer/rnn.py
index fe8732301..2a165b0c1 100644
--- a/egs/librispeech/ASR/transducer/rnn.py
+++ b/egs/librispeech/ASR/transducer/rnn.py
@@ -117,8 +117,12 @@ class LayerNormLSTMCell(nn.Module):
         )
 
         if bias:
-            self.bias_ih = nn.Parameter(torch.empty(4 * hidden_size, **factory_kwargs))
-            self.bias_hh = nn.Parameter(torch.empty(4 * hidden_size, **factory_kwargs))
+            self.bias_ih = nn.Parameter(
+                torch.empty(4 * hidden_size, **factory_kwargs)
+            )
+            self.bias_hh = nn.Parameter(
+                torch.empty(4 * hidden_size, **factory_kwargs)
+            )
         else:
             self.register_parameter("bias_ih", None)
             self.register_parameter("bias_hh", None)
@@ -344,7 +348,9 @@ class LayerNormLSTM(nn.Module):
             device=device,
             dtype=dtype,
         )
-        first_layer = LayerNormLSTMLayer(input_size=input_size, **factory_kwargs)
+        first_layer = LayerNormLSTMLayer(
+            input_size=input_size, **factory_kwargs
+        )
         layers = [first_layer]
         for i in range(1, num_layers):
             layers.append(
@@ -379,7 +385,9 @@ class LayerNormLSTM(nn.Module):
             - List[(next_h, next_c)] containing the hidden states for all layers
 
         """
-        output_states = torch.jit.annotate(List[Tuple[torch.Tensor, torch.Tensor]], [])
+        output_states = torch.jit.annotate(
+            List[Tuple[torch.Tensor, torch.Tensor]], []
+        )
         output = input
         for i, rnn_layer in enumerate(self.layers):
             state = states[i]
@@ -448,8 +456,12 @@ class LayerNormGRUCell(nn.Module):
         )
 
         if bias:
-            self.bias_ih = nn.Parameter(torch.empty(3 * hidden_size, **factory_kwargs))
-            self.bias_hh = nn.Parameter(torch.empty(3 * hidden_size, **factory_kwargs))
+            self.bias_ih = nn.Parameter(
+                torch.empty(3 * hidden_size, **factory_kwargs)
+            )
+            self.bias_hh = nn.Parameter(
+                torch.empty(3 * hidden_size, **factory_kwargs)
+            )
         else:
             self.register_parameter("bias_ih", None)
             self.register_parameter("bias_hh", None)
diff --git a/egs/librispeech/ASR/transducer/test_rnn.py b/egs/librispeech/ASR/transducer/test_rnn.py
index 74c94cc70..8591e2d8a 100755
--- a/egs/librispeech/ASR/transducer/test_rnn.py
+++ b/egs/librispeech/ASR/transducer/test_rnn.py
@@ -254,7 +254,9 @@ def test_layernorm_lstm_layer_with_projection_forward(device="cpu"):
         for name, self_param in self_layer.cell.named_parameters():
             getattr(torch_layer, f"{name}_l0").copy_(self_param)
 
-    torch_y, (torch_h, torch_c) = torch_layer(x_clone, (h.unsqueeze(0), c.unsqueeze(0)))
+    torch_y, (torch_h, torch_c) = torch_layer(
+        x_clone, (h.unsqueeze(0), c.unsqueeze(0))
+    )
     assert_allclose(self_y, torch_y)
     assert_allclose(self_h, torch_h)
     assert_allclose(self_c, torch_c)
@@ -301,7 +303,9 @@ def test_layernorm_lstm_layer_forward(device="cpu"):
         for name, self_param in self_layer.cell.named_parameters():
             getattr(torch_layer, f"{name}_l0").copy_(self_param)
 
-    torch_y, (torch_h, torch_c) = torch_layer(x_clone, (h.unsqueeze(0), c.unsqueeze(0)))
+    torch_y, (torch_h, torch_c) = torch_layer(
+        x_clone, (h.unsqueeze(0), c.unsqueeze(0))
+    )
     assert_allclose(self_y, torch_y)
     assert_allclose(self_h, torch_h)
     assert_allclose(self_c, torch_c)
@@ -590,7 +594,9 @@ def test_layernorm_gru_cell_forward(device="cpu"):
 
     assert_allclose(self_h, torch_h, atol=1e-5)
 
-    (self_h.reshape(-1) * torch.arange(self_h.numel(), device=device)).sum().backward()
+    (
+        self_h.reshape(-1) * torch.arange(self_h.numel(), device=device)
+    ).sum().backward()
     (
         torch_h.reshape(-1) * torch.arange(torch_h.numel(), device=device)
     ).sum().backward()
@@ -712,7 +718,9 @@ def test_layernorm_gru_forward(device="cpu"):
     T = torch.randint(low=2, high=100, size=(1,))
 
     x = torch.rand(N, T, input_size, device=device).requires_grad_()
-    states = [torch.rand(N, hidden_size, device=device) for _ in range(num_layers)]
+    states = [
+        torch.rand(N, hidden_size, device=device) for _ in range(num_layers)
+    ]
 
     x_clone = x.detach().clone().requires_grad_()
 
diff --git a/egs/librispeech/ASR/transducer/train.py b/egs/librispeech/ASR/transducer/train.py
index 674ea10a6..1dd65eddb 100755
--- a/egs/librispeech/ASR/transducer/train.py
+++ b/egs/librispeech/ASR/transducer/train.py
@@ -396,7 +396,9 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -518,7 +520,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -655,7 +659,9 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
+            tb_writer.add_scalar(
+                "train/learning_rate", cur_lr, params.batch_idx_train
+            )
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/transducer_lstm/beam_search.py b/egs/librispeech/ASR/transducer_lstm/beam_search.py
index 5342c3e8c..3531a9633 100644
--- a/egs/librispeech/ASR/transducer_lstm/beam_search.py
+++ b/egs/librispeech/ASR/transducer_lstm/beam_search.py
@@ -38,7 +38,9 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
     blank_id = model.decoder.blank_id
     device = model.device
 
-    sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(1, 1)
+    sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(
+        1, 1
+    )
     decoder_out, (h, c) = model.decoder(sos)
     T = encoder_out.size(1)
     t = 0
@@ -122,7 +124,9 @@ def beam_search(
     max_u = 20000  # terminate after this number of steps
     u = 0
 
-    cache: Dict[str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = {}
+    cache: Dict[
+        str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
+    ] = {}
 
     while t < T and u < max_u:
         # fmt: off
@@ -154,9 +158,9 @@ def beam_search(
             cached_key = "_".join(map(str, y_star.ys))
 
             if cached_key not in cache:
-                decoder_input = torch.tensor([y_star.ys[-1]], device=device).reshape(
-                    1, 1
-                )
+                decoder_input = torch.tensor(
+                    [y_star.ys[-1]], device=device
+                ).reshape(1, 1)
 
                 decoder_out, decoder_state = model.decoder(
                     decoder_input,
diff --git a/egs/librispeech/ASR/transducer_lstm/decode.py b/egs/librispeech/ASR/transducer_lstm/decode.py
index 61b9de504..604235e2a 100755
--- a/egs/librispeech/ASR/transducer_lstm/decode.py
+++ b/egs/librispeech/ASR/transducer_lstm/decode.py
@@ -71,19 +71,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=77,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=55,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -228,7 +225,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
     hyps = []
     batch_size = encoder_out.size(0)
 
@@ -243,7 +242,9 @@ def decode_one_batch(
                 model=model, encoder_out=encoder_out_i, beam=params.beam_size
             )
         else:
-            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
+            raise ValueError(
+                f"Unsupported decoding method: {params.decoding_method}"
+            )
         hyps.append(sp.decode(hyp).split())
 
     if params.decoding_method == "greedy_search":
@@ -314,7 +315,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -347,7 +350,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
diff --git a/egs/librispeech/ASR/transducer_lstm/encoder.py b/egs/librispeech/ASR/transducer_lstm/encoder.py
index 038d80077..3dc992dd2 100644
--- a/egs/librispeech/ASR/transducer_lstm/encoder.py
+++ b/egs/librispeech/ASR/transducer_lstm/encoder.py
@@ -48,7 +48,9 @@ class LstmEncoder(EncoderInterface):
         if vgg_frontend:
             self.encoder_embed = VggSubsampling(num_features, real_hidden_size)
         else:
-            self.encoder_embed = Conv2dSubsampling(num_features, real_hidden_size)
+            self.encoder_embed = Conv2dSubsampling(
+                num_features, real_hidden_size
+            )
 
         self.rnn = nn.LSTM(
             input_size=hidden_size,
diff --git a/egs/librispeech/ASR/transducer_lstm/train.py b/egs/librispeech/ASR/transducer_lstm/train.py
index 57bda63fd..cdb801e79 100755
--- a/egs/librispeech/ASR/transducer_lstm/train.py
+++ b/egs/librispeech/ASR/transducer_lstm/train.py
@@ -400,7 +400,9 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -522,7 +524,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -661,7 +665,9 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
+            tb_writer.add_scalar(
+                "train/learning_rate", cur_lr, params.batch_idx_train
+            )
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/transducer_stateless/alignment.py b/egs/librispeech/ASR/transducer_stateless/alignment.py
index 65f2c58d8..f143611ea 100644
--- a/egs/librispeech/ASR/transducer_stateless/alignment.py
+++ b/egs/librispeech/ASR/transducer_stateless/alignment.py
@@ -193,7 +193,9 @@ def force_alignment(
         decoder_out = model.decoder(decoder_input, need_pad=False)
         # decoder_output is of shape (num_active_items, 1, decoder_output_dim)
 
-        current_encoder_out = current_encoder_out.expand(decoder_out.size(0), 1, -1)
+        current_encoder_out = current_encoder_out.expand(
+            decoder_out.size(0), 1, -1
+        )
 
         logits = model.joiner(
             current_encoder_out,
diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py
index 1d79eef9d..ea985f30d 100644
--- a/egs/librispeech/ASR/transducer_stateless/beam_search.py
+++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py
@@ -316,9 +316,9 @@ def greedy_search(
         y = logits.argmax().item()
         if y != blank_id:
             hyp.append(y)
-            decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape(
-                1, context_size
-            )
+            decoder_input = torch.tensor(
+                [hyp[-context_size:]], device=device
+            ).reshape(1, context_size)
 
             decoder_out = model.decoder(decoder_input, need_pad=False)
 
@@ -478,7 +478,9 @@ class HypothesisList(object):
         key = hyp.key
         if key in self:
             old_hyp = self._data[key]  # shallow copy
-            torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob)
+            torch.logaddexp(
+                old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
+            )
         else:
             self._data[key] = hyp
 
@@ -494,7 +496,9 @@ class HypothesisList(object):
           Return the hypothesis that has the largest `log_prob`.
         """
         if length_norm:
-            return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
+            return max(
+                self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
+            )
         else:
             return max(self._data.values(), key=lambda hyp: hyp.log_prob)
 
@@ -782,7 +786,9 @@ def modified_beam_search(
         log_probs_shape = k2.ragged.create_ragged_shape2(
             row_splits=row_splits, cached_tot_size=log_probs.numel()
         )
-        ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
+        ragged_log_probs = k2.RaggedTensor(
+            shape=log_probs_shape, value=log_probs
+        )
 
         for i in range(batch_size):
             topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
@@ -881,7 +887,9 @@ def _deprecated_modified_beam_search(
         decoder_out = model.decoder(decoder_input, need_pad=False)
         # decoder_output is of shape (num_hyps, 1, decoder_output_dim)
 
-        current_encoder_out = current_encoder_out.expand(decoder_out.size(0), 1, -1)
+        current_encoder_out = current_encoder_out.expand(
+            decoder_out.size(0), 1, -1
+        )
 
         logits = model.joiner(
             current_encoder_out,
@@ -951,9 +959,9 @@ def beam_search(
 
     device = model.device
 
-    decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
-        1, context_size
-    )
+    decoder_input = torch.tensor(
+        [blank_id] * context_size, device=device
+    ).reshape(1, context_size)
 
     decoder_out = model.decoder(decoder_input, need_pad=False)
 
diff --git a/egs/librispeech/ASR/transducer_stateless/compute_ali.py b/egs/librispeech/ASR/transducer_stateless/compute_ali.py
index 89992856d..48769e9d1 100755
--- a/egs/librispeech/ASR/transducer_stateless/compute_ali.py
+++ b/egs/librispeech/ASR/transducer_stateless/compute_ali.py
@@ -54,19 +54,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=34,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=20,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -127,7 +124,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     return parser
@@ -164,7 +162,9 @@ def compute_alignments(
 
         feature_lens = supervisions["num_frames"].to(device)
 
-        encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+        encoder_out, encoder_out_lens = model.encoder(
+            x=feature, x_lens=feature_lens
+        )
 
         batch_size = encoder_out.size(0)
 
@@ -204,7 +204,9 @@ def compute_alignments(
         if batch_idx % 2 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
 
     return CutSet.from_cuts(cuts)
 
diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py
index d279eae85..cde52c9fc 100644
--- a/egs/librispeech/ASR/transducer_stateless/conformer.py
+++ b/egs/librispeech/ASR/transducer_stateless/conformer.py
@@ -209,7 +209,10 @@ class Conformer(Transformer):
 
           NOTE: the returned tensors are on the given device.
         """
-        if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context:
+        if (
+            len(self._init_state) == 2
+            and self._init_state[0].size(1) == left_context
+        ):
             # Note: It is OK to share the init state as it is
             # not going to be modified by the model
             return self._init_state
@@ -418,7 +421,9 @@ class ConformerEncoderLayer(nn.Module):
         causal: bool = False,
     ) -> None:
         super(ConformerEncoderLayer, self).__init__()
-        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
+        self.self_attn = RelPositionMultiheadAttention(
+            d_model, nhead, dropout=0.0
+        )
 
         self.feed_forward = nn.Sequential(
             nn.Linear(d_model, dim_feedforward),
@@ -434,16 +439,22 @@ class ConformerEncoderLayer(nn.Module):
             nn.Linear(dim_feedforward, d_model),
         )
 
-        self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal)
+        self.conv_module = ConvolutionModule(
+            d_model, cnn_module_kernel, causal=causal
+        )
 
-        self.norm_ff_macaron = nn.LayerNorm(d_model)  # for the macaron style FNN module
+        self.norm_ff_macaron = nn.LayerNorm(
+            d_model
+        )  # for the macaron style FNN module
         self.norm_ff = nn.LayerNorm(d_model)  # for the FNN module
         self.norm_mha = nn.LayerNorm(d_model)  # for the MHA module
 
         self.ff_scale = 0.5
 
         self.norm_conv = nn.LayerNorm(d_model)  # for the CNN module
-        self.norm_final = nn.LayerNorm(d_model)  # for the final output of the block
+        self.norm_final = nn.LayerNorm(
+            d_model
+        )  # for the final output of the block
 
         self.dropout = nn.Dropout(dropout)
 
@@ -475,7 +486,9 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
+        src = residual + self.ff_scale * self.dropout(
+            self.feed_forward_macaron(src)
+        )
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -501,7 +514,9 @@ class ConformerEncoderLayer(nn.Module):
         if self.normalize_before:
             src = self.norm_conv(src)
 
-        src, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
+        src, _ = self.conv_module(
+            src, src_key_padding_mask=src_key_padding_mask
+        )
         src = residual + self.dropout(src)
 
         if not self.normalize_before:
@@ -566,7 +581,9 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
+        src = residual + self.ff_scale * self.dropout(
+            self.feed_forward_macaron(src)
+        )
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -608,7 +625,9 @@ class ConformerEncoderLayer(nn.Module):
         if self.normalize_before:
             src = self.norm_conv(src)
 
-        src, conv_cache = self.conv_module(src, states[1], right_context=right_context)
+        src, conv_cache = self.conv_module(
+            src, states[1], right_context=right_context
+        )
         states[1] = conv_cache
         src = residual + self.dropout(src)
 
@@ -760,7 +779,9 @@ class RelPositionalEncoding(torch.nn.Module):
 
     """
 
-    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
+    def __init__(
+        self, d_model: int, dropout_rate: float, max_len: int = 5000
+    ) -> None:
         """Construct an PositionalEncoding object."""
         super(RelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -777,7 +798,9 @@ class RelPositionalEncoding(torch.nn.Module):
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(1) >= x_size_1 * 2 - 1:
                 # Note: TorchScript doesn't implement operator== for torch.Device
-                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
+                    x.device
+                ):
                     self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                 return
         # Suppose `i` means to the position of query vector and `j` means the
@@ -803,7 +826,9 @@ class RelPositionalEncoding(torch.nn.Module):
         pe = torch.cat([pe_positive, pe_negative], dim=1)
         self.pe = pe.to(device=x.device, dtype=x.dtype)
 
-    def forward(self, x: torch.Tensor, left_context: int = 0) -> Tuple[Tensor, Tensor]:
+    def forward(
+        self, x: torch.Tensor, left_context: int = 0
+    ) -> Tuple[Tensor, Tensor]:
         """Add positional encoding.
 
         Args:
@@ -1067,9 +1092,9 @@ class RelPositionMultiheadAttention(nn.Module):
 
         if torch.equal(query, key) and torch.equal(key, value):
             # self-attention
-            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
-                3, dim=-1
-            )
+            q, k, v = nn.functional.linear(
+                query, in_proj_weight, in_proj_bias
+            ).chunk(3, dim=-1)
 
         elif torch.equal(key, value):
             # encoder-decoder attention
@@ -1138,25 +1163,33 @@ class RelPositionMultiheadAttention(nn.Module):
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0)
                 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
-                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
+                    raise RuntimeError(
+                        "The size of the 2D attn_mask is not correct."
+                    )
             elif attn_mask.dim() == 3:
                 if list(attn_mask.size()) != [
                     bsz * num_heads,
                     query.size(0),
                     key.size(0),
                 ]:
-                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
+                    raise RuntimeError(
+                        "The size of the 3D attn_mask is not correct."
+                    )
             else:
                 raise RuntimeError(
-                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
+                    "attn_mask's dimension {} is not supported".format(
+                        attn_mask.dim()
+                    )
                 )
             # attn_mask's dim is 3 now.
 
         # convert ByteTensor key_padding_mask to bool
-        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
+        if (
+            key_padding_mask is not None
+            and key_padding_mask.dtype == torch.uint8
+        ):
             warnings.warn(
-                "Byte tensor for key_padding_mask is deprecated. Use bool tensor"
-                " instead."
+                "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
             )
             key_padding_mask = key_padding_mask.to(torch.bool)
 
@@ -1195,10 +1228,14 @@ class RelPositionMultiheadAttention(nn.Module):
         # first compute matrix a and matrix c
         # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
         k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
-        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(
+            q_with_bias_u, k
+        )  # (batch, head, time1, time2)
 
         # compute matrix b and matrix d
-        matrix_bd = torch.matmul(q_with_bias_v, p)  # (batch, head, time1, 2*time1-1)
+        matrix_bd = torch.matmul(
+            q_with_bias_v, p
+        )  # (batch, head, time1, 2*time1-1)
 
         matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
 
@@ -1206,7 +1243,9 @@ class RelPositionMultiheadAttention(nn.Module):
             matrix_ac + matrix_bd
         ) * scaling  # (batch, head, time1, time2)
 
-        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
+        attn_output_weights = attn_output_weights.view(
+            bsz * num_heads, tgt_len, -1
+        )
 
         assert list(attn_output_weights.size()) == [
             bsz * num_heads,
@@ -1251,7 +1290,9 @@ class RelPositionMultiheadAttention(nn.Module):
             attn_output_weights = attn_output_weights.view(
                 bsz, num_heads, tgt_len, src_len
             )
-            attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0)
+            attn_output_weights = attn_output_weights.masked_fill(
+                combined_mask, 0.0
+            )
             attn_output_weights = attn_output_weights.view(
                 bsz * num_heads, tgt_len, src_len
             )
@@ -1263,9 +1304,13 @@ class RelPositionMultiheadAttention(nn.Module):
         attn_output = torch.bmm(attn_output_weights, v)
         assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
         attn_output = (
-            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+            attn_output.transpose(0, 1)
+            .contiguous()
+            .view(tgt_len, bsz, embed_dim)
+        )
+        attn_output = nn.functional.linear(
+            attn_output, out_proj_weight, out_proj_bias
         )
-        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
 
         if need_weights:
             # average attention weights over heads
@@ -1373,12 +1418,16 @@ class ConvolutionModule(nn.Module):
                 # manualy padding self.lorder zeros to the left
                 x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
             else:
-                assert not self.training, "Cache should be None in training time"
+                assert (
+                    not self.training
+                ), "Cache should be None in training time"
                 assert cache.size(0) == self.lorder
                 x = torch.cat([cache.permute(1, 2, 0), x], dim=2)
                 if right_context > 0:
                     cache = x.permute(2, 0, 1)[
-                        -(self.lorder + right_context) : (-right_context),  # noqa
+                        -(self.lorder + right_context) : (  # noqa
+                            -right_context
+                        ),
                         ...,
                     ]
                 else:
diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py
index 314f49154..74bba9cad 100755
--- a/egs/librispeech/ASR/transducer_stateless/decode.py
+++ b/egs/librispeech/ASR/transducer_stateless/decode.py
@@ -94,19 +94,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=29,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=13,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -174,7 +171,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -232,7 +230,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
 
     hyps = []
 
@@ -248,7 +248,10 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -294,7 +297,11 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -367,7 +374,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -400,7 +409,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -440,7 +450,9 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        params.suffix += (
+            f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        )
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
diff --git a/egs/librispeech/ASR/transducer_stateless/decoder.py b/egs/librispeech/ASR/transducer_stateless/decoder.py
index a182d91e2..fbc2373a9 100644
--- a/egs/librispeech/ASR/transducer_stateless/decoder.py
+++ b/egs/librispeech/ASR/transducer_stateless/decoder.py
@@ -87,7 +87,9 @@ class Decoder(nn.Module):
         if self.context_size > 1:
             embedding_out = embedding_out.permute(0, 2, 1)
             if need_pad is True:
-                embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
+                embedding_out = F.pad(
+                    embedding_out, pad=(self.context_size - 1, 0)
+                )
             else:
                 # During inference time, there is no need to do extra padding
                 # as we only need one output
diff --git a/egs/librispeech/ASR/transducer_stateless/export.py b/egs/librispeech/ASR/transducer_stateless/export.py
index 7c10b4348..8bd0bdea1 100755
--- a/egs/librispeech/ASR/transducer_stateless/export.py
+++ b/egs/librispeech/ASR/transducer_stateless/export.py
@@ -68,20 +68,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=20,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=10,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -112,7 +109,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     return parser
@@ -246,7 +244,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless/joiner.py b/egs/librispeech/ASR/transducer_stateless/joiner.py
index e1625992d..93cccbd8c 100644
--- a/egs/librispeech/ASR/transducer_stateless/joiner.py
+++ b/egs/librispeech/ASR/transducer_stateless/joiner.py
@@ -60,9 +60,13 @@ class Joiner(nn.Module):
         encoder_out_len: List[int] = encoder_out_len.tolist()
         decoder_out_len: List[int] = decoder_out_len.tolist()
 
-        encoder_out_list = [encoder_out[i, : encoder_out_len[i], :] for i in range(N)]
+        encoder_out_list = [
+            encoder_out[i, : encoder_out_len[i], :] for i in range(N)
+        ]
 
-        decoder_out_list = [decoder_out[i, : decoder_out_len[i], :] for i in range(N)]
+        decoder_out_list = [
+            decoder_out[i, : decoder_out_len[i], :] for i in range(N)
+        ]
 
         x = [
             e.unsqueeze(1) + d.unsqueeze(0)
diff --git a/egs/librispeech/ASR/transducer_stateless/pretrained.py b/egs/librispeech/ASR/transducer_stateless/pretrained.py
index bd7eeff28..b64521801 100755
--- a/egs/librispeech/ASR/transducer_stateless/pretrained.py
+++ b/egs/librispeech/ASR/transducer_stateless/pretrained.py
@@ -90,11 +90,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -119,12 +117,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -171,7 +167,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -200,9 +197,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -261,7 +259,9 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -334,7 +334,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py b/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py
index 9af46846a..b00fc34f1 100755
--- a/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py
+++ b/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py
@@ -140,13 +140,16 @@ def main():
                 token_alignment[i, : token_alignment_length[i]].tolist(), sp=sp
             )
             word_starting_time = [
-                "{:.2f}".format(i * frame_shift_in_second) for i in word_starting_frames
+                "{:.2f}".format(i * frame_shift_in_second)
+                for i in word_starting_frames
             ]
 
             words = supervisions["text"][i].split()
 
             assert len(word_starting_frames) == len(words)
-            word_starting_time_dict[cuts[i].id] = list(zip(words, word_starting_time))
+            word_starting_time_dict[cuts[i].id] = list(
+                zip(words, word_starting_time)
+            )
 
         # This is a demo script and we exit here after processing
         # one batch.
@@ -157,7 +160,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless/test_conformer.py b/egs/librispeech/ASR/transducer_stateless/test_conformer.py
index 65b08d425..d1350c8ab 100755
--- a/egs/librispeech/ASR/transducer_stateless/test_conformer.py
+++ b/egs/librispeech/ASR/transducer_stateless/test_conformer.py
@@ -29,7 +29,9 @@ from conformer import Conformer
 
 def test_conformer():
     feature_dim = 50
-    c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4)
+    c = Conformer(
+        num_features=feature_dim, output_dim=256, d_model=128, nhead=4
+    )
     batch_size = 5
     seq_len = 20
     # Just make sure the forward pass runs.
diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py
index bcb883fa5..ae93f3348 100755
--- a/egs/librispeech/ASR/transducer_stateless/train.py
+++ b/egs/librispeech/ASR/transducer_stateless/train.py
@@ -136,7 +136,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -421,7 +422,9 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -542,7 +545,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -659,9 +664,13 @@ def run(rank, world_size, args):
         num_removed = num_in_total - num_left
         removed_percent = num_removed / num_in_total * 100
 
-        logging.info(f"Before removing short and long utterances: {num_in_total}")
+        logging.info(
+            f"Before removing short and long utterances: {num_in_total}"
+        )
         logging.info(f"After removing short and long utterances: {num_left}")
-        logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
+        logging.info(
+            f"Removed {num_removed} utterances ({removed_percent:.5f}%)"
+        )
     except TypeError as e:
         # You can ignore this error as previous versions of Lhotse work fine
         # for the above code. In recent versions of Lhotse, it uses
@@ -689,7 +698,9 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
+            tb_writer.add_scalar(
+                "train/learning_rate", cur_lr, params.batch_idx_train
+            )
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/transducer_stateless/transformer.py b/egs/librispeech/ASR/transducer_stateless/transformer.py
index b3ff153c1..e851dcc32 100644
--- a/egs/librispeech/ASR/transducer_stateless/transformer.py
+++ b/egs/librispeech/ASR/transducer_stateless/transformer.py
@@ -250,7 +250,9 @@ def _get_activation_fn(activation: str):
     elif activation == "gelu":
         return nn.functional.gelu
 
-    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
+    raise RuntimeError(
+        "activation should be relu/gelu, not {}".format(activation)
+    )
 
 
 class PositionalEncoding(nn.Module):
diff --git a/egs/librispeech/ASR/transducer_stateless2/decode.py b/egs/librispeech/ASR/transducer_stateless2/decode.py
index 86ef9e5b6..ac2807241 100755
--- a/egs/librispeech/ASR/transducer_stateless2/decode.py
+++ b/egs/librispeech/ASR/transducer_stateless2/decode.py
@@ -94,19 +94,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=29,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=13,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -174,7 +171,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -232,7 +230,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
 
     hyps = []
 
@@ -248,7 +248,10 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -294,7 +297,11 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -367,7 +374,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -400,7 +409,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -440,7 +450,9 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        params.suffix += (
+            f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        )
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
diff --git a/egs/librispeech/ASR/transducer_stateless2/export.py b/egs/librispeech/ASR/transducer_stateless2/export.py
index d95eeb1f4..57c1a6094 100755
--- a/egs/librispeech/ASR/transducer_stateless2/export.py
+++ b/egs/librispeech/ASR/transducer_stateless2/export.py
@@ -63,20 +63,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=20,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=10,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -107,7 +104,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     return parser
@@ -178,7 +176,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless2/pretrained.py b/egs/librispeech/ASR/transducer_stateless2/pretrained.py
index 793931e3b..292f77f03 100755
--- a/egs/librispeech/ASR/transducer_stateless2/pretrained.py
+++ b/egs/librispeech/ASR/transducer_stateless2/pretrained.py
@@ -90,11 +90,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -119,12 +117,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -171,7 +167,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -200,9 +197,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -261,7 +259,9 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -334,7 +334,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless2/train.py b/egs/librispeech/ASR/transducer_stateless2/train.py
index 68e247f23..ea15c9040 100755
--- a/egs/librispeech/ASR/transducer_stateless2/train.py
+++ b/egs/librispeech/ASR/transducer_stateless2/train.py
@@ -136,7 +136,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -409,7 +410,9 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -530,7 +533,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -647,9 +652,13 @@ def run(rank, world_size, args):
         num_removed = num_in_total - num_left
         removed_percent = num_removed / num_in_total * 100
 
-        logging.info(f"Before removing short and long utterances: {num_in_total}")
+        logging.info(
+            f"Before removing short and long utterances: {num_in_total}"
+        )
         logging.info(f"After removing short and long utterances: {num_left}")
-        logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
+        logging.info(
+            f"Removed {num_removed} utterances ({removed_percent:.5f}%)"
+        )
     except TypeError as e:
         # You can ignore this error as previous versions of Lhotse work fine
         # for the above code. In recent versions of Lhotse, it uses
@@ -677,7 +686,9 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
+            tb_writer.add_scalar(
+                "train/learning_rate", cur_lr, params.batch_idx_train
+            )
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py
index 22b6ab911..d596e05cb 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py
@@ -95,19 +95,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=29,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=13,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -175,7 +172,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -233,7 +231,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
 
     hyps = []
 
@@ -249,7 +249,10 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -295,7 +298,11 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -368,7 +375,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -401,7 +410,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -441,7 +451,9 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        params.suffix += (
+            f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        )
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py
index fad9a6977..b6b69d932 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py
@@ -69,20 +69,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=20,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=10,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -113,7 +110,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     return parser
@@ -249,7 +247,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py
index efd257b5d..f297fa2b2 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py
@@ -90,11 +90,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -119,12 +117,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -171,7 +167,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -200,9 +197,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -261,7 +259,9 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -334,7 +334,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py
index 1e1188ca6..ef51a7811 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py
@@ -41,7 +41,9 @@ def test_dataset():
     print(args)
 
     if args.enable_musan:
-        cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
+        cuts_musan = load_manifest(
+            Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
+        )
     else:
         cuts_musan = None
 
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
index 88987d91c..27912738c 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
@@ -114,7 +114,8 @@ def get_parser():
         "--full-libri",
         type=str2bool,
         default=True,
-        help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.",
+        help="When enabled, use 960h LibriSpeech. "
+        "Otherwise, use 100h subset.",
     )
 
     parser.add_argument(
@@ -169,7 +170,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -467,7 +469,9 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -631,7 +635,9 @@ def train_one_epoch(
                     f"train/current_{prefix}_",
                     params.batch_idx_train,
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
                 libri_tot_loss.write_summary(
                     tb_writer, "train/libri_tot_", params.batch_idx_train
                 )
@@ -778,7 +784,9 @@ def run(rank, world_size, args):
     train_giga_cuts = train_giga_cuts.repeat(times=None)
 
     if args.enable_musan:
-        cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
+        cuts_musan = load_manifest(
+            Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
+        )
     else:
         cuts_musan = None
 
@@ -817,7 +825,9 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
+            tb_writer.add_scalar(
+                "train/learning_rate", cur_lr, params.batch_idx_train
+            )
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/ptb/LM/local/sort_lm_training_data.py b/egs/ptb/LM/local/sort_lm_training_data.py
index bed3856e4..af54dbd07 100755
--- a/egs/ptb/LM/local/sort_lm_training_data.py
+++ b/egs/ptb/LM/local/sort_lm_training_data.py
@@ -135,7 +135,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/ptb/LM/local/test_prepare_lm_training_data.py b/egs/ptb/LM/local/test_prepare_lm_training_data.py
index 3790045fa..877720e7b 100755
--- a/egs/ptb/LM/local/test_prepare_lm_training_data.py
+++ b/egs/ptb/LM/local/test_prepare_lm_training_data.py
@@ -54,7 +54,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/spgispeech/ASR/local/compute_fbank_musan.py b/egs/spgispeech/ASR/local/compute_fbank_musan.py
index 9bea28a41..6cb8b65ae 100755
--- a/egs/spgispeech/ASR/local/compute_fbank_musan.py
+++ b/egs/spgispeech/ASR/local/compute_fbank_musan.py
@@ -87,7 +87,9 @@ def compute_fbank_musan():
     # create chunks of Musan with duration 5 - 10 seconds
     musan_cuts = (
         CutSet.from_manifests(
-            recordings=combine(part["recordings"] for part in manifests.values())
+            recordings=combine(
+                part["recordings"] for part in manifests.values()
+            )
         )
         .cut_into_windows(10.0)
         .filter(lambda c: c.duration > 5)
@@ -106,6 +108,8 @@ def compute_fbank_musan():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
     logging.basicConfig(format=formatter, level=logging.INFO)
     compute_fbank_musan()
diff --git a/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py b/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py
index 20ff6d7ab..8116e7605 100755
--- a/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py
+++ b/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py
@@ -103,7 +103,11 @@ def compute_fbank_spgispeech(args):
             chunk_size=chunk_size,
         )
         start = args.start
-        stop = min(args.stop, args.num_splits) if args.stop > 0 else args.num_splits
+        stop = (
+            min(args.stop, args.num_splits)
+            if args.stop > 0
+            else args.num_splits
+        )
         num_digits = len(str(args.num_splits))
         for i in range(start, stop):
             idx = f"{i + 1}".zfill(num_digits)
@@ -125,7 +129,9 @@ def compute_fbank_spgispeech(args):
                 logging.info(f"{partition} already exists - skipping.")
                 continue
             logging.info(f"Processing {partition}")
-            cut_set = load_manifest_lazy(src_dir / f"cuts_{partition}_raw.jsonl.gz")
+            cut_set = load_manifest_lazy(
+                src_dir / f"cuts_{partition}_raw.jsonl.gz"
+            )
             cut_set = cut_set.compute_and_store_features_batch(
                 extractor=extractor,
                 storage_path=output_dir / f"feats_{partition}",
@@ -138,7 +144,9 @@ def compute_fbank_spgispeech(args):
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
     logging.basicConfig(format=formatter, level=logging.INFO)
 
     args = get_args()
diff --git a/egs/spgispeech/ASR/local/prepare_splits.py b/egs/spgispeech/ASR/local/prepare_splits.py
index 508d4acd8..8c8f1c133 100755
--- a/egs/spgispeech/ASR/local/prepare_splits.py
+++ b/egs/spgispeech/ASR/local/prepare_splits.py
@@ -55,7 +55,9 @@ def split_spgispeech_train():
 
     # Add speed perturbation
     train_cuts = (
-        train_cuts + train_cuts.perturb_speed(0.9) + train_cuts.perturb_speed(1.1)
+        train_cuts
+        + train_cuts.perturb_speed(0.9)
+        + train_cuts.perturb_speed(1.1)
     )
 
     # Write the manifests to disk.
@@ -71,7 +73,9 @@ def split_spgispeech_train():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
     logging.basicConfig(format=formatter, level=logging.INFO)
 
     split_spgispeech_train()
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
index 83f95d123..f165f6e60 100644
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -70,12 +70,10 @@ class SPGISpeechAsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description=(
-                "These options are used for the preparation of "
-                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-                "effective batch sizes, sampling strategies, applied data "
-                "augmentations, etc."
-            ),
+            description="These options are used for the preparation of "
+            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+            "effective batch sizes, sampling strategies, applied data "
+            "augmentations, etc.",
         )
         group.add_argument(
             "--manifest-dir",
@@ -87,81 +85,67 @@ class SPGISpeechAsrDataModule:
             "--enable-musan",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, select noise from MUSAN and mix it "
-                "with training dataset. "
-            ),
+            help="When enabled, select noise from MUSAN and mix it "
+            "with training dataset. ",
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, utterances (cuts) will be concatenated "
-                "to minimize the amount of padding."
-            ),
+            help="When enabled, utterances (cuts) will be concatenated "
+            "to minimize the amount of padding.",
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help=(
-                "Determines the maximum duration of a concatenated cut "
-                "relative to the duration of the longest cut in a batch."
-            ),
+            help="Determines the maximum duration of a concatenated cut "
+            "relative to the duration of the longest cut in a batch.",
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help=(
-                "The amount of padding (in seconds) inserted between "
-                "concatenated cuts. This padding is filled with noise when "
-                "noise augmentation is used."
-            ),
+            help="The amount of padding (in seconds) inserted between "
+            "concatenated cuts. This padding is filled with noise when "
+            "noise augmentation is used.",
         )
         group.add_argument(
             "--max-duration",
             type=int,
             default=100.0,
-            help=(
-                "Maximum pooled recordings duration (seconds) in a "
-                "single batch. You can reduce it if it causes CUDA OOM."
-            ),
+            help="Maximum pooled recordings duration (seconds) in a "
+            "single batch. You can reduce it if it causes CUDA OOM.",
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=30,
-            help=(
-                "The number of buckets for the BucketingSampler"
-                "(you might want to increase it for larger datasets)."
-            ),
+            help="The number of buckets for the BucketingSampler"
+            "(you might want to increase it for larger datasets).",
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, use on-the-fly cut mixing and feature "
-                "extraction. Will drop existing precomputed feature manifests "
-                "if available."
-            ),
+            help="When enabled, use on-the-fly cut mixing and feature "
+            "extraction. Will drop existing precomputed feature manifests "
+            "if available.",
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled (=default), the examples will be shuffled for each epoch."
-            ),
+            help="When enabled (=default), the examples will be "
+            "shuffled for each epoch.",
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=8,
-            help="The number of training dataloader workers that collect the batches.",
+            help="The number of training dataloader workers that "
+            "collect the batches.",
         )
         group.add_argument(
             "--enable-spec-aug",
@@ -173,12 +157,10 @@ class SPGISpeechAsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help=(
-                "Used only when --enable-spec-aug is True. "
-                "It specifies the factor for time warping in SpecAugment. "
-                "Larger values mean more warping. "
-                "A value less than 1 means to disable time warp."
-            ),
+            help="Used only when --enable-spec-aug is True. "
+            "It specifies the factor for time warping in SpecAugment. "
+            "Larger values mean more warping. "
+            "A value less than 1 means to disable time warp.",
         )
 
     def train_dataloaders(
@@ -194,20 +176,24 @@ class SPGISpeechAsrDataModule:
             The state dict for the training sampler.
         """
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.jsonl.gz")
+        cuts_musan = load_manifest(
+            self.args.manifest_dir / "cuts_musan.jsonl.gz"
+        )
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+                CutMix(
+                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
+                )
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                "Using cut concatenation with duration factor "
+                f"Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -222,7 +208,9 @@ class SPGISpeechAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+            logging.info(
+                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
+            )
             input_transforms.append(
                 SpecAugment(
                     time_warp_factor=self.args.spec_aug_time_warp_factor,
@@ -239,7 +227,9 @@ class SPGISpeechAsrDataModule:
         if self.args.on_the_fly_feats:
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 input_transforms=input_transforms,
             )
         else:
@@ -292,7 +282,9 @@ class SPGISpeechAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
             )
         else:
             validate = K2SpeechRecognitionDataset(
@@ -336,7 +328,9 @@ class SPGISpeechAsrDataModule:
     @lru_cache()
     def train_cuts(self) -> CutSet:
         logging.info("About to get SPGISpeech train cuts")
-        return load_manifest_lazy(self.args.manifest_dir / "cuts_train_shuf.jsonl.gz")
+        return load_manifest_lazy(
+            self.args.manifest_dir / "cuts_train_shuf.jsonl.gz"
+        )
 
     @lru_cache()
     def dev_cuts(self) -> CutSet:
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py
index 72a7cd1c1..c39bd0530 100755
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py
@@ -76,7 +76,11 @@ from beam_search import (
 )
 from train import get_params, get_transducer_model
 
-from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
+from icefall.checkpoint import (
+    average_checkpoints,
+    find_checkpoints,
+    load_checkpoint,
+)
 from icefall.utils import (
     AttributeDict,
     setup_logger,
@@ -113,11 +117,9 @@ def get_parser():
         "--avg",
         type=int,
         default=10,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch' and '--iter'"
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
     )
 
     parser.add_argument(
@@ -185,7 +187,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -243,7 +246,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -258,7 +263,10 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -304,7 +312,11 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -377,7 +389,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -410,7 +424,9 @@ def save_results(
         # we also compute CER for spgispeech dataset.
         results_char = []
         for res in results:
-            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
+            results_char.append(
+                (res[0], list("".join(res[1])), list("".join(res[2])))
+            )
         cers_filename = (
             params.res_dir / f"cers-{test_set_name}-{key}-{params.suffix}.txt"
         )
@@ -422,23 +438,32 @@ def save_results(
 
         logging.info("Wrote detailed error stats to {}".format(wers_filename))
 
-    test_set_wers = {k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])}
-    test_set_cers = {k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])}
+    test_set_wers = {
+        k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])
+    }
+    test_set_cers = {
+        k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])
+    }
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER\tCER", file=f)
         for key in test_set_wers:
             print(
-                "{}\t{}\t{}".format(key, test_set_wers[key], test_set_cers[key]),
+                "{}\t{}\t{}".format(
+                    key, test_set_wers[key], test_set_cers[key]
+                ),
                 file=f,
             )
 
     s = "\nFor {}, WER/CER of different settings are:\n".format(test_set_name)
     note = "\tbest for {}".format(test_set_name)
     for key in test_set_wers:
-        s += "{}\t{}\t{}{}\n".format(key, test_set_wers[key], test_set_cers[key], note)
+        s += "{}\t{}\t{}{}\n".format(
+            key, test_set_wers[key], test_set_cers[key], note
+        )
         note = ""
     logging.info(s)
 
@@ -471,7 +496,9 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        params.suffix += (
+            f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        )
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@@ -503,7 +530,8 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for"
+                f" --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py
index 1f18ae2f3..77faa3c0e 100755
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py
@@ -50,7 +50,11 @@ import sentencepiece as spm
 import torch
 from train import get_params, get_transducer_model
 
-from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
+from icefall.checkpoint import (
+    average_checkpoints,
+    find_checkpoints,
+    load_checkpoint,
+)
 from icefall.utils import str2bool
 
 
@@ -63,20 +67,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=28,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -118,7 +119,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     return parser
@@ -194,7 +196,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py
index cd835a7b4..dda29b3e5 100755
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py
@@ -77,7 +77,9 @@ from icefall.dist import cleanup_dist, setup_dist
 from icefall.env import get_env_info
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+LRSchedulerType = Union[
+    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
+]
 
 
 def get_parser():
@@ -153,7 +155,8 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need to be changed.",
+        help="The initial learning rate.  This value should not need to be "
+        "changed.",
     )
 
     parser.add_argument(
@@ -176,45 +179,42 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help=(
-            "The prune range for rnnt loss, it means how many symbols(context)"
-            "we are using to compute the loss"
-        ),
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help=(
-            "The scale to smooth the loss with lm (output of prediction network) part."
-        ),
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)part.",
+        help="The scale to smooth the loss with am (output of encoder network)"
+        "part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help=(
-            "To get pruning ranges, we will calculate a simple version"
-            "loss(joiner is just addition), this simple loss also uses for"
-            "training (as a regularization item). We will scale the simple loss"
-            "with this parameter before adding to the final loss."
-        ),
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
     )
 
     parser.add_argument(
@@ -554,16 +554,23 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+            0.0
+            if warmup < 1.0
+            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+        )
+        loss = (
+            params.simple_loss_scale * simple_loss
+            + pruned_loss_scale * pruned_loss
         )
-        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
 
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -726,7 +733,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
diff --git a/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py b/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py
index 602e50d29..4582609ac 100755
--- a/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py
+++ b/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py
@@ -84,7 +84,9 @@ def compute_fbank_tal_csasr(num_mel_bins: int = 80):
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+                    cut_set
+                    + cut_set.perturb_speed(0.9)
+                    + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -110,7 +112,9 @@ def get_args():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/tal_csasr/ASR/local/prepare_char.py b/egs/tal_csasr/ASR/local/prepare_char.py
index 1262baf63..2c5b8b8b3 100755
--- a/egs/tal_csasr/ASR/local/prepare_char.py
+++ b/egs/tal_csasr/ASR/local/prepare_char.py
@@ -87,7 +87,9 @@ def lexicon_to_fst_no_sil(
         cur_state = loop_state
 
         word = word2id[word]
-        pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces]
+        pieces = [
+            token2id[i] if i in token2id else token2id[""] for i in pieces
+        ]
 
         for i in range(len(pieces) - 1):
             w = word if i == 0 else eps
diff --git a/egs/tal_csasr/ASR/local/prepare_lang.py b/egs/tal_csasr/ASR/local/prepare_lang.py
index c8cf9b881..e5ae89ec4 100755
--- a/egs/tal_csasr/ASR/local/prepare_lang.py
+++ b/egs/tal_csasr/ASR/local/prepare_lang.py
@@ -317,7 +317,9 @@ def lexicon_to_fst(
 
 def get_args():
     parser = argparse.ArgumentParser()
-    parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
+    parser.add_argument(
+        "--lang-dir", type=str, help="The lang dir, data/lang_phone"
+    )
     return parser.parse_args()
 
 
diff --git a/egs/tal_csasr/ASR/local/test_prepare_lang.py b/egs/tal_csasr/ASR/local/test_prepare_lang.py
index 74e025ad7..d4cf62bba 100755
--- a/egs/tal_csasr/ASR/local/test_prepare_lang.py
+++ b/egs/tal_csasr/ASR/local/test_prepare_lang.py
@@ -88,7 +88,9 @@ def test_read_lexicon(filename: str):
     fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa.draw("L.pdf", title="L")
 
-    fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
+    fsa_disambig = lexicon_to_fst(
+        lexicon_disambig, phone2id=phone2id, word2id=word2id
+    )
     fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
     fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
diff --git a/egs/tal_csasr/ASR/local/text2token.py b/egs/tal_csasr/ASR/local/text2token.py
index 2be639b7a..71be2a613 100755
--- a/egs/tal_csasr/ASR/local/text2token.py
+++ b/egs/tal_csasr/ASR/local/text2token.py
@@ -50,15 +50,15 @@ def get_parser():
         "-n",
         default=1,
         type=int,
-        help=(
-            "number of characters to split, i.e.,                         aabb -> a a b"
-            " b with -n 1 and aa bb with -n 2"
-        ),
+        help="number of characters to split, i.e., \
+                        aabb -> a a b b with -n 1 and aa bb with -n 2",
     )
     parser.add_argument(
         "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
     )
-    parser.add_argument("--space", default="", type=str, help="space symbol")
+    parser.add_argument(
+        "--space", default="", type=str, help="space symbol"
+    )
     parser.add_argument(
         "--non-lang-syms",
         "-l",
@@ -66,7 +66,9 @@ def get_parser():
         type=str,
         help="list of non-linguistic symobles, e.g.,  etc.",
     )
-    parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
+    parser.add_argument(
+        "text", type=str, default=False, nargs="?", help="input text"
+    )
     parser.add_argument(
         "--trans_type",
         "-t",
@@ -106,7 +108,8 @@ def token2id(
             if token_type == "lazy_pinyin":
                 text = lazy_pinyin(chars_list)
                 sub_ids = [
-                    token_table[txt] if txt in token_table else oov_id for txt in text
+                    token_table[txt] if txt in token_table else oov_id
+                    for txt in text
                 ]
                 ids.append(sub_ids)
             else:  # token_type = "pinyin"
@@ -132,7 +135,9 @@ def main():
     if args.text:
         f = codecs.open(args.text, encoding="utf-8")
     else:
-        f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
+        f = codecs.getreader("utf-8")(
+            sys.stdin if is_python2 else sys.stdin.buffer
+        )
 
     sys.stdout = codecs.getwriter("utf-8")(
         sys.stdout if is_python2 else sys.stdout.buffer
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py
index 02bd6e2cc..49bfb148b 100644
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py
@@ -74,12 +74,10 @@ class TAL_CSASRAsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description=(
-                "These options are used for the preparation of "
-                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-                "effective batch sizes, sampling strategies, applied data "
-                "augmentations, etc."
-            ),
+            description="These options are used for the preparation of "
+            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+            "effective batch sizes, sampling strategies, applied data "
+            "augmentations, etc.",
         )
 
         group.add_argument(
@@ -93,81 +91,66 @@ class TAL_CSASRAsrDataModule:
             "--max-duration",
             type=int,
             default=200.0,
-            help=(
-                "Maximum pooled recordings duration (seconds) in a "
-                "single batch. You can reduce it if it causes CUDA OOM."
-            ),
+            help="Maximum pooled recordings duration (seconds) in a "
+            "single batch. You can reduce it if it causes CUDA OOM.",
         )
 
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, the batches will come from buckets of "
-                "similar duration (saves padding frames)."
-            ),
+            help="When enabled, the batches will come from buckets of "
+            "similar duration (saves padding frames).",
         )
 
         group.add_argument(
             "--num-buckets",
             type=int,
             default=300,
-            help=(
-                "The number of buckets for the DynamicBucketingSampler"
-                "(you might want to increase it for larger datasets)."
-            ),
+            help="The number of buckets for the DynamicBucketingSampler"
+            "(you might want to increase it for larger datasets).",
         )
 
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, utterances (cuts) will be concatenated "
-                "to minimize the amount of padding."
-            ),
+            help="When enabled, utterances (cuts) will be concatenated "
+            "to minimize the amount of padding.",
         )
 
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help=(
-                "Determines the maximum duration of a concatenated cut "
-                "relative to the duration of the longest cut in a batch."
-            ),
+            help="Determines the maximum duration of a concatenated cut "
+            "relative to the duration of the longest cut in a batch.",
         )
 
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help=(
-                "The amount of padding (in seconds) inserted between "
-                "concatenated cuts. This padding is filled with noise when "
-                "noise augmentation is used."
-            ),
+            help="The amount of padding (in seconds) inserted between "
+            "concatenated cuts. This padding is filled with noise when "
+            "noise augmentation is used.",
         )
 
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, use on-the-fly cut mixing and feature "
-                "extraction. Will drop existing precomputed feature manifests "
-                "if available."
-            ),
+            help="When enabled, use on-the-fly cut mixing and feature "
+            "extraction. Will drop existing precomputed feature manifests "
+            "if available.",
         )
 
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled (=default), the examples will be shuffled for each epoch."
-            ),
+            help="When enabled (=default), the examples will be "
+            "shuffled for each epoch.",
         )
 
         group.add_argument(
@@ -181,18 +164,17 @@ class TAL_CSASRAsrDataModule:
             "--return-cuts",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, each batch will have the "
-                "field: batch['supervisions']['cut'] with the cuts that "
-                "were used to construct it."
-            ),
+            help="When enabled, each batch will have the "
+            "field: batch['supervisions']['cut'] with the cuts that "
+            "were used to construct it.",
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that collect the batches.",
+            help="The number of training dataloader workers that "
+            "collect the batches.",
         )
 
         group.add_argument(
@@ -206,22 +188,18 @@ class TAL_CSASRAsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help=(
-                "Used only when --enable-spec-aug is True. "
-                "It specifies the factor for time warping in SpecAugment. "
-                "Larger values mean more warping. "
-                "A value less than 1 means to disable time warp."
-            ),
+            help="Used only when --enable-spec-aug is True. "
+            "It specifies the factor for time warping in SpecAugment. "
+            "Larger values mean more warping. "
+            "A value less than 1 means to disable time warp.",
         )
 
         group.add_argument(
             "--enable-musan",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, select noise from MUSAN and mix it"
-                "with training dataset. "
-            ),
+            help="When enabled, select noise from MUSAN and mix it"
+            "with training dataset. ",
         )
 
         group.add_argument(
@@ -244,20 +222,24 @@ class TAL_CSASRAsrDataModule:
             The state dict for the training sampler.
         """
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+        cuts_musan = load_manifest(
+            self.args.manifest_dir / "musan_cuts.jsonl.gz"
+        )
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+                CutMix(
+                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
+                )
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                "Using cut concatenation with duration factor "
+                f"Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -272,7 +254,9 @@ class TAL_CSASRAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+            logging.info(
+                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
+            )
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -316,7 +300,9 @@ class TAL_CSASRAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -374,7 +360,9 @@ class TAL_CSASRAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 return_cuts=self.args.return_cuts,
             )
         else:
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py
index b2aef7e86..b624913f5 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py
@@ -124,24 +124,20 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch' and '--iter'"
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
     )
 
     parser.add_argument(
         "--use-averaged-model",
         type=str2bool,
         default=False,
-        help=(
-            "Whether to load averaged model. Currently it only supports "
-            "using --epoch. If True, it would decode with the averaged model "
-            "over the epoch range from `epoch-avg` (excluded) to `epoch`."
-            "Actually only the models with epoch number of `epoch-avg` and "
-            "`epoch` are loaded for averaging. "
-        ),
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
     )
 
     parser.add_argument(
@@ -212,7 +208,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -271,7 +268,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
     hyps = []
     zh_hyps = []
     en_hyps = []
@@ -304,7 +303,10 @@ def decode_one_batch(
             hyps.append(chars_new)
             zh_hyps.append(zh_text)
             en_hyps.append(en_text)
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -373,7 +375,9 @@ def decode_one_batch(
                     f"Unsupported decoding method: {params.decoding_method}"
                 )
             for i in range(encoder_out.size(0)):
-                hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+                hyp = sp.decode(
+                    [lexicon.token_table[idx] for idx in hyp_tokens[i]]
+                )
                 chars = pattern.split(hyp.upper())
                 chars_new = []
                 zh_text = []
@@ -392,11 +396,11 @@ def decode_one_batch(
         return {"greedy_search": (hyps, zh_hyps, en_hyps)}
     elif params.decoding_method == "fast_beam_search":
         return {
-            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": (
-                hyps,
-                zh_hyps,
-                en_hyps,
-            )
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): (hyps, zh_hyps, en_hyps)
         }
     else:
         return {f"beam_size_{params.beam_size}": (hyps, zh_hyps, en_hyps)}
@@ -502,7 +506,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results, zh_results, en_results
 
 
@@ -535,7 +541,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -578,7 +585,9 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        params.suffix += (
+            f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        )
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@@ -610,12 +619,13 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg:
                 raise ValueError(
@@ -638,12 +648,13 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg + 1
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg + 1]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg + 1:
                 raise ValueError(
@@ -671,7 +682,7 @@ def main():
             filename_start = f"{params.exp_dir}/epoch-{start}.pt"
             filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
             logging.info(
-                "Calculating the averaged model over epoch range from "
+                f"Calculating the averaged model over epoch range from "
                 f"{start} (excluded) to {params.epoch}"
             )
             model.to(device)
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
index 94a4c7a2e..8f900208a 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
@@ -92,24 +92,20 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch' and '--iter'"
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
     )
 
     parser.add_argument(
         "--use-averaged-model",
         type=str2bool,
         default=False,
-        help=(
-            "Whether to load averaged model. Currently it only supports "
-            "using --epoch. If True, it would decode with the averaged model "
-            "over the epoch range from `epoch-avg` (excluded) to `epoch`."
-            "Actually only the models with epoch number of `epoch-avg` and "
-            "`epoch` are loaded for averaging. "
-        ),
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
     )
 
     parser.add_argument(
@@ -143,7 +139,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     add_model_arguments(parser)
@@ -179,12 +176,13 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg:
                 raise ValueError(
@@ -207,12 +205,13 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg + 1
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg + 1]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg + 1:
                 raise ValueError(
@@ -240,7 +239,7 @@ def main():
             filename_start = f"{params.exp_dir}/epoch-{start}.pt"
             filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
             logging.info(
-                "Calculating the averaged model over epoch range from "
+                f"Calculating the averaged model over epoch range from "
                 f"{start} (excluded) to {params.epoch}"
             )
             model.to(device)
@@ -278,7 +277,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py
index 198242129..dbe213b24 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py
@@ -84,11 +84,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -117,12 +115,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -169,7 +165,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -200,9 +197,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -265,11 +263,15 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=features, x_lens=feature_lengths
+    )
 
     num_waves = encoder_out.size(0)
     hyps = []
@@ -365,7 +367,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py
index 676e8c904..ca35eba45 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py
@@ -86,7 +86,9 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+LRSchedulerType = Union[
+    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
+]
 
 
 def add_model_arguments(parser: argparse.ArgumentParser):
@@ -212,7 +214,8 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need to be changed.",
+        help="The initial learning rate.  This value should not need "
+        "to be changed.",
     )
 
     parser.add_argument(
@@ -235,45 +238,42 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help=(
-            "The prune range for rnnt loss, it means how many symbols(context)"
-            "we are using to compute the loss"
-        ),
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help=(
-            "The scale to smooth the loss with lm (output of prediction network) part."
-        ),
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)part.",
+        help="The scale to smooth the loss with am (output of encoder network)"
+        "part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help=(
-            "To get pruning ranges, we will calculate a simple version"
-            "loss(joiner is just addition), this simple loss also uses for"
-            "training (as a regularization item). We will scale the simple loss"
-            "with this parameter before adding to the final loss."
-        ),
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
     )
 
     parser.add_argument(
@@ -600,7 +600,11 @@ def compute_loss(
      warmup: a floating point value which increases throughout training;
         values >= 1.0 are fully warmed up and have all modules present.
     """
-    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+    device = (
+        model.device
+        if isinstance(model, DDP)
+        else next(model.parameters()).device
+    )
     feature = batch["inputs"]
     # at entry, feature is (N, T, C)
     assert feature.ndim == 3
@@ -630,15 +634,22 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+            0.0
+            if warmup < 1.0
+            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+        )
+        loss = (
+            params.simple_loss_scale * simple_loss
+            + pruned_loss_scale * pruned_loss
         )
-        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -817,7 +828,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -931,7 +944,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2**22
+            2 ** 22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
diff --git a/egs/tedlium3/ASR/local/compute_fbank_tedlium.py b/egs/tedlium3/ASR/local/compute_fbank_tedlium.py
index 733ebf235..327962a79 100755
--- a/egs/tedlium3/ASR/local/compute_fbank_tedlium.py
+++ b/egs/tedlium3/ASR/local/compute_fbank_tedlium.py
@@ -83,7 +83,9 @@ def compute_fbank_tedlium():
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+                    cut_set
+                    + cut_set.perturb_speed(0.9)
+                    + cut_set.perturb_speed(1.1)
                 )
             cur_num_jobs = num_jobs if ex is None else 80
             cur_num_jobs = min(cur_num_jobs, len(cut_set))
@@ -102,7 +104,9 @@ def compute_fbank_tedlium():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py b/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py
index 9dbcc9d9e..49544ccb3 100644
--- a/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py
+++ b/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py
@@ -25,7 +25,9 @@ import sentencepiece as spm
 
 def get_args():
     parser = argparse.ArgumentParser()
-    parser.add_argument("--texts", type=List[str], help="The input transcripts list.")
+    parser.add_argument(
+        "--texts", type=List[str], help="The input transcripts list."
+    )
     parser.add_argument(
         "--bpe-model",
         type=str,
diff --git a/egs/tedlium3/ASR/local/prepare_lexicon.py b/egs/tedlium3/ASR/local/prepare_lexicon.py
index b9160b6d4..35dd332e8 100755
--- a/egs/tedlium3/ASR/local/prepare_lexicon.py
+++ b/egs/tedlium3/ASR/local/prepare_lexicon.py
@@ -23,12 +23,11 @@ consisting of supervisions_train.json and does the following:
 1. Generate lexicon_words.txt.
 
 """
+import lhotse
 import argparse
 import logging
 from pathlib import Path
 
-import lhotse
-
 
 def get_args():
     parser = argparse.ArgumentParser()
@@ -62,7 +61,9 @@ def prepare_lexicon(manifests_dir: str, lang_dir: str):
     words = set()
 
     lexicon = Path(lang_dir) / "lexicon_words.txt"
-    sups = lhotse.load_manifest(f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz")
+    sups = lhotse.load_manifest(
+        f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz"
+    )
     for s in sups:
         # list the words units and filter the empty item
         words_list = list(filter(None, s.text.split()))
@@ -87,7 +88,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/tedlium3/ASR/local/prepare_transcripts.py b/egs/tedlium3/ASR/local/prepare_transcripts.py
index 7ea4e89a4..1039ac5bb 100755
--- a/egs/tedlium3/ASR/local/prepare_transcripts.py
+++ b/egs/tedlium3/ASR/local/prepare_transcripts.py
@@ -23,12 +23,11 @@ consisting of supervisions_train.json and does the following:
 1. Generate train.text.
 
 """
+import lhotse
 import argparse
 import logging
 from pathlib import Path
 
-import lhotse
-
 
 def get_args():
     parser = argparse.ArgumentParser()
@@ -62,7 +61,9 @@ def prepare_transcripts(manifests_dir: str, lang_dir: str):
     texts = []
 
     train_text = Path(lang_dir) / "train.text"
-    sups = lhotse.load_manifest(f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz")
+    sups = lhotse.load_manifest(
+        f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz"
+    )
     for s in sups:
         texts.append(s.text)
 
@@ -82,7 +83,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py
index 6bae33e65..2b294e601 100755
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py
@@ -94,20 +94,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=29,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=13,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -175,7 +172,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -233,7 +231,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -248,7 +248,10 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -294,7 +297,11 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -367,7 +374,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -400,7 +409,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/export.py b/egs/tedlium3/ASR/pruned_transducer_stateless/export.py
index 244740932..a1c3bcea3 100644
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/export.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/export.py
@@ -65,20 +65,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=30,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=13,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -109,7 +106,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     return parser
@@ -181,7 +179,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py
index 00545f107..8480ac029 100644
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py
@@ -93,11 +93,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -124,12 +122,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -169,7 +165,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -206,9 +203,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -273,7 +271,9 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -298,7 +298,10 @@ def main():
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -350,7 +353,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/train.py b/egs/tedlium3/ASR/pruned_transducer_stateless/train.py
index 70c5e290f..8d5cdf683 100755
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/train.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/train.py
@@ -133,45 +133,42 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help=(
-            "The prune range for rnnt loss, it means how many symbols(context)"
-            "we are using to compute the loss"
-        ),
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help=(
-            "The scale to smooth the loss with lm (output of prediction network) part."
-        ),
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)part.",
+        help="The scale to smooth the loss with am (output of encoder network)"
+        "part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help=(
-            "To get pruning ranges, we will calculate a simple version"
-            "loss(joiner is just addition), this simple loss also uses for"
-            "training (as a regularization item). We will scale the simple loss"
-            "with this parameter before adding to the final loss."
-        ),
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
     )
 
     parser.add_argument(
@@ -559,7 +556,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -679,7 +678,9 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
+            tb_writer.add_scalar(
+                "train/learning_rate", cur_lr, params.batch_idx_train
+            )
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py
index f90f79d8c..94784c4c4 100644
--- a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py
+++ b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py
@@ -18,6 +18,7 @@
 
 import argparse
 import logging
+
 from functools import lru_cache
 from pathlib import Path
 from typing import Any, Dict, Optional
@@ -62,12 +63,10 @@ class TedLiumAsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description=(
-                "These options are used for the preparation of "
-                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-                "effective batch sizes, sampling strategies, applied data "
-                "augmentations, etc."
-            ),
+            description="These options are used for the preparation of "
+            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+            "effective batch sizes, sampling strategies, applied data "
+            "augmentations, etc.",
         )
         group.add_argument(
             "--manifest-dir",
@@ -79,90 +78,74 @@ class TedLiumAsrDataModule:
             "--max-duration",
             type=int,
             default=200.0,
-            help=(
-                "Maximum pooled recordings duration (seconds) in a "
-                "single batch. You can reduce it if it causes CUDA OOM."
-            ),
+            help="Maximum pooled recordings duration (seconds) in a "
+            "single batch. You can reduce it if it causes CUDA OOM.",
         )
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, the batches will come from buckets of "
-                "similar duration (saves padding frames)."
-            ),
+            help="When enabled, the batches will come from buckets of "
+            "similar duration (saves padding frames).",
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=30,
-            help=(
-                "The number of buckets for the DynamicBucketingSampler"
-                "(you might want to increase it for larger datasets)."
-            ),
+            help="The number of buckets for the DynamicBucketingSampler"
+            "(you might want to increase it for larger datasets).",
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, utterances (cuts) will be concatenated "
-                "to minimize the amount of padding."
-            ),
+            help="When enabled, utterances (cuts) will be concatenated "
+            "to minimize the amount of padding.",
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help=(
-                "Determines the maximum duration of a concatenated cut "
-                "relative to the duration of the longest cut in a batch."
-            ),
+            help="Determines the maximum duration of a concatenated cut "
+            "relative to the duration of the longest cut in a batch.",
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help=(
-                "The amount of padding (in seconds) inserted between "
-                "concatenated cuts. This padding is filled with noise when "
-                "noise augmentation is used."
-            ),
+            help="The amount of padding (in seconds) inserted between "
+            "concatenated cuts. This padding is filled with noise when "
+            "noise augmentation is used.",
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, use on-the-fly cut mixing and feature "
-                "extraction. Will drop existing precomputed feature manifests "
-                "if available."
-            ),
+            help="When enabled, use on-the-fly cut mixing and feature "
+            "extraction. Will drop existing precomputed feature manifests "
+            "if available.",
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled (=default), the examples will be shuffled for each epoch."
-            ),
+            help="When enabled (=default), the examples will be "
+            "shuffled for each epoch.",
         )
         group.add_argument(
             "--return-cuts",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, each batch will have the "
-                "field: batch['supervisions']['cut'] with the cuts that "
-                "were used to construct it."
-            ),
+            help="When enabled, each batch will have the "
+            "field: batch['supervisions']['cut'] with the cuts that "
+            "were used to construct it.",
         )
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that collect the batches.",
+            help="The number of training dataloader workers that "
+            "collect the batches.",
         )
         group.add_argument(
             "--enable-spec-aug",
@@ -174,25 +157,23 @@ class TedLiumAsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help=(
-                "Used only when --enable-spec-aug is True. "
-                "It specifies the factor for time warping in SpecAugment. "
-                "Larger values mean more warping. "
-                "A value less than 1 means to disable time warp."
-            ),
+            help="Used only when --enable-spec-aug is True. "
+            "It specifies the factor for time warping in SpecAugment. "
+            "Larger values mean more warping. "
+            "A value less than 1 means to disable time warp.",
         )
         group.add_argument(
             "--enable-musan",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, select noise from MUSAN and mix it"
-                "with training dataset. "
-            ),
+            help="When enabled, select noise from MUSAN and mix it"
+            "with training dataset.",
         )
 
     def train_dataloaders(
-        self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None
+        self,
+        cuts_train: CutSet,
+        sampler_state_dict: Optional[Dict[str, Any]] = None
     ) -> DataLoader:
         """
         Args:
@@ -205,7 +186,9 @@ class TedLiumAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+            logging.info(
+                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
+            )
 
             input_transforms.append(
                 SpecAugment(
@@ -225,16 +208,20 @@ class TedLiumAsrDataModule:
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
-            cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+            cuts_musan = load_manifest(
+                self.args.manifest_dir / "musan_cuts.jsonl.gz"
+            )
             transforms.append(
-                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+                CutMix(
+                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
+                )
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                "Using cut concatenation with duration factor "
+                f"Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -260,7 +247,9 @@ class TedLiumAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -317,7 +306,9 @@ class TedLiumAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -348,7 +339,9 @@ class TedLiumAsrDataModule:
         logging.debug("About to create test dataset")
         if self.args.on_the_fly_feats:
             test = K2SpeechRecognitionDataset(
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -382,9 +375,13 @@ class TedLiumAsrDataModule:
     @lru_cache()
     def dev_cuts(self) -> CutSet:
         logging.info("About to get dev cuts")
-        return load_manifest_lazy(self.args.manifest_dir / "tedlium_cuts_dev.jsonl.gz")
+        return load_manifest_lazy(
+            self.args.manifest_dir / "tedlium_cuts_dev.jsonl.gz"
+        )
 
     @lru_cache()
     def test_cuts(self) -> CutSet:
         logging.info("About to get test cuts")
-        return load_manifest_lazy(self.args.manifest_dir / "tedlium_cuts_test.jsonl.gz")
+        return load_manifest_lazy(
+            self.args.manifest_dir / "tedlium_cuts_test.jsonl.gz"
+        )
diff --git a/egs/tedlium3/ASR/transducer_stateless/beam_search.py b/egs/tedlium3/ASR/transducer_stateless/beam_search.py
index 1f99edaf3..77caf6460 100644
--- a/egs/tedlium3/ASR/transducer_stateless/beam_search.py
+++ b/egs/tedlium3/ASR/transducer_stateless/beam_search.py
@@ -87,9 +87,9 @@ def greedy_search(
         y = logits.argmax().item()
         if y != blank_id and y != unk_id:
             hyp.append(y)
-            decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape(
-                1, context_size
-            )
+            decoder_input = torch.tensor(
+                [hyp[-context_size:]], device=device
+            ).reshape(1, context_size)
 
             decoder_out = model.decoder(decoder_input, need_pad=False)
 
@@ -148,7 +148,9 @@ class HypothesisList(object):
         key = hyp.key
         if key in self:
             old_hyp = self._data[key]  # shallow copy
-            torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob)
+            torch.logaddexp(
+                old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
+            )
         else:
             self._data[key] = hyp
 
@@ -164,7 +166,9 @@ class HypothesisList(object):
           Return the hypothesis that has the largest `log_prob`.
         """
         if length_norm:
-            return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
+            return max(
+                self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
+            )
         else:
             return max(self._data.values(), key=lambda hyp: hyp.log_prob)
 
@@ -340,9 +344,9 @@ def modified_beam_search(
 
     device = model.device
 
-    decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
-        1, context_size
-    )
+    decoder_input = torch.tensor(
+        [blank_id] * context_size, device=device
+    ).reshape(1, context_size)
 
     decoder_out = model.decoder(decoder_input, need_pad=False)
 
@@ -379,7 +383,9 @@ def modified_beam_search(
         decoder_out = model.decoder(decoder_input, need_pad=False)
         # decoder_output is of shape (num_hyps, 1, decoder_output_dim)
 
-        current_encoder_out = current_encoder_out.expand(decoder_out.size(0), 1, -1)
+        current_encoder_out = current_encoder_out.expand(
+            decoder_out.size(0), 1, -1
+        )
 
         logits = model.joiner(
             current_encoder_out,
@@ -448,9 +454,9 @@ def beam_search(
 
     device = model.device
 
-    decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
-        1, context_size
-    )
+    decoder_input = torch.tensor(
+        [blank_id] * context_size, device=device
+    ).reshape(1, context_size)
 
     decoder_out = model.decoder(decoder_input, need_pad=False)
 
diff --git a/egs/tedlium3/ASR/transducer_stateless/decode.py b/egs/tedlium3/ASR/transducer_stateless/decode.py
index 12d0e2652..d3e9e55e7 100755
--- a/egs/tedlium3/ASR/transducer_stateless/decode.py
+++ b/egs/tedlium3/ASR/transducer_stateless/decode.py
@@ -81,19 +81,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=29,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=13,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -133,7 +130,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -252,7 +250,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
     hyps = []
     batch_size = encoder_out.size(0)
 
@@ -275,7 +275,9 @@ def decode_one_batch(
                 model=model, encoder_out=encoder_out_i, beam=params.beam_size
             )
         else:
-            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
+            raise ValueError(
+                f"Unsupported decoding method: {params.decoding_method}"
+            )
         hyps.append(sp.decode(hyp).split())
 
     if params.decoding_method == "greedy_search":
@@ -346,7 +348,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -379,7 +383,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
diff --git a/egs/tedlium3/ASR/transducer_stateless/decoder.py b/egs/tedlium3/ASR/transducer_stateless/decoder.py
index f9a3814c6..f0c6f32b6 100644
--- a/egs/tedlium3/ASR/transducer_stateless/decoder.py
+++ b/egs/tedlium3/ASR/transducer_stateless/decoder.py
@@ -90,7 +90,9 @@ class Decoder(nn.Module):
         if self.context_size > 1:
             embedding_out = embedding_out.permute(0, 2, 1)
             if need_pad is True:
-                embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
+                embedding_out = F.pad(
+                    embedding_out, pad=(self.context_size - 1, 0)
+                )
             else:
                 # During inference time, there is no need to do extra padding
                 # as we only need one output
diff --git a/egs/tedlium3/ASR/transducer_stateless/export.py b/egs/tedlium3/ASR/transducer_stateless/export.py
index 0b2ae970b..c32b1d002 100644
--- a/egs/tedlium3/ASR/transducer_stateless/export.py
+++ b/egs/tedlium3/ASR/transducer_stateless/export.py
@@ -69,20 +69,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=20,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=10,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -113,7 +110,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     return parser
@@ -249,7 +247,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tedlium3/ASR/transducer_stateless/pretrained.py b/egs/tedlium3/ASR/transducer_stateless/pretrained.py
index 912d65497..c0e3bb844 100644
--- a/egs/tedlium3/ASR/transducer_stateless/pretrained.py
+++ b/egs/tedlium3/ASR/transducer_stateless/pretrained.py
@@ -82,11 +82,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -112,12 +110,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -131,7 +127,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -225,9 +222,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -287,7 +285,9 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -335,7 +335,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tedlium3/ASR/transducer_stateless/train.py b/egs/tedlium3/ASR/transducer_stateless/train.py
index 6fed32e81..09cbf4a00 100755
--- a/egs/tedlium3/ASR/transducer_stateless/train.py
+++ b/egs/tedlium3/ASR/transducer_stateless/train.py
@@ -133,7 +133,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -524,7 +525,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -644,7 +647,9 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
+            tb_writer.add_scalar(
+                "train/learning_rate", cur_lr, params.batch_idx_train
+            )
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/timit/ASR/RESULTS.md b/egs/timit/ASR/RESULTS.md
index d8ceb82b6..b78c16b88 100644
--- a/egs/timit/ASR/RESULTS.md
+++ b/egs/timit/ASR/RESULTS.md
@@ -71,4 +71,4 @@ python tdnn_ligru_ctc/decode.py --epoch 25 \
                                --avg 17 \
                                --max-duration 20 \
                                --lang-dir data/lang_phone
-```
+```
\ No newline at end of file
diff --git a/egs/timit/ASR/local/compile_hlg.py b/egs/timit/ASR/local/compile_hlg.py
index 32c248d7e..58cab4cf2 100644
--- a/egs/timit/ASR/local/compile_hlg.py
+++ b/egs/timit/ASR/local/compile_hlg.py
@@ -146,7 +146,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/timit/ASR/local/compute_fbank_timit.py b/egs/timit/ASR/local/compute_fbank_timit.py
index ecdf10ba9..f25786a0c 100644
--- a/egs/timit/ASR/local/compute_fbank_timit.py
+++ b/egs/timit/ASR/local/compute_fbank_timit.py
@@ -85,7 +85,9 @@ def compute_fbank_timit():
             )
             if partition == "TRAIN":
                 cut_set = (
-                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+                    cut_set
+                    + cut_set.perturb_speed(0.9)
+                    + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -99,7 +101,9 @@ def compute_fbank_timit():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/timit/ASR/local/prepare_lexicon.py b/egs/timit/ASR/local/prepare_lexicon.py
index 0cf0f0deb..04023a9ab 100644
--- a/egs/timit/ASR/local/prepare_lexicon.py
+++ b/egs/timit/ASR/local/prepare_lexicon.py
@@ -62,7 +62,9 @@ def prepare_lexicon(manifests_dir: str, lang_dir: str):
 
     phones = set()
 
-    supervisions_train = Path(manifests_dir) / "timit_supervisions_TRAIN.jsonl.gz"
+    supervisions_train = (
+        Path(manifests_dir) / "timit_supervisions_TRAIN.jsonl.gz"
+    )
     lexicon = Path(lang_dir) / "lexicon.txt"
 
     logging.info(f"Loading {supervisions_train}!")
@@ -95,7 +97,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/timit/ASR/prepare.sh b/egs/timit/ASR/prepare.sh
index d11cd3a05..ae1b96a68 100644
--- a/egs/timit/ASR/prepare.sh
+++ b/egs/timit/ASR/prepare.sh
@@ -20,9 +20,9 @@ stop_stage=100
 #  - $dl_dir/lm
 #      This directory contains the language model(LM) downloaded from
 #      https://huggingface.co/luomingshuang/timit_lm, and the LM is based
-#	     on 39 phones. About how to get these LM files, you can know it
+#	     on 39 phones. About how to get these LM files, you can know it 
 #      from https://github.com/luomingshuang/Train_LM_with_kaldilm.
-#
+#	
 #	    - lm_3_gram.arpa
 #     - lm_4_gram.arpa
 #
diff --git a/egs/timit/ASR/tdnn_ligru_ctc/decode.py b/egs/timit/ASR/tdnn_ligru_ctc/decode.py
index 5a59a13ce..4f2aa2340 100644
--- a/egs/timit/ASR/tdnn_ligru_ctc/decode.py
+++ b/egs/timit/ASR/tdnn_ligru_ctc/decode.py
@@ -57,19 +57,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=19,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=5,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
     parser.add_argument(
         "--method",
@@ -339,7 +336,9 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -401,7 +400,9 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
+    HLG = k2.Fsa.from_dict(
+        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
+    )
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -461,7 +462,9 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
+        torch.save(
+            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
+        )
         return
 
     model.to(device)
@@ -482,7 +485,9 @@ def main():
         G=G,
     )
 
-    save_results(params=params, test_set_name=test_set, results_dict=results_dict)
+    save_results(
+        params=params, test_set_name=test_set, results_dict=results_dict
+    )
 
     logging.info("Done!")
 
diff --git a/egs/timit/ASR/tdnn_ligru_ctc/model.py b/egs/timit/ASR/tdnn_ligru_ctc/model.py
index 9a594a969..4d2199ace 100644
--- a/egs/timit/ASR/tdnn_ligru_ctc/model.py
+++ b/egs/timit/ASR/tdnn_ligru_ctc/model.py
@@ -16,11 +16,11 @@
 # limitations under the License.
 
 
-from typing import Optional
-
 import torch
 import torch.nn as nn
+
 from torch import Tensor
+from typing import Optional
 
 
 class TdnnLiGRU(nn.Module):
@@ -261,7 +261,9 @@ class LiGRU(torch.nn.Module):
         h = []
         if hx is not None:
             if self.bidirectional:
-                hx = hx.reshape(self.num_layers, self.batch_size * 2, self.hidden_size)
+                hx = hx.reshape(
+                    self.num_layers, self.batch_size * 2, self.hidden_size
+                )
         # Processing the different layers
         for i, ligru_lay in enumerate(self.rnn):
             if hx is not None:
@@ -443,7 +445,9 @@ class LiGRU_Layer(torch.nn.Module):
             if self.drop_mask_cnt + self.batch_size > self.N_drop_masks:
                 self.drop_mask_cnt = 0
                 self.drop_masks = self.drop(
-                    torch.ones(self.N_drop_masks, self.hidden_size, device=w.device)
+                    torch.ones(
+                        self.N_drop_masks, self.hidden_size, device=w.device
+                    )
                 ).data
 
             # Sampling the mask
diff --git a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py b/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py
index da669bc39..7da285944 100644
--- a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py
+++ b/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py
@@ -29,7 +29,11 @@ import torchaudio
 from model import TdnnLiGRU
 from torch.nn.utils.rnn import pad_sequence
 
-from icefall.decode import get_lattice, one_best_decoding, rescore_with_whole_lattice
+from icefall.decode import (
+    get_lattice,
+    one_best_decoding,
+    rescore_with_whole_lattice,
+)
 from icefall.utils import AttributeDict, get_texts
 
 
@@ -42,11 +46,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -56,7 +58,9 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
+    parser.add_argument(
+        "--HLG", type=str, required=True, help="Path to HLG.pt."
+    )
 
     parser.add_argument(
         "--method",
@@ -99,12 +103,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     return parser
@@ -142,9 +144,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -212,7 +215,9 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
     features = features.permute(0, 2, 1)  # now features is (N, C, T)
 
     with torch.no_grad():
@@ -264,7 +269,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/timit/ASR/tdnn_ligru_ctc/train.py b/egs/timit/ASR/tdnn_ligru_ctc/train.py
index 48b7feda0..452c2a7cb 100644
--- a/egs/timit/ASR/tdnn_ligru_ctc/train.py
+++ b/egs/timit/ASR/tdnn_ligru_ctc/train.py
@@ -449,7 +449,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py
index d957c22e1..1554e987f 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -63,12 +63,10 @@ class TimitAsrDataModule(DataModule):
         super().add_arguments(parser)
         group = parser.add_argument_group(
             title="ASR data related options",
-            description=(
-                "These options are used for the preparation of "
-                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-                "effective batch sizes, sampling strategies, applied data "
-                "augmentations, etc."
-            ),
+            description="These options are used for the preparation of "
+            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+            "effective batch sizes, sampling strategies, applied data "
+            "augmentations, etc.",
         )
         group.add_argument(
             "--feature-dir",
@@ -80,91 +78,75 @@ class TimitAsrDataModule(DataModule):
             "--max-duration",
             type=int,
             default=200.0,
-            help=(
-                "Maximum pooled recordings duration (seconds) in a "
-                "single batch. You can reduce it if it causes CUDA OOM."
-            ),
+            help="Maximum pooled recordings duration (seconds) in a "
+            "single batch. You can reduce it if it causes CUDA OOM.",
         )
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, the batches will come from buckets of "
-                "similar duration (saves padding frames)."
-            ),
+            help="When enabled, the batches will come from buckets of "
+            "similar duration (saves padding frames).",
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=30,
-            help=(
-                "The number of buckets for the DynamicBucketingSampler"
-                "(you might want to increase it for larger datasets)."
-            ),
+            help="The number of buckets for the DynamicBucketingSampler"
+            "(you might want to increase it for larger datasets).",
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, utterances (cuts) will be concatenated "
-                "to minimize the amount of padding."
-            ),
+            help="When enabled, utterances (cuts) will be concatenated "
+            "to minimize the amount of padding.",
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help=(
-                "Determines the maximum duration of a concatenated cut "
-                "relative to the duration of the longest cut in a batch."
-            ),
+            help="Determines the maximum duration of a concatenated cut "
+            "relative to the duration of the longest cut in a batch.",
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help=(
-                "The amount of padding (in seconds) inserted between "
-                "concatenated cuts. This padding is filled with noise when "
-                "noise augmentation is used."
-            ),
+            help="The amount of padding (in seconds) inserted between "
+            "concatenated cuts. This padding is filled with noise when "
+            "noise augmentation is used.",
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, use on-the-fly cut mixing and feature "
-                "extraction. Will drop existing precomputed feature manifests "
-                "if available."
-            ),
+            help="When enabled, use on-the-fly cut mixing and feature "
+            "extraction. Will drop existing precomputed feature manifests "
+            "if available.",
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled (=default), the examples will be shuffled for each epoch."
-            ),
+            help="When enabled (=default), the examples will be "
+            "shuffled for each epoch.",
         )
         group.add_argument(
             "--return-cuts",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, each batch will have the "
-                "field: batch['supervisions']['cut'] with the cuts that "
-                "were used to construct it."
-            ),
+            help="When enabled, each batch will have the "
+            "field: batch['supervisions']['cut'] with the cuts that "
+            "were used to construct it.",
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that collect the batches.",
+            help="The number of training dataloader workers that "
+            "collect the batches.",
         )
 
     def train_dataloaders(self) -> DataLoader:
@@ -172,13 +154,15 @@ class TimitAsrDataModule(DataModule):
         cuts_train = self.train_cuts()
 
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(self.args.feature_dir / "musan_cuts.jsonl.gz")
+        cuts_musan = load_manifest(
+            self.args.feature_dir / "musan_cuts.jsonl.gz"
+        )
 
         logging.info("About to create train dataset")
         transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))]
         if self.args.concatenate_cuts:
             logging.info(
-                "Using cut concatenation with duration factor "
+                f"Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -194,9 +178,9 @@ class TimitAsrDataModule(DataModule):
         # In different Lhotse's versions, the default of num_frame_masks is
         # different.
         num_frame_masks = 10
-        num_frame_masks_parameter = inspect.signature(SpecAugment.__init__).parameters[
-            "num_frame_masks"
-        ]
+        num_frame_masks_parameter = inspect.signature(
+            SpecAugment.__init__
+        ).parameters["num_frame_masks"]
         if num_frame_masks_parameter.default == 1:
             num_frame_masks = 2
         logging.info(f"Num frame mask: {num_frame_masks}")
@@ -228,7 +212,9 @@ class TimitAsrDataModule(DataModule):
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -277,7 +263,9 @@ class TimitAsrDataModule(DataModule):
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -311,14 +299,20 @@ class TimitAsrDataModule(DataModule):
         for cuts_test in cuts:
             logging.debug("About to create test dataset")
             test = K2SpeechRecognitionDataset(
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                )
                 if self.args.on_the_fly_feats
                 else PrecomputedFeatures(),
                 return_cuts=self.args.return_cuts,
             )
-            sampler = SingleCutSampler(cuts_test, max_duration=self.args.max_duration)
+            sampler = SingleCutSampler(
+                cuts_test, max_duration=self.args.max_duration
+            )
             logging.debug("About to create test dataloader")
-            test_dl = DataLoader(test, batch_size=None, sampler=sampler, num_workers=1)
+            test_dl = DataLoader(
+                test, batch_size=None, sampler=sampler, num_workers=1
+            )
             test_loaders.append(test_dl)
 
         if is_list:
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/decode.py b/egs/timit/ASR/tdnn_lstm_ctc/decode.py
index 319ee5515..5e7300cf2 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/decode.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/decode.py
@@ -56,19 +56,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=25,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=5,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
     parser.add_argument(
         "--method",
@@ -338,7 +335,9 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -400,7 +399,9 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
+    HLG = k2.Fsa.from_dict(
+        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
+    )
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -460,7 +461,9 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
+        torch.save(
+            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
+        )
         return
 
     model.to(device)
@@ -480,7 +483,9 @@ def main():
         G=G,
     )
 
-    save_results(params=params, test_set_name=test_set, results_dict=results_dict)
+    save_results(
+        params=params, test_set_name=test_set, results_dict=results_dict
+    )
 
     logging.info("Done!")
 
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/model.py b/egs/timit/ASR/tdnn_lstm_ctc/model.py
index e211ad80d..51edb97e2 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/model.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/model.py
@@ -74,7 +74,10 @@ class TdnnLstm(nn.Module):
             nn.BatchNorm1d(num_features=512, affine=False),
         )
         self.lstms = nn.ModuleList(
-            [nn.LSTM(input_size=512, hidden_size=512, num_layers=1) for _ in range(4)]
+            [
+                nn.LSTM(input_size=512, hidden_size=512, num_layers=1)
+                for _ in range(4)
+            ]
         )
         self.lstm_bnorms = nn.ModuleList(
             [nn.BatchNorm1d(num_features=512, affine=False) for _ in range(5)]
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py b/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py
index 0c72c973b..5f478da1c 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py
@@ -29,7 +29,11 @@ import torchaudio
 from model import TdnnLstm
 from torch.nn.utils.rnn import pad_sequence
 
-from icefall.decode import get_lattice, one_best_decoding, rescore_with_whole_lattice
+from icefall.decode import (
+    get_lattice,
+    one_best_decoding,
+    rescore_with_whole_lattice,
+)
 from icefall.utils import AttributeDict, get_texts
 
 
@@ -42,11 +46,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -56,7 +58,9 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
+    parser.add_argument(
+        "--HLG", type=str, required=True, help="Path to HLG.pt."
+    )
 
     parser.add_argument(
         "--method",
@@ -99,12 +103,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     return parser
@@ -142,9 +144,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -212,7 +215,9 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
     features = features.permute(0, 2, 1)  # now features is (N, C, T)
 
     with torch.no_grad():
@@ -264,7 +269,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/train.py b/egs/timit/ASR/tdnn_lstm_ctc/train.py
index be1ecffaa..849256b98 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/train.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/train.py
@@ -449,7 +449,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py
index bd73e520e..8a9f6ed30 100755
--- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py
+++ b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py
@@ -20,7 +20,12 @@ import logging
 from pathlib import Path
 
 import torch
-from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomHdf5Writer
+from lhotse import (
+    CutSet,
+    KaldifeatFbank,
+    KaldifeatFbankConfig,
+    LilcomHdf5Writer,
+)
 
 # Torch's multithreaded behavior needs to be disabled or
 # it wastes a lot of CPU and slow things down.
@@ -78,7 +83,9 @@ def compute_fbank_wenetspeech_dev_test():
 
 
 def main():
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
     logging.basicConfig(format=formatter, level=logging.INFO)
 
     compute_fbank_wenetspeech_dev_test()
diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py
index c228597b8..a882b6113 100755
--- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py
+++ b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py
@@ -62,10 +62,8 @@ def get_parser():
         "--batch-duration",
         type=float,
         default=600.0,
-        help=(
-            "The maximum number of audio seconds in a batch."
-            "Determines batch size dynamically."
-        ),
+        help="The maximum number of audio seconds in a batch."
+        "Determines batch size dynamically.",
     )
 
     parser.add_argument(
@@ -154,7 +152,9 @@ def main():
     date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
 
     log_filename = "log-compute_fbank_wenetspeech_splits"
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
     log_filename = f"{log_filename}-{date_time}"
 
     logging.basicConfig(
diff --git a/egs/wenetspeech/ASR/local/prepare_char.py b/egs/wenetspeech/ASR/local/prepare_char.py
index d8622842f..8bc073c75 100755
--- a/egs/wenetspeech/ASR/local/prepare_char.py
+++ b/egs/wenetspeech/ASR/local/prepare_char.py
@@ -83,7 +83,9 @@ def lexicon_to_fst_no_sil(
         cur_state = loop_state
 
         word = word2id[word]
-        pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces]
+        pieces = [
+            token2id[i] if i in token2id else token2id[""] for i in pieces
+        ]
 
         for i in range(len(pieces) - 1):
             w = word if i == 0 else eps
@@ -136,7 +138,9 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
     return False
 
 
-def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
+def generate_lexicon(
+    token_sym_table: Dict[str, int], words: List[str]
+) -> Lexicon:
     """Generate a lexicon from a word list and token_sym_table.
     Args:
       token_sym_table:
diff --git a/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py b/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py
index 93ce750f8..817969c47 100755
--- a/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py
+++ b/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py
@@ -115,7 +115,11 @@ def preprocess_wenet_speech():
                 f"Speed perturb for {partition} with factors 0.9 and 1.1 "
                 "(Perturbing may take 8 minutes and saving may take 20 minutes)"
             )
-            cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+            cut_set = (
+                cut_set
+                + cut_set.perturb_speed(0.9)
+                + cut_set.perturb_speed(1.1)
+            )
         logging.info(f"Saving to {raw_cuts_path}")
         cut_set.to_file(raw_cuts_path)
 
diff --git a/egs/wenetspeech/ASR/local/text2token.py b/egs/wenetspeech/ASR/local/text2token.py
index e121d842c..1c463cf1c 100755
--- a/egs/wenetspeech/ASR/local/text2token.py
+++ b/egs/wenetspeech/ASR/local/text2token.py
@@ -50,15 +50,15 @@ def get_parser():
         "-n",
         default=1,
         type=int,
-        help=(
-            "number of characters to split, i.e.,                         aabb -> a a b"
-            " b with -n 1 and aa bb with -n 2"
-        ),
+        help="number of characters to split, i.e., \
+                        aabb -> a a b b with -n 1 and aa bb with -n 2",
     )
     parser.add_argument(
         "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
     )
-    parser.add_argument("--space", default="", type=str, help="space symbol")
+    parser.add_argument(
+        "--space", default="", type=str, help="space symbol"
+    )
     parser.add_argument(
         "--non-lang-syms",
         "-l",
@@ -66,7 +66,9 @@ def get_parser():
         type=str,
         help="list of non-linguistic symobles, e.g.,  etc.",
     )
-    parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
+    parser.add_argument(
+        "text", type=str, default=False, nargs="?", help="input text"
+    )
     parser.add_argument(
         "--trans_type",
         "-t",
@@ -106,7 +108,8 @@ def token2id(
             if token_type == "lazy_pinyin":
                 text = lazy_pinyin(chars_list)
                 sub_ids = [
-                    token_table[txt] if txt in token_table else oov_id for txt in text
+                    token_table[txt] if txt in token_table else oov_id
+                    for txt in text
                 ]
                 ids.append(sub_ids)
             else:  # token_type = "pinyin"
@@ -132,7 +135,9 @@ def main():
     if args.text:
         f = codecs.open(args.text, encoding="utf-8")
     else:
-        f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
+        f = codecs.getreader("utf-8")(
+            sys.stdin if is_python2 else sys.stdin.buffer
+        )
 
     sys.stdout = codecs.getwriter("utf-8")(
         sys.stdout if is_python2 else sys.stdout.buffer
diff --git a/egs/wenetspeech/ASR/prepare.sh b/egs/wenetspeech/ASR/prepare.sh
index da7d7e061..755fbb2d7 100755
--- a/egs/wenetspeech/ASR/prepare.sh
+++ b/egs/wenetspeech/ASR/prepare.sh
@@ -190,7 +190,7 @@ if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then
   mkdir -p $lang_char_dir
 
   if ! which jq; then
-      echo "This script is intended to be used with jq but you have not installed jq
+      echo "This script is intended to be used with jq but you have not installed jq 
       Note: in Linux, you can install jq with the following command:
       1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
       2. chmod +x ./jq
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
index bd92ac115..10c953e3b 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -81,12 +81,10 @@ class WenetSpeechAsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description=(
-                "These options are used for the preparation of "
-                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-                "effective batch sizes, sampling strategies, applied data "
-                "augmentations, etc."
-            ),
+            description="These options are used for the preparation of "
+            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+            "effective batch sizes, sampling strategies, applied data "
+            "augmentations, etc.",
         )
         group.add_argument(
             "--manifest-dir",
@@ -98,91 +96,75 @@ class WenetSpeechAsrDataModule:
             "--max-duration",
             type=int,
             default=200.0,
-            help=(
-                "Maximum pooled recordings duration (seconds) in a "
-                "single batch. You can reduce it if it causes CUDA OOM."
-            ),
+            help="Maximum pooled recordings duration (seconds) in a "
+            "single batch. You can reduce it if it causes CUDA OOM.",
         )
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, the batches will come from buckets of "
-                "similar duration (saves padding frames)."
-            ),
+            help="When enabled, the batches will come from buckets of "
+            "similar duration (saves padding frames).",
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=300,
-            help=(
-                "The number of buckets for the DynamicBucketingSampler"
-                "(you might want to increase it for larger datasets)."
-            ),
+            help="The number of buckets for the DynamicBucketingSampler"
+            "(you might want to increase it for larger datasets).",
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, utterances (cuts) will be concatenated "
-                "to minimize the amount of padding."
-            ),
+            help="When enabled, utterances (cuts) will be concatenated "
+            "to minimize the amount of padding.",
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help=(
-                "Determines the maximum duration of a concatenated cut "
-                "relative to the duration of the longest cut in a batch."
-            ),
+            help="Determines the maximum duration of a concatenated cut "
+            "relative to the duration of the longest cut in a batch.",
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help=(
-                "The amount of padding (in seconds) inserted between "
-                "concatenated cuts. This padding is filled with noise when "
-                "noise augmentation is used."
-            ),
+            help="The amount of padding (in seconds) inserted between "
+            "concatenated cuts. This padding is filled with noise when "
+            "noise augmentation is used.",
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, use on-the-fly cut mixing and feature "
-                "extraction. Will drop existing precomputed feature manifests "
-                "if available."
-            ),
+            help="When enabled, use on-the-fly cut mixing and feature "
+            "extraction. Will drop existing precomputed feature manifests "
+            "if available.",
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled (=default), the examples will be shuffled for each epoch."
-            ),
+            help="When enabled (=default), the examples will be "
+            "shuffled for each epoch.",
         )
         group.add_argument(
             "--return-cuts",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, each batch will have the "
-                "field: batch['supervisions']['cut'] with the cuts that "
-                "were used to construct it."
-            ),
+            help="When enabled, each batch will have the "
+            "field: batch['supervisions']['cut'] with the cuts that "
+            "were used to construct it.",
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that collect the batches.",
+            help="The number of training dataloader workers that "
+            "collect the batches.",
         )
 
         group.add_argument(
@@ -196,22 +178,18 @@ class WenetSpeechAsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help=(
-                "Used only when --enable-spec-aug is True. "
-                "It specifies the factor for time warping in SpecAugment. "
-                "Larger values mean more warping. "
-                "A value less than 1 means to disable time warp."
-            ),
+            help="Used only when --enable-spec-aug is True. "
+            "It specifies the factor for time warping in SpecAugment. "
+            "Larger values mean more warping. "
+            "A value less than 1 means to disable time warp.",
         )
 
         group.add_argument(
             "--enable-musan",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, select noise from MUSAN and mix it"
-                "with training dataset. "
-            ),
+            help="When enabled, select noise from MUSAN and mix it"
+            "with training dataset. ",
         )
 
         group.add_argument(
@@ -234,20 +212,24 @@ class WenetSpeechAsrDataModule:
             The state dict for the training sampler.
         """
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+        cuts_musan = load_manifest(
+            self.args.manifest_dir / "musan_cuts.jsonl.gz"
+        )
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+                CutMix(
+                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
+                )
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                "Using cut concatenation with duration factor "
+                f"Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -262,7 +244,9 @@ class WenetSpeechAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+            logging.info(
+                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
+            )
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -305,7 +289,9 @@ class WenetSpeechAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -362,7 +348,9 @@ class WenetSpeechAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -426,7 +414,8 @@ class WenetSpeechAsrDataModule:
     def train_cuts(self) -> CutSet:
         logging.info("About to get train cuts")
         cuts_train = load_manifest_lazy(
-            self.args.manifest_dir / f"cuts_{self.args.training_subset}.jsonl.gz"
+            self.args.manifest_dir
+            / f"cuts_{self.args.training_subset}.jsonl.gz"
         )
         return cuts_train
 
@@ -438,9 +427,13 @@ class WenetSpeechAsrDataModule:
     @lru_cache()
     def test_net_cuts(self) -> List[CutSet]:
         logging.info("About to get TEST_NET cuts")
-        return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_NET.jsonl.gz")
+        return load_manifest_lazy(
+            self.args.manifest_dir / "cuts_TEST_NET.jsonl.gz"
+        )
 
     @lru_cache()
     def test_meeting_cuts(self) -> List[CutSet]:
         logging.info("About to get TEST_MEETING cuts")
-        return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz")
+        return load_manifest_lazy(
+            self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz"
+        )
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py
index 6e856248c..f0c9bebec 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py
@@ -114,7 +114,11 @@ from beam_search import (
 from train import get_params, get_transducer_model
 
 from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
-from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
+from icefall.checkpoint import (
+    average_checkpoints,
+    find_checkpoints,
+    load_checkpoint,
+)
 from icefall.lexicon import Lexicon
 from icefall.utils import (
     AttributeDict,
@@ -133,30 +137,25 @@ def get_parser():
         "--epoch",
         type=int,
         default=28,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--batch",
         type=int,
         default=None,
-        help=(
-            "It specifies the batch checkpoint to use for decoding."
-            "Note: Epoch counts from 0."
-        ),
+        help="It specifies the batch checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -253,7 +252,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -328,7 +328,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -387,7 +389,10 @@ def decode_one_batch(
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -433,7 +438,11 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -506,7 +515,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -539,7 +550,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -651,7 +663,9 @@ def main():
             )
             decoding_graph.scores *= params.ngram_lm_scale
         else:
-            decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+            decoding_graph = k2.trivial_graph(
+                params.vocab_size - 1, device=device
+            )
     else:
         decoding_graph = None
 
@@ -702,7 +716,8 @@ def main():
         )
 
     dev_shards = [
-        str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
+        str(path)
+        for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
     ]
     cuts_dev_webdataset = CutSet.from_webdataset(
         dev_shards,
@@ -712,7 +727,8 @@ def main():
     )
 
     test_net_shards = [
-        str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
+        str(path)
+        for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
     ]
     cuts_test_net_webdataset = CutSet.from_webdataset(
         test_net_shards,
@@ -723,7 +739,9 @@ def main():
 
     test_meeting_shards = [
         str(path)
-        for path in sorted(glob.glob(os.path.join(test_meeting, "shared-*.tar")))
+        for path in sorted(
+            glob.glob(os.path.join(test_meeting, "shared-*.tar"))
+        )
     ]
     cuts_test_meeting_webdataset = CutSet.from_webdataset(
         test_meeting_shards,
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
index c742593df..933642a0f 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
@@ -126,20 +126,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=28,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -208,7 +205,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     return parser
@@ -470,9 +468,13 @@ def export_joiner_model_onnx(
 
         - projected_decoder_out: a tensor of shape (N, joiner_dim)
     """
-    encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx")
+    encoder_proj_filename = str(joiner_filename).replace(
+        ".onnx", "_encoder_proj.onnx"
+    )
 
-    decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx")
+    decoder_proj_filename = str(joiner_filename).replace(
+        ".onnx", "_decoder_proj.onnx"
+    )
 
     encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
     decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
@@ -643,7 +645,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
index ed9020c67..e5cc47bfe 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
@@ -107,12 +107,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -147,9 +145,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -332,7 +331,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py
index a46ff5a07..c396c50ef 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py
@@ -219,7 +219,9 @@ def test_joiner(
         )
 
         # Now test encoder_proj
-        joiner_encoder_proj_inputs = {encoder_proj_input_name: encoder_out.numpy()}
+        joiner_encoder_proj_inputs = {
+            encoder_proj_input_name: encoder_out.numpy()
+        }
         joiner_encoder_proj_out = joiner_encoder_proj_session.run(
             [encoder_proj_output_name], joiner_encoder_proj_inputs
         )[0]
@@ -228,10 +230,16 @@ def test_joiner(
         torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out)
         assert torch.allclose(
             joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5
-        ), ((joiner_encoder_proj_out - torch_joiner_encoder_proj_out).abs().max())
+        ), (
+            (joiner_encoder_proj_out - torch_joiner_encoder_proj_out)
+            .abs()
+            .max()
+        )
 
         # Now test decoder_proj
-        joiner_decoder_proj_inputs = {decoder_proj_input_name: decoder_out.numpy()}
+        joiner_decoder_proj_inputs = {
+            decoder_proj_input_name: decoder_out.numpy()
+        }
         joiner_decoder_proj_out = joiner_decoder_proj_session.run(
             [decoder_proj_output_name], joiner_decoder_proj_inputs
         )[0]
@@ -240,7 +248,11 @@ def test_joiner(
         torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out)
         assert torch.allclose(
             joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5
-        ), ((joiner_decoder_proj_out - torch_joiner_decoder_proj_out).abs().max())
+        ), (
+            (joiner_decoder_proj_out - torch_joiner_decoder_proj_out)
+            .abs()
+            .max()
+        )
 
 
 @torch.no_grad()
@@ -292,7 +304,9 @@ def main():
 
 if __name__ == "__main__":
     torch.manual_seed(20220727)
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py
index f7d962008..3770fbbb4 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py
@@ -111,12 +111,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -151,9 +149,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -201,7 +200,11 @@ def greedy_search(
 
     projected_encoder_out = joiner_encoder_proj.run(
         [joiner_encoder_proj.get_outputs()[0].name],
-        {joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()},
+        {
+            joiner_encoder_proj.get_inputs()[
+                0
+            ].name: packed_encoder_out.data.numpy()
+        },
     )[0]
 
     blank_id = 0  # hard-code to 0
@@ -386,7 +389,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
index 26c9c2b8c..9a549efd9 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
@@ -80,11 +80,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -109,12 +107,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -162,7 +158,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -192,9 +189,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -255,7 +253,9 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -280,7 +280,10 @@ def main():
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -332,7 +335,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
index e020c4c05..d3cc7c9c9 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
@@ -115,7 +115,9 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+LRSchedulerType = Union[
+    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
+]
 
 
 def get_parser():
@@ -217,45 +219,42 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help=(
-            "The prune range for rnnt loss, it means how many symbols(context)"
-            "we are using to compute the loss"
-        ),
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help=(
-            "The scale to smooth the loss with lm (output of prediction network) part."
-        ),
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)part.",
+        help="The scale to smooth the loss with am (output of encoder network)"
+        "part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help=(
-            "To get pruning ranges, we will calculate a simple version"
-            "loss(joiner is just addition), this simple loss also uses for"
-            "training (as a regularization item). We will scale the simple loss"
-            "with this parameter before adding to the final loss."
-        ),
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
     )
 
     parser.add_argument(
@@ -591,15 +590,22 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+            0.0
+            if warmup < 1.0
+            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+        )
+        loss = (
+            params.simple_loss_scale * simple_loss
+            + pruned_loss_scale * pruned_loss
         )
-        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -756,7 +762,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -856,7 +864,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2**22
+            2 ** 22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py
index 1023c931a..dd27c17f0 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py
@@ -210,7 +210,10 @@ class Conformer(EncoderInterface):
           (num_encoder_layers, cnn_module_kernel - 1, encoder_dim).
           NOTE: the returned tensors are on the given device.
         """
-        if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context:
+        if (
+            len(self._init_state) == 2
+            and self._init_state[0].size(1) == left_context
+        ):
             # Note: It is OK to share the init state as it is
             # not going to be modified by the model
             return self._init_state
@@ -430,7 +433,9 @@ class ConformerEncoderLayer(nn.Module):
 
         self.d_model = d_model
 
-        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
+        self.self_attn = RelPositionMultiheadAttention(
+            d_model, nhead, dropout=0.0
+        )
 
         self.feed_forward = nn.Sequential(
             ScaledLinear(d_model, dim_feedforward),
@@ -448,7 +453,9 @@ class ConformerEncoderLayer(nn.Module):
             ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
         )
 
-        self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal)
+        self.conv_module = ConvolutionModule(
+            d_model, cnn_module_kernel, causal=causal
+        )
 
         self.norm_final = BasicNorm(d_model)
 
@@ -513,7 +520,9 @@ class ConformerEncoderLayer(nn.Module):
         src = src + self.dropout(src_att)
 
         # convolution module
-        conv, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
+        conv, _ = self.conv_module(
+            src, src_key_padding_mask=src_key_padding_mask
+        )
         src = src + self.dropout(conv)
 
         # feed forward module
@@ -757,7 +766,9 @@ class RelPositionalEncoding(torch.nn.Module):
         max_len: Maximum input length.
     """
 
-    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
+    def __init__(
+        self, d_model: int, dropout_rate: float, max_len: int = 5000
+    ) -> None:
         """Construct an PositionalEncoding object."""
         super(RelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -773,7 +784,9 @@ class RelPositionalEncoding(torch.nn.Module):
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(1) >= x_size_1 * 2 - 1:
                 # Note: TorchScript doesn't implement operator== for torch.Device
-                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
+                    x.device
+                ):
                     self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                 return
         # Suppose `i` means to the position of query vector and `j` means the
@@ -1060,9 +1073,9 @@ class RelPositionMultiheadAttention(nn.Module):
 
         if torch.equal(query, key) and torch.equal(key, value):
             # self-attention
-            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
-                3, dim=-1
-            )
+            q, k, v = nn.functional.linear(
+                query, in_proj_weight, in_proj_bias
+            ).chunk(3, dim=-1)
 
         elif torch.equal(key, value):
             # encoder-decoder attention
@@ -1131,25 +1144,33 @@ class RelPositionMultiheadAttention(nn.Module):
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0)
                 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
-                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
+                    raise RuntimeError(
+                        "The size of the 2D attn_mask is not correct."
+                    )
             elif attn_mask.dim() == 3:
                 if list(attn_mask.size()) != [
                     bsz * num_heads,
                     query.size(0),
                     key.size(0),
                 ]:
-                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
+                    raise RuntimeError(
+                        "The size of the 3D attn_mask is not correct."
+                    )
             else:
                 raise RuntimeError(
-                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
+                    "attn_mask's dimension {} is not supported".format(
+                        attn_mask.dim()
+                    )
                 )
             # attn_mask's dim is 3 now.
 
         # convert ByteTensor key_padding_mask to bool
-        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
+        if (
+            key_padding_mask is not None
+            and key_padding_mask.dtype == torch.uint8
+        ):
             warnings.warn(
-                "Byte tensor for key_padding_mask is deprecated. Use bool tensor"
-                " instead."
+                "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
             )
             key_padding_mask = key_padding_mask.to(torch.bool)
 
@@ -1187,15 +1208,23 @@ class RelPositionMultiheadAttention(nn.Module):
         # first compute matrix a and matrix c
         # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
         k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
-        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(
+            q_with_bias_u, k
+        )  # (batch, head, time1, time2)
 
         # compute matrix b and matrix d
-        matrix_bd = torch.matmul(q_with_bias_v, p)  # (batch, head, time1, 2*time1-1)
+        matrix_bd = torch.matmul(
+            q_with_bias_v, p
+        )  # (batch, head, time1, 2*time1-1)
         matrix_bd = self.rel_shift(matrix_bd, left_context)
 
-        attn_output_weights = matrix_ac + matrix_bd  # (batch, head, time1, time2)
+        attn_output_weights = (
+            matrix_ac + matrix_bd
+        )  # (batch, head, time1, time2)
 
-        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
+        attn_output_weights = attn_output_weights.view(
+            bsz * num_heads, tgt_len, -1
+        )
 
         assert list(attn_output_weights.size()) == [
             bsz * num_heads,
@@ -1236,17 +1265,21 @@ class RelPositionMultiheadAttention(nn.Module):
         ):
             if attn_mask.size(0) != 1:
                 attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len)
-                combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2)
-            else:
-                # attn_mask.shape == (1, tgt_len, src_len)
-                combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze(
+                combined_mask = attn_mask | key_padding_mask.unsqueeze(
                     1
                 ).unsqueeze(2)
+            else:
+                # attn_mask.shape == (1, tgt_len, src_len)
+                combined_mask = attn_mask.unsqueeze(
+                    0
+                ) | key_padding_mask.unsqueeze(1).unsqueeze(2)
 
             attn_output_weights = attn_output_weights.view(
                 bsz, num_heads, tgt_len, src_len
             )
-            attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0)
+            attn_output_weights = attn_output_weights.masked_fill(
+                combined_mask, 0.0
+            )
             attn_output_weights = attn_output_weights.view(
                 bsz * num_heads, tgt_len, src_len
             )
@@ -1258,9 +1291,13 @@ class RelPositionMultiheadAttention(nn.Module):
         attn_output = torch.bmm(attn_output_weights, v)
         assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
         attn_output = (
-            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+            attn_output.transpose(0, 1)
+            .contiguous()
+            .view(tgt_len, bsz, embed_dim)
+        )
+        attn_output = nn.functional.linear(
+            attn_output, out_proj_weight, out_proj_bias
         )
-        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
 
         if need_weights:
             # average attention weights over heads
@@ -1393,12 +1430,16 @@ class ConvolutionModule(nn.Module):
                 # manualy padding self.lorder zeros to the left
                 x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
             else:
-                assert not self.training, "Cache should be None in training time"
+                assert (
+                    not self.training
+                ), "Cache should be None in training time"
                 assert cache.size(0) == self.lorder
                 x = torch.cat([cache.permute(1, 2, 0), x], dim=2)
                 if right_context > 0:
                     cache = x.permute(2, 0, 1)[
-                        -(self.lorder + right_context) : (-right_context),  # noqa
+                        -(self.lorder + right_context) : (  # noqa
+                            -right_context
+                        ),
                         ...,
                     ]
                 else:
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
index 3d66f9dc9..344e31283 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
@@ -160,24 +160,20 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch' and '--iter'"
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
     )
 
     parser.add_argument(
         "--use-averaged-model",
         type=str2bool,
         default=True,
-        help=(
-            "Whether to load averaged model. Currently it only supports "
-            "using --epoch. If True, it would decode with the averaged model "
-            "over the epoch range from `epoch-avg` (excluded) to `epoch`."
-            "Actually only the models with epoch number of `epoch-avg` and "
-            "`epoch` are loaded for averaging. "
-        ),
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
     )
 
     parser.add_argument(
@@ -248,7 +244,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -345,7 +342,9 @@ def decode_one_batch(
             simulate_streaming=True,
         )
     else:
-        encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+        encoder_out, encoder_out_lens = model.encoder(
+            x=feature, x_lens=feature_lens
+        )
 
     hyps = []
 
@@ -361,7 +360,10 @@ def decode_one_batch(
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -407,7 +409,11 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -478,7 +484,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -511,7 +519,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -580,12 +589,13 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg:
                 raise ValueError(
@@ -608,12 +618,13 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg + 1
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg + 1]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg + 1:
                 raise ValueError(
@@ -641,7 +652,7 @@ def main():
             filename_start = f"{params.exp_dir}/epoch-{start}.pt"
             filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
             logging.info(
-                "Calculating the averaged model over epoch range from "
+                f"Calculating the averaged model over epoch range from "
                 f"{start} (excluded) to {params.epoch}"
             )
             model.to(device)
@@ -709,7 +720,8 @@ def main():
         )
 
     dev_shards = [
-        str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
+        str(path)
+        for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
     ]
     cuts_dev_webdataset = CutSet.from_webdataset(
         dev_shards,
@@ -719,7 +731,8 @@ def main():
     )
 
     test_net_shards = [
-        str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
+        str(path)
+        for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
     ]
     cuts_test_net_webdataset = CutSet.from_webdataset(
         test_net_shards,
@@ -730,7 +743,9 @@ def main():
 
     test_meeting_shards = [
         str(path)
-        for path in sorted(glob.glob(os.path.join(test_meeting, "shared-*.tar")))
+        for path in sorted(
+            glob.glob(os.path.join(test_meeting, "shared-*.tar"))
+        )
     ]
     cuts_test_meeting_webdataset = CutSet.from_webdataset(
         test_meeting_shards,
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py
index e522943c0..386248554 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py
@@ -75,7 +75,9 @@ class DecodeStream(object):
         # encoder.streaming_forward
         self.done_frames: int = 0
 
-        self.pad_length = (params.right_context + 2) * params.subsampling_factor + 3
+        self.pad_length = (
+            params.right_context + 2
+        ) * params.subsampling_factor + 3
 
         if params.decoding_method == "greedy_search":
             self.hyp = [params.blank_id] * params.context_size
@@ -89,11 +91,13 @@ class DecodeStream(object):
             )
         elif params.decoding_method == "fast_beam_search":
             # The rnnt_decoding_stream for fast_beam_search.
-            self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream(
-                decoding_graph
+            self.rnnt_decoding_stream: k2.RnntDecodingStream = (
+                k2.RnntDecodingStream(decoding_graph)
             )
         else:
-            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
+            raise ValueError(
+                f"Unsupported decoding method: {params.decoding_method}"
+            )
 
     @property
     def done(self) -> bool:
@@ -122,10 +126,13 @@ class DecodeStream(object):
         """Consume chunk_size frames of features"""
         chunk_length = chunk_size + self.pad_length
 
-        ret_length = min(self.num_frames - self.num_processed_frames, chunk_length)
+        ret_length = min(
+            self.num_frames - self.num_processed_frames, chunk_length
+        )
 
         ret_features = self.features[
-            self.num_processed_frames : self.num_processed_frames + ret_length  # noqa
+            self.num_processed_frames : self.num_processed_frames  # noqa
+            + ret_length
         ]
 
         self.num_processed_frames += chunk_size
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
index fb53f70ab..d0a7fd69f 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
@@ -90,20 +90,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=28,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -134,7 +131,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     add_model_arguments(parser)
 
@@ -203,7 +201,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py
index 9834189d8..1b064c874 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py
@@ -80,11 +80,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -109,12 +107,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -161,7 +157,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -192,9 +189,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -255,7 +253,9 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -280,7 +280,10 @@ def main():
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -332,7 +335,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py
index 810d94135..651aff6c9 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py
@@ -173,10 +173,14 @@ def modified_beam_search(
         log_probs_shape = k2.ragged.create_ragged_shape2(
             row_splits=row_splits, cached_tot_size=log_probs.numel()
         )
-        ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
+        ragged_log_probs = k2.RaggedTensor(
+            shape=log_probs_shape, value=log_probs
+        )
 
         for i in range(batch_size):
-            topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths)
+            topk_log_probs, topk_indexes = ragged_log_probs[i].topk(
+                num_active_paths
+            )
 
             with warnings.catch_warnings():
                 warnings.simplefilter("ignore")
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py
index 31a7fe605..ff96c6487 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py
@@ -119,24 +119,20 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch' and '--iter'"
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
     )
 
     parser.add_argument(
         "--use-averaged-model",
         type=str2bool,
         default=True,
-        help=(
-            "Whether to load averaged model. Currently it only supports "
-            "using --epoch. If True, it would decode with the averaged model "
-            "over the epoch range from `epoch-avg` (excluded) to `epoch`."
-            "Actually only the models with epoch number of `epoch-avg` and "
-            "`epoch` are loaded for averaging. "
-        ),
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
     )
 
     parser.add_argument(
@@ -205,7 +201,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -314,7 +311,9 @@ def decode_one_chunk(
     encoder_out = model.joiner.encoder_proj(encoder_out)
 
     if params.decoding_method == "greedy_search":
-        greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
+        greedy_search(
+            model=model, encoder_out=encoder_out, streams=decode_streams
+        )
     elif params.decoding_method == "fast_beam_search":
         processed_lens = processed_lens + encoder_out_lens
         fast_beam_search_one_best(
@@ -334,7 +333,9 @@ def decode_one_chunk(
             num_active_paths=params.num_active_paths,
         )
     else:
-        raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
+        raise ValueError(
+            f"Unsupported decoding method: {params.decoding_method}"
+        )
 
     states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)]
 
@@ -388,7 +389,9 @@ def decode_dataset(
     decode_results = []
     # Contain decode streams currently running.
     decode_streams = []
-    initial_states = model.encoder.get_init_state(params.left_context, device=device)
+    initial_states = model.encoder.get_init_state(
+        params.left_context, device=device
+    )
     for num, cut in enumerate(cuts):
         # each utterance has a DecodeStream.
         decode_stream = DecodeStream(
@@ -458,7 +461,9 @@ def decode_dataset(
     elif params.decoding_method == "modified_beam_search":
         key = f"num_active_paths_{params.num_active_paths}"
     else:
-        raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
+        raise ValueError(
+            f"Unsupported decoding method: {params.decoding_method}"
+        )
 
     return {key: decode_results}
 
@@ -494,7 +499,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -559,12 +565,13 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg:
                 raise ValueError(
@@ -587,12 +594,13 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg + 1
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg + 1]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg + 1:
                 raise ValueError(
@@ -620,7 +628,7 @@ def main():
             filename_start = f"{params.exp_dir}/epoch-{start}.pt"
             filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
             logging.info(
-                "Calculating the averaged model over epoch range from "
+                f"Calculating the averaged model over epoch range from "
                 f"{start} (excluded) to {params.epoch}"
             )
             model.to(device)
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py
index 40c9665f7..2052e9da7 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py
@@ -98,7 +98,9 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+LRSchedulerType = Union[
+    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
+]
 
 
 def add_model_arguments(parser: argparse.ArgumentParser):
@@ -258,7 +260,8 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need to be changed.",
+        help="The initial learning rate.  This value should not need "
+        "to be changed.",
     )
 
     parser.add_argument(
@@ -281,45 +284,42 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help=(
-            "The prune range for rnnt loss, it means how many symbols(context)"
-            "we are using to compute the loss"
-        ),
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help=(
-            "The scale to smooth the loss with lm (output of prediction network) part."
-        ),
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)part.",
+        help="The scale to smooth the loss with am (output of encoder network)"
+        "part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help=(
-            "To get pruning ranges, we will calculate a simple version"
-            "loss(joiner is just addition), this simple loss also uses for"
-            "training (as a regularization item). We will scale the simple loss"
-            "with this parameter before adding to the final loss."
-        ),
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
     )
 
     parser.add_argument(
@@ -665,7 +665,11 @@ def compute_loss(
      warmup: a floating point value which increases throughout training;
         values >= 1.0 are fully warmed up and have all modules present.
     """
-    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+    device = (
+        model.device
+        if isinstance(model, DDP)
+        else next(model.parameters()).device
+    )
     feature = batch["inputs"]
     # at entry, feature is (N, T, C)
     assert feature.ndim == 3
@@ -697,16 +701,23 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+            0.0
+            if warmup < 1.0
+            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+        )
+        loss = (
+            params.simple_loss_scale * simple_loss
+            + pruned_loss_scale * pruned_loss
         )
-        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
 
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -830,7 +841,9 @@ def train_one_epoch(
             scaler.update()
             optimizer.zero_grad()
         except:  # noqa
-            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
+            display_and_save_batch(
+                batch, params=params, graph_compiler=graph_compiler
+            )
             raise
 
         if params.print_diagnostics and batch_idx == 5:
@@ -888,7 +901,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -1001,7 +1016,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2**22
+            2 ** 22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
@@ -1169,7 +1184,9 @@ def scan_pessimistic_batches_for_oom(
                     f"Failing criterion: {criterion} "
                     f"(={crit_values[criterion]}) ..."
                 )
-            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
+            display_and_save_batch(
+                batch, params=params, graph_compiler=graph_compiler
+            )
             raise
 
 
diff --git a/egs/yesno/ASR/local/compile_hlg.py b/egs/yesno/ASR/local/compile_hlg.py
index 7234ca929..f83be05cf 100755
--- a/egs/yesno/ASR/local/compile_hlg.py
+++ b/egs/yesno/ASR/local/compile_hlg.py
@@ -128,7 +128,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/yesno/ASR/local/compute_fbank_yesno.py b/egs/yesno/ASR/local/compute_fbank_yesno.py
index 75d95df68..9a4e8a36f 100755
--- a/egs/yesno/ASR/local/compute_fbank_yesno.py
+++ b/egs/yesno/ASR/local/compute_fbank_yesno.py
@@ -54,7 +54,9 @@ def compute_fbank_yesno():
         dataset_parts,
     )
 
-    extractor = Fbank(FbankConfig(sampling_rate=8000, num_mel_bins=num_mel_bins))
+    extractor = Fbank(
+        FbankConfig(sampling_rate=8000, num_mel_bins=num_mel_bins)
+    )
 
     with get_executor() as ex:  # Initialize the executor only once.
         for partition, m in manifests.items():
@@ -69,7 +71,9 @@ def compute_fbank_yesno():
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+                    cut_set
+                    + cut_set.perturb_speed(0.9)
+                    + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -83,7 +87,9 @@ def compute_fbank_yesno():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/yesno/ASR/tdnn/asr_datamodule.py b/egs/yesno/ASR/tdnn/asr_datamodule.py
index 21860d2f5..85e5f1358 100644
--- a/egs/yesno/ASR/tdnn/asr_datamodule.py
+++ b/egs/yesno/ASR/tdnn/asr_datamodule.py
@@ -56,12 +56,10 @@ class YesNoAsrDataModule(DataModule):
         super().add_arguments(parser)
         group = parser.add_argument_group(
             title="ASR data related options",
-            description=(
-                "These options are used for the preparation of "
-                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-                "effective batch sizes, sampling strategies, applied data "
-                "augmentations, etc."
-            ),
+            description="These options are used for the preparation of "
+            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+            "effective batch sizes, sampling strategies, applied data "
+            "augmentations, etc.",
         )
         group.add_argument(
             "--feature-dir",
@@ -73,91 +71,75 @@ class YesNoAsrDataModule(DataModule):
             "--max-duration",
             type=int,
             default=30.0,
-            help=(
-                "Maximum pooled recordings duration (seconds) in a "
-                "single batch. You can reduce it if it causes CUDA OOM."
-            ),
+            help="Maximum pooled recordings duration (seconds) in a "
+            "single batch. You can reduce it if it causes CUDA OOM.",
         )
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, the batches will come from buckets of "
-                "similar duration (saves padding frames)."
-            ),
+            help="When enabled, the batches will come from buckets of "
+            "similar duration (saves padding frames).",
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=10,
-            help=(
-                "The number of buckets for the DynamicBucketingSampler"
-                "(you might want to increase it for larger datasets)."
-            ),
+            help="The number of buckets for the DynamicBucketingSampler"
+            "(you might want to increase it for larger datasets).",
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, utterances (cuts) will be concatenated "
-                "to minimize the amount of padding."
-            ),
+            help="When enabled, utterances (cuts) will be concatenated "
+            "to minimize the amount of padding.",
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help=(
-                "Determines the maximum duration of a concatenated cut "
-                "relative to the duration of the longest cut in a batch."
-            ),
+            help="Determines the maximum duration of a concatenated cut "
+            "relative to the duration of the longest cut in a batch.",
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help=(
-                "The amount of padding (in seconds) inserted between "
-                "concatenated cuts. This padding is filled with noise when "
-                "noise augmentation is used."
-            ),
+            help="The amount of padding (in seconds) inserted between "
+            "concatenated cuts. This padding is filled with noise when "
+            "noise augmentation is used.",
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, use on-the-fly cut mixing and feature "
-                "extraction. Will drop existing precomputed feature manifests "
-                "if available."
-            ),
+            help="When enabled, use on-the-fly cut mixing and feature "
+            "extraction. Will drop existing precomputed feature manifests "
+            "if available.",
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled (=default), the examples will be shuffled for each epoch."
-            ),
+            help="When enabled (=default), the examples will be "
+            "shuffled for each epoch.",
         )
         group.add_argument(
             "--return-cuts",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, each batch will have the "
-                "field: batch['supervisions']['cut'] with the cuts that "
-                "were used to construct it."
-            ),
+            help="When enabled, each batch will have the "
+            "field: batch['supervisions']['cut'] with the cuts that "
+            "were used to construct it.",
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that collect the batches.",
+            help="The number of training dataloader workers that "
+            "collect the batches.",
         )
 
     def train_dataloaders(self) -> DataLoader:
@@ -168,7 +150,7 @@ class YesNoAsrDataModule(DataModule):
         transforms = []
         if self.args.concatenate_cuts:
             logging.info(
-                "Using cut concatenation with duration factor "
+                f"Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
diff --git a/egs/yesno/ASR/tdnn/decode.py b/egs/yesno/ASR/tdnn/decode.py
index 41afe0404..9d4ab4b61 100755
--- a/egs/yesno/ASR/tdnn/decode.py
+++ b/egs/yesno/ASR/tdnn/decode.py
@@ -35,19 +35,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=14,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=2,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -204,7 +201,9 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -275,7 +274,9 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
+    HLG = k2.Fsa.from_dict(
+        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
+    )
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -296,7 +297,9 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
+        torch.save(
+            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
+        )
         return
 
     model.to(device)
@@ -314,7 +317,9 @@ def main():
         word_table=lexicon.word_table,
     )
 
-    save_results(exp_dir=params.exp_dir, test_set_name="test_set", results=results)
+    save_results(
+        exp_dir=params.exp_dir, test_set_name="test_set", results=results
+    )
 
     logging.info("Done!")
 
diff --git a/egs/yesno/ASR/tdnn/pretrained.py b/egs/yesno/ASR/tdnn/pretrained.py
index 09a8672ae..14220be19 100755
--- a/egs/yesno/ASR/tdnn/pretrained.py
+++ b/egs/yesno/ASR/tdnn/pretrained.py
@@ -41,11 +41,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -55,18 +53,18 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
+    parser.add_argument(
+        "--HLG", type=str, required=True, help="Path to HLG.pt."
+    )
 
     parser.add_argument(
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     return parser
@@ -103,9 +101,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -160,7 +159,9 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     # Note: We don't use key padding mask for attention during decoding
     with torch.no_grad():
@@ -200,7 +201,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/yesno/ASR/tdnn/train.py b/egs/yesno/ASR/tdnn/train.py
index 335493491..f32a27f35 100755
--- a/egs/yesno/ASR/tdnn/train.py
+++ b/egs/yesno/ASR/tdnn/train.py
@@ -430,7 +430,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/egs/yesno/ASR/transducer/decode.py b/egs/yesno/ASR/transducer/decode.py
index de478334e..6714180db 100755
--- a/egs/yesno/ASR/transducer/decode.py
+++ b/egs/yesno/ASR/transducer/decode.py
@@ -48,19 +48,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=125,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=20,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
     parser.add_argument(
         "--exp-dir",
@@ -119,7 +116,9 @@ def decode_one_batch(
     # at entry, feature is (N, T, C)
     feature_lens = batch["supervisions"]["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
 
     hyps = []
     batch_size = encoder_out.size(0)
@@ -187,7 +186,9 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -302,7 +303,9 @@ def main():
         model=model,
     )
 
-    save_results(exp_dir=params.exp_dir, test_set_name="test_set", results=results)
+    save_results(
+        exp_dir=params.exp_dir, test_set_name="test_set", results=results
+    )
 
     logging.info("Done!")
 
diff --git a/egs/yesno/ASR/transducer/train.py b/egs/yesno/ASR/transducer/train.py
index 88866ae81..deb92107d 100755
--- a/egs/yesno/ASR/transducer/train.py
+++ b/egs/yesno/ASR/transducer/train.py
@@ -430,7 +430,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/icefall/char_graph_compiler.py b/icefall/char_graph_compiler.py
index c31db6e4c..235160e14 100644
--- a/icefall/char_graph_compiler.py
+++ b/icefall/char_graph_compiler.py
@@ -71,7 +71,9 @@ class CharCtcTrainingGraphCompiler(object):
         for text in texts:
             text = re.sub(whitespace, "", text)
             sub_ids = [
-                self.token_table[txt] if txt in self.token_table else self.oov_id
+                self.token_table[txt]
+                if txt in self.token_table
+                else self.oov_id
                 for txt in text
             ]
             ids.append(sub_ids)
@@ -94,7 +96,9 @@ class CharCtcTrainingGraphCompiler(object):
         for text in texts:
             text = text.split("/")
             sub_ids = [
-                self.token_table[txt] if txt in self.token_table else self.oov_id
+                self.token_table[txt]
+                if txt in self.token_table
+                else self.oov_id
                 for txt in text
             ]
             ids.append(sub_ids)
diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py
index 8aa0a8eeb..5069b78e8 100644
--- a/icefall/checkpoint.py
+++ b/icefall/checkpoint.py
@@ -292,11 +292,15 @@ def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
     """
     checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt"))
     pattern = re.compile(r"checkpoint-([0-9]+).pt")
-    iter_checkpoints = [(int(pattern.search(c).group(1)), c) for c in checkpoints]
+    iter_checkpoints = [
+        (int(pattern.search(c).group(1)), c) for c in checkpoints
+    ]
     # iter_checkpoints is a list of tuples. Each tuple contains
     # two elements: (iteration_number, checkpoint-iteration_number.pt)
 
-    iter_checkpoints = sorted(iter_checkpoints, reverse=True, key=lambda x: x[0])
+    iter_checkpoints = sorted(
+        iter_checkpoints, reverse=True, key=lambda x: x[0]
+    )
     if iteration >= 0:
         ans = [ic[1] for ic in iter_checkpoints if ic[0] >= iteration]
     else:
@@ -465,5 +469,7 @@ def average_state_dict(
         v = state_dict_1[k]
         if torch.is_floating_point(v):
             v *= weight_1
-            v += state_dict_2[k].to(device=state_dict_1[k].device) * weight_2
+            v += (
+                state_dict_2[k].to(device=state_dict_1[k].device) * weight_2
+            )
             v *= scaling_factor
diff --git a/icefall/decode.py b/icefall/decode.py
index e4c614c4e..099e2d171 100644
--- a/icefall/decode.py
+++ b/icefall/decode.py
@@ -334,9 +334,13 @@ class Nbest(object):
         if hasattr(lattice, "aux_labels"):
             # delete token IDs as it is not needed
             del word_fsa.aux_labels
-            word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa)
+            word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(
+                word_fsa
+            )
         else:
-            word_fsa_with_epsilon_loops = k2.linear_fst_with_self_loops(word_fsa)
+            word_fsa_with_epsilon_loops = k2.linear_fst_with_self_loops(
+                word_fsa
+            )
 
         path_to_utt_map = self.shape.row_ids(1)
 
@@ -366,7 +370,9 @@ class Nbest(object):
         # path_lattice has word IDs as labels and token IDs as aux_labels
         path_lattice = k2.top_sort(k2.connect(path_lattice))
 
-        one_best = k2.shortest_path(path_lattice, use_double_scores=use_double_scores)
+        one_best = k2.shortest_path(
+            path_lattice, use_double_scores=use_double_scores
+        )
 
         one_best = k2.invert(one_best)
         # Now one_best has token IDs as labels and word IDs as aux_labels
@@ -436,7 +442,9 @@ class Nbest(object):
         scores_shape = self.fsa.arcs.shape().remove_axis(1)
         # scores_shape has axes [path][arc]
 
-        ragged_scores = k2.RaggedTensor(scores_shape, self.fsa.scores.contiguous())
+        ragged_scores = k2.RaggedTensor(
+            scores_shape, self.fsa.scores.contiguous()
+        )
 
         tot_scores = ragged_scores.sum()
 
@@ -475,7 +483,9 @@ def one_best_decoding(
             am_scores = saved_am_scores / lm_scale
             lattice.scores = am_scores + lattice.lm_scores
 
-            best_path = k2.shortest_path(lattice, use_double_scores=use_double_scores)
+            best_path = k2.shortest_path(
+                lattice, use_double_scores=use_double_scores
+            )
             key = f"lm_scale_{lm_scale}"
             ans[key] = best_path
         return ans
@@ -686,7 +696,9 @@ def rescore_with_n_best_list(
             logging.info(f"num_paths before decreasing: {num_paths}")
             num_paths = int(num_paths / 2)
             if loop_count >= max_loop_count or num_paths <= 0:
-                logging.info("Return None as the resulting lattice is too large.")
+                logging.info(
+                    "Return None as the resulting lattice is too large."
+                )
                 return None
             logging.info(
                 "This OOM is not an error. You can ignore it. "
@@ -793,9 +805,13 @@ def rescore_with_whole_lattice(
         except RuntimeError as e:
             logging.info(f"Caught exception:\n{e}\n")
             if loop_count >= max_loop_count:
-                logging.info("Return None as the resulting lattice is too large.")
+                logging.info(
+                    "Return None as the resulting lattice is too large."
+                )
                 return None
-            logging.info(f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}")
+            logging.info(
+                f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}"
+            )
             logging.info(
                 "This OOM is not an error. You can ignore it. "
                 "If your model does not converge well, or --max-duration "
@@ -807,7 +823,9 @@ def rescore_with_whole_lattice(
                 prune_th_list[loop_count],
                 True,
             )
-            logging.info(f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}")
+            logging.info(
+                f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}"
+            )
         loop_count += 1
 
     # lat has token IDs as labels
@@ -894,7 +912,9 @@ def rescore_with_attention_decoder(
             logging.info(f"num_paths before decreasing: {num_paths}")
             num_paths = int(num_paths / 2)
             if loop_count >= max_loop_count or num_paths <= 0:
-                logging.info("Return None as the resulting lattice is too large.")
+                logging.info(
+                    "Return None as the resulting lattice is too large."
+                )
                 return None
             logging.info(
                 "This OOM is not an error. You can ignore it. "
diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py
index 7b58ffbd4..b075aceac 100644
--- a/icefall/diagnostics.py
+++ b/icefall/diagnostics.py
@@ -19,7 +19,7 @@
 
 import random
 from dataclasses import dataclass
-from typing import List, Optional, Tuple
+from typing import Optional, Tuple, List
 
 import torch
 from torch import Tensor, nn
@@ -78,11 +78,11 @@ def get_tensor_stats(
     elif stats_type == "abs":
         x = x.abs()
     elif stats_type == "rms":
-        x = x**2
+        x = x ** 2
     elif stats_type == "positive":
         x = (x > 0).to(dtype=torch.float)
     else:
-        assert stats_type in ["value", "max", "min"]
+        assert stats_type in [ "value", "max", "min" ]
 
     sum_dims = [d for d in range(x.ndim) if d != dim]
     if len(sum_dims) > 0:
@@ -121,9 +121,7 @@ class TensorDiagnostic(object):
         self.name = name
         self.class_name = None  # will assign in accumulate()
 
-        self.stats = (
-            None  # we'll later assign a list to this data member.  It's a list of dict.
-        )
+        self.stats = None  # we'll later assign a list to this data member.  It's a list of dict.
 
         # the keys into self.stats[dim] are strings, whose values can be
         # "abs", "max", "min" ,"value", "positive", "rms", "value".
@@ -135,6 +133,7 @@ class TensorDiagnostic(object):
         # only adding a new element to the list if there was a different dim.
         # if the string in the key is "eigs", if we detect a length mismatch we put None as the value.
 
+
     def accumulate(self, x, class_name: Optional[str] = None):
         """
         Accumulate tensors.
@@ -186,12 +185,17 @@ class TensorDiagnostic(object):
                         done = True
                         break
                 if not done:
-                    if this_dim_stats[stats_type] != [] and stats_type == "eigs":
+                    if (
+                        this_dim_stats[stats_type] != []
+                        and stats_type == "eigs"
+                    ):
                         # >1 size encountered on this dim, e.g. it's a batch or time dimension,
                         # don't accumulat "eigs" stats type, it uses too much memory
                         this_dim_stats[stats_type] = None
                     else:
-                        this_dim_stats[stats_type].append(TensorAndCount(stats, count))
+                        this_dim_stats[stats_type].append(
+                            TensorAndCount(stats, count)
+                        )
 
     def print_diagnostics(self):
         """Print diagnostics for each dimension of the tensor."""
@@ -207,6 +211,7 @@ class TensorDiagnostic(object):
                     assert stats_type == "eigs"
                     continue
 
+
                 def get_count(count):
                     return 1 if stats_type in ["max", "min"] else count
 
@@ -216,8 +221,7 @@ class TensorDiagnostic(object):
                     # a dimension that has variable size in different nnet
                     # forwards, e.g. a time dimension in an ASR model.
                     stats = torch.cat(
-                        [x.tensor / get_count(x.count) for x in stats_list],
-                        dim=0,
+                        [x.tensor / get_count(x.count) for x in stats_list], dim=0
                     )
 
                 if stats_type == "eigs":
@@ -225,7 +229,9 @@ class TensorDiagnostic(object):
                         eigs, _ = torch.symeig(stats)
                         stats = eigs.abs().sqrt()
                     except:  # noqa
-                        print("Error getting eigenvalues, trying another method.")
+                        print(
+                            "Error getting eigenvalues, trying another method."
+                        )
                         eigs, _ = torch.eig(stats)
                         stats = eigs.abs().sqrt()
                         # sqrt so it reflects data magnitude, like stddev- not variance
@@ -236,9 +242,9 @@ class TensorDiagnostic(object):
 
                 # if `summarize` we print percentiles of the stats; else,
                 # we print out individual elements.
-                summarize = (len(stats_list) > 1) or self.opts.dim_is_summarized(
-                    stats.numel()
-                )
+                summarize = (
+                    len(stats_list) > 1
+                ) or self.opts.dim_is_summarized(stats.numel())
                 if summarize:  # usually `summarize` will be true
                     # print out percentiles.
                     stats = stats.sort()[0]
@@ -255,15 +261,15 @@ class TensorDiagnostic(object):
                     ans = stats.tolist()
                     ans = ["%.2g" % x for x in ans]
                     ans = "[" + " ".join(ans) + "]"
-                if stats_type in ["value", "rms", "eigs"]:
+                if stats_type in [ "value", "rms", "eigs" ]:
                     # This norm is useful because it is strictly less than the largest
                     # sqrt(eigenvalue) of the variance, which we print out, and shows,
                     # speaking in an approximate way, how much of that largest eigenvalue
                     # can be attributed to the mean of the distribution.
-                    norm = (stats**2).sum().sqrt().item()
+                    norm = (stats ** 2).sum().sqrt().item()
                     ans += f", norm={norm:.2g}"
                 mean = stats.mean().item()
-                rms = (stats**2).mean().sqrt().item()
+                rms = (stats ** 2).mean().sqrt().item()
                 ans += f", mean={mean:.3g}, rms={rms:.3g}"
 
                 # OK, "ans" contains the actual stats, e.g.
@@ -271,17 +277,17 @@ class TensorDiagnostic(object):
 
                 sizes = [x.tensor.shape[0] for x in stats_list]
                 size_str = (
-                    f"{sizes[0]}" if len(sizes) == 1 else f"{min(sizes)}..{max(sizes)}"
-                )
-                maybe_class_name = (
-                    f" type={self.class_name}," if self.class_name is not None else ""
+                    f"{sizes[0]}"
+                    if len(sizes) == 1
+                    else f"{min(sizes)}..{max(sizes)}"
                 )
+                maybe_class_name = f" type={self.class_name}," if self.class_name is not None else ""
                 print(
-                    f"module={self.name},{maybe_class_name} dim={dim}, size={size_str},"
-                    f" {stats_type} {ans}"
+                    f"module={self.name},{maybe_class_name} dim={dim}, size={size_str}, {stats_type} {ans}"
                 )
 
 
+
 class ModelDiagnostic(object):
     """This class stores diagnostics for all tensors in the torch.nn.Module.
 
@@ -339,32 +345,32 @@ def attach_diagnostics(
         # (matters for name, since the variable gets overwritten).
         # These closures don't really capture by value, only by
         # "the final value the variable got in the function" :-(
-        def forward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
+        def forward_hook(
+            _module, _input, _output, _model_diagnostic=ans, _name=name
+        ):
             if isinstance(_output, tuple) and len(_output) == 1:
                 _output = _output[0]
 
             if isinstance(_output, Tensor):
-                _model_diagnostic[f"{_name}.output"].accumulate(
-                    _output, class_name=type(_module).__name__
-                )
+                _model_diagnostic[f"{_name}.output"].accumulate(_output,
+                                                                class_name=type(_module).__name__)
             elif isinstance(_output, tuple):
                 for i, o in enumerate(_output):
-                    _model_diagnostic[f"{_name}.output[{i}]"].accumulate(
-                        o, class_name=type(_module).__name__
-                    )
+                    _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o,
+                                                                         class_name=type(_module).__name__)
 
-        def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
+        def backward_hook(
+            _module, _input, _output, _model_diagnostic=ans, _name=name
+        ):
             if isinstance(_output, tuple) and len(_output) == 1:
                 _output = _output[0]
             if isinstance(_output, Tensor):
-                _model_diagnostic[f"{_name}.grad"].accumulate(
-                    _output, class_name=type(_module).__name__
-                )
+                _model_diagnostic[f"{_name}.grad"].accumulate(_output,
+                                                              class_name=type(_module).__name__)
             elif isinstance(_output, tuple):
                 for i, o in enumerate(_output):
-                    _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(
-                        o, class_name=type(_module).__name__
-                    )
+                    _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o,
+                                                                       class_name=type(_module).__name__)
 
         module.register_forward_hook(forward_hook)
         module.register_backward_hook(backward_hook)
diff --git a/icefall/dist.py b/icefall/dist.py
index 9df1c5bd1..7016beafb 100644
--- a/icefall/dist.py
+++ b/icefall/dist.py
@@ -29,7 +29,9 @@ def setup_dist(rank, world_size, master_port=None, use_ddp_launch=False):
         os.environ["MASTER_ADDR"] = "localhost"
 
     if "MASTER_PORT" not in os.environ:
-        os.environ["MASTER_PORT"] = "12354" if master_port is None else str(master_port)
+        os.environ["MASTER_PORT"] = (
+            "12354" if master_port is None else str(master_port)
+        )
 
     if use_ddp_launch is False:
         dist.init_process_group("nccl", rank=rank, world_size=world_size)
diff --git a/icefall/env.py b/icefall/env.py
index 373e9a9ff..8aeda6be2 100644
--- a/icefall/env.py
+++ b/icefall/env.py
@@ -53,7 +53,9 @@ def get_git_sha1():
             )
             > 0
         )
-        git_commit = git_commit + "-dirty" if dirty_commit else git_commit + "-clean"
+        git_commit = (
+            git_commit + "-dirty" if dirty_commit else git_commit + "-clean"
+        )
     except:  # noqa
         return None
 
diff --git a/icefall/graph_compiler.py b/icefall/graph_compiler.py
index e2ff03f61..570ed7d7a 100644
--- a/icefall/graph_compiler.py
+++ b/icefall/graph_compiler.py
@@ -75,7 +75,9 @@ class CtcTrainingGraphCompiler(object):
 
         # NOTE: k2.compose runs on CUDA only when treat_epsilons_specially
         # is False, so we add epsilon self-loops here
-        fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(transcript_fsa)
+        fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(
+            transcript_fsa
+        )
 
         fsa_with_self_loops = k2.arc_sort(fsa_with_self_loops)
 
diff --git a/icefall/hooks.py b/icefall/hooks.py
index 398a5f689..fbcf5e148 100644
--- a/icefall/hooks.py
+++ b/icefall/hooks.py
@@ -14,11 +14,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import logging
 import random
-
 import torch
 from torch import Tensor, nn
+import logging
 
 
 def register_inf_check_hooks(model: nn.Module) -> None:
@@ -57,7 +56,7 @@ def register_inf_check_hooks(model: nn.Module) -> None:
             if isinstance(_output, Tensor):
                 if not torch.isfinite(_output.to(torch.float32).sum()):
                     logging.warning(
-                        f"The sum of {_name}.grad is not finite"  # ": {_output}"
+                        f"The sum of {_name}.grad is not finite" # ": {_output}"
                     )
             elif isinstance(_output, tuple):
                 for i, o in enumerate(_output):
@@ -66,20 +65,28 @@ def register_inf_check_hooks(model: nn.Module) -> None:
                     if not isinstance(o, Tensor):
                         continue
                     if not torch.isfinite(o.to(torch.float32).sum()):
-                        logging.warning(f"The sum of {_name}.grad[{i}] is not finite")
+                        logging.warning(
+                            f"The sum of {_name}.grad[{i}] is not finite"
+                        )
 
         module.register_forward_hook(forward_hook)
         module.register_backward_hook(backward_hook)
 
+
     for name, parameter in model.named_parameters():
 
-        def param_backward_hook(grad, _name=name):
+        def param_backward_hook(
+                grad, _name=name
+        ):
             if not torch.isfinite(grad.to(torch.float32).sum()):
-                logging.warning(f"The sum of {_name}.param_grad is not finite")
+                logging.warning(
+                    f"The sum of {_name}.param_grad is not finite"
+                )
 
         parameter.register_hook(param_backward_hook)
 
 
+
 def _test_inf_check_hooks():
     model = nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 80))
 
diff --git a/icefall/lexicon.py b/icefall/lexicon.py
index 22e1b78bb..80bd7c1ee 100644
--- a/icefall/lexicon.py
+++ b/icefall/lexicon.py
@@ -49,12 +49,18 @@ def read_lexicon(filename: str) -> List[Tuple[str, List[str]]]:
                 continue
 
             if len(a) < 2:
-                logging.info(f"Found bad line {line} in lexicon file {filename}")
-                logging.info("Every line is expected to contain at least 2 fields")
+                logging.info(
+                    f"Found bad line {line} in lexicon file {filename}"
+                )
+                logging.info(
+                    "Every line is expected to contain at least 2 fields"
+                )
                 sys.exit(1)
             word = a[0]
             if word == "":
-                logging.info(f"Found bad line {line} in lexicon file {filename}")
+                logging.info(
+                    f"Found bad line {line} in lexicon file {filename}"
+                )
                 logging.info(" should not be a valid word")
                 sys.exit(1)
 
@@ -113,7 +119,9 @@ def convert_lexicon_to_ragged(
     lexicon_tmp = read_lexicon(filename)
     lexicon = dict(lexicon_tmp)
     if len(lexicon_tmp) != len(lexicon):
-        raise RuntimeError("It's assumed that each word has a unique pronunciation")
+        raise RuntimeError(
+            "It's assumed that each word has a unique pronunciation"
+        )
 
     for i in range(disambig_id):
         w = word_table[i]
diff --git a/icefall/mmi.py b/icefall/mmi.py
index 16ed6e032..2c479fc2c 100644
--- a/icefall/mmi.py
+++ b/icefall/mmi.py
@@ -63,7 +63,10 @@ def _compute_mmi_loss_exact_optimized(
 
     # [0, num_fsas, 1, num_fsas, 2, num_fsas, ... ]
     num_den_graphs_indexes = (
-        torch.stack([num_graphs_indexes, den_graphs_indexes]).t().reshape(-1).to(device)
+        torch.stack([num_graphs_indexes, den_graphs_indexes])
+        .t()
+        .reshape(-1)
+        .to(device)
     )
 
     num_den_reordered_graphs = k2.index(num_den_graphs, num_den_graphs_indexes)
@@ -112,12 +115,20 @@ def _compute_mmi_loss_exact_non_optimized(
     num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True)
 
     # TODO: pass output_beam as function argument
-    num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=beam_size)
-    den_lats = k2.intersect_dense(den_graphs, dense_fsa_vec, output_beam=beam_size)
+    num_lats = k2.intersect_dense(
+        num_graphs, dense_fsa_vec, output_beam=beam_size
+    )
+    den_lats = k2.intersect_dense(
+        den_graphs, dense_fsa_vec, output_beam=beam_size
+    )
 
-    num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
+    num_tot_scores = num_lats.get_tot_scores(
+        log_semiring=True, use_double_scores=True
+    )
 
-    den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
+    den_tot_scores = den_lats.get_tot_scores(
+        log_semiring=True, use_double_scores=True
+    )
 
     tot_scores = num_tot_scores - den_scale * den_tot_scores
 
@@ -157,9 +168,13 @@ def _compute_mmi_loss_pruned(
         max_active_states=10000,
     )
 
-    num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
+    num_tot_scores = num_lats.get_tot_scores(
+        log_semiring=True, use_double_scores=True
+    )
 
-    den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
+    den_tot_scores = den_lats.get_tot_scores(
+        log_semiring=True, use_double_scores=True
+    )
 
     tot_scores = num_tot_scores - den_scale * den_tot_scores
 
diff --git a/icefall/mmi_graph_compiler.py b/icefall/mmi_graph_compiler.py
index 9f680f83d..0d901227d 100644
--- a/icefall/mmi_graph_compiler.py
+++ b/icefall/mmi_graph_compiler.py
@@ -137,7 +137,9 @@ class MmiTrainingGraphCompiler(object):
             transcript_fsa
         )
 
-        transcript_fsa_with_self_loops = k2.arc_sort(transcript_fsa_with_self_loops)
+        transcript_fsa_with_self_loops = k2.arc_sort(
+            transcript_fsa_with_self_loops
+        )
 
         num = k2.compose(
             self.ctc_topo_P,
@@ -153,7 +155,9 @@ class MmiTrainingGraphCompiler(object):
 
         ctc_topo_P_vec = k2.create_fsa_vec([self.ctc_topo_P])
         if replicate_den:
-            indexes = torch.zeros(len(texts), dtype=torch.int32, device=self.device)
+            indexes = torch.zeros(
+                len(texts), dtype=torch.int32, device=self.device
+            )
             den = k2.index_fsa(ctc_topo_P_vec, indexes)
         else:
             den = ctc_topo_P_vec
diff --git a/icefall/rnn_lm/compute_perplexity.py b/icefall/rnn_lm/compute_perplexity.py
index 9a275bf28..550801a8f 100755
--- a/icefall/rnn_lm/compute_perplexity.py
+++ b/icefall/rnn_lm/compute_perplexity.py
@@ -46,19 +46,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=49,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=20,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -197,7 +194,7 @@ def main():
 
     logging.info(f"Number of model parameters: {num_param}")
     logging.info(
-        "Number of model parameters (requires_grad): "
+        f"Number of model parameters (requires_grad): "
         f"{num_param_requires_grad} "
         f"({num_param_requires_grad/num_param_requires_grad*100}%)"
     )
diff --git a/icefall/rnn_lm/dataset.py b/icefall/rnn_lm/dataset.py
index 4bf982503..598e329c4 100644
--- a/icefall/rnn_lm/dataset.py
+++ b/icefall/rnn_lm/dataset.py
@@ -155,8 +155,12 @@ class LmDatasetCollate:
         sentence_tokens_with_sos = add_sos(sentence_tokens, self.sos_id)
         sentence_tokens_with_eos = add_eos(sentence_tokens, self.eos_id)
 
-        x = sentence_tokens_with_sos.pad(mode="constant", padding_value=self.blank_id)
-        y = sentence_tokens_with_eos.pad(mode="constant", padding_value=self.blank_id)
+        x = sentence_tokens_with_sos.pad(
+            mode="constant", padding_value=self.blank_id
+        )
+        y = sentence_tokens_with_eos.pad(
+            mode="constant", padding_value=self.blank_id
+        )
         sentence_token_lengths += 1  # plus 1 since we added a SOS
 
         return x.to(torch.int64), y.to(torch.int64), sentence_token_lengths
diff --git a/icefall/rnn_lm/export.py b/icefall/rnn_lm/export.py
index 2e878f5c8..094035fce 100644
--- a/icefall/rnn_lm/export.py
+++ b/icefall/rnn_lm/export.py
@@ -38,20 +38,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=29,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=5,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -162,7 +159,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/icefall/rnn_lm/model.py b/icefall/rnn_lm/model.py
index 9eef88840..a6144727a 100644
--- a/icefall/rnn_lm/model.py
+++ b/icefall/rnn_lm/model.py
@@ -129,7 +129,9 @@ class RnnLmModel(torch.nn.Module):
         tokens_eos = add_eos(tokens, eos_id)
         sos_tokens_row_splits = sos_tokens.shape.row_splits(1)
 
-        sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
+        sentence_lengths = (
+            sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
+        )
 
         x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id)
         y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id)
@@ -159,12 +161,12 @@ class RnnLmModel(torch.nn.Module):
         if state:
             h, c = state
         else:
-            h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(
-                device
-            )
-            c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(
-                device
-            )
+            h = torch.zeros(
+                self.rnn.num_layers, batch_size, self.rnn.input_size
+            ).to(device)
+            c = torch.zeros(
+                self.rnn.num_layers, batch_size, self.rnn.input_size
+            ).to(device)
 
         embedding = self.input_embedding(tokens)
         rnn_out, states = self.rnn(embedding, (h, c))
@@ -179,8 +181,12 @@ class RnnLmModel(torch.nn.Module):
         if state:
             h, c = state
         else:
-            h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size)
-            c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size)
+            h = torch.zeros(
+                self.rnn.num_layers, batch_size, self.rnn.input_size
+            )
+            c = torch.zeros(
+                self.rnn.num_layers, batch_size, self.rnn.input_size
+            )
 
         device = next(self.parameters()).device
 
@@ -188,7 +194,9 @@ class RnnLmModel(torch.nn.Module):
         tokens_eos = add_eos(tokens, eos_id)
         sos_tokens_row_splits = sos_tokens.shape.row_splits(1)
 
-        sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
+        sentence_lengths = (
+            sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
+        )
 
         x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id)
         y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id)
diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py
index e17b50332..bb5f03fb9 100755
--- a/icefall/rnn_lm/train.py
+++ b/icefall/rnn_lm/train.py
@@ -446,13 +446,17 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
                 tb_writer.add_scalar(
                     "train/current_ppl", this_batch_ppl, params.batch_idx_train
                 )
 
-                tb_writer.add_scalar("train/tot_ppl", tot_ppl, params.batch_idx_train)
+                tb_writer.add_scalar(
+                    "train/tot_ppl", tot_ppl, params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -467,7 +471,8 @@ def train_one_epoch(
 
             valid_ppl = math.exp(valid_info["loss"] / valid_info["frames"])
             logging.info(
-                f"Epoch {params.cur_epoch}, validation: {valid_info}, ppl: {valid_ppl}"
+                f"Epoch {params.cur_epoch}, validation: {valid_info}, "
+                f"ppl: {valid_ppl}"
             )
 
             if tb_writer is not None:
diff --git a/icefall/shared/make_kn_lm.py b/icefall/shared/make_kn_lm.py
index a3bf1ef4c..c2edd823e 100755
--- a/icefall/shared/make_kn_lm.py
+++ b/icefall/shared/make_kn_lm.py
@@ -15,50 +15,30 @@
 # The data structure is based on: kaldi/egs/wsj/s5/utils/lang/make_phone_lm.py
 # The smoothing algorithm is based on: http://www.speech.sri.com/projects/srilm/manpages/ngram-discount.7.html
 
-import argparse
-import io
-import math
+import sys
 import os
 import re
-import sys
+import io
+import math
+import argparse
 from collections import Counter, defaultdict
 
-parser = argparse.ArgumentParser(
-    description="""
+
+parser = argparse.ArgumentParser(description="""
     Generate kneser-ney language model as arpa format. By default,
     it will read the corpus from standard input, and output to standard output.
-    """
-)
-parser.add_argument(
-    "-ngram-order",
-    type=int,
-    default=4,
-    choices=[2, 3, 4, 5, 6, 7],
-    help="Order of n-gram",
-)
+    """)
+parser.add_argument("-ngram-order", type=int, default=4, choices=[2, 3, 4, 5, 6, 7], help="Order of n-gram")
 parser.add_argument("-text", type=str, default=None, help="Path to the corpus file")
-parser.add_argument(
-    "-lm",
-    type=str,
-    default=None,
-    help="Path to output arpa file for language models",
-)
-parser.add_argument(
-    "-verbose",
-    type=int,
-    default=0,
-    choices=[0, 1, 2, 3, 4, 5],
-    help="Verbose level",
-)
+parser.add_argument("-lm", type=str, default=None, help="Path to output arpa file for language models")
+parser.add_argument("-verbose", type=int, default=0, choices=[0, 1, 2, 3, 4, 5], help="Verbose level")
 args = parser.parse_args()
 
-default_encoding = (
-    "latin-1"  # For encoding-agnostic scripts, we assume byte stream as input.
-)
-# Need to be very careful about the use of strip() and split()
-# in this case, because there is a latin-1 whitespace character
-# (nbsp) which is part of the unicode encoding range.
-# Ref: kaldi/egs/wsj/s5/utils/lang/bpe/prepend_words.py @ 69cd717
+default_encoding = "latin-1"  # For encoding-agnostic scripts, we assume byte stream as input.
+                              # Need to be very careful about the use of strip() and split()
+                              # in this case, because there is a latin-1 whitespace character
+                              # (nbsp) which is part of the unicode encoding range.
+                              # Ref: kaldi/egs/wsj/s5/utils/lang/bpe/prepend_words.py @ 69cd717
 strip_chars = " \t\r\n"
 whitespace = re.compile("[ \t]+")
 
@@ -72,9 +52,7 @@ class CountsForHistory:
         # The 'lambda: defaultdict(float)' is an anonymous function taking no
         # arguments that returns a new defaultdict(float).
         self.word_to_count = defaultdict(int)
-        self.word_to_context = defaultdict(
-            set
-        )  # using a set to count the number of unique contexts
+        self.word_to_context = defaultdict(set)  # using a set to count the number of unique contexts
         self.word_to_f = dict()  # discounted probability
         self.word_to_bow = dict()  # back-off weight
         self.total_count = 0
@@ -84,15 +62,10 @@ class CountsForHistory:
 
     def __str__(self):
         # e.g. returns ' total=12: 3->4, 4->6, -1->2'
-        return " total={0}: {1}".format(
+        return ' total={0}: {1}'.format(
             str(self.total_count),
-            ", ".join(
-                [
-                    "{0} -> {1}".format(word, count)
-                    for word, count in self.word_to_count.items()
-                ]
-            ),
-        )
+            ', '.join(['{0} -> {1}'.format(word, count)
+                      for word, count in self.word_to_count.items()]))
 
     def add_count(self, predicted_word, context_word, count):
         assert count >= 0
@@ -112,7 +85,7 @@ class NgramCounts:
     # accumulating the 4-gram count for the '8' in the sequence '5 6 7 8', we'd
     # do as follows: self.counts[3][[5,6,7]][8] += 1.0 where the [3] indexes an
     # array, the [[5,6,7]] indexes a dict, and the [8] indexes a dict.
-    def __init__(self, ngram_order, bos_symbol="", eos_symbol=""):
+    def __init__(self, ngram_order, bos_symbol='', eos_symbol=''):
         assert ngram_order >= 2
 
         self.ngram_order = ngram_order
@@ -130,48 +103,39 @@ class NgramCounts:
     # would be (6,7,8) and 'predicted_word' would be 9; 'count' would be
     # 1.
     def add_count(self, history, predicted_word, context_word, count):
-        self.counts[len(history)][history].add_count(
-            predicted_word, context_word, count
-        )
+        self.counts[len(history)][history].add_count(predicted_word, context_word, count)
 
     # 'line' is a string containing a sequence of integer word-ids.
     # This function adds the un-smoothed counts from this line of text.
     def add_raw_counts_from_line(self, line):
-        if line == "":
+        if line == '':
             words = [self.bos_symbol, self.eos_symbol]
         else:
             words = [self.bos_symbol] + whitespace.split(line) + [self.eos_symbol]
 
         for i in range(len(words)):
-            for n in range(1, self.ngram_order + 1):
+            for n in range(1, self.ngram_order+1):
                 if i + n > len(words):
                     break
-                ngram = words[i : i + n]
+                ngram = words[i: i + n]
                 predicted_word = ngram[-1]
-                history = tuple(ngram[:-1])
+                history = tuple(ngram[: -1])
                 if i == 0 or n == self.ngram_order:
                     context_word = None
                 else:
-                    context_word = words[i - 1]
+                    context_word = words[i-1]
 
                 self.add_count(history, predicted_word, context_word, 1)
 
     def add_raw_counts_from_standard_input(self):
         lines_processed = 0
-        infile = io.TextIOWrapper(
-            sys.stdin.buffer, encoding=default_encoding
-        )  # byte stream as input
+        infile = io.TextIOWrapper(sys.stdin.buffer, encoding=default_encoding)  # byte stream as input
         for line in infile:
             line = line.strip(strip_chars)
             self.add_raw_counts_from_line(line)
             lines_processed += 1
         if lines_processed == 0 or args.verbose > 0:
-            print(
-                "make_phone_lm.py: processed {0} lines of input".format(
-                    lines_processed
-                ),
-                file=sys.stderr,
-            )
+            print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr)
 
     def add_raw_counts_from_file(self, filename):
         lines_processed = 0
@@ -181,12 +145,7 @@ class NgramCounts:
                 self.add_raw_counts_from_line(line)
                 lines_processed += 1
         if lines_processed == 0 or args.verbose > 0:
-            print(
-                "make_phone_lm.py: processed {0} lines of input".format(
-                    lines_processed
-                ),
-                file=sys.stderr,
-            )
+            print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr)
 
     def cal_discounting_constants(self):
         # For each order N of N-grams, we calculate discounting constant D_N = n1_N / (n1_N + 2 * n2_N),
@@ -194,11 +153,9 @@ class NgramCounts:
         # This constant is used similarly to absolute discounting.
         # Return value: d is a list of floats, where d[N+1] = D_N
 
-        self.d = [
-            0
-        ]  # for the lowest order, i.e., 1-gram, we do not need to discount, thus the constant is 0
-        # This is a special case: as we currently assumed having seen all vocabularies in the dictionary,
-        # but perhaps this is not the case for some other scenarios.
+        self.d = [0]  # for the lowest order, i.e., 1-gram, we do not need to discount, thus the constant is 0
+                      # This is a special case: as we currently assumed having seen all vocabularies in the dictionary,
+                      # but perhaps this is not the case for some other scenarios.
         for n in range(1, self.ngram_order):
             this_order_counts = self.counts[n]
             n1 = 0
@@ -208,11 +165,9 @@ class NgramCounts:
                 n1 += stat[1]
                 n2 += stat[2]
             assert n1 + 2 * n2 > 0
-            self.d.append(
-                max(0.1, n1 * 1.0) / (n1 + 2 * n2)
-            )  # We are doing this max(0.001, xxx) to avoid zero discounting constant D due to n1=0,
-            # which could happen if the number of symbols is small.
-            # Otherwise, zero discounting constant can cause division by zero in computing BOW.
+            self.d.append(max(0.1, n1 * 1.0) / (n1 + 2 * n2))   # We are doing this max(0.001, xxx) to avoid zero discounting constant D due to n1=0, 
+                                                                # which could happen if the number of symbols is small.
+                                                                # Otherwise, zero discounting constant can cause division by zero in computing BOW.
 
     def cal_f(self):
         # f(a_z) is a probability distribution of word sequence a_z.
@@ -227,9 +182,7 @@ class NgramCounts:
         this_order_counts = self.counts[n]
         for hist, counts_for_hist in this_order_counts.items():
             for w, c in counts_for_hist.word_to_count.items():
-                counts_for_hist.word_to_f[w] = (
-                    max((c - self.d[n]), 0) * 1.0 / counts_for_hist.total_count
-                )
+                counts_for_hist.word_to_f[w] = max((c - self.d[n]), 0) * 1.0 / counts_for_hist.total_count
 
         # lower order N-grams
         for n in range(0, self.ngram_order - 1):
@@ -243,17 +196,11 @@ class NgramCounts:
                 if n_star_star != 0:
                     for w in counts_for_hist.word_to_count.keys():
                         n_star_z = len(counts_for_hist.word_to_context[w])
-                        counts_for_hist.word_to_f[w] = (
-                            max((n_star_z - self.d[n]), 0) * 1.0 / n_star_star
-                        )
+                        counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / n_star_star
                 else:  # patterns begin with , they do not have "modified count", so use raw count instead
                     for w in counts_for_hist.word_to_count.keys():
                         n_star_z = counts_for_hist.word_to_count[w]
-                        counts_for_hist.word_to_f[w] = (
-                            max((n_star_z - self.d[n]), 0)
-                            * 1.0
-                            / counts_for_hist.total_count
-                        )
+                        counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / counts_for_hist.total_count
 
     def cal_bow(self):
         # Backoff weights are only necessary for ngrams which form a prefix of a longer ngram.
@@ -293,18 +240,12 @@ class NgramCounts:
                         sum_z1_f_z = 0
                         _ = a_[1:]
                         _counts_for_hist = self.counts[len(_)][_]
-                        for (
-                            u
-                        ) in (
-                            a_counts_for_hist.word_to_count.keys()
-                        ):  # Should be careful here: what is Z1
+                        for u in a_counts_for_hist.word_to_count.keys():  # Should be careful here: what is Z1
                             sum_z1_f_z += _counts_for_hist.word_to_f[u]
 
                         if sum_z1_f_z < 1:
                             # assert sum_z1_f_a_z < 1
-                            counts_for_hist.word_to_bow[w] = (1.0 - sum_z1_f_a_z) / (
-                                1.0 - sum_z1_f_z
-                            )
+                            counts_for_hist.word_to_bow[w] = (1.0 - sum_z1_f_a_z) / (1.0 - sum_z1_f_z)
                         else:
                             counts_for_hist.word_to_bow[w] = None
 
@@ -318,9 +259,7 @@ class NgramCounts:
                     ngram = " ".join(hist) + " " + w
                     ngram = ngram.strip(strip_chars)
 
-                    res.append(
-                        "{0}\t{1}".format(ngram, counts_for_hist.word_to_count[w])
-                    )
+                    res.append("{0}\t{1}".format(ngram, counts_for_hist.word_to_count[w]))
         res.sort(reverse=True)
         for r in res:
             print(r)
@@ -383,40 +322,27 @@ class NgramCounts:
                     if bow is None:
                         res.append("{1}\t{0}".format(ngram, math.log(f, 10)))
                     else:
-                        res.append(
-                            "{1}\t{0}\t{2}".format(
-                                ngram, math.log(f, 10), math.log(bow, 10)
-                            )
-                        )
+                        res.append("{1}\t{0}\t{2}".format(ngram, math.log(f, 10), math.log(bow, 10)))
         res.sort(reverse=True)
         for r in res:
             print(r)
 
-    def print_as_arpa(
-        self, fout=io.TextIOWrapper(sys.stdout.buffer, encoding="latin-1")
-    ):
+    def print_as_arpa(self, fout=io.TextIOWrapper(sys.stdout.buffer, encoding='latin-1')):
         # print as ARPA format.
 
-        print("\\data\\", file=fout)
+        print('\\data\\', file=fout)
         for hist_len in range(self.ngram_order):
             # print the number of n-grams.
-            print(
-                "ngram {0}={1}".format(
-                    hist_len + 1,
-                    sum(
-                        [
-                            len(counts_for_hist.word_to_f)
-                            for counts_for_hist in self.counts[hist_len].values()
-                        ]
-                    ),
-                ),
-                file=fout,
+            print('ngram {0}={1}'.format(
+                hist_len + 1,
+                sum([len(counts_for_hist.word_to_f) for counts_for_hist in self.counts[hist_len].values()])),
+                file=fout
             )
 
-        print("", file=fout)
+        print('', file=fout)
 
         for hist_len in range(self.ngram_order):
-            print("\\{0}-grams:".format(hist_len + 1), file=fout)
+            print('\\{0}-grams:'.format(hist_len + 1), file=fout)
 
             this_order_counts = self.counts[hist_len]
             for hist, counts_for_hist in this_order_counts.items():
@@ -428,12 +354,12 @@ class NgramCounts:
                     if prob == 0:  # f() is always 0
                         prob = 1e-99
 
-                    line = "{0}\t{1}".format("%.7f" % math.log10(prob), " ".join(ngram))
+                    line = '{0}\t{1}'.format('%.7f' % math.log10(prob), ' '.join(ngram))
                     if bow is not None:
-                        line += "\t{0}".format("%.7f" % math.log10(bow))
+                        line += '\t{0}'.format('%.7f' % math.log10(bow))
                     print(line, file=fout)
-            print("", file=fout)
-        print("\\end\\", file=fout)
+            print('', file=fout)
+        print('\\end\\', file=fout)
 
 
 if __name__ == "__main__":
@@ -453,5 +379,5 @@ if __name__ == "__main__":
     if args.lm is None:
         ngram_counts.print_as_arpa()
     else:
-        with open(args.lm, "w", encoding=default_encoding) as f:
+        with open(args.lm, 'w', encoding=default_encoding) as f:
             ngram_counts.print_as_arpa(fout=f)
diff --git a/icefall/utils.py b/icefall/utils.py
index 785bd80f9..143c79497 100644
--- a/icefall/utils.py
+++ b/icefall/utils.py
@@ -130,7 +130,9 @@ def setup_logger(
         formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s"  # noqa
         log_filename = f"{log_filename}-{date_time}-{rank}"
     else:
-        formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+        formatter = (
+            "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+        )
         log_filename = f"{log_filename}-{date_time}"
 
     os.makedirs(os.path.dirname(log_filename), exist_ok=True)
@@ -201,7 +203,7 @@ def encode_supervisions(
                 supervisions["num_frames"],
                 subsampling_factor,
                 rounding_mode="floor",
-            ),
+            )
         ),
         1,
     ).to(torch.int32)
@@ -286,9 +288,13 @@ def get_texts_with_timestamp(
     """
     if isinstance(best_paths.aux_labels, k2.RaggedTensor):
         all_aux_shape = (
-            best_paths.arcs.shape().remove_axis(1).compose(best_paths.aux_labels.shape)
+            best_paths.arcs.shape()
+            .remove_axis(1)
+            .compose(best_paths.aux_labels.shape)
+        )
+        all_aux_labels = k2.RaggedTensor(
+            all_aux_shape, best_paths.aux_labels.values
         )
-        all_aux_labels = k2.RaggedTensor(all_aux_shape, best_paths.aux_labels.values)
         # remove 0's and -1's.
         aux_labels = best_paths.aux_labels.remove_values_leq(0)
         # TODO: change arcs.shape() to arcs.shape
@@ -357,7 +363,9 @@ def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]:
     # arc.shape() has axes [fsa][state][arc], we remove "state"-axis here
     token_shape = best_paths.arcs.shape().remove_axis(1)
     # token_shape has axes [fsa][arc]
-    tokens = k2.RaggedTensor(token_shape, getattr(best_paths, kind).contiguous())
+    tokens = k2.RaggedTensor(
+        token_shape, getattr(best_paths, kind).contiguous()
+    )
     tokens = tokens.remove_values_eq(-1)
     return tokens.tolist()
 
@@ -578,7 +586,9 @@ def write_error_stats(
             f"{cut_id}:\t"
             + " ".join(
                 (
-                    ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
+                    ref_word
+                    if ref_word == hyp_word
+                    else f"({ref_word}->{hyp_word})"
                     for ref_word, hyp_word in ali
                 )
             ),
@@ -588,7 +598,9 @@ def write_error_stats(
     print("", file=f)
     print("SUBSTITUTIONS: count ref -> hyp", file=f)
 
-    for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
+    for count, (ref, hyp) in sorted(
+        [(v, k) for k, v in subs.items()], reverse=True
+    ):
         print(f"{count}   {ref} -> {hyp}", file=f)
 
     print("", file=f)
@@ -602,7 +614,9 @@ def write_error_stats(
         print(f"{count}   {hyp}", file=f)
 
     print("", file=f)
-    print("PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp", file=f)
+    print(
+        "PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp", file=f
+    )
     for _, word, counts in sorted(
         [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
     ):
@@ -777,7 +791,9 @@ def write_error_stats_with_timestamps(
             f"{cut_id}:\t"
             + " ".join(
                 (
-                    ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
+                    ref_word
+                    if ref_word == hyp_word
+                    else f"({ref_word}->{hyp_word})"
                     for ref_word, hyp_word in ali
                 )
             ),
@@ -787,7 +803,9 @@ def write_error_stats_with_timestamps(
     print("", file=f)
     print("SUBSTITUTIONS: count ref -> hyp", file=f)
 
-    for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
+    for count, (ref, hyp) in sorted(
+        [(v, k) for k, v in subs.items()], reverse=True
+    ):
         print(f"{count}   {ref} -> {hyp}", file=f)
 
     print("", file=f)
@@ -801,7 +819,9 @@ def write_error_stats_with_timestamps(
         print(f"{count}   {hyp}", file=f)
 
     print("", file=f)
-    print("PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp", file=f)
+    print(
+        "PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp", file=f
+    )
     for _, word, counts in sorted(
         [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
     ):
@@ -871,7 +891,9 @@ class MetricsTracker(collections.defaultdict):
             if k == "frames" or k == "utterances":
                 continue
             norm_value = (
-                float(v) / num_frames if "utt_" not in k else float(v) / num_utterances
+                float(v) / num_frames
+                if "utt_" not in k
+                else float(v) / num_utterances
             )
             ans.append((k, norm_value))
         return ans
@@ -905,7 +927,9 @@ class MetricsTracker(collections.defaultdict):
             tb_writer.add_scalar(prefix + k, v, batch_idx)
 
 
-def concat(ragged: k2.RaggedTensor, value: int, direction: str) -> k2.RaggedTensor:
+def concat(
+    ragged: k2.RaggedTensor, value: int, direction: str
+) -> k2.RaggedTensor:
     """Prepend a value to the beginning of each sublist or append a value.
     to the end of each sublist.
 
@@ -951,8 +975,8 @@ def concat(ragged: k2.RaggedTensor, value: int, direction: str) -> k2.RaggedTens
         ans = k2.ragged.cat([ragged, pad], axis=1)
     else:
         raise ValueError(
-            f'Unsupported direction: {direction}. "             "Expect either "left"'
-            ' or "right"'
+            f'Unsupported direction: {direction}. " \
+            "Expect either "left" or "right"'
         )
     return ans
 
@@ -1077,7 +1101,9 @@ def linf_norm(x):
     return torch.max(torch.abs(x))
 
 
-def measure_weight_norms(model: nn.Module, norm: str = "l2") -> Dict[str, float]:
+def measure_weight_norms(
+    model: nn.Module, norm: str = "l2"
+) -> Dict[str, float]:
     """
     Compute the norms of the model's parameters.
 
@@ -1100,7 +1126,9 @@ def measure_weight_norms(model: nn.Module, norm: str = "l2") -> Dict[str, float]
         return norms
 
 
-def measure_gradient_norms(model: nn.Module, norm: str = "l1") -> Dict[str, float]:
+def measure_gradient_norms(
+    model: nn.Module, norm: str = "l1"
+) -> Dict[str, float]:
     """
     Compute the norms of the gradients for each of model's parameters.
 
@@ -1385,7 +1413,9 @@ def parse_hyp_and_timestamp(
         use_word_table = True
 
     for i in range(N):
-        time = convert_timestamp(res.timestamps[i], subsampling_factor, frame_shift_ms)
+        time = convert_timestamp(
+            res.timestamps[i], subsampling_factor, frame_shift_ms
+        )
         if use_word_table:
             words = [word_table[i] for i in res.hyps[i]]
         else:
diff --git a/pyproject.toml b/pyproject.toml
index 3183055d4..b4f8c3377 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -3,7 +3,7 @@ profile = "black"
 skip = ["icefall/__init__.py"]
 
 [tool.black]
-line-length = 88
+line-length = 80
 exclude = '''
 /(
     \.git
diff --git a/setup.py b/setup.py
index ccd2503ff..6c720e121 100644
--- a/setup.py
+++ b/setup.py
@@ -1,8 +1,7 @@
 #!/usr/bin/env python3
 
-from pathlib import Path
-
 from setuptools import find_packages, setup
+from pathlib import Path
 
 icefall_dir = Path(__file__).parent
 install_requires = (icefall_dir / "requirements.txt").read_text().splitlines()
diff --git a/test/test_checkpoint.py b/test/test_checkpoint.py
index 34e829642..511a11c23 100644
--- a/test/test_checkpoint.py
+++ b/test/test_checkpoint.py
@@ -20,7 +20,11 @@ import pytest
 import torch
 import torch.nn as nn
 
-from icefall.checkpoint import average_checkpoints, load_checkpoint, save_checkpoint
+from icefall.checkpoint import (
+    average_checkpoints,
+    load_checkpoint,
+    save_checkpoint,
+)
 
 
 @pytest.fixture
diff --git a/test/test_decode.py b/test/test_decode.py
index 4c2e192a7..97964ac67 100644
--- a/test/test_decode.py
+++ b/test/test_decode.py
@@ -23,7 +23,6 @@ You can run this file in one of the two ways:
 """
 
 import k2
-
 from icefall.decode import Nbest
 
 
diff --git a/test/test_graph_compiler.py b/test/test_graph_compiler.py
index 10443cf22..ccfb57d49 100644
--- a/test/test_graph_compiler.py
+++ b/test/test_graph_compiler.py
@@ -154,7 +154,9 @@ class TestCtcTrainingGraphCompiler(object):
         fsas = k2.Fsa.from_fsas([fsa1, fsa2])
 
         decoding_graph = k2.arc_sort(decoding_graph)
-        lattice = k2.intersect(decoding_graph, fsas, treat_epsilons_specially=False)
+        lattice = k2.intersect(
+            decoding_graph, fsas, treat_epsilons_specially=False
+        )
         lattice = k2.connect(lattice)
 
         aux_labels0 = lattice[0].aux_labels[:-1]
diff --git a/test/test_utils.py b/test/test_utils.py
index 31f06bd51..6a9ce7853 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -50,7 +50,9 @@ def test_encode_supervisions(sup):
     assert torch.all(
         torch.eq(
             supervision_segments,
-            torch.tensor([[1, 0, 30 // 4], [0, 0, 20 // 4], [2, 9 // 4, 10 // 4]]),
+            torch.tensor(
+                [[1, 0, 30 // 4], [0, 0, 20 // 4], [2, 9 // 4, 10 // 4]]
+            ),
         )
     )
     assert texts == ["two", "one", "three"]

From 107df3b115a58f1b68a6458c3f94a130004be34c Mon Sep 17 00:00:00 2001
From: Desh Raj 
Date: Thu, 17 Nov 2022 09:42:17 -0500
Subject: [PATCH 038/120] apply black on all files

---
 .github/workflows/style_check.yml             |  11 +-
 .pre-commit-config.yaml                       |  28 +-
 docker/README.md                              |  24 +-
 .../Dockerfile                                |  14 +-
 .../Dockerfile                                |  17 +-
 .../images/k2-gt-v1.9-blueviolet.svg          |   2 +-
 .../images/python-gt-v3.6-blue.svg            |   2 +-
 .../images/torch-gt-v1.6.0-green.svg          |   2 +-
 docs/source/recipes/aishell/index.rst         |   1 -
 docs/source/recipes/timit/index.rst           |   1 -
 docs/source/recipes/timit/tdnn_ligru_ctc.rst  |  28 +-
 docs/source/recipes/timit/tdnn_lstm_ctc.rst   |  24 +-
 .../local/compute_fbank_aidatatang_200zh.py   |   8 +-
 .../ASR/local/prepare_char.py                 |   8 +-
 .../ASR/local/prepare_lang.py                 |   4 +-
 .../ASR/local/test_prepare_lang.py            |   4 +-
 egs/aidatatang_200zh/ASR/local/text2token.py  |  15 +-
 egs/aidatatang_200zh/ASR/prepare.sh           |   3 +-
 .../asr_datamodule.py                         |  20 +-
 .../pruned_transducer_stateless2/decode.py    |  25 +-
 .../pruned_transducer_stateless2/export.py    |   7 +-
 .../pretrained.py                             |  19 +-
 .../ASR/pruned_transducer_stateless2/train.py |  29 +-
 egs/aishell/ASR/conformer_ctc/conformer.py    |  67 +-
 egs/aishell/ASR/conformer_ctc/decode.py       |  16 +-
 egs/aishell/ASR/conformer_ctc/export.py       |   4 +-
 egs/aishell/ASR/conformer_ctc/pretrained.py   |  11 +-
 egs/aishell/ASR/conformer_ctc/subsampling.py  |  16 +-
 .../ASR/conformer_ctc/test_subsampling.py     |   3 +-
 egs/aishell/ASR/conformer_ctc/train.py        |  12 +-
 egs/aishell/ASR/conformer_ctc/transformer.py  |  44 +-
 egs/aishell/ASR/conformer_mmi/conformer.py    |  67 +-
 egs/aishell/ASR/conformer_mmi/decode.py       |  20 +-
 egs/aishell/ASR/conformer_mmi/subsampling.py  |  16 +-
 egs/aishell/ASR/conformer_mmi/train.py        |   8 +-
 egs/aishell/ASR/conformer_mmi/transformer.py  |  44 +-
 .../local/compute_fbank_aidatatang_200zh.py   |   8 +-
 .../ASR/local/compute_fbank_aishell.py        |   8 +-
 egs/aishell/ASR/local/prepare_char.py         |   8 +-
 egs/aishell/ASR/local/prepare_lang.py         |   4 +-
 egs/aishell/ASR/local/test_prepare_lang.py    |   4 +-
 .../pruned_transducer_stateless2/decode.py    |  36 +-
 .../pruned_transducer_stateless2/export.py    |  23 +-
 .../pretrained.py                             |  22 +-
 .../ASR/pruned_transducer_stateless2/train.py |  43 +-
 .../pruned_transducer_stateless3/decode.py    |  39 +-
 .../pruned_transducer_stateless3/export.py    |  26 +-
 .../ASR/pruned_transducer_stateless3/model.py |   8 +-
 .../pretrained.py                             |  22 +-
 .../ASR/pruned_transducer_stateless3/train.py |  58 +-
 .../ASR/tdnn_lstm_ctc/asr_datamodule.py       |  28 +-
 egs/aishell/ASR/tdnn_lstm_ctc/decode.py       |  20 +-
 egs/aishell/ASR/tdnn_lstm_ctc/model.py        |   5 +-
 egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py   |  15 +-
 egs/aishell/ASR/tdnn_lstm_ctc/train.py        |   7 +-
 .../ASR/transducer_stateless/beam_search.py   |  22 +-
 .../ASR/transducer_stateless/conformer.py     |  67 +-
 .../ASR/transducer_stateless/decode.py        |  26 +-
 .../ASR/transducer_stateless/decoder.py       |   4 +-
 .../ASR/transducer_stateless/export.py        |   7 +-
 egs/aishell/ASR/transducer_stateless/model.py |   4 +-
 .../ASR/transducer_stateless/pretrained.py    |  14 +-
 egs/aishell/ASR/transducer_stateless/train.py |  15 +-
 .../ASR/transducer_stateless/transformer.py   |   4 +-
 .../asr_datamodule.py                         |  17 +-
 .../transducer_stateless_modified-2/decode.py |  27 +-
 .../transducer_stateless_modified-2/export.py |   7 +-
 .../pretrained.py                             |  22 +-
 .../transducer_stateless_modified-2/train.py  |  22 +-
 .../transducer_stateless_modified/decode.py   |  27 +-
 .../transducer_stateless_modified/export.py   |   7 +-
 .../pretrained.py                             |  22 +-
 .../transducer_stateless_modified/train.py    |  15 +-
 egs/aishell2/ASR/local/__init__.py            |   0
 .../ASR/local/compute_fbank_aishell2.py       |   8 +-
 .../pruned_transducer_stateless5/__init__.py  |   0
 .../asr_datamodule.py                         |  24 +-
 .../pruned_transducer_stateless5/decode.py    |  39 +-
 .../pruned_transducer_stateless5/export.py    |  19 +-
 .../pretrained.py                             |  18 +-
 .../ASR/pruned_transducer_stateless5/train.py |  46 +-
 .../ASR/local/compute_fbank_aishell4.py       |   8 +-
 egs/aishell4/ASR/local/prepare_char.py        |   8 +-
 egs/aishell4/ASR/local/prepare_lang.py        |   4 +-
 egs/aishell4/ASR/local/test_prepare_lang.py   |   4 +-
 egs/aishell4/ASR/local/text2token.py          |  15 +-
 .../asr_datamodule.py                         |  20 +-
 .../pruned_transducer_stateless5/decode.py    |  35 +-
 .../pruned_transducer_stateless5/export.py    |  19 +-
 .../pretrained.py                             |  23 +-
 .../ASR/pruned_transducer_stateless5/train.py |  38 +-
 .../ASR/local/compute_fbank_alimeeting.py     |   8 +-
 egs/alimeeting/ASR/local/prepare_char.py      |   8 +-
 egs/alimeeting/ASR/local/prepare_lang.py      |   4 +-
 egs/alimeeting/ASR/local/test_prepare_lang.py |   4 +-
 egs/alimeeting/ASR/local/text2segments.py     |   2 +-
 egs/alimeeting/ASR/local/text2token.py        |  15 +-
 .../asr_datamodule.py                         |  20 +-
 .../pruned_transducer_stateless2/decode.py    |  35 +-
 .../pruned_transducer_stateless2/export.py    |   7 +-
 .../pretrained.py                             |  19 +-
 .../ASR/pruned_transducer_stateless2/train.py |  29 +-
 egs/csj/ASR/.gitignore                        |   2 +-
 egs/csj/ASR/local/compute_fbank_csj.py        |  38 +-
 egs/csj/ASR/local/compute_fbank_musan.py      |  17 +-
 egs/csj/ASR/local/conf/disfluent.ini          |  55 +-
 egs/csj/ASR/local/conf/fluent.ini             |  55 +-
 egs/csj/ASR/local/conf/number.ini             |  55 +-
 egs/csj/ASR/local/conf/symbol.ini             |  55 +-
 .../ASR/local/display_manifest_statistics.py  |   4 +-
 egs/csj/ASR/local/prepare_lang_char.py        |  14 +-
 egs/csj/ASR/local/validate_manifest.py        |   4 +-
 .../ASR/conformer_ctc/asr_datamodule.py       |  27 +-
 egs/gigaspeech/ASR/conformer_ctc/conformer.py |  63 +-
 egs/gigaspeech/ASR/conformer_ctc/decode.py    |  16 +-
 .../ASR/conformer_ctc/label_smoothing.py      |   7 +-
 .../ASR/conformer_ctc/subsampling.py          |  16 +-
 egs/gigaspeech/ASR/conformer_ctc/train.py     |  12 +-
 .../ASR/conformer_ctc/transformer.py          |  49 +-
 .../compute_fbank_gigaspeech_dev_test.py      |   4 +-
 .../local/compute_fbank_gigaspeech_splits.py  |   4 +-
 .../ASR/local/preprocess_gigaspeech.py        |  10 +-
 .../asr_datamodule.py                         |  27 +-
 .../pruned_transducer_stateless2/decode.py    |  28 +-
 .../pruned_transducer_stateless2/export.py    |  16 +-
 .../ASR/pruned_transducer_stateless2/train.py |  27 +-
 egs/librispeech/ASR/conformer_ctc/ali.py      |  12 +-
 .../ASR/conformer_ctc/conformer.py            |  63 +-
 egs/librispeech/ASR/conformer_ctc/decode.py   |  16 +-
 egs/librispeech/ASR/conformer_ctc/export.py   |   4 +-
 .../ASR/conformer_ctc/label_smoothing.py      |   7 +-
 .../ASR/conformer_ctc/pretrained.py           |  11 +-
 .../ASR/conformer_ctc/subsampling.py          |  16 +-
 egs/librispeech/ASR/conformer_ctc/train.py    |  20 +-
 .../ASR/conformer_ctc/transformer.py          |  49 +-
 .../ASR/conformer_ctc2/attention.py           |  19 +-
 .../ASR/conformer_ctc2/conformer.py           |  62 +-
 egs/librispeech/ASR/conformer_ctc2/decode.py  |  28 +-
 egs/librispeech/ASR/conformer_ctc2/export.py  |  21 +-
 egs/librispeech/ASR/conformer_ctc2/train.py   |  34 +-
 .../ASR/conformer_ctc2/transformer.py         |  46 +-
 .../ASR/conformer_mmi/conformer.py            |  67 +-
 egs/librispeech/ASR/conformer_mmi/decode.py   |  16 +-
 .../ASR/conformer_mmi/subsampling.py          |  16 +-
 .../ASR/conformer_mmi/test_subsampling.py     |   3 +-
 .../ASR/conformer_mmi/test_transformer.py     |   9 +-
 .../ASR/conformer_mmi/train-with-attention.py |  27 +-
 egs/librispeech/ASR/conformer_mmi/train.py    |  27 +-
 .../ASR/conformer_mmi/transformer.py          |  28 +-
 .../decode.py                                 |  35 +-
 .../emformer.py                               | 119 +---
 .../export.py                                 |  19 +-
 .../stream.py                                 |   8 +-
 .../streaming_decode.py                       |  42 +-
 .../train.py                                  |  35 +-
 .../decode.py                                 |  35 +-
 .../emformer.py                               | 108 +--
 .../export.py                                 |  19 +-
 .../streaming_decode.py                       |  42 +-
 .../train.py                                  |  35 +-
 .../ASR/local/add_alignment_librispeech.py    |  12 +-
 egs/librispeech/ASR/local/compile_hlg.py      |   6 +-
 egs/librispeech/ASR/local/compile_lg.py       |   4 +-
 .../compute_fbank_gigaspeech_dev_test.py      |   4 +-
 .../local/compute_fbank_gigaspeech_splits.py  |   4 +-
 .../ASR/local/compute_fbank_librispeech.py    |   8 +-
 .../ASR/local/compute_fbank_musan.py          |   8 +-
 .../convert_transcript_words_to_tokens.py     |   8 +-
 egs/librispeech/ASR/local/download_lm.py      |   4 +-
 egs/librispeech/ASR/local/filter_cuts.py      |  10 +-
 .../ASR/local/generate_unique_lexicon.py      |   4 +-
 egs/librispeech/ASR/local/prepare_lang_bpe.py |   4 +-
 .../ASR/local/prepare_lm_training_data.py     |  11 +-
 .../ASR/local/preprocess_gigaspeech.py        |   4 +-
 .../ASR/local/test_prepare_lang.py            |   4 +-
 .../ASR/local/validate_manifest.py            |   4 +-
 .../ASR/lstm_transducer_stateless/decode.py   |  39 +-
 .../ASR/lstm_transducer_stateless/export.py   |  19 +-
 .../jit_pretrained.py                         |   7 +-
 .../ASR/lstm_transducer_stateless/lstm.py     |  14 +-
 .../ASR/lstm_transducer_stateless/model.py    |   8 +-
 .../lstm_transducer_stateless/pretrained.py   |  18 +-
 .../ASR/lstm_transducer_stateless/stream.py   |   8 +-
 .../streaming_decode.py                       |  41 +-
 .../ASR/lstm_transducer_stateless/train.py    |  40 +-
 .../ASR/lstm_transducer_stateless2/decode.py  |  39 +-
 .../ASR/lstm_transducer_stateless2/export.py  |  31 +-
 .../jit_pretrained.py                         |   7 +-
 .../ASR/lstm_transducer_stateless2/model.py   |   8 +-
 .../lstm_transducer_stateless2/ncnn-decode.py |  11 +-
 .../lstm_transducer_stateless2/pretrained.py  |  18 +-
 .../streaming-ncnn-decode.py                  |  23 +-
 .../streaming-onnx-decode.py                  |  31 +-
 .../ASR/lstm_transducer_stateless2/train.py   |  47 +-
 .../ASR/lstm_transducer_stateless3/decode.py  |  51 +-
 .../ASR/lstm_transducer_stateless3/export.py  |  19 +-
 .../jit_pretrained.py                         |   7 +-
 .../ASR/lstm_transducer_stateless3/lstm.py    |  14 +-
 .../lstm_transducer_stateless3/pretrained.py  |  18 +-
 .../streaming_decode.py                       |  41 +-
 .../ASR/lstm_transducer_stateless3/train.py   |  45 +-
 .../ASR/pruned2_knowledge/asr_datamodule.py   |  35 +-
 .../ASR/pruned2_knowledge/beam_search.py      |  18 +-
 .../ASR/pruned2_knowledge/conformer.py        |  82 +--
 .../ASR/pruned2_knowledge/decode.py           |  25 +-
 .../ASR/pruned2_knowledge/decoder.py          |   4 +-
 .../ASR/pruned2_knowledge/decoder2.py         |  81 ++-
 .../ASR/pruned2_knowledge/export.py           |   7 +-
 .../ASR/pruned2_knowledge/joiner.py           |   4 +-
 .../ASR/pruned2_knowledge/model.py            |   8 +-
 .../ASR/pruned2_knowledge/optim.py            |  35 +-
 .../ASR/pruned2_knowledge/sampling.py         | 181 ++---
 .../ASR/pruned2_knowledge/scaling.py          |  51 +-
 .../ASR/pruned2_knowledge/scaling_tmp.py      | 355 ++++++----
 .../ASR/pruned2_knowledge/train.py            |  29 +-
 .../pruned_stateless_emformer_rnnt2/decode.py |  35 +-
 .../emformer.py                               |   8 +-
 .../pruned_stateless_emformer_rnnt2/export.py |  19 +-
 .../pruned_stateless_emformer_rnnt2/model.py  |   4 +-
 .../pruned_stateless_emformer_rnnt2/train.py  |  23 +-
 .../beam_search.py                            |  26 +-
 .../ASR/pruned_transducer_stateless/decode.py |  36 +-
 .../decode_stream.py                          |  19 +-
 .../pruned_transducer_stateless/decoder.py    |   4 +-
 .../ASR/pruned_transducer_stateless/export.py |   7 +-
 .../ASR/pruned_transducer_stateless/model.py  |   4 +-
 .../pruned_transducer_stateless/pretrained.py |  14 +-
 .../streaming_beam_search.py                  |   8 +-
 .../streaming_decode.py                       |  31 +-
 .../ASR/pruned_transducer_stateless/train.py  |  25 +-
 .../beam_search.py                            |  51 +-
 .../pruned_transducer_stateless2/conformer.py |  94 +--
 .../pruned_transducer_stateless2/decode.py    |  36 +-
 .../pruned_transducer_stateless2/decoder.py   |   8 +-
 .../pruned_transducer_stateless2/export.py    |  16 +-
 .../pruned_transducer_stateless2/joiner.py    |   4 +-
 .../ASR/pruned_transducer_stateless2/model.py |   8 +-
 .../ASR/pruned_transducer_stateless2/optim.py |  35 +-
 .../pretrained.py                             |  14 +-
 .../pruned_transducer_stateless2/scaling.py   |  53 +-
 .../streaming_beam_search.py                  |  12 +-
 .../streaming_decode.py                       |  31 +-
 .../ASR/pruned_transducer_stateless2/train.py |  37 +-
 .../asr_datamodule.py                         |  17 +-
 .../decode-giga.py                            |  32 +-
 .../pruned_transducer_stateless3/decode.py    |  50 +-
 .../pruned_transducer_stateless3/export.py    |  24 +-
 .../gigaspeech.py                             |   8 +-
 .../jit_pretrained.py                         |   7 +-
 .../ASR/pruned_transducer_stateless3/model.py |   8 +-
 .../onnx_check.py                             |  24 +-
 .../onnx_pretrained.py                        |  13 +-
 .../pretrained.py                             |  14 +-
 .../scaling_converter.py                      |   4 +-
 .../streaming_decode.py                       |  31 +-
 .../pruned_transducer_stateless3/test_onnx.py |  24 +-
 .../ASR/pruned_transducer_stateless3/train.py |  44 +-
 .../pruned_transducer_stateless4/decode.py    |  51 +-
 .../pruned_transducer_stateless4/export.py    |  19 +-
 .../streaming_decode.py                       |  34 +-
 .../ASR/pruned_transducer_stateless4/train.py |  40 +-
 .../pruned_transducer_stateless5/conformer.py | 112 +---
 .../pruned_transducer_stateless5/decode.py    |  39 +-
 .../pruned_transducer_stateless5/export.py    |  19 +-
 .../pretrained.py                             |  18 +-
 .../streaming_decode.py                       |  34 +-
 .../ASR/pruned_transducer_stateless5/train.py |  45 +-
 .../pruned_transducer_stateless6/conformer.py |  64 +-
 .../pruned_transducer_stateless6/decode.py    |  35 +-
 .../pruned_transducer_stateless6/export.py    |  16 +-
 .../extract_codebook_index.py                 |   3 +-
 .../hubert_decode.py                          |  17 +-
 .../hubert_xlarge.py                          |  22 +-
 .../ASR/pruned_transducer_stateless6/model.py |  12 +-
 .../ASR/pruned_transducer_stateless6/train.py |  44 +-
 .../pruned_transducer_stateless6/vq_utils.py  |  28 +-
 .../pruned_transducer_stateless7/decode.py    |  39 +-
 .../pruned_transducer_stateless7/decoder.py   |   6 +-
 .../pruned_transducer_stateless7/export.py    |  19 +-
 .../jit_pretrained.py                         |   7 +-
 .../pruned_transducer_stateless7/joiner.py    |   4 +-
 .../ASR/pruned_transducer_stateless7/model.py |  16 +-
 .../ASR/pruned_transducer_stateless7/optim.py | 436 ++++++------
 .../pretrained.py                             |  18 +-
 .../pruned_transducer_stateless7/scaling.py   | 481 +++++++-------
 .../scaling_converter.py                      |   6 +-
 .../ASR/pruned_transducer_stateless7/train.py |  48 +-
 .../pruned_transducer_stateless7/zipformer.py | 625 +++++++++---------
 .../pruned_transducer_stateless8/decode.py    |  39 +-
 .../pruned_transducer_stateless8/export.py    |  19 +-
 .../jit_pretrained.py                         |   7 +-
 .../ASR/pruned_transducer_stateless8/model.py |   4 +-
 .../pretrained.py                             |  18 +-
 .../ASR/pruned_transducer_stateless8/train.py |  59 +-
 .../ASR/streaming_conformer_ctc/README.md     |  16 +-
 .../ASR/streaming_conformer_ctc/conformer.py  | 113 +---
 .../streaming_decode.py                       |  34 +-
 .../ASR/streaming_conformer_ctc/train.py      |  16 +-
 .../streaming_conformer_ctc/transformer.py    |  40 +-
 .../ASR/tdnn_lstm_ctc/asr_datamodule.py       |  23 +-
 egs/librispeech/ASR/tdnn_lstm_ctc/decode.py   |  16 +-
 egs/librispeech/ASR/tdnn_lstm_ctc/model.py    |   5 +-
 .../ASR/tdnn_lstm_ctc/pretrained.py           |  21 +-
 egs/librispeech/ASR/tdnn_lstm_ctc/train.py    |   8 +-
 egs/librispeech/ASR/transducer/beam_search.py |  14 +-
 egs/librispeech/ASR/transducer/decode.py      |  15 +-
 egs/librispeech/ASR/transducer/export.py      |   4 +-
 egs/librispeech/ASR/transducer/pretrained.py  |  11 +-
 egs/librispeech/ASR/transducer/rnn.py         |  24 +-
 egs/librispeech/ASR/transducer/test_rnn.py    |  16 +-
 egs/librispeech/ASR/transducer/train.py       |  12 +-
 .../ASR/transducer_lstm/beam_search.py        |  14 +-
 egs/librispeech/ASR/transducer_lstm/decode.py |  15 +-
 .../ASR/transducer_lstm/encoder.py            |   4 +-
 egs/librispeech/ASR/transducer_lstm/train.py  |  12 +-
 .../ASR/transducer_stateless/alignment.py     |   4 +-
 .../ASR/transducer_stateless/beam_search.py   |  28 +-
 .../ASR/transducer_stateless/compute_ali.py   |  11 +-
 .../ASR/transducer_stateless/conformer.py     | 104 +--
 .../ASR/transducer_stateless/decode.py        |  23 +-
 .../ASR/transducer_stateless/decoder.py       |   4 +-
 .../ASR/transducer_stateless/export.py        |   7 +-
 .../ASR/transducer_stateless/joiner.py        |   8 +-
 .../ASR/transducer_stateless/pretrained.py    |  14 +-
 .../transducer_stateless/test_compute_ali.py  |  11 +-
 .../transducer_stateless/test_conformer.py    |   4 +-
 .../ASR/transducer_stateless/train.py         |  23 +-
 .../ASR/transducer_stateless/transformer.py   |   4 +-
 .../ASR/transducer_stateless2/decode.py       |  23 +-
 .../ASR/transducer_stateless2/export.py       |   7 +-
 .../ASR/transducer_stateless2/pretrained.py   |  14 +-
 .../ASR/transducer_stateless2/train.py        |  23 +-
 .../decode.py                                 |  23 +-
 .../export.py                                 |   7 +-
 .../pretrained.py                             |  14 +-
 .../test_asr_datamodule.py                    |   4 +-
 .../train.py                                  |  22 +-
 egs/ptb/LM/local/sort_lm_training_data.py     |   4 +-
 .../LM/local/test_prepare_lm_training_data.py |   4 +-
 .../ASR/local/compute_fbank_musan.py          |   8 +-
 .../ASR/local/compute_fbank_spgispeech.py     |  14 +-
 egs/spgispeech/ASR/local/prepare_splits.py    |   8 +-
 .../asr_datamodule.py                         |  24 +-
 .../pruned_transducer_stateless2/decode.py    |  52 +-
 .../pruned_transducer_stateless2/export.py    |  13 +-
 .../ASR/pruned_transducer_stateless2/train.py |  30 +-
 .../ASR/local/compute_fbank_tal_csasr.py      |   8 +-
 egs/tal_csasr/ASR/local/prepare_char.py       |   4 +-
 egs/tal_csasr/ASR/local/prepare_lang.py       |   4 +-
 egs/tal_csasr/ASR/local/test_prepare_lang.py  |   4 +-
 egs/tal_csasr/ASR/local/text2token.py         |  15 +-
 .../asr_datamodule.py                         |  20 +-
 .../pruned_transducer_stateless5/decode.py    |  39 +-
 .../pruned_transducer_stateless5/export.py    |  19 +-
 .../pretrained.py                             |  18 +-
 .../ASR/pruned_transducer_stateless5/train.py |  38 +-
 .../ASR/local/compute_fbank_tedlium.py        |   8 +-
 .../convert_transcript_words_to_bpe_ids.py    |   4 +-
 egs/tedlium3/ASR/local/prepare_lexicon.py     |  11 +-
 egs/tedlium3/ASR/local/prepare_transcripts.py |  11 +-
 .../ASR/pruned_transducer_stateless/decode.py |  19 +-
 .../ASR/pruned_transducer_stateless/export.py |   7 +-
 .../pruned_transducer_stateless/pretrained.py |  19 +-
 .../ASR/pruned_transducer_stateless/train.py  |  14 +-
 .../transducer_stateless/asr_datamodule.py    |  37 +-
 .../ASR/transducer_stateless/beam_search.py   |  30 +-
 .../ASR/transducer_stateless/decode.py        |  18 +-
 .../ASR/transducer_stateless/decoder.py       |   4 +-
 .../ASR/transducer_stateless/export.py        |   7 +-
 .../ASR/transducer_stateless/pretrained.py    |  14 +-
 .../ASR/transducer_stateless/train.py         |  11 +-
 egs/timit/ASR/RESULTS.md                      |   2 +-
 egs/timit/ASR/local/compile_hlg.py            |   4 +-
 egs/timit/ASR/local/compute_fbank_timit.py    |   8 +-
 egs/timit/ASR/local/prepare_lexicon.py        |   8 +-
 egs/timit/ASR/prepare.sh                      |   4 +-
 egs/timit/ASR/tdnn_ligru_ctc/decode.py        |  16 +-
 egs/timit/ASR/tdnn_ligru_ctc/model.py         |  12 +-
 egs/timit/ASR/tdnn_ligru_ctc/pretrained.py    |  21 +-
 egs/timit/ASR/tdnn_ligru_ctc/train.py         |   4 +-
 egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py |  30 +-
 egs/timit/ASR/tdnn_lstm_ctc/decode.py         |  16 +-
 egs/timit/ASR/tdnn_lstm_ctc/model.py          |   5 +-
 egs/timit/ASR/tdnn_lstm_ctc/pretrained.py     |  21 +-
 egs/timit/ASR/tdnn_lstm_ctc/train.py          |   4 +-
 .../compute_fbank_wenetspeech_dev_test.py     |  11 +-
 .../local/compute_fbank_wenetspeech_splits.py |   4 +-
 egs/wenetspeech/ASR/local/prepare_char.py     |   8 +-
 .../ASR/local/preprocess_wenetspeech.py       |   6 +-
 egs/wenetspeech/ASR/local/text2token.py       |  15 +-
 egs/wenetspeech/ASR/prepare.sh                |   2 +-
 .../asr_datamodule.py                         |  31 +-
 .../pruned_transducer_stateless2/decode.py    |  39 +-
 .../pruned_transducer_stateless2/export.py    |  15 +-
 .../jit_pretrained.py                         |   7 +-
 .../onnx_check.py                             |  24 +-
 .../onnx_pretrained.py                        |  13 +-
 .../pretrained.py                             |  19 +-
 .../ASR/pruned_transducer_stateless2/train.py |  29 +-
 .../pruned_transducer_stateless5/conformer.py |  94 +--
 .../pruned_transducer_stateless5/decode.py    |  41 +-
 .../decode_stream.py                          |  19 +-
 .../pruned_transducer_stateless5/export.py    |   7 +-
 .../pretrained.py                             |  19 +-
 .../streaming_beam_search.py                  |   8 +-
 .../streaming_decode.py                       |  34 +-
 .../ASR/pruned_transducer_stateless5/train.py |  46 +-
 egs/yesno/ASR/local/compile_hlg.py            |   4 +-
 egs/yesno/ASR/local/compute_fbank_yesno.py    |  12 +-
 egs/yesno/ASR/tdnn/decode.py                  |  16 +-
 egs/yesno/ASR/tdnn/pretrained.py              |  15 +-
 egs/yesno/ASR/tdnn/train.py                   |   4 +-
 egs/yesno/ASR/transducer/decode.py            |  12 +-
 egs/yesno/ASR/transducer/train.py             |   4 +-
 icefall/char_graph_compiler.py                |   8 +-
 icefall/checkpoint.py                         |  12 +-
 icefall/decode.py                             |  40 +-
 icefall/diagnostics.py                        |  74 +--
 icefall/dist.py                               |   4 +-
 icefall/env.py                                |   4 +-
 icefall/graph_compiler.py                     |   4 +-
 icefall/hooks.py                              |  19 +-
 icefall/lexicon.py                            |  16 +-
 icefall/mmi.py                                |  29 +-
 icefall/mmi_graph_compiler.py                 |   8 +-
 icefall/rnn_lm/dataset.py                     |   8 +-
 icefall/rnn_lm/export.py                      |   4 +-
 icefall/rnn_lm/model.py                       |  28 +-
 icefall/rnn_lm/train.py                       |   8 +-
 icefall/shared/make_kn_lm.py                  | 177 +++--
 icefall/utils.py                              |  62 +-
 pyproject.toml                                |   2 +-
 setup.py                                      |   3 +-
 test/test_checkpoint.py                       |   6 +-
 test/test_decode.py                           |   1 +
 test/test_graph_compiler.py                   |   4 +-
 test/test_utils.py                            |   4 +-
 437 files changed, 3861 insertions(+), 7334 deletions(-)
 mode change 100755 => 100644 egs/aishell2/ASR/local/__init__.py
 mode change 100755 => 100644 egs/aishell2/ASR/pruned_transducer_stateless5/__init__.py
 mode change 100755 => 100644 egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py

diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml
index 90459bc1c..45d261ccc 100644
--- a/.github/workflows/style_check.yml
+++ b/.github/workflows/style_check.yml
@@ -45,17 +45,18 @@ jobs:
 
       - name: Install Python dependencies
         run: |
-          python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 click==8.0.4
-          # See https://github.com/psf/black/issues/2964
-          # The version of click should be selected from 8.0.0, 8.0.1, 8.0.2, 8.0.3, and 8.0.4
+          python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0
+          # Click issue fixed in https://github.com/psf/black/pull/2966
 
       - name: Run flake8
         shell: bash
         working-directory: ${{github.workspace}}
         run: |
           # stop the build if there are Python syntax errors or undefined names
-          flake8 . --count --show-source --statistics
-          flake8 .
+          flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
+          # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
+          flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 \
+            --statistics --extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503
 
       - name: Run black
         shell: bash
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 446ba0fe7..5cb213327 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,26 +1,38 @@
 repos:
   - repo: https://github.com/psf/black
-    rev: 21.6b0
+    rev: 22.3.0
     hooks:
       - id: black
-        args: [--line-length=80]
-        additional_dependencies: ['click==8.0.1']
+        args: ["--line-length=88"]
+        additional_dependencies: ['click==8.1.0']
         exclude: icefall\/__init__\.py
 
   - repo: https://github.com/PyCQA/flake8
-    rev: 3.9.2
+    rev: 5.0.4
     hooks:
       - id: flake8
-        args: [--max-line-length=80]
+        args: ["--max-line-length=88", "--extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503"]
+
+      # What are we ignoring here?
+      # E203: whitespace before ':'
+      # E266: too many leading '#' for block comment
+      # E501: line too long
+      # F401: module imported but unused
+      # E402: module level import not at top of file
+      # F403: 'from module import *' used; unable to detect undefined names
+      # F841: local variable is assigned to but never used
+      # W503: line break before binary operator
+      # In addition, the default ignore list is:
+      # E121,E123,E126,E226,E24,E704,W503,W504
 
   - repo: https://github.com/pycqa/isort
-    rev: 5.9.2
+    rev: 5.10.1
     hooks:
       - id: isort
-        args: [--profile=black, --line-length=80]
+        args: ["--profile=black"]
 
   - repo: https://github.com/pre-commit/pre-commit-hooks
-    rev: v4.0.1
+    rev: v4.2.0
     hooks:
       - id: check-executables-have-shebangs
       - id: end-of-file-fixer
diff --git a/docker/README.md b/docker/README.md
index 6f2314e96..c14b9bf75 100644
--- a/docker/README.md
+++ b/docker/README.md
@@ -2,7 +2,7 @@
 
 2 sets of configuration are provided - (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8, and (b) Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8.
 
-If your NVIDIA driver supports CUDA Version: 11.3, please go for case (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8. 
+If your NVIDIA driver supports CUDA Version: 11.3, please go for case (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8.
 
 Otherwise, since the older PyTorch images are not updated with the [apt-key rotation by NVIDIA](https://developer.nvidia.com/blog/updating-the-cuda-linux-gpg-repository-key), you have to go for case (b) Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8. Ensure that your NVDIA driver supports at least CUDA 11.0.
 
@@ -10,7 +10,7 @@ You can check the highest CUDA version within your NVIDIA driver's support with
 
 ```bash
 $ nvidia-smi
-Tue Sep 20 00:26:13 2022       
+Tue Sep 20 00:26:13 2022
 +-----------------------------------------------------------------------------+
 | NVIDIA-SMI 450.119.03   Driver Version: 450.119.03   CUDA Version: 11.0     |
 |-------------------------------+----------------------+----------------------+
@@ -26,7 +26,7 @@ Tue Sep 20 00:26:13 2022
 | 41%   30C    P8    11W / 280W |      6MiB / 24220MiB |      0%      Default |
 |                               |                      |                  N/A |
 +-------------------------------+----------------------+----------------------+
-                                                                               
+
 +-----------------------------------------------------------------------------+
 | Processes:                                                                  |
 |  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
@@ -40,15 +40,15 @@ Tue Sep 20 00:26:13 2022
 ```
 
 ## Building images locally
-If your environment requires a proxy to access the Internet, remember to add those information into the Dockerfile directly. 
-For most cases, you can uncomment these lines in the Dockerfile and add in your proxy details. 
+If your environment requires a proxy to access the Internet, remember to add those information into the Dockerfile directly.
+For most cases, you can uncomment these lines in the Dockerfile and add in your proxy details.
 
 ```dockerfile
 ENV http_proxy=http://aaa.bb.cc.net:8080 \
     https_proxy=http://aaa.bb.cc.net:8080
 ```
 
-Then, proceed with these commands. 
+Then, proceed with these commands.
 
 ### If you are case (a), i.e. your NVIDIA driver supports CUDA version >= 11.3:
 
@@ -72,11 +72,11 @@ docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all icefall
 ```
 
 ### Tips:
-1. Since your data and models most probably won't be in the docker, you must use the -v flag to access the host machine. Do this by specifying `-v {/path/in/host/machine}:{/path/in/docker}`. 
+1. Since your data and models most probably won't be in the docker, you must use the -v flag to access the host machine. Do this by specifying `-v {/path/in/host/machine}:{/path/in/docker}`.
 
 2. Also, if your environment requires a proxy, this would be a good time to add it in too: `-e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080`.
 
-Overall, your docker run command should look like this. 
+Overall, your docker run command should look like this.
 
 ```bash
 docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all -v {/path/in/host/machine}:{/path/in/docker} -e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080 icefall/pytorch1.12.1
@@ -86,9 +86,9 @@ You can explore more docker run options [here](https://docs.docker.com/engine/re
 
 ### Linking to icefall in your host machine
 
-If you already have icefall downloaded onto your host machine, you can use that repository instead so that changes in your code are visible inside and outside of the container. 
+If you already have icefall downloaded onto your host machine, you can use that repository instead so that changes in your code are visible inside and outside of the container.
 
-Note: Remember to set the -v flag above during the first run of the container, as that is the only way for your container to access your host machine. 
+Note: Remember to set the -v flag above during the first run of the container, as that is the only way for your container to access your host machine.
 Warning: Check that the icefall in your host machine is visible from within your container before proceeding to the commands below.
 
 Use these commands once you are inside the container.
@@ -103,7 +103,7 @@ ln -s {/path/in/docker/to/icefall} /workspace/icefall
 docker exec -it icefall /bin/bash
 ```
 
-## Restarting a killed container that has been run before. 
+## Restarting a killed container that has been run before.
 ```bash
 docker start -ai icefall
 ```
@@ -111,4 +111,4 @@ docker start -ai icefall
 ## Sample usage of the CPU based images:
 ```bash
 docker run -it icefall /bin/bash
-``` 
+```
diff --git a/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile b/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile
index 3637d2f11..ff9e40604 100644
--- a/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile
+++ b/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile
@@ -1,7 +1,7 @@
 FROM pytorch/pytorch:1.12.1-cuda11.3-cudnn8-devel
 
 # ENV http_proxy=http://aaa.bbb.cc.net:8080 \
-#	https_proxy=http://aaa.bbb.cc.net:8080 
+#	https_proxy=http://aaa.bbb.cc.net:8080
 
 # install normal source
 RUN apt-get update && \
@@ -38,10 +38,10 @@ RUN wget -P /opt https://cmake.org/files/v3.18/cmake-3.18.0.tar.gz && \
     rm -rf cmake-3.18.0.tar.gz && \
     find /opt/cmake-3.18.0 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
     cd -
-	
-# flac 
+
+# flac
 RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz  && \
-    cd /opt && \ 
+    cd /opt && \
     xz -d flac-1.3.2.tar.xz && \
     tar -xvf flac-1.3.2.tar && \
     cd flac-1.3.2 && \
@@ -49,11 +49,11 @@ RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz  &&
     make && make install && \
     rm -rf flac-1.3.2.tar && \
     find /opt/flac-1.3.2  -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
-    cd - 
+    cd -
 
 RUN conda install -y -c pytorch torchaudio=0.12 && \
     pip install graphviz
-	
+
 
 #install k2 from source
 RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \
@@ -68,7 +68,7 @@ RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
 	cd /workspace/icefall && \
 	pip install -r requirements.txt
 
-RUN pip install kaldifeat 
+RUN pip install kaldifeat
 ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
 
 WORKDIR /workspace/icefall
diff --git a/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile b/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile
index 17a8215f9..5c7423fa5 100644
--- a/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile
+++ b/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile
@@ -1,12 +1,12 @@
 FROM pytorch/pytorch:1.7.1-cuda11.0-cudnn8-devel
 
 # ENV http_proxy=http://aaa.bbb.cc.net:8080 \
-#	https_proxy=http://aaa.bbb.cc.net:8080 
+#	https_proxy=http://aaa.bbb.cc.net:8080
 
 RUN rm /etc/apt/sources.list.d/cuda.list && \
 	rm /etc/apt/sources.list.d/nvidia-ml.list && \
 	apt-key del 7fa2af80
-	
+
 # install normal source
 RUN apt-get update && \
     apt-get install -y --no-install-recommends \
@@ -36,7 +36,7 @@ RUN curl -fsSL https://developer.download.nvidia.com/compute/cuda/repos/ubuntu18
 	curl -fsSL https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub | apt-key add - && \
 	echo "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/cuda.list && \
 	echo "deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/nvidia-ml.list && \
-	rm -rf /var/lib/apt/lists/* && \ 
+	rm -rf /var/lib/apt/lists/* && \
 	mv /opt/conda/lib/libcufft.so.10 /opt/libcufft.so.10.bak && \
     mv /opt/conda/lib/libcurand.so.10 /opt/libcurand.so.10.bak && \
     mv /opt/conda/lib/libcublas.so.11 /opt/libcublas.so.11.bak && \
@@ -56,10 +56,10 @@ RUN wget -P /opt https://cmake.org/files/v3.18/cmake-3.18.0.tar.gz && \
     rm -rf cmake-3.18.0.tar.gz && \
     find /opt/cmake-3.18.0 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
     cd -
-	
-# flac 
+
+# flac
 RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz  && \
-    cd /opt && \ 
+    cd /opt && \
     xz -d flac-1.3.2.tar.xz && \
     tar -xvf flac-1.3.2.tar && \
     cd flac-1.3.2 && \
@@ -67,7 +67,7 @@ RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz  &&
     make && make install && \
     rm -rf flac-1.3.2.tar && \
     find /opt/flac-1.3.2  -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
-    cd - 
+    cd -
 
 RUN conda install -y -c pytorch torchaudio=0.7.1 && \
     pip install graphviz
@@ -79,7 +79,7 @@ RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \
     cd -
 
 # install  lhotse
-RUN pip install git+https://github.com/lhotse-speech/lhotse 
+RUN pip install git+https://github.com/lhotse-speech/lhotse
 
 RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
 	cd /workspace/icefall && \
@@ -88,4 +88,3 @@ RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
 ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
 
 WORKDIR /workspace/icefall
-
diff --git a/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg b/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg
index 534b2e534..3019ff03d 100644
--- a/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg
+++ b/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg
@@ -1 +1 @@
-k2: >= v1.9k2>= v1.9
\ No newline at end of file
+k2: >= v1.9k2>= v1.9
diff --git a/docs/source/installation/images/python-gt-v3.6-blue.svg b/docs/source/installation/images/python-gt-v3.6-blue.svg
index 4254dc58a..df677ad09 100644
--- a/docs/source/installation/images/python-gt-v3.6-blue.svg
+++ b/docs/source/installation/images/python-gt-v3.6-blue.svg
@@ -1 +1 @@
-python: >= 3.6python>= 3.6
\ No newline at end of file
+python: >= 3.6python>= 3.6
diff --git a/docs/source/installation/images/torch-gt-v1.6.0-green.svg b/docs/source/installation/images/torch-gt-v1.6.0-green.svg
index d3ece9a17..d7007d742 100644
--- a/docs/source/installation/images/torch-gt-v1.6.0-green.svg
+++ b/docs/source/installation/images/torch-gt-v1.6.0-green.svg
@@ -1 +1 @@
-torch: >= 1.6.0torch>= 1.6.0
\ No newline at end of file
+torch: >= 1.6.0torch>= 1.6.0
diff --git a/docs/source/recipes/aishell/index.rst b/docs/source/recipes/aishell/index.rst
index d072d6e9c..b77d59bca 100644
--- a/docs/source/recipes/aishell/index.rst
+++ b/docs/source/recipes/aishell/index.rst
@@ -19,4 +19,3 @@ It can be downloaded from ``_
    tdnn_lstm_ctc
    conformer_ctc
    stateless_transducer
-
diff --git a/docs/source/recipes/timit/index.rst b/docs/source/recipes/timit/index.rst
index 17f40cdb7..5ee147be7 100644
--- a/docs/source/recipes/timit/index.rst
+++ b/docs/source/recipes/timit/index.rst
@@ -6,4 +6,3 @@ TIMIT
 
    tdnn_ligru_ctc
    tdnn_lstm_ctc
-
diff --git a/docs/source/recipes/timit/tdnn_ligru_ctc.rst b/docs/source/recipes/timit/tdnn_ligru_ctc.rst
index 186420ee7..3d7aefe02 100644
--- a/docs/source/recipes/timit/tdnn_ligru_ctc.rst
+++ b/docs/source/recipes/timit/tdnn_ligru_ctc.rst
@@ -148,10 +148,10 @@ Some commonly used options are:
 
         $ ./tdnn_ligru_ctc/decode.py --epoch 25 --avg 17
 
-    uses the average of ``epoch-9.pt``, ``epoch-10.pt``, ``epoch-11.pt``, 
-    ``epoch-12.pt``, ``epoch-13.pt``, ``epoch-14.pt``, ``epoch-15.pt``, 
-    ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, ``epoch-19.pt``, 
-    ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, ``epoch-23.pt``, 
+    uses the average of ``epoch-9.pt``, ``epoch-10.pt``, ``epoch-11.pt``,
+    ``epoch-12.pt``, ``epoch-13.pt``, ``epoch-14.pt``, ``epoch-15.pt``,
+    ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, ``epoch-19.pt``,
+    ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, ``epoch-23.pt``,
     ``epoch-24.pt`` and ``epoch-25.pt``
     for decoding.
 
@@ -317,13 +317,13 @@ To decode with ``1best`` method, we can use:
 
 .. code-block:: bash
 
-  ./tdnn_ligru_ctc/pretrained.py 
+  ./tdnn_ligru_ctc/pretrained.py
     --method 1best
-    --checkpoint ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/exp/pretrained_average_9_25.pt 
-    --words-file ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/words.txt 
-    --HLG ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt 
-    ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV 
-    ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV 
+    --checkpoint ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/exp/pretrained_average_9_25.pt
+    --words-file ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/words.txt
+    --HLG ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt
+    ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV
+    ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV
     ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV
 
 The output is:
@@ -337,7 +337,7 @@ The output is:
   2021-11-08 20:41:38,697 INFO [pretrained.py:210] Reading sound files: ['./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV']
   2021-11-08 20:41:38,704 INFO [pretrained.py:216] Decoding started
   2021-11-08 20:41:39,819 INFO [pretrained.py:246] Use HLG decoding
-  2021-11-08 20:41:39,829 INFO [pretrained.py:267] 
+  2021-11-08 20:41:39,829 INFO [pretrained.py:267]
   ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV:
   sil dh ih sh uw ah l iy v iy z ih sil p r aa sil k s ih m ey dx ih sil d w uh dx ih w ih s f iy l ih ng w ih th ih n ih m s eh l f sil jh
 
@@ -362,8 +362,8 @@ To decode with ``whole-lattice-rescoring`` methond, you can use
     --HLG ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt \
     --G ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lm/G_4_gram.pt \
     --ngram-lm-scale 0.1 \
-    ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV 
-    ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV 
+    ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV
+    ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV
     ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV
 
 The decoding output is:
@@ -378,7 +378,7 @@ The decoding output is:
   2021-11-08 20:37:54,715 INFO [pretrained.py:210] Reading sound files: ['./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV']
   2021-11-08 20:37:54,720 INFO [pretrained.py:216] Decoding started
   2021-11-08 20:37:55,808 INFO [pretrained.py:251] Use HLG decoding + LM rescoring
-  2021-11-08 20:37:56,348 INFO [pretrained.py:267] 
+  2021-11-08 20:37:56,348 INFO [pretrained.py:267]
   ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV:
   sil dh ih sh uw ah l iy v iy z ah sil p r aa sil k s ih m ey dx ih sil d w uh dx iy w ih s f iy l iy ng w ih th ih n ih m s eh l f sil jh
 
diff --git a/docs/source/recipes/timit/tdnn_lstm_ctc.rst b/docs/source/recipes/timit/tdnn_lstm_ctc.rst
index 6f760a9ce..ee67a6edc 100644
--- a/docs/source/recipes/timit/tdnn_lstm_ctc.rst
+++ b/docs/source/recipes/timit/tdnn_lstm_ctc.rst
@@ -148,8 +148,8 @@ Some commonly used options are:
 
         $ ./tdnn_lstm_ctc/decode.py --epoch 25 --avg 10
 
-    uses the average of ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, 
-    ``epoch-19.pt``, ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, 
+    uses the average of ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``,
+    ``epoch-19.pt``, ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``,
     ``epoch-23.pt``, ``epoch-24.pt`` and ``epoch-25.pt``
     for decoding.
 
@@ -315,13 +315,13 @@ To decode with ``1best`` method, we can use:
 
 .. code-block:: bash
 
-  ./tdnn_lstm_ctc/pretrained.py 
+  ./tdnn_lstm_ctc/pretrained.py
     --method 1best
-    --checkpoint ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/exp/pretrained_average_16_25.pt 
-    --words-file ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/words.txt 
-    --HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt 
-    ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV 
-    ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV 
+    --checkpoint ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/exp/pretrained_average_16_25.pt
+    --words-file ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/words.txt
+    --HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt
+    ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV
+    ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV
     ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV
 
 The output is:
@@ -335,7 +335,7 @@ The output is:
   2021-11-08 21:02:53,827 INFO [pretrained.py:210] Reading sound files: ['./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV']
   2021-11-08 21:02:53,831 INFO [pretrained.py:216] Decoding started
   2021-11-08 21:02:54,380 INFO [pretrained.py:246] Use HLG decoding
-  2021-11-08 21:02:54,387 INFO [pretrained.py:267] 
+  2021-11-08 21:02:54,387 INFO [pretrained.py:267]
   ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV:
   sil dh ih sh uw ah l iy v iy z ih sil p r aa sil k s ih m ey dx ih sil d w uh dx iy w ih s f iy l iy w ih th ih n ih m s eh l f sil jh
 
@@ -360,8 +360,8 @@ To decode with ``whole-lattice-rescoring`` methond, you can use
     --HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt \
     --G ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lm/G_4_gram.pt \
     --ngram-lm-scale 0.08 \
-    ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV 
-    ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV 
+    ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV
+    ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV
     ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV
 
 The decoding output is:
@@ -376,7 +376,7 @@ The decoding output is:
   2021-11-08 20:05:26,978 INFO [pretrained.py:210] Reading sound files: ['./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV']
   2021-11-08 20:05:26,981 INFO [pretrained.py:216] Decoding started
   2021-11-08 20:05:27,519 INFO [pretrained.py:251] Use HLG decoding + LM rescoring
-  2021-11-08 20:05:27,878 INFO [pretrained.py:267] 
+  2021-11-08 20:05:27,878 INFO [pretrained.py:267]
   ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV:
   sil dh ih sh uw l iy v iy z ih sil p r aa sil k s ah m ey dx ih sil w uh dx iy w ih s f iy l ih ng w ih th ih n ih m s eh l f sil jh
 
diff --git a/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py b/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py
index fb2751c0f..387c14acf 100755
--- a/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py
+++ b/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py
@@ -87,9 +87,7 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set
-                    + cut_set.perturb_speed(0.9)
-                    + cut_set.perturb_speed(1.1)
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -116,9 +114,7 @@ def get_args():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/aidatatang_200zh/ASR/local/prepare_char.py b/egs/aidatatang_200zh/ASR/local/prepare_char.py
index d9e47d17a..6b440dfb3 100755
--- a/egs/aidatatang_200zh/ASR/local/prepare_char.py
+++ b/egs/aidatatang_200zh/ASR/local/prepare_char.py
@@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil(
         cur_state = loop_state
 
         word = word2id[word]
-        pieces = [
-            token2id[i] if i in token2id else token2id[""] for i in pieces
-        ]
+        pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces]
 
         for i in range(len(pieces) - 1):
             w = word if i == 0 else eps
@@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
     return False
 
 
-def generate_lexicon(
-    token_sym_table: Dict[str, int], words: List[str]
-) -> Lexicon:
+def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
     """Generate a lexicon from a word list and token_sym_table.
 
     Args:
diff --git a/egs/aidatatang_200zh/ASR/local/prepare_lang.py b/egs/aidatatang_200zh/ASR/local/prepare_lang.py
index e5ae89ec4..c8cf9b881 100755
--- a/egs/aidatatang_200zh/ASR/local/prepare_lang.py
+++ b/egs/aidatatang_200zh/ASR/local/prepare_lang.py
@@ -317,9 +317,7 @@ def lexicon_to_fst(
 
 def get_args():
     parser = argparse.ArgumentParser()
-    parser.add_argument(
-        "--lang-dir", type=str, help="The lang dir, data/lang_phone"
-    )
+    parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
     return parser.parse_args()
 
 
diff --git a/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py b/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py
index d4cf62bba..74e025ad7 100755
--- a/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py
+++ b/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py
@@ -88,9 +88,7 @@ def test_read_lexicon(filename: str):
     fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa.draw("L.pdf", title="L")
 
-    fsa_disambig = lexicon_to_fst(
-        lexicon_disambig, phone2id=phone2id, word2id=word2id
-    )
+    fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
     fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
     fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
diff --git a/egs/aidatatang_200zh/ASR/local/text2token.py b/egs/aidatatang_200zh/ASR/local/text2token.py
index 71be2a613..85047c367 100755
--- a/egs/aidatatang_200zh/ASR/local/text2token.py
+++ b/egs/aidatatang_200zh/ASR/local/text2token.py
@@ -56,9 +56,7 @@ def get_parser():
     parser.add_argument(
         "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
     )
-    parser.add_argument(
-        "--space", default="", type=str, help="space symbol"
-    )
+    parser.add_argument("--space", default="", type=str, help="space symbol")
     parser.add_argument(
         "--non-lang-syms",
         "-l",
@@ -66,9 +64,7 @@ def get_parser():
         type=str,
         help="list of non-linguistic symobles, e.g.,  etc.",
     )
-    parser.add_argument(
-        "text", type=str, default=False, nargs="?", help="input text"
-    )
+    parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
     parser.add_argument(
         "--trans_type",
         "-t",
@@ -108,8 +104,7 @@ def token2id(
             if token_type == "lazy_pinyin":
                 text = lazy_pinyin(chars_list)
                 sub_ids = [
-                    token_table[txt] if txt in token_table else oov_id
-                    for txt in text
+                    token_table[txt] if txt in token_table else oov_id for txt in text
                 ]
                 ids.append(sub_ids)
             else:  # token_type = "pinyin"
@@ -135,9 +130,7 @@ def main():
     if args.text:
         f = codecs.open(args.text, encoding="utf-8")
     else:
-        f = codecs.getreader("utf-8")(
-            sys.stdin if is_python2 else sys.stdin.buffer
-        )
+        f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
 
     sys.stdout = codecs.getwriter("utf-8")(
         sys.stdout if is_python2 else sys.stdout.buffer
diff --git a/egs/aidatatang_200zh/ASR/prepare.sh b/egs/aidatatang_200zh/ASR/prepare.sh
index 039951354..4749e1b7f 100755
--- a/egs/aidatatang_200zh/ASR/prepare.sh
+++ b/egs/aidatatang_200zh/ASR/prepare.sh
@@ -106,11 +106,10 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
   if [ ! -f $lang_char_dir/words.txt ]; then
     ./local/prepare_words.py \
       --input-file $lang_char_dir/words_no_ids.txt \
-      --output-file $lang_char_dir/words.txt 
+      --output-file $lang_char_dir/words.txt
   fi
 
   if [ ! -f $lang_char_dir/L_disambig.pt ]; then
     ./local/prepare_char.py
   fi
 fi
-
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py
index 6a5b57e24..167d5e15e 100644
--- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -205,17 +205,13 @@ class Aidatatang_200zhAsrDataModule:
             The state dict for the training sampler.
         """
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(
-            self.args.manifest_dir / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
@@ -237,9 +233,7 @@ class Aidatatang_200zhAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -282,9 +276,7 @@ class Aidatatang_200zhAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -340,9 +332,7 @@ class Aidatatang_200zhAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py
index f0407f429..b1c7c2839 100755
--- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py
@@ -69,11 +69,7 @@ from beam_search import (
 )
 from train import get_params, get_transducer_model
 
-from icefall.checkpoint import (
-    average_checkpoints,
-    find_checkpoints,
-    load_checkpoint,
-)
+from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
 from icefall.lexicon import Lexicon
 from icefall.utils import (
     AttributeDict,
@@ -192,8 +188,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -249,9 +244,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -266,10 +259,7 @@ def decode_one_batch(
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -390,9 +380,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -425,8 +413,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py
index 00b54c39f..de37ec7e4 100644
--- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py
@@ -103,8 +103,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     return parser
@@ -173,9 +172,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py
index eb5e6b0d4..548b7263c 100644
--- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py
@@ -162,8 +162,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -194,8 +193,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -257,9 +255,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -284,10 +280,7 @@ def main():
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -339,9 +332,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py
index d46838b68..322fa6b00 100644
--- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py
@@ -81,9 +81,7 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[
-    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
 
 os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
 
@@ -187,8 +185,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -211,8 +208,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network)" "part.",
     )
 
     parser.add_argument(
@@ -542,22 +538,15 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0
-            if warmup < 1.0
-            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
-        )
-        loss = (
-            params.simple_loss_scale * simple_loss
-            + pruned_loss_scale * pruned_loss
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
         )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -711,9 +700,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -813,7 +800,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2 ** 22
+            2**22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
diff --git a/egs/aishell/ASR/conformer_ctc/conformer.py b/egs/aishell/ASR/conformer_ctc/conformer.py
index cb7205e51..ab1cbbae4 100644
--- a/egs/aishell/ASR/conformer_ctc/conformer.py
+++ b/egs/aishell/ASR/conformer_ctc/conformer.py
@@ -157,9 +157,7 @@ class ConformerEncoderLayer(nn.Module):
         normalize_before: bool = True,
     ) -> None:
         super(ConformerEncoderLayer, self).__init__()
-        self.self_attn = RelPositionMultiheadAttention(
-            d_model, nhead, dropout=0.0
-        )
+        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
 
         self.feed_forward = nn.Sequential(
             nn.Linear(d_model, dim_feedforward),
@@ -177,18 +175,14 @@ class ConformerEncoderLayer(nn.Module):
 
         self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
 
-        self.norm_ff_macaron = nn.LayerNorm(
-            d_model
-        )  # for the macaron style FNN module
+        self.norm_ff_macaron = nn.LayerNorm(d_model)  # for the macaron style FNN module
         self.norm_ff = nn.LayerNorm(d_model)  # for the FNN module
         self.norm_mha = nn.LayerNorm(d_model)  # for the MHA module
 
         self.ff_scale = 0.5
 
         self.norm_conv = nn.LayerNorm(d_model)  # for the CNN module
-        self.norm_final = nn.LayerNorm(
-            d_model
-        )  # for the final output of the block
+        self.norm_final = nn.LayerNorm(d_model)  # for the final output of the block
 
         self.dropout = nn.Dropout(dropout)
 
@@ -222,9 +216,7 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(
-            self.feed_forward_macaron(src)
-        )
+        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -343,9 +335,7 @@ class RelPositionalEncoding(torch.nn.Module):
 
     """
 
-    def __init__(
-        self, d_model: int, dropout_rate: float, max_len: int = 5000
-    ) -> None:
+    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
         """Construct an PositionalEncoding object."""
         super(RelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -361,9 +351,7 @@ class RelPositionalEncoding(torch.nn.Module):
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(1) >= x.size(1) * 2 - 1:
                 # Note: TorchScript doesn't implement operator== for torch.Device
-                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
-                    x.device
-                ):
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
                     self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                 return
         # Suppose `i` means to the position of query vector and `j` means the
@@ -633,9 +621,9 @@ class RelPositionMultiheadAttention(nn.Module):
 
         if torch.equal(query, key) and torch.equal(key, value):
             # self-attention
-            q, k, v = nn.functional.linear(
-                query, in_proj_weight, in_proj_bias
-            ).chunk(3, dim=-1)
+            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
+                3, dim=-1
+            )
 
         elif torch.equal(key, value):
             # encoder-decoder attention
@@ -703,31 +691,22 @@ class RelPositionMultiheadAttention(nn.Module):
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0)
                 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
-                    raise RuntimeError(
-                        "The size of the 2D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
             elif attn_mask.dim() == 3:
                 if list(attn_mask.size()) != [
                     bsz * num_heads,
                     query.size(0),
                     key.size(0),
                 ]:
-                    raise RuntimeError(
-                        "The size of the 3D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
             else:
                 raise RuntimeError(
-                    "attn_mask's dimension {} is not supported".format(
-                        attn_mask.dim()
-                    )
+                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
                 )
             # attn_mask's dim is 3 now.
 
         # convert ByteTensor key_padding_mask to bool
-        if (
-            key_padding_mask is not None
-            and key_padding_mask.dtype == torch.uint8
-        ):
+        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
             warnings.warn(
                 "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
             )
@@ -766,9 +745,7 @@ class RelPositionMultiheadAttention(nn.Module):
         # first compute matrix a and matrix c
         # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
         k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
-        matrix_ac = torch.matmul(
-            q_with_bias_u, k
-        )  # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
 
         # compute matrix b and matrix d
         matrix_bd = torch.matmul(
@@ -780,9 +757,7 @@ class RelPositionMultiheadAttention(nn.Module):
             matrix_ac + matrix_bd
         ) * scaling  # (batch, head, time1, time2)
 
-        attn_output_weights = attn_output_weights.view(
-            bsz * num_heads, tgt_len, -1
-        )
+        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
 
         assert list(attn_output_weights.size()) == [
             bsz * num_heads,
@@ -816,13 +791,9 @@ class RelPositionMultiheadAttention(nn.Module):
         attn_output = torch.bmm(attn_output_weights, v)
         assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
         attn_output = (
-            attn_output.transpose(0, 1)
-            .contiguous()
-            .view(tgt_len, bsz, embed_dim)
-        )
-        attn_output = nn.functional.linear(
-            attn_output, out_proj_weight, out_proj_bias
+            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
         )
+        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
 
         if need_weights:
             # average attention weights over heads
@@ -845,9 +816,7 @@ class ConvolutionModule(nn.Module):
 
     """
 
-    def __init__(
-        self, channels: int, kernel_size: int, bias: bool = True
-    ) -> None:
+    def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None:
         """Construct an ConvolutionModule object."""
         super(ConvolutionModule, self).__init__()
         # kernerl_size should be a odd number for 'SAME' padding
diff --git a/egs/aishell/ASR/conformer_ctc/decode.py b/egs/aishell/ASR/conformer_ctc/decode.py
index 751b7d5b5..74a7b5933 100755
--- a/egs/aishell/ASR/conformer_ctc/decode.py
+++ b/egs/aishell/ASR/conformer_ctc/decode.py
@@ -401,9 +401,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -431,9 +429,7 @@ def save_results(
         # we compute CER for aishell dataset.
         results_char = []
         for res in results:
-            results_char.append(
-                (res[0], list("".join(res[1])), list("".join(res[2])))
-            )
+            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
@@ -441,9 +437,7 @@ def save_results(
             test_set_wers[key] = wer
 
         if enable_log:
-            logging.info(
-                "Wrote detailed error stats to {}".format(errs_filename)
-            )
+            logging.info("Wrote detailed error stats to {}".format(errs_filename))
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt"
@@ -562,9 +556,7 @@ def main():
             eos_id=eos_id,
         )
 
-        save_results(
-            params=params, test_set_name=test_set, results_dict=results_dict
-        )
+        save_results(params=params, test_set_name=test_set, results_dict=results_dict)
 
     logging.info("Done!")
 
diff --git a/egs/aishell/ASR/conformer_ctc/export.py b/egs/aishell/ASR/conformer_ctc/export.py
index 42b8c29e7..1df3cfdc2 100644
--- a/egs/aishell/ASR/conformer_ctc/export.py
+++ b/egs/aishell/ASR/conformer_ctc/export.py
@@ -157,9 +157,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/conformer_ctc/pretrained.py b/egs/aishell/ASR/conformer_ctc/pretrained.py
index 27776bc24..e0dcb8ad4 100755
--- a/egs/aishell/ASR/conformer_ctc/pretrained.py
+++ b/egs/aishell/ASR/conformer_ctc/pretrained.py
@@ -211,8 +211,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -274,9 +273,7 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     # Note: We don't use key padding mask for attention during decoding
     with torch.no_grad():
@@ -371,9 +368,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/conformer_ctc/subsampling.py b/egs/aishell/ASR/conformer_ctc/subsampling.py
index 542fb0364..8e0f73d05 100644
--- a/egs/aishell/ASR/conformer_ctc/subsampling.py
+++ b/egs/aishell/ASR/conformer_ctc/subsampling.py
@@ -42,13 +42,9 @@ class Conv2dSubsampling(nn.Module):
         assert idim >= 7
         super().__init__()
         self.conv = nn.Sequential(
-            nn.Conv2d(
-                in_channels=1, out_channels=odim, kernel_size=3, stride=2
-            ),
+            nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2),
             nn.ReLU(),
-            nn.Conv2d(
-                in_channels=odim, out_channels=odim, kernel_size=3, stride=2
-            ),
+            nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2),
             nn.ReLU(),
         )
         self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
@@ -132,17 +128,13 @@ class VggSubsampling(nn.Module):
                 )
             )
             layers.append(
-                torch.nn.MaxPool2d(
-                    kernel_size=2, stride=2, padding=0, ceil_mode=True
-                )
+                torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)
             )
             cur_channels = block_dim
 
         self.layers = nn.Sequential(*layers)
 
-        self.out = nn.Linear(
-            block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
-        )
+        self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim)
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
         """Subsample x.
diff --git a/egs/aishell/ASR/conformer_ctc/test_subsampling.py b/egs/aishell/ASR/conformer_ctc/test_subsampling.py
index e3361d0c9..81fa234dd 100755
--- a/egs/aishell/ASR/conformer_ctc/test_subsampling.py
+++ b/egs/aishell/ASR/conformer_ctc/test_subsampling.py
@@ -16,9 +16,8 @@
 # limitations under the License.
 
 
-from subsampling import Conv2dSubsampling
-from subsampling import VggSubsampling
 import torch
+from subsampling import Conv2dSubsampling, VggSubsampling
 
 
 def test_conv2d_subsampling():
diff --git a/egs/aishell/ASR/conformer_ctc/train.py b/egs/aishell/ASR/conformer_ctc/train.py
index a228cc1fe..c2cbe6e3b 100755
--- a/egs/aishell/ASR/conformer_ctc/train.py
+++ b/egs/aishell/ASR/conformer_ctc/train.py
@@ -382,9 +382,7 @@ def compute_loss(
             #
             # See https://github.com/k2-fsa/icefall/issues/97
             # for more details
-            unsorted_token_ids = graph_compiler.texts_to_ids(
-                supervisions["text"]
-            )
+            unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"])
             att_loss = mmodel.decoder_forward(
                 encoder_memory,
                 memory_mask,
@@ -520,9 +518,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -630,9 +626,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/aishell/ASR/conformer_ctc/transformer.py b/egs/aishell/ASR/conformer_ctc/transformer.py
index f93914aaa..a3e50e385 100644
--- a/egs/aishell/ASR/conformer_ctc/transformer.py
+++ b/egs/aishell/ASR/conformer_ctc/transformer.py
@@ -149,9 +149,7 @@ class Transformer(nn.Module):
                 norm=decoder_norm,
             )
 
-            self.decoder_output_layer = torch.nn.Linear(
-                d_model, self.decoder_num_class
-            )
+            self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class)
 
             self.decoder_criterion = LabelSmoothingLoss()
         else:
@@ -183,9 +181,7 @@ class Transformer(nn.Module):
             x = x.permute(0, 2, 1)  # (N, T, C) -> (N, C, T)
             x = self.feat_batchnorm(x)
             x = x.permute(0, 2, 1)  # (N, C, T) -> (N, T, C)
-        encoder_memory, memory_key_padding_mask = self.run_encoder(
-            x, supervision
-        )
+        encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision)
         x = self.ctc_output(encoder_memory)
         return x, encoder_memory, memory_key_padding_mask
 
@@ -266,23 +262,17 @@ class Transformer(nn.Module):
         """
         ys_in = add_sos(token_ids, sos_id=sos_id)
         ys_in = [torch.tensor(y) for y in ys_in]
-        ys_in_pad = pad_sequence(
-            ys_in, batch_first=True, padding_value=float(eos_id)
-        )
+        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
 
         ys_out = add_eos(token_ids, eos_id=eos_id)
         ys_out = [torch.tensor(y) for y in ys_out]
-        ys_out_pad = pad_sequence(
-            ys_out, batch_first=True, padding_value=float(-1)
-        )
+        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
 
         device = memory.device
         ys_in_pad = ys_in_pad.to(device)
         ys_out_pad = ys_out_pad.to(device)
 
-        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
-            device
-        )
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
 
         tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
         # TODO: Use length information to create the decoder padding mask
@@ -343,23 +333,17 @@ class Transformer(nn.Module):
 
         ys_in = add_sos(token_ids, sos_id=sos_id)
         ys_in = [torch.tensor(y) for y in ys_in]
-        ys_in_pad = pad_sequence(
-            ys_in, batch_first=True, padding_value=float(eos_id)
-        )
+        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
 
         ys_out = add_eos(token_ids, eos_id=eos_id)
         ys_out = [torch.tensor(y) for y in ys_out]
-        ys_out_pad = pad_sequence(
-            ys_out, batch_first=True, padding_value=float(-1)
-        )
+        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
 
         device = memory.device
         ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
         ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
 
-        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
-            device
-        )
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
 
         tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
         # TODO: Use length information to create the decoder padding mask
@@ -632,9 +616,7 @@ def _get_activation_fn(activation: str):
     elif activation == "gelu":
         return nn.functional.gelu
 
-    raise RuntimeError(
-        "activation should be relu/gelu, not {}".format(activation)
-    )
+    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
 
 
 class PositionalEncoding(nn.Module):
@@ -836,9 +818,7 @@ def encoder_padding_mask(
         1,
     ).to(torch.int32)
 
-    lengths = [
-        0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
-    ]
+    lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)]
     for idx in range(supervision_segments.size(0)):
         # Note: TorchScript doesn't allow to unpack tensors as tuples
         sequence_idx = supervision_segments[idx, 0].item()
@@ -859,9 +839,7 @@ def encoder_padding_mask(
     return mask
 
 
-def decoder_padding_mask(
-    ys_pad: torch.Tensor, ignore_id: int = -1
-) -> torch.Tensor:
+def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor:
     """Generate a length mask for input.
 
     The masked position are filled with True,
diff --git a/egs/aishell/ASR/conformer_mmi/conformer.py b/egs/aishell/ASR/conformer_mmi/conformer.py
index cb7205e51..ab1cbbae4 100644
--- a/egs/aishell/ASR/conformer_mmi/conformer.py
+++ b/egs/aishell/ASR/conformer_mmi/conformer.py
@@ -157,9 +157,7 @@ class ConformerEncoderLayer(nn.Module):
         normalize_before: bool = True,
     ) -> None:
         super(ConformerEncoderLayer, self).__init__()
-        self.self_attn = RelPositionMultiheadAttention(
-            d_model, nhead, dropout=0.0
-        )
+        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
 
         self.feed_forward = nn.Sequential(
             nn.Linear(d_model, dim_feedforward),
@@ -177,18 +175,14 @@ class ConformerEncoderLayer(nn.Module):
 
         self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
 
-        self.norm_ff_macaron = nn.LayerNorm(
-            d_model
-        )  # for the macaron style FNN module
+        self.norm_ff_macaron = nn.LayerNorm(d_model)  # for the macaron style FNN module
         self.norm_ff = nn.LayerNorm(d_model)  # for the FNN module
         self.norm_mha = nn.LayerNorm(d_model)  # for the MHA module
 
         self.ff_scale = 0.5
 
         self.norm_conv = nn.LayerNorm(d_model)  # for the CNN module
-        self.norm_final = nn.LayerNorm(
-            d_model
-        )  # for the final output of the block
+        self.norm_final = nn.LayerNorm(d_model)  # for the final output of the block
 
         self.dropout = nn.Dropout(dropout)
 
@@ -222,9 +216,7 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(
-            self.feed_forward_macaron(src)
-        )
+        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -343,9 +335,7 @@ class RelPositionalEncoding(torch.nn.Module):
 
     """
 
-    def __init__(
-        self, d_model: int, dropout_rate: float, max_len: int = 5000
-    ) -> None:
+    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
         """Construct an PositionalEncoding object."""
         super(RelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -361,9 +351,7 @@ class RelPositionalEncoding(torch.nn.Module):
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(1) >= x.size(1) * 2 - 1:
                 # Note: TorchScript doesn't implement operator== for torch.Device
-                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
-                    x.device
-                ):
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
                     self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                 return
         # Suppose `i` means to the position of query vector and `j` means the
@@ -633,9 +621,9 @@ class RelPositionMultiheadAttention(nn.Module):
 
         if torch.equal(query, key) and torch.equal(key, value):
             # self-attention
-            q, k, v = nn.functional.linear(
-                query, in_proj_weight, in_proj_bias
-            ).chunk(3, dim=-1)
+            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
+                3, dim=-1
+            )
 
         elif torch.equal(key, value):
             # encoder-decoder attention
@@ -703,31 +691,22 @@ class RelPositionMultiheadAttention(nn.Module):
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0)
                 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
-                    raise RuntimeError(
-                        "The size of the 2D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
             elif attn_mask.dim() == 3:
                 if list(attn_mask.size()) != [
                     bsz * num_heads,
                     query.size(0),
                     key.size(0),
                 ]:
-                    raise RuntimeError(
-                        "The size of the 3D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
             else:
                 raise RuntimeError(
-                    "attn_mask's dimension {} is not supported".format(
-                        attn_mask.dim()
-                    )
+                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
                 )
             # attn_mask's dim is 3 now.
 
         # convert ByteTensor key_padding_mask to bool
-        if (
-            key_padding_mask is not None
-            and key_padding_mask.dtype == torch.uint8
-        ):
+        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
             warnings.warn(
                 "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
             )
@@ -766,9 +745,7 @@ class RelPositionMultiheadAttention(nn.Module):
         # first compute matrix a and matrix c
         # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
         k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
-        matrix_ac = torch.matmul(
-            q_with_bias_u, k
-        )  # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
 
         # compute matrix b and matrix d
         matrix_bd = torch.matmul(
@@ -780,9 +757,7 @@ class RelPositionMultiheadAttention(nn.Module):
             matrix_ac + matrix_bd
         ) * scaling  # (batch, head, time1, time2)
 
-        attn_output_weights = attn_output_weights.view(
-            bsz * num_heads, tgt_len, -1
-        )
+        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
 
         assert list(attn_output_weights.size()) == [
             bsz * num_heads,
@@ -816,13 +791,9 @@ class RelPositionMultiheadAttention(nn.Module):
         attn_output = torch.bmm(attn_output_weights, v)
         assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
         attn_output = (
-            attn_output.transpose(0, 1)
-            .contiguous()
-            .view(tgt_len, bsz, embed_dim)
-        )
-        attn_output = nn.functional.linear(
-            attn_output, out_proj_weight, out_proj_bias
+            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
         )
+        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
 
         if need_weights:
             # average attention weights over heads
@@ -845,9 +816,7 @@ class ConvolutionModule(nn.Module):
 
     """
 
-    def __init__(
-        self, channels: int, kernel_size: int, bias: bool = True
-    ) -> None:
+    def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None:
         """Construct an ConvolutionModule object."""
         super(ConvolutionModule, self).__init__()
         # kernerl_size should be a odd number for 'SAME' padding
diff --git a/egs/aishell/ASR/conformer_mmi/decode.py b/egs/aishell/ASR/conformer_mmi/decode.py
index 4db367e36..20a855e7f 100755
--- a/egs/aishell/ASR/conformer_mmi/decode.py
+++ b/egs/aishell/ASR/conformer_mmi/decode.py
@@ -413,9 +413,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -443,9 +441,7 @@ def save_results(
         # we compute CER for aishell dataset.
         results_char = []
         for res in results:
-            results_char.append(
-                (res[0], list("".join(res[1])), list("".join(res[2])))
-            )
+            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
@@ -453,9 +449,7 @@ def save_results(
             test_set_wers[key] = wer
 
         if enable_log:
-            logging.info(
-                "Wrote detailed error stats to {}".format(errs_filename)
-            )
+            logging.info("Wrote detailed error stats to {}".format(errs_filename))
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt"
@@ -550,9 +544,7 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save(
-            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
-        )
+        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
         return
 
     model.to(device)
@@ -581,9 +573,7 @@ def main():
             eos_id=eos_id,
         )
 
-        save_results(
-            params=params, test_set_name=test_set, results_dict=results_dict
-        )
+        save_results(params=params, test_set_name=test_set, results_dict=results_dict)
 
     logging.info("Done!")
 
diff --git a/egs/aishell/ASR/conformer_mmi/subsampling.py b/egs/aishell/ASR/conformer_mmi/subsampling.py
index 720ed6c22..398837a46 100644
--- a/egs/aishell/ASR/conformer_mmi/subsampling.py
+++ b/egs/aishell/ASR/conformer_mmi/subsampling.py
@@ -42,13 +42,9 @@ class Conv2dSubsampling(nn.Module):
         assert idim >= 7
         super().__init__()
         self.conv = nn.Sequential(
-            nn.Conv2d(
-                in_channels=1, out_channels=odim, kernel_size=3, stride=2
-            ),
+            nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2),
             nn.ReLU(),
-            nn.Conv2d(
-                in_channels=odim, out_channels=odim, kernel_size=3, stride=2
-            ),
+            nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2),
             nn.ReLU(),
         )
         self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
@@ -132,17 +128,13 @@ class VggSubsampling(nn.Module):
                 )
             )
             layers.append(
-                torch.nn.MaxPool2d(
-                    kernel_size=2, stride=2, padding=0, ceil_mode=True
-                )
+                torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)
             )
             cur_channels = block_dim
 
         self.layers = nn.Sequential(*layers)
 
-        self.out = nn.Linear(
-            block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
-        )
+        self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim)
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
         """Subsample x.
diff --git a/egs/aishell/ASR/conformer_mmi/train.py b/egs/aishell/ASR/conformer_mmi/train.py
index 685831d09..09cd6e60c 100755
--- a/egs/aishell/ASR/conformer_mmi/train.py
+++ b/egs/aishell/ASR/conformer_mmi/train.py
@@ -511,9 +511,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -625,9 +623,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/aishell/ASR/conformer_mmi/transformer.py b/egs/aishell/ASR/conformer_mmi/transformer.py
index f93914aaa..a3e50e385 100644
--- a/egs/aishell/ASR/conformer_mmi/transformer.py
+++ b/egs/aishell/ASR/conformer_mmi/transformer.py
@@ -149,9 +149,7 @@ class Transformer(nn.Module):
                 norm=decoder_norm,
             )
 
-            self.decoder_output_layer = torch.nn.Linear(
-                d_model, self.decoder_num_class
-            )
+            self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class)
 
             self.decoder_criterion = LabelSmoothingLoss()
         else:
@@ -183,9 +181,7 @@ class Transformer(nn.Module):
             x = x.permute(0, 2, 1)  # (N, T, C) -> (N, C, T)
             x = self.feat_batchnorm(x)
             x = x.permute(0, 2, 1)  # (N, C, T) -> (N, T, C)
-        encoder_memory, memory_key_padding_mask = self.run_encoder(
-            x, supervision
-        )
+        encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision)
         x = self.ctc_output(encoder_memory)
         return x, encoder_memory, memory_key_padding_mask
 
@@ -266,23 +262,17 @@ class Transformer(nn.Module):
         """
         ys_in = add_sos(token_ids, sos_id=sos_id)
         ys_in = [torch.tensor(y) for y in ys_in]
-        ys_in_pad = pad_sequence(
-            ys_in, batch_first=True, padding_value=float(eos_id)
-        )
+        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
 
         ys_out = add_eos(token_ids, eos_id=eos_id)
         ys_out = [torch.tensor(y) for y in ys_out]
-        ys_out_pad = pad_sequence(
-            ys_out, batch_first=True, padding_value=float(-1)
-        )
+        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
 
         device = memory.device
         ys_in_pad = ys_in_pad.to(device)
         ys_out_pad = ys_out_pad.to(device)
 
-        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
-            device
-        )
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
 
         tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
         # TODO: Use length information to create the decoder padding mask
@@ -343,23 +333,17 @@ class Transformer(nn.Module):
 
         ys_in = add_sos(token_ids, sos_id=sos_id)
         ys_in = [torch.tensor(y) for y in ys_in]
-        ys_in_pad = pad_sequence(
-            ys_in, batch_first=True, padding_value=float(eos_id)
-        )
+        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
 
         ys_out = add_eos(token_ids, eos_id=eos_id)
         ys_out = [torch.tensor(y) for y in ys_out]
-        ys_out_pad = pad_sequence(
-            ys_out, batch_first=True, padding_value=float(-1)
-        )
+        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
 
         device = memory.device
         ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
         ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
 
-        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
-            device
-        )
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
 
         tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
         # TODO: Use length information to create the decoder padding mask
@@ -632,9 +616,7 @@ def _get_activation_fn(activation: str):
     elif activation == "gelu":
         return nn.functional.gelu
 
-    raise RuntimeError(
-        "activation should be relu/gelu, not {}".format(activation)
-    )
+    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
 
 
 class PositionalEncoding(nn.Module):
@@ -836,9 +818,7 @@ def encoder_padding_mask(
         1,
     ).to(torch.int32)
 
-    lengths = [
-        0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
-    ]
+    lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)]
     for idx in range(supervision_segments.size(0)):
         # Note: TorchScript doesn't allow to unpack tensors as tuples
         sequence_idx = supervision_segments[idx, 0].item()
@@ -859,9 +839,7 @@ def encoder_padding_mask(
     return mask
 
 
-def decoder_padding_mask(
-    ys_pad: torch.Tensor, ignore_id: int = -1
-) -> torch.Tensor:
+def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor:
     """Generate a length mask for input.
 
     The masked position are filled with True,
diff --git a/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py b/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py
index 42700a972..037971927 100755
--- a/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py
+++ b/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py
@@ -87,9 +87,7 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set
-                    + cut_set.perturb_speed(0.9)
-                    + cut_set.perturb_speed(1.1)
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -116,9 +114,7 @@ def get_args():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/aishell/ASR/local/compute_fbank_aishell.py b/egs/aishell/ASR/local/compute_fbank_aishell.py
index deab6c809..115ca1031 100755
--- a/egs/aishell/ASR/local/compute_fbank_aishell.py
+++ b/egs/aishell/ASR/local/compute_fbank_aishell.py
@@ -83,9 +83,7 @@ def compute_fbank_aishell(num_mel_bins: int = 80):
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set
-                    + cut_set.perturb_speed(0.9)
-                    + cut_set.perturb_speed(1.1)
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -111,9 +109,7 @@ def get_args():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/aishell/ASR/local/prepare_char.py b/egs/aishell/ASR/local/prepare_char.py
index d9e47d17a..6b440dfb3 100755
--- a/egs/aishell/ASR/local/prepare_char.py
+++ b/egs/aishell/ASR/local/prepare_char.py
@@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil(
         cur_state = loop_state
 
         word = word2id[word]
-        pieces = [
-            token2id[i] if i in token2id else token2id[""] for i in pieces
-        ]
+        pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces]
 
         for i in range(len(pieces) - 1):
             w = word if i == 0 else eps
@@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
     return False
 
 
-def generate_lexicon(
-    token_sym_table: Dict[str, int], words: List[str]
-) -> Lexicon:
+def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
     """Generate a lexicon from a word list and token_sym_table.
 
     Args:
diff --git a/egs/aishell/ASR/local/prepare_lang.py b/egs/aishell/ASR/local/prepare_lang.py
index e5ae89ec4..c8cf9b881 100755
--- a/egs/aishell/ASR/local/prepare_lang.py
+++ b/egs/aishell/ASR/local/prepare_lang.py
@@ -317,9 +317,7 @@ def lexicon_to_fst(
 
 def get_args():
     parser = argparse.ArgumentParser()
-    parser.add_argument(
-        "--lang-dir", type=str, help="The lang dir, data/lang_phone"
-    )
+    parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
     return parser.parse_args()
 
 
diff --git a/egs/aishell/ASR/local/test_prepare_lang.py b/egs/aishell/ASR/local/test_prepare_lang.py
index d4cf62bba..74e025ad7 100755
--- a/egs/aishell/ASR/local/test_prepare_lang.py
+++ b/egs/aishell/ASR/local/test_prepare_lang.py
@@ -88,9 +88,7 @@ def test_read_lexicon(filename: str):
     fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa.draw("L.pdf", title="L")
 
-    fsa_disambig = lexicon_to_fst(
-        lexicon_disambig, phone2id=phone2id, word2id=word2id
-    )
+    fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
     fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
     fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py
index a12934d55..199acf6c3 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py
@@ -76,11 +76,7 @@ from beam_search import (
 )
 from train import add_model_arguments, get_params, get_transducer_model
 
-from icefall.checkpoint import (
-    average_checkpoints,
-    find_checkpoints,
-    load_checkpoint,
-)
+from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
 from icefall.lexicon import Lexicon
 from icefall.utils import (
     AttributeDict,
@@ -188,8 +184,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -249,9 +244,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
     if params.decoding_method == "fast_beam_search":
         hyp_tokens = fast_beam_search_one_best(
@@ -263,10 +256,7 @@ def decode_one_batch(
             max_contexts=params.max_contexts,
             max_states=params.max_states,
         )
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -387,9 +377,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -415,9 +403,7 @@ def save_results(
         # we compute CER for aishell dataset.
         results_char = []
         for res in results:
-            results_char.append(
-                (res[0], list("".join(res[1])), list("".join(res[2])))
-            )
+            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results_char, enable_log=True
@@ -428,8 +414,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -473,9 +458,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@@ -504,8 +487,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for"
-                f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/export.py b/egs/aishell/ASR/pruned_transducer_stateless2/export.py
index feababdd2..4d41e425c 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless2/export.py
@@ -50,11 +50,7 @@ from pathlib import Path
 import torch
 from train import add_model_arguments, get_params, get_transducer_model
 
-from icefall.checkpoint import (
-    average_checkpoints,
-    find_checkpoints,
-    load_checkpoint,
-)
+from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
 from icefall.lexicon import Lexicon
 from icefall.utils import str2bool
 
@@ -120,8 +116,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     add_model_arguments(parser)
@@ -157,8 +152,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for"
-                f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
@@ -191,9 +185,7 @@ def main():
         model.__class__.forward = torch.jit.ignore(model.__class__.forward)
         logging.info("Using torch.jit.script")
         model = torch.jit.script(model)
-        filename = (
-            params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
-        )
+        filename = params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
         model.save(str(filename))
         logging.info(f"Saved to {filename}")
     else:
@@ -201,17 +193,14 @@ def main():
         # Save it using a format so that it can be loaded
         # by :func:`load_checkpoint`
         filename = (
-            params.exp_dir
-            / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
+            params.exp_dir / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
         )
         torch.save({"model": model.state_dict()}, str(filename))
         logging.info(f"Saved to {filename}")
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py
index 3c38e5db7..8aa0fbdd7 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py
@@ -165,8 +165,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -197,8 +196,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -256,13 +254,9 @@ def main():
     feature_lens = [f.size(0) for f in features]
     feature_lens = torch.tensor(feature_lens, device=device)
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=features, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
 
     num_waves = encoder_out.size(0)
     hyp_list = []
@@ -310,9 +304,7 @@ def main():
                     beam=params.beam_size,
                 )
             else:
-                raise ValueError(
-                    f"Unsupported decoding method: {params.method}"
-                )
+                raise ValueError(f"Unsupported decoding method: {params.method}")
             hyp_list.append(hyp)
 
     hyps = []
@@ -329,9 +321,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/train.py b/egs/aishell/ASR/pruned_transducer_stateless2/train.py
index 97d892754..f81ab2568 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless2/train.py
@@ -49,7 +49,6 @@ import optim
 import torch
 import torch.multiprocessing as mp
 import torch.nn as nn
-
 from asr_datamodule import AishellAsrDataModule
 from conformer import Conformer
 from decoder import Decoder
@@ -75,9 +74,7 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[
-    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
 
 
 def add_model_arguments(parser: argparse.ArgumentParser):
@@ -203,8 +200,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need "
-        "to be changed.",
+        help="The initial learning rate.  This value should not need " "to be changed.",
     )
 
     parser.add_argument(
@@ -227,8 +223,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -251,8 +246,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network)" "part.",
     )
 
     parser.add_argument(
@@ -561,11 +555,7 @@ def compute_loss(
      warmup: a floating point value which increases throughout training;
         values >= 1.0 are fully warmed up and have all modules present.
     """
-    device = (
-        model.device
-        if isinstance(model, DDP)
-        else next(model.parameters()).device
-    )
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
     feature = batch["inputs"]
     # at entry, feature is (N, T, C)
     assert feature.ndim == 3
@@ -593,23 +583,16 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0
-            if warmup < 1.0
-            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
-        )
-        loss = (
-            params.simple_loss_scale * simple_loss
-            + pruned_loss_scale * pruned_loss
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
         )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
 
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -725,9 +708,7 @@ def train_one_epoch(
             scaler.update()
             optimizer.zero_grad()
         except:  # noqa
-            display_and_save_batch(
-                batch, params=params, graph_compiler=graph_compiler
-            )
+            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
             raise
 
         if params.print_diagnostics and batch_idx == 5:
@@ -891,7 +872,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2 ** 22
+            2**22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
@@ -1029,9 +1010,7 @@ def scan_pessimistic_batches_for_oom(
                     f"Failing criterion: {criterion} "
                     f"(={crit_values[criterion]}) ..."
                 )
-            display_and_save_batch(
-                batch, params=params, graph_compiler=graph_compiler
-            )
+            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
             raise
 
 
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py
index d159e420b..f6c919e9d 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py
@@ -202,8 +202,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -263,9 +262,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
     if params.decoding_method == "fast_beam_search":
         hyp_tokens = fast_beam_search_one_best(
@@ -277,10 +274,7 @@ def decode_one_batch(
             max_contexts=params.max_contexts,
             max_states=params.max_states,
         )
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -401,9 +395,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -429,9 +421,7 @@ def save_results(
         # we compute CER for aishell dataset.
         results_char = []
         for res in results:
-            results_char.append(
-                (res[0], list("".join(res[1])), list("".join(res[2])))
-            )
+            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results_char, enable_log=True
@@ -442,8 +432,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tCER", file=f)
@@ -488,9 +477,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@@ -518,9 +505,9 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
@@ -551,9 +538,9 @@ def main():
             )
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg + 1]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/export.py b/egs/aishell/ASR/pruned_transducer_stateless3/export.py
index 566902a85..5e701c121 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless3/export.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/export.py
@@ -132,8 +132,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     add_model_arguments(parser)
@@ -166,9 +165,9 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
@@ -195,9 +194,9 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg + 1]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
@@ -252,9 +251,7 @@ def main():
         model.__class__.forward = torch.jit.ignore(model.__class__.forward)
         logging.info("Using torch.jit.script")
         model = torch.jit.script(model)
-        filename = (
-            params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
-        )
+        filename = params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
         model.save(str(filename))
         logging.info(f"Saved to {filename}")
     else:
@@ -262,17 +259,14 @@ def main():
         # Save it using a format so that it can be loaded
         # by :func:`load_checkpoint`
         filename = (
-            params.exp_dir
-            / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
+            params.exp_dir / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
         )
         torch.save({"model": model.state_dict()}, str(filename))
         logging.info(f"Saved to {filename}")
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/model.py b/egs/aishell/ASR/pruned_transducer_stateless3/model.py
index e150e8230..a4dda0d6d 100644
--- a/egs/aishell/ASR/pruned_transducer_stateless3/model.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/model.py
@@ -84,9 +84,7 @@ class Transducer(nn.Module):
         self.decoder_datatang = decoder_datatang
         self.joiner_datatang = joiner_datatang
 
-        self.simple_am_proj = ScaledLinear(
-            encoder_dim, vocab_size, initial_speed=0.5
-        )
+        self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5)
         self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
 
         if decoder_datatang is not None:
@@ -179,9 +177,7 @@ class Transducer(nn.Module):
         y_padded = y.pad(mode="constant", padding_value=0)
 
         y_padded = y_padded.to(torch.int64)
-        boundary = torch.zeros(
-            (x.size(0), 4), dtype=torch.int64, device=x.device
-        )
+        boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
         boundary[:, 2] = y_lens
         boundary[:, 3] = encoder_out_lens
 
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py
index 04a0a882a..40926173c 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py
@@ -165,8 +165,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -197,8 +196,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -257,13 +255,9 @@ def main():
     feature_lens = [f.size(0) for f in features]
     feature_lens = torch.tensor(feature_lens, device=device)
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=features, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
 
     num_waves = encoder_out.size(0)
     hyp_list = []
@@ -311,9 +305,7 @@ def main():
                     beam=params.beam_size,
                 )
             else:
-                raise ValueError(
-                    f"Unsupported decoding method: {params.method}"
-                )
+                raise ValueError(f"Unsupported decoding method: {params.method}")
             hyp_list.append(hyp)
 
     hyps = []
@@ -330,9 +322,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/train.py b/egs/aishell/ASR/pruned_transducer_stateless3/train.py
index feaef5cf6..680986ee9 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless3/train.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/train.py
@@ -96,9 +96,7 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[
-    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
 
 
 def add_model_arguments(parser: argparse.ArgumentParser):
@@ -224,8 +222,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need "
-        "to be changed.",
+        help="The initial learning rate.  This value should not need " "to be changed.",
     )
 
     parser.add_argument(
@@ -248,8 +245,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -272,8 +268,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network)" "part.",
     )
 
     parser.add_argument(
@@ -635,11 +630,7 @@ def compute_loss(
      warmup: a floating point value which increases throughout training;
         values >= 1.0 are fully warmed up and have all modules present.
     """
-    device = (
-        model.device
-        if isinstance(model, DDP)
-        else next(model.parameters()).device
-    )
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
     feature = batch["inputs"]
     # at entry, feature is (N, T, C)
     assert feature.ndim == 3
@@ -670,23 +661,16 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0
-            if warmup < 1.0
-            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
-        )
-        loss = (
-            params.simple_loss_scale * simple_loss
-            + pruned_loss_scale * pruned_loss
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
         )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
 
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -824,9 +808,7 @@ def train_one_epoch(
                 )
             # summary stats
             if datatang_train_dl is not None:
-                tot_loss = (
-                    tot_loss * (1 - 1 / params.reset_interval)
-                ) + loss_info
+                tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
 
             if aishell:
                 aishell_tot_loss = (
@@ -847,9 +829,7 @@ def train_one_epoch(
             scaler.update()
             optimizer.zero_grad()
         except:  # noqa
-            display_and_save_batch(
-                batch, params=params, graph_compiler=graph_compiler
-            )
+            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
             raise
 
         if params.print_diagnostics and batch_idx == 5:
@@ -892,9 +872,7 @@ def train_one_epoch(
             cur_lr = scheduler.get_last_lr()[0]
             if datatang_train_dl is not None:
                 datatang_str = f"datatang_tot_loss[{datatang_tot_loss}], "
-                tot_loss_str = (
-                    f"tot_loss[{tot_loss}], batch size: {batch_size}, "
-                )
+                tot_loss_str = f"tot_loss[{tot_loss}], batch size: {batch_size}, "
             else:
                 tot_loss_str = ""
                 datatang_str = ""
@@ -1067,7 +1045,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2 ** 22
+            2**22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
@@ -1076,9 +1054,7 @@ def run(rank, world_size, args):
     train_cuts = filter_short_and_long_utterances(train_cuts)
 
     if args.enable_musan:
-        cuts_musan = load_manifest(
-            Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
     else:
         cuts_musan = None
 
@@ -1093,9 +1069,7 @@ def run(rank, world_size, args):
     if params.datatang_prob > 0:
         datatang = AIDatatang200zh(manifest_dir=args.manifest_dir)
         train_datatang_cuts = datatang.train_cuts()
-        train_datatang_cuts = filter_short_and_long_utterances(
-            train_datatang_cuts
-        )
+        train_datatang_cuts = filter_short_and_long_utterances(train_datatang_cuts)
         train_datatang_cuts = train_datatang_cuts.repeat(times=None)
         datatang_train_dl = asr_datamodule.train_dataloaders(
             train_datatang_cuts,
@@ -1249,9 +1223,7 @@ def scan_pessimistic_batches_for_oom(
                     f"Failing criterion: {criterion} "
                     f"(={crit_values[criterion]}) ..."
                 )
-            display_and_save_batch(
-                batch, params=params, graph_compiler=graph_compiler
-            )
+            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
             raise
 
 
diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py
index d24ba6bb7..fc28e8dbc 100644
--- a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py
+++ b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -183,17 +183,13 @@ class AishellAsrDataModule:
 
     def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(
-            self.args.manifest_dir / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
@@ -215,9 +211,7 @@ class AishellAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -260,9 +254,7 @@ class AishellAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -308,9 +300,7 @@ class AishellAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -366,13 +356,9 @@ class AishellAsrDataModule:
     @lru_cache()
     def valid_cuts(self) -> CutSet:
         logging.info("About to get dev cuts")
-        return load_manifest_lazy(
-            self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz"
-        )
+        return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz")
 
     @lru_cache()
     def test_cuts(self) -> List[CutSet]:
         logging.info("About to get test cuts")
-        return load_manifest_lazy(
-            self.args.manifest_dir / "aishell_cuts_test.jsonl.gz"
-        )
+        return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_test.jsonl.gz")
diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py
index 66b734fc4..824ca2a92 100755
--- a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py
+++ b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py
@@ -265,9 +265,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -289,9 +287,7 @@ def save_results(
         # We compute CER for aishell dataset.
         results_char = []
         for res in results:
-            results_char.append(
-                (res[0], list("".join(res[1])), list("".join(res[2])))
-            )
+            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
         with open(errs_filename, "w") as f:
             wer = write_error_stats(f, f"{test_set_name}-{key}", results_char)
             test_set_wers[key] = wer
@@ -335,9 +331,7 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(
-        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
-    )
+    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -362,9 +356,7 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save(
-            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
-        )
+        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
 
     model.to(device)
     model.eval()
@@ -392,9 +384,7 @@ def main():
             lexicon=lexicon,
         )
 
-        save_results(
-            params=params, test_set_name=test_set, results_dict=results_dict
-        )
+        save_results(params=params, test_set_name=test_set, results_dict=results_dict)
 
     logging.info("Done!")
 
diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/model.py b/egs/aishell/ASR/tdnn_lstm_ctc/model.py
index 5e04c11b4..1731e1ebe 100644
--- a/egs/aishell/ASR/tdnn_lstm_ctc/model.py
+++ b/egs/aishell/ASR/tdnn_lstm_ctc/model.py
@@ -66,10 +66,7 @@ class TdnnLstm(nn.Module):
             nn.BatchNorm1d(num_features=500, affine=False),
         )
         self.lstms = nn.ModuleList(
-            [
-                nn.LSTM(input_size=500, hidden_size=500, num_layers=1)
-                for _ in range(5)
-            ]
+            [nn.LSTM(input_size=500, hidden_size=500, num_layers=1) for _ in range(5)]
         )
         self.lstm_bnorms = nn.ModuleList(
             [nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)]
diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py
index 9bd810809..fe197a9f9 100644
--- a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py
+++ b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py
@@ -53,9 +53,7 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument(
-        "--HLG", type=str, required=True, help="Path to HLG.pt."
-    )
+    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
 
     parser.add_argument(
         "--method",
@@ -113,8 +111,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -173,9 +170,7 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
     features = features.permute(0, 2, 1)  # now features is [N, C, T]
 
     with torch.no_grad():
@@ -219,9 +214,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/train.py b/egs/aishell/ASR/tdnn_lstm_ctc/train.py
index 7619b0551..e574cf89b 100755
--- a/egs/aishell/ASR/tdnn_lstm_ctc/train.py
+++ b/egs/aishell/ASR/tdnn_lstm_ctc/train.py
@@ -49,12 +49,7 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
 from icefall.dist import cleanup_dist, setup_dist
 from icefall.graph_compiler import CtcTrainingGraphCompiler
 from icefall.lexicon import Lexicon
-from icefall.utils import (
-    AttributeDict,
-    encode_supervisions,
-    setup_logger,
-    str2bool,
-)
+from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool
 
 
 def get_parser():
diff --git a/egs/aishell/ASR/transducer_stateless/beam_search.py b/egs/aishell/ASR/transducer_stateless/beam_search.py
index 9ed9b2ad1..de0a8d0f5 100644
--- a/egs/aishell/ASR/transducer_stateless/beam_search.py
+++ b/egs/aishell/ASR/transducer_stateless/beam_search.py
@@ -47,9 +47,9 @@ def greedy_search(
 
     device = model.device
 
-    decoder_input = torch.tensor(
-        [blank_id] * context_size, device=device
-    ).reshape(1, context_size)
+    decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
+        1, context_size
+    )
 
     decoder_out = model.decoder(decoder_input, need_pad=False)
 
@@ -81,9 +81,9 @@ def greedy_search(
         y = logits.argmax().item()
         if y != blank_id:
             hyp.append(y)
-            decoder_input = torch.tensor(
-                [hyp[-context_size:]], device=device
-            ).reshape(1, context_size)
+            decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape(
+                1, context_size
+            )
 
             decoder_out = model.decoder(decoder_input, need_pad=False)
 
@@ -157,9 +157,7 @@ class HypothesisList(object):
 
         """
         if length_norm:
-            return max(
-                self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
-            )
+            return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
         else:
             return max(self._data.values(), key=lambda hyp: hyp.log_prob)
 
@@ -246,9 +244,9 @@ def beam_search(
 
     device = model.device
 
-    decoder_input = torch.tensor(
-        [blank_id] * context_size, device=device
-    ).reshape(1, context_size)
+    decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
+        1, context_size
+    )
 
     decoder_out = model.decoder(decoder_input, need_pad=False)
 
diff --git a/egs/aishell/ASR/transducer_stateless/conformer.py b/egs/aishell/ASR/transducer_stateless/conformer.py
index 64114253d..78424aea2 100644
--- a/egs/aishell/ASR/transducer_stateless/conformer.py
+++ b/egs/aishell/ASR/transducer_stateless/conformer.py
@@ -155,9 +155,7 @@ class ConformerEncoderLayer(nn.Module):
         normalize_before: bool = True,
     ) -> None:
         super(ConformerEncoderLayer, self).__init__()
-        self.self_attn = RelPositionMultiheadAttention(
-            d_model, nhead, dropout=0.0
-        )
+        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
 
         self.feed_forward = nn.Sequential(
             nn.Linear(d_model, dim_feedforward),
@@ -175,18 +173,14 @@ class ConformerEncoderLayer(nn.Module):
 
         self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
 
-        self.norm_ff_macaron = nn.LayerNorm(
-            d_model
-        )  # for the macaron style FNN module
+        self.norm_ff_macaron = nn.LayerNorm(d_model)  # for the macaron style FNN module
         self.norm_ff = nn.LayerNorm(d_model)  # for the FNN module
         self.norm_mha = nn.LayerNorm(d_model)  # for the MHA module
 
         self.ff_scale = 0.5
 
         self.norm_conv = nn.LayerNorm(d_model)  # for the CNN module
-        self.norm_final = nn.LayerNorm(
-            d_model
-        )  # for the final output of the block
+        self.norm_final = nn.LayerNorm(d_model)  # for the final output of the block
 
         self.dropout = nn.Dropout(dropout)
 
@@ -220,9 +214,7 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(
-            self.feed_forward_macaron(src)
-        )
+        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -341,9 +333,7 @@ class RelPositionalEncoding(torch.nn.Module):
 
     """
 
-    def __init__(
-        self, d_model: int, dropout_rate: float, max_len: int = 5000
-    ) -> None:
+    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
         """Construct an PositionalEncoding object."""
         super(RelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -359,9 +349,7 @@ class RelPositionalEncoding(torch.nn.Module):
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(1) >= x.size(1) * 2 - 1:
                 # Note: TorchScript doesn't implement operator== for torch.Device
-                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
-                    x.device
-                ):
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
                     self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                 return
         # Suppose `i` means to the position of query vector and `j` means the
@@ -631,9 +619,9 @@ class RelPositionMultiheadAttention(nn.Module):
 
         if torch.equal(query, key) and torch.equal(key, value):
             # self-attention
-            q, k, v = nn.functional.linear(
-                query, in_proj_weight, in_proj_bias
-            ).chunk(3, dim=-1)
+            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
+                3, dim=-1
+            )
 
         elif torch.equal(key, value):
             # encoder-decoder attention
@@ -701,31 +689,22 @@ class RelPositionMultiheadAttention(nn.Module):
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0)
                 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
-                    raise RuntimeError(
-                        "The size of the 2D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
             elif attn_mask.dim() == 3:
                 if list(attn_mask.size()) != [
                     bsz * num_heads,
                     query.size(0),
                     key.size(0),
                 ]:
-                    raise RuntimeError(
-                        "The size of the 3D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
             else:
                 raise RuntimeError(
-                    "attn_mask's dimension {} is not supported".format(
-                        attn_mask.dim()
-                    )
+                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
                 )
             # attn_mask's dim is 3 now.
 
         # convert ByteTensor key_padding_mask to bool
-        if (
-            key_padding_mask is not None
-            and key_padding_mask.dtype == torch.uint8
-        ):
+        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
             warnings.warn(
                 "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
             )
@@ -764,9 +743,7 @@ class RelPositionMultiheadAttention(nn.Module):
         # first compute matrix a and matrix c
         # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
         k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
-        matrix_ac = torch.matmul(
-            q_with_bias_u, k
-        )  # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
 
         # compute matrix b and matrix d
         matrix_bd = torch.matmul(
@@ -778,9 +755,7 @@ class RelPositionMultiheadAttention(nn.Module):
             matrix_ac + matrix_bd
         ) * scaling  # (batch, head, time1, time2)
 
-        attn_output_weights = attn_output_weights.view(
-            bsz * num_heads, tgt_len, -1
-        )
+        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
 
         assert list(attn_output_weights.size()) == [
             bsz * num_heads,
@@ -814,13 +789,9 @@ class RelPositionMultiheadAttention(nn.Module):
         attn_output = torch.bmm(attn_output_weights, v)
         assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
         attn_output = (
-            attn_output.transpose(0, 1)
-            .contiguous()
-            .view(tgt_len, bsz, embed_dim)
-        )
-        attn_output = nn.functional.linear(
-            attn_output, out_proj_weight, out_proj_bias
+            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
         )
+        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
 
         if need_weights:
             # average attention weights over heads
@@ -843,9 +814,7 @@ class ConvolutionModule(nn.Module):
 
     """
 
-    def __init__(
-        self, channels: int, kernel_size: int, bias: bool = True
-    ) -> None:
+    def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None:
         """Construct an ConvolutionModule object."""
         super(ConvolutionModule, self).__init__()
         # kernerl_size should be a odd number for 'SAME' padding
diff --git a/egs/aishell/ASR/transducer_stateless/decode.py b/egs/aishell/ASR/transducer_stateless/decode.py
index 780b0c4bb..fbc54f68b 100755
--- a/egs/aishell/ASR/transducer_stateless/decode.py
+++ b/egs/aishell/ASR/transducer_stateless/decode.py
@@ -99,8 +99,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -227,9 +226,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
     batch_size = encoder_out.size(0)
 
@@ -248,9 +245,7 @@ def decode_one_batch(
                 model=model, encoder_out=encoder_out_i, beam=params.beam_size
             )
         else:
-            raise ValueError(
-                f"Unsupported decoding method: {params.decoding_method}"
-            )
+            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
         hyps.append([lexicon.token_table[i] for i in hyp])
 
     if params.decoding_method == "greedy_search":
@@ -319,9 +314,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -346,9 +339,7 @@ def save_results(
         # we compute CER for aishell dataset.
         results_char = []
         for res in results:
-            results_char.append(
-                (res[0], list("".join(res[1])), list("".join(res[2])))
-            )
+            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results_char, enable_log=True
@@ -359,8 +350,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tCER", file=f)
@@ -430,9 +420,7 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save(
-            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
-        )
+        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
         return
 
     model.to(device)
diff --git a/egs/aishell/ASR/transducer_stateless/decoder.py b/egs/aishell/ASR/transducer_stateless/decoder.py
index c2c6552a9..70e9e6c96 100644
--- a/egs/aishell/ASR/transducer_stateless/decoder.py
+++ b/egs/aishell/ASR/transducer_stateless/decoder.py
@@ -86,9 +86,7 @@ class Decoder(nn.Module):
         if self.context_size > 1:
             embedding_out = embedding_out.permute(0, 2, 1)
             if need_pad is True:
-                embedding_out = F.pad(
-                    embedding_out, pad=(self.context_size - 1, 0)
-                )
+                embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
             else:
                 # During inference time, there is no need to do extra padding
                 # as we only need one output
diff --git a/egs/aishell/ASR/transducer_stateless/export.py b/egs/aishell/ASR/transducer_stateless/export.py
index 4c6519b96..eea9b6883 100755
--- a/egs/aishell/ASR/transducer_stateless/export.py
+++ b/egs/aishell/ASR/transducer_stateless/export.py
@@ -110,8 +110,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     return parser
@@ -243,9 +242,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/transducer_stateless/model.py b/egs/aishell/ASR/transducer_stateless/model.py
index 994305fc1..591bbe44f 100644
--- a/egs/aishell/ASR/transducer_stateless/model.py
+++ b/egs/aishell/ASR/transducer_stateless/model.py
@@ -103,9 +103,7 @@ class Transducer(nn.Module):
         y_padded = y.pad(mode="constant", padding_value=0)
 
         y_padded = y_padded.to(torch.int64)
-        boundary = torch.zeros(
-            (x.size(0), 4), dtype=torch.int64, device=x.device
-        )
+        boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
         boundary[:, 2] = y_lens
         boundary[:, 3] = x_lens
 
diff --git a/egs/aishell/ASR/transducer_stateless/pretrained.py b/egs/aishell/ASR/transducer_stateless/pretrained.py
index db89c4d67..b03a2643a 100755
--- a/egs/aishell/ASR/transducer_stateless/pretrained.py
+++ b/egs/aishell/ASR/transducer_stateless/pretrained.py
@@ -117,8 +117,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -212,8 +211,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -273,9 +271,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -319,9 +315,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/transducer_stateless/train.py b/egs/aishell/ASR/transducer_stateless/train.py
index d54157709..4ea902507 100755
--- a/egs/aishell/ASR/transducer_stateless/train.py
+++ b/egs/aishell/ASR/transducer_stateless/train.py
@@ -126,8 +126,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -389,9 +388,7 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -504,9 +501,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -625,9 +620,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/aishell/ASR/transducer_stateless/transformer.py b/egs/aishell/ASR/transducer_stateless/transformer.py
index e851dcc32..b3ff153c1 100644
--- a/egs/aishell/ASR/transducer_stateless/transformer.py
+++ b/egs/aishell/ASR/transducer_stateless/transformer.py
@@ -250,9 +250,7 @@ def _get_activation_fn(activation: str):
     elif activation == "gelu":
         return nn.functional.gelu
 
-    raise RuntimeError(
-        "activation should be relu/gelu, not {}".format(activation)
-    )
+    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
 
 
 class PositionalEncoding(nn.Module):
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py b/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py
index 838e53658..5d49d7338 100644
--- a/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py
@@ -29,10 +29,7 @@ from lhotse.dataset import (
     K2SpeechRecognitionDataset,
     SpecAugment,
 )
-from lhotse.dataset.input_strategies import (
-    OnTheFlyFeatures,
-    PrecomputedFeatures,
-)
+from lhotse.dataset.input_strategies import OnTheFlyFeatures, PrecomputedFeatures
 from torch.utils.data import DataLoader
 
 from icefall.utils import str2bool
@@ -162,9 +159,7 @@ class AsrDataModule:
         if cuts_musan is not None:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
@@ -173,9 +168,7 @@ class AsrDataModule:
 
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -252,9 +245,7 @@ class AsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py
index ea3f94fd8..cb206af6d 100755
--- a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py
@@ -170,8 +170,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -227,9 +226,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
     if params.decoding_method == "fast_beam_search":
         hyp_tokens = fast_beam_search_one_best(
@@ -241,10 +238,7 @@ def decode_one_batch(
             max_contexts=params.max_contexts,
             max_states=params.max_states,
         )
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -365,9 +359,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -393,9 +385,7 @@ def save_results(
         # we compute CER for aishell dataset.
         results_char = []
         for res in results:
-            results_char.append(
-                (res[0], list("".join(res[1])), list("".join(res[2])))
-            )
+            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results_char, enable_log=True
@@ -406,8 +396,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tCER", file=f)
@@ -448,9 +437,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/export.py b/egs/aishell/ASR/transducer_stateless_modified-2/export.py
index 3bd2ceb11..3c56d4a01 100755
--- a/egs/aishell/ASR/transducer_stateless_modified-2/export.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/export.py
@@ -109,8 +109,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     return parser
@@ -241,9 +240,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py
index a95a4bc52..d8c0c5fcd 100755
--- a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py
@@ -165,8 +165,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -195,8 +194,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -254,13 +252,9 @@ def main():
     feature_lens = [f.size(0) for f in features]
     feature_lens = torch.tensor(feature_lens, device=device)
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=features, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
 
     num_waves = encoder_out.size(0)
     hyp_list = []
@@ -308,9 +302,7 @@ def main():
                     beam=params.beam_size,
                 )
             else:
-                raise ValueError(
-                    f"Unsupported decoding method: {params.method}"
-                )
+                raise ValueError(f"Unsupported decoding method: {params.method}")
             hyp_list.append(hyp)
 
     hyps = []
@@ -327,9 +319,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/train.py b/egs/aishell/ASR/transducer_stateless_modified-2/train.py
index 225d0d709..a9a30d7f7 100755
--- a/egs/aishell/ASR/transducer_stateless_modified-2/train.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/train.py
@@ -149,8 +149,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -168,8 +167,7 @@ def get_parser():
         "--datatang-prob",
         type=float,
         default=0.2,
-        help="The probability to select a batch from the "
-        "aidatatang_200zh dataset",
+        help="The probability to select a batch from the " "aidatatang_200zh dataset",
     )
 
     return parser
@@ -449,9 +447,7 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -605,9 +601,7 @@ def train_one_epoch(
                     f"train/current_{prefix}_",
                     params.batch_idx_train,
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
                 aishell_tot_loss.write_summary(
                     tb_writer, "train/aishell_tot_", params.batch_idx_train
                 )
@@ -735,9 +729,7 @@ def run(rank, world_size, args):
     train_datatang_cuts = train_datatang_cuts.repeat(times=None)
 
     if args.enable_musan:
-        cuts_musan = load_manifest(
-            Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
     else:
         cuts_musan = None
 
@@ -776,9 +768,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/aishell/ASR/transducer_stateless_modified/decode.py b/egs/aishell/ASR/transducer_stateless_modified/decode.py
index 65fcda873..ba3cb3218 100755
--- a/egs/aishell/ASR/transducer_stateless_modified/decode.py
+++ b/egs/aishell/ASR/transducer_stateless_modified/decode.py
@@ -171,8 +171,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -231,9 +230,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
     if params.decoding_method == "fast_beam_search":
         hyp_tokens = fast_beam_search_one_best(
@@ -245,10 +242,7 @@ def decode_one_batch(
             max_contexts=params.max_contexts,
             max_states=params.max_states,
         )
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -369,9 +363,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -397,9 +389,7 @@ def save_results(
         # we compute CER for aishell dataset.
         results_char = []
         for res in results:
-            results_char.append(
-                (res[0], list("".join(res[1])), list("".join(res[2])))
-            )
+            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results_char, enable_log=True
@@ -410,8 +400,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tCER", file=f)
@@ -452,9 +441,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
diff --git a/egs/aishell/ASR/transducer_stateless_modified/export.py b/egs/aishell/ASR/transducer_stateless_modified/export.py
index 11335a834..cbdbdbeb6 100755
--- a/egs/aishell/ASR/transducer_stateless_modified/export.py
+++ b/egs/aishell/ASR/transducer_stateless_modified/export.py
@@ -109,8 +109,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     return parser
@@ -241,9 +240,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py
index 262e822c2..7dfa92a3c 100755
--- a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py
+++ b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py
@@ -165,8 +165,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -195,8 +194,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -254,13 +252,9 @@ def main():
     feature_lens = [f.size(0) for f in features]
     feature_lens = torch.tensor(feature_lens, device=device)
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=features, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
 
     num_waves = encoder_out.size(0)
     hyp_list = []
@@ -308,9 +302,7 @@ def main():
                     beam=params.beam_size,
                 )
             else:
-                raise ValueError(
-                    f"Unsupported decoding method: {params.method}"
-                )
+                raise ValueError(f"Unsupported decoding method: {params.method}")
             hyp_list.append(hyp)
 
     hyps = []
@@ -327,9 +319,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/transducer_stateless_modified/train.py b/egs/aishell/ASR/transducer_stateless_modified/train.py
index d3ffccafa..c4bf4dd56 100755
--- a/egs/aishell/ASR/transducer_stateless_modified/train.py
+++ b/egs/aishell/ASR/transducer_stateless_modified/train.py
@@ -142,8 +142,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -414,9 +413,7 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -529,9 +526,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -657,9 +652,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/aishell2/ASR/local/__init__.py b/egs/aishell2/ASR/local/__init__.py
old mode 100755
new mode 100644
diff --git a/egs/aishell2/ASR/local/compute_fbank_aishell2.py b/egs/aishell2/ASR/local/compute_fbank_aishell2.py
index d8d3622bd..ec0c584ca 100755
--- a/egs/aishell2/ASR/local/compute_fbank_aishell2.py
+++ b/egs/aishell2/ASR/local/compute_fbank_aishell2.py
@@ -83,9 +83,7 @@ def compute_fbank_aishell2(num_mel_bins: int = 80):
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set
-                    + cut_set.perturb_speed(0.9)
-                    + cut_set.perturb_speed(1.1)
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -111,9 +109,7 @@ def get_args():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/__init__.py b/egs/aishell2/ASR/pruned_transducer_stateless5/__init__.py
old mode 100755
new mode 100644
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py
old mode 100755
new mode 100644
index b7a21f579..0f383a244
--- a/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py
@@ -216,13 +216,9 @@ class AiShell2AsrDataModule:
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             logging.info("About to get Musan cuts")
-            cuts_musan = load_manifest(
-                self.args.manifest_dir / "musan_cuts.jsonl.gz"
-            )
+            cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
@@ -244,9 +240,7 @@ class AiShell2AsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -290,9 +284,7 @@ class AiShell2AsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -348,9 +340,7 @@ class AiShell2AsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -406,9 +396,7 @@ class AiShell2AsrDataModule:
     @lru_cache()
     def valid_cuts(self) -> CutSet:
         logging.info("About to gen cuts from aishell2_cuts_dev.jsonl.gz")
-        return load_manifest_lazy(
-            self.args.manifest_dir / "aishell2_cuts_dev.jsonl.gz"
-        )
+        return load_manifest_lazy(self.args.manifest_dir / "aishell2_cuts_dev.jsonl.gz")
 
     @lru_cache()
     def test_cuts(self) -> CutSet:
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py
index 915737f4a..7900c5883 100755
--- a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py
@@ -269,8 +269,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -348,9 +347,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -409,10 +406,7 @@ def decode_one_batch(
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -538,9 +532,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -573,8 +565,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -625,9 +616,7 @@ def main():
             if "LG" in params.decoding_method:
                 params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@@ -661,9 +650,9 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
@@ -690,9 +679,9 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg + 1]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
@@ -749,9 +738,7 @@ def main():
             )
             decoding_graph.scores *= params.ngram_lm_scale
         else:
-            decoding_graph = k2.trivial_graph(
-                params.vocab_size - 1, device=device
-            )
+            decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
     else:
         decoding_graph = None
 
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py
index bc7bd71cb..ea4a8d4f9 100755
--- a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py
@@ -133,8 +133,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     add_model_arguments(parser)
@@ -167,9 +166,9 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
@@ -196,9 +195,9 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg + 1]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
@@ -266,9 +265,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py b/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py
index 09de1bece..94536fa6f 100755
--- a/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py
@@ -159,8 +159,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -192,8 +191,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -254,15 +252,11 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=features, x_lens=feature_lengths
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
 
     num_waves = encoder_out.size(0)
     hyps = []
@@ -334,9 +328,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py
index 838a0497f..4a228113d 100755
--- a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py
@@ -92,9 +92,7 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[
-    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
 
 
 def add_model_arguments(parser: argparse.ArgumentParser):
@@ -220,8 +218,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need "
-        "to be changed.",
+        help="The initial learning rate.  This value should not need " "to be changed.",
     )
 
     parser.add_argument(
@@ -244,8 +241,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -268,8 +264,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network)" "part.",
     )
 
     parser.add_argument(
@@ -603,11 +598,7 @@ def compute_loss(
      warmup: a floating point value which increases throughout training;
         values >= 1.0 are fully warmed up and have all modules present.
     """
-    device = (
-        model.device
-        if isinstance(model, DDP)
-        else next(model.parameters()).device
-    )
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
     feature = batch["inputs"]
     # at entry, feature is (N, T, C)
     assert feature.ndim == 3
@@ -636,23 +627,16 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0
-            if warmup < 1.0
-            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
-        )
-        loss = (
-            params.simple_loss_scale * simple_loss
-            + pruned_loss_scale * pruned_loss
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
         )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
 
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -771,9 +755,7 @@ def train_one_epoch(
             scaler.update()
             optimizer.zero_grad()
         except:  # noqa
-            display_and_save_batch(
-                batch, params=params, graph_compiler=graph_compiler
-            )
+            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
             raise
 
         if params.print_diagnostics and batch_idx == 5:
@@ -829,9 +811,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -939,7 +919,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2 ** 22
+            2**22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
@@ -1104,9 +1084,7 @@ def scan_pessimistic_batches_for_oom(
                     f"Failing criterion: {criterion} "
                     f"(={crit_values[criterion]}) ..."
                 )
-            display_and_save_batch(
-                batch, params=params, graph_compiler=graph_compiler
-            )
+            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
             raise
 
 
diff --git a/egs/aishell4/ASR/local/compute_fbank_aishell4.py b/egs/aishell4/ASR/local/compute_fbank_aishell4.py
index 3f50d9e3e..400c406f0 100755
--- a/egs/aishell4/ASR/local/compute_fbank_aishell4.py
+++ b/egs/aishell4/ASR/local/compute_fbank_aishell4.py
@@ -85,9 +85,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set
-                    + cut_set.perturb_speed(0.9)
-                    + cut_set.perturb_speed(1.1)
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -120,9 +118,7 @@ def get_args():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/aishell4/ASR/local/prepare_char.py b/egs/aishell4/ASR/local/prepare_char.py
index d9e47d17a..6b440dfb3 100755
--- a/egs/aishell4/ASR/local/prepare_char.py
+++ b/egs/aishell4/ASR/local/prepare_char.py
@@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil(
         cur_state = loop_state
 
         word = word2id[word]
-        pieces = [
-            token2id[i] if i in token2id else token2id[""] for i in pieces
-        ]
+        pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces]
 
         for i in range(len(pieces) - 1):
             w = word if i == 0 else eps
@@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
     return False
 
 
-def generate_lexicon(
-    token_sym_table: Dict[str, int], words: List[str]
-) -> Lexicon:
+def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
     """Generate a lexicon from a word list and token_sym_table.
 
     Args:
diff --git a/egs/aishell4/ASR/local/prepare_lang.py b/egs/aishell4/ASR/local/prepare_lang.py
index e5ae89ec4..c8cf9b881 100755
--- a/egs/aishell4/ASR/local/prepare_lang.py
+++ b/egs/aishell4/ASR/local/prepare_lang.py
@@ -317,9 +317,7 @@ def lexicon_to_fst(
 
 def get_args():
     parser = argparse.ArgumentParser()
-    parser.add_argument(
-        "--lang-dir", type=str, help="The lang dir, data/lang_phone"
-    )
+    parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
     return parser.parse_args()
 
 
diff --git a/egs/aishell4/ASR/local/test_prepare_lang.py b/egs/aishell4/ASR/local/test_prepare_lang.py
index d4cf62bba..74e025ad7 100755
--- a/egs/aishell4/ASR/local/test_prepare_lang.py
+++ b/egs/aishell4/ASR/local/test_prepare_lang.py
@@ -88,9 +88,7 @@ def test_read_lexicon(filename: str):
     fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa.draw("L.pdf", title="L")
 
-    fsa_disambig = lexicon_to_fst(
-        lexicon_disambig, phone2id=phone2id, word2id=word2id
-    )
+    fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
     fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
     fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
diff --git a/egs/aishell4/ASR/local/text2token.py b/egs/aishell4/ASR/local/text2token.py
index 71be2a613..85047c367 100755
--- a/egs/aishell4/ASR/local/text2token.py
+++ b/egs/aishell4/ASR/local/text2token.py
@@ -56,9 +56,7 @@ def get_parser():
     parser.add_argument(
         "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
     )
-    parser.add_argument(
-        "--space", default="", type=str, help="space symbol"
-    )
+    parser.add_argument("--space", default="", type=str, help="space symbol")
     parser.add_argument(
         "--non-lang-syms",
         "-l",
@@ -66,9 +64,7 @@ def get_parser():
         type=str,
         help="list of non-linguistic symobles, e.g.,  etc.",
     )
-    parser.add_argument(
-        "text", type=str, default=False, nargs="?", help="input text"
-    )
+    parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
     parser.add_argument(
         "--trans_type",
         "-t",
@@ -108,8 +104,7 @@ def token2id(
             if token_type == "lazy_pinyin":
                 text = lazy_pinyin(chars_list)
                 sub_ids = [
-                    token_table[txt] if txt in token_table else oov_id
-                    for txt in text
+                    token_table[txt] if txt in token_table else oov_id for txt in text
                 ]
                 ids.append(sub_ids)
             else:  # token_type = "pinyin"
@@ -135,9 +130,7 @@ def main():
     if args.text:
         f = codecs.open(args.text, encoding="utf-8")
     else:
-        f = codecs.getreader("utf-8")(
-            sys.stdin if is_python2 else sys.stdin.buffer
-        )
+        f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
 
     sys.stdout = codecs.getwriter("utf-8")(
         sys.stdout if is_python2 else sys.stdout.buffer
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py
index 7aa53ddda..d980a857f 100644
--- a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py
@@ -222,17 +222,13 @@ class Aishell4AsrDataModule:
             The state dict for the training sampler.
         """
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(
-            self.args.manifest_dir / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
@@ -254,9 +250,7 @@ class Aishell4AsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -300,9 +294,7 @@ class Aishell4AsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -359,9 +351,7 @@ class Aishell4AsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py
index 14e44c7d9..cb533df35 100755
--- a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py
@@ -201,8 +201,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -260,9 +259,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -277,10 +274,7 @@ def decode_one_batch(
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -401,9 +395,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -436,8 +428,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -480,9 +471,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@@ -510,9 +499,9 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
@@ -543,9 +532,9 @@ def main():
             )
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg + 1]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py
index 993341131..cc9b7b444 100755
--- a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py
@@ -136,8 +136,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     add_model_arguments(parser)
@@ -169,9 +168,9 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
@@ -202,9 +201,9 @@ def main():
             )
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg + 1]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
@@ -276,9 +275,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py b/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py
index 1fa893637..a234f9d65 100755
--- a/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py
@@ -172,8 +172,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -205,8 +204,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -266,15 +264,11 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=features, x_lens=feature_lengths
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
 
     num_waves = encoder_out.size(0)
     hyps = []
@@ -306,10 +300,7 @@ def main():
 
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -350,9 +341,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py
index 0a48b9059..73ee34284 100755
--- a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py
@@ -85,9 +85,7 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[
-    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
 
 
 def add_model_arguments(parser: argparse.ArgumentParser):
@@ -213,8 +211,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need "
-        "to be changed.",
+        help="The initial learning rate.  This value should not need " "to be changed.",
     )
 
     parser.add_argument(
@@ -237,8 +234,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -261,8 +257,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network)" "part.",
     )
 
     parser.add_argument(
@@ -599,11 +594,7 @@ def compute_loss(
      warmup: a floating point value which increases throughout training;
         values >= 1.0 are fully warmed up and have all modules present.
     """
-    device = (
-        model.device
-        if isinstance(model, DDP)
-        else next(model.parameters()).device
-    )
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
     feature = batch["inputs"]
     # at entry, feature is (N, T, C)
     assert feature.ndim == 3
@@ -633,22 +624,15 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0
-            if warmup < 1.0
-            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
-        )
-        loss = (
-            params.simple_loss_scale * simple_loss
-            + pruned_loss_scale * pruned_loss
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
         )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -827,9 +811,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -937,7 +919,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2 ** 22
+            2**22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
diff --git a/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py b/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py
index af926aa53..96115a230 100755
--- a/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py
+++ b/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py
@@ -84,9 +84,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80):
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set
-                    + cut_set.perturb_speed(0.9)
-                    + cut_set.perturb_speed(1.1)
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
                 )
             cur_num_jobs = num_jobs if ex is None else 80
             cur_num_jobs = min(cur_num_jobs, len(cut_set))
@@ -121,9 +119,7 @@ def get_args():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/alimeeting/ASR/local/prepare_char.py b/egs/alimeeting/ASR/local/prepare_char.py
index d9e47d17a..6b440dfb3 100755
--- a/egs/alimeeting/ASR/local/prepare_char.py
+++ b/egs/alimeeting/ASR/local/prepare_char.py
@@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil(
         cur_state = loop_state
 
         word = word2id[word]
-        pieces = [
-            token2id[i] if i in token2id else token2id[""] for i in pieces
-        ]
+        pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces]
 
         for i in range(len(pieces) - 1):
             w = word if i == 0 else eps
@@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
     return False
 
 
-def generate_lexicon(
-    token_sym_table: Dict[str, int], words: List[str]
-) -> Lexicon:
+def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
     """Generate a lexicon from a word list and token_sym_table.
 
     Args:
diff --git a/egs/alimeeting/ASR/local/prepare_lang.py b/egs/alimeeting/ASR/local/prepare_lang.py
index e5ae89ec4..c8cf9b881 100755
--- a/egs/alimeeting/ASR/local/prepare_lang.py
+++ b/egs/alimeeting/ASR/local/prepare_lang.py
@@ -317,9 +317,7 @@ def lexicon_to_fst(
 
 def get_args():
     parser = argparse.ArgumentParser()
-    parser.add_argument(
-        "--lang-dir", type=str, help="The lang dir, data/lang_phone"
-    )
+    parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
     return parser.parse_args()
 
 
diff --git a/egs/alimeeting/ASR/local/test_prepare_lang.py b/egs/alimeeting/ASR/local/test_prepare_lang.py
index d4cf62bba..74e025ad7 100755
--- a/egs/alimeeting/ASR/local/test_prepare_lang.py
+++ b/egs/alimeeting/ASR/local/test_prepare_lang.py
@@ -88,9 +88,7 @@ def test_read_lexicon(filename: str):
     fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa.draw("L.pdf", title="L")
 
-    fsa_disambig = lexicon_to_fst(
-        lexicon_disambig, phone2id=phone2id, word2id=word2id
-    )
+    fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
     fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
     fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
diff --git a/egs/alimeeting/ASR/local/text2segments.py b/egs/alimeeting/ASR/local/text2segments.py
index 7c1019aa8..27b904fc8 100644
--- a/egs/alimeeting/ASR/local/text2segments.py
+++ b/egs/alimeeting/ASR/local/text2segments.py
@@ -30,8 +30,8 @@ with word segmenting:
 
 import argparse
 
-import paddle
 import jieba
+import paddle
 from tqdm import tqdm
 
 paddle.enable_static()
diff --git a/egs/alimeeting/ASR/local/text2token.py b/egs/alimeeting/ASR/local/text2token.py
index 71be2a613..85047c367 100755
--- a/egs/alimeeting/ASR/local/text2token.py
+++ b/egs/alimeeting/ASR/local/text2token.py
@@ -56,9 +56,7 @@ def get_parser():
     parser.add_argument(
         "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
     )
-    parser.add_argument(
-        "--space", default="", type=str, help="space symbol"
-    )
+    parser.add_argument("--space", default="", type=str, help="space symbol")
     parser.add_argument(
         "--non-lang-syms",
         "-l",
@@ -66,9 +64,7 @@ def get_parser():
         type=str,
         help="list of non-linguistic symobles, e.g.,  etc.",
     )
-    parser.add_argument(
-        "text", type=str, default=False, nargs="?", help="input text"
-    )
+    parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
     parser.add_argument(
         "--trans_type",
         "-t",
@@ -108,8 +104,7 @@ def token2id(
             if token_type == "lazy_pinyin":
                 text = lazy_pinyin(chars_list)
                 sub_ids = [
-                    token_table[txt] if txt in token_table else oov_id
-                    for txt in text
+                    token_table[txt] if txt in token_table else oov_id for txt in text
                 ]
                 ids.append(sub_ids)
             else:  # token_type = "pinyin"
@@ -135,9 +130,7 @@ def main():
     if args.text:
         f = codecs.open(args.text, encoding="utf-8")
     else:
-        f = codecs.getreader("utf-8")(
-            sys.stdin if is_python2 else sys.stdin.buffer
-        )
+        f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
 
     sys.stdout = codecs.getwriter("utf-8")(
         sys.stdout if is_python2 else sys.stdout.buffer
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py
index bf6faad7a..a9a4675a9 100644
--- a/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -205,17 +205,13 @@ class AlimeetingAsrDataModule:
             The state dict for the training sampler.
         """
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(
-            self.args.manifest_dir / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
@@ -237,9 +233,7 @@ class AlimeetingAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -282,9 +276,7 @@ class AlimeetingAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -341,9 +333,7 @@ class AlimeetingAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py
index 6358fe970..f3b63b222 100755
--- a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py
@@ -70,11 +70,7 @@ from beam_search import (
 from lhotse.cut import Cut
 from train import get_params, get_transducer_model
 
-from icefall.checkpoint import (
-    average_checkpoints,
-    find_checkpoints,
-    load_checkpoint,
-)
+from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
 from icefall.lexicon import Lexicon
 from icefall.utils import (
     AttributeDict,
@@ -193,8 +189,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -249,9 +244,7 @@ def decode_one_batch(
 
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -266,10 +259,7 @@ def decode_one_batch(
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -390,9 +380,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -425,8 +413,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -563,8 +550,7 @@ def main():
         )
 
     dev_shards = [
-        str(path)
-        for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
+        str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
     ]
     cuts_dev_webdataset = CutSet.from_webdataset(
         dev_shards,
@@ -574,8 +560,7 @@ def main():
     )
 
     test_shards = [
-        str(path)
-        for path in sorted(glob.glob(os.path.join(test, "shared-*.tar")))
+        str(path) for path in sorted(glob.glob(os.path.join(test, "shared-*.tar")))
     ]
     cuts_test_webdataset = CutSet.from_webdataset(
         test_shards,
@@ -588,9 +573,7 @@ def main():
         return 1.0 <= c.duration
 
     cuts_dev_webdataset = cuts_dev_webdataset.filter(remove_short_and_long_utt)
-    cuts_test_webdataset = cuts_test_webdataset.filter(
-        remove_short_and_long_utt
-    )
+    cuts_test_webdataset = cuts_test_webdataset.filter(remove_short_and_long_utt)
 
     dev_dl = alimeeting.valid_dataloaders(cuts_dev_webdataset)
     test_dl = alimeeting.test_dataloaders(cuts_test_webdataset)
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py
index 8beec1b8a..538853f67 100644
--- a/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py
@@ -103,8 +103,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     return parser
@@ -173,9 +172,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py
index 93b1e1f57..4da8d8e14 100644
--- a/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py
@@ -162,8 +162,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -194,8 +193,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -257,9 +255,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -284,10 +280,7 @@ def main():
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -339,9 +332,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py
index 81a0ede7f..c9d2f3cb9 100644
--- a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py
@@ -81,9 +81,7 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[
-    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
 
 os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
 
@@ -187,8 +185,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -211,8 +208,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network)" "part.",
     )
 
     parser.add_argument(
@@ -542,22 +538,15 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0
-            if warmup < 1.0
-            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
-        )
-        loss = (
-            params.simple_loss_scale * simple_loss
-            + pruned_loss_scale * pruned_loss
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
         )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -711,9 +700,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -813,7 +800,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2 ** 22
+            2**22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
diff --git a/egs/csj/ASR/.gitignore b/egs/csj/ASR/.gitignore
index 5d965832e..cd0e20c4c 100644
--- a/egs/csj/ASR/.gitignore
+++ b/egs/csj/ASR/.gitignore
@@ -5,4 +5,4 @@ notify_tg.py
 finetune_*
 misc.ini
 .vscode/*
-offline/*
\ No newline at end of file
+offline/*
diff --git a/egs/csj/ASR/local/compute_fbank_csj.py b/egs/csj/ASR/local/compute_fbank_csj.py
index 994dedbdd..c248aa668 100644
--- a/egs/csj/ASR/local/compute_fbank_csj.py
+++ b/egs/csj/ASR/local/compute_fbank_csj.py
@@ -25,15 +25,10 @@ from random import Random
 from typing import List, Tuple
 
 import torch
-from lhotse import (
+from lhotse import (  # fmt: off; See the following for why LilcomChunkyWriter is preferred; https://github.com/k2-fsa/icefall/pull/404; https://github.com/lhotse-speech/lhotse/pull/527; fmt: on
     CutSet,
     Fbank,
     FbankConfig,
-    # fmt: off
-    # See the following for why LilcomChunkyWriter is preferred
-    # https://github.com/k2-fsa/icefall/pull/404
-    # https://github.com/lhotse-speech/lhotse/pull/527
-    # fmt: on
     LilcomChunkyWriter,
     RecordingSet,
     SupervisionSet,
@@ -81,17 +76,13 @@ def make_cutset_blueprints(
         cut_sets.append((f"eval{i}", cut_set))
 
     # Create train and valid cuts
-    logging.info(
-        "Loading, trimming, and shuffling the remaining core+noncore cuts."
-    )
+    logging.info("Loading, trimming, and shuffling the remaining core+noncore cuts.")
     recording_set = RecordingSet.from_file(
         manifest_dir / "csj_recordings_core.jsonl.gz"
     ) + RecordingSet.from_file(manifest_dir / "csj_recordings_noncore.jsonl.gz")
     supervision_set = SupervisionSet.from_file(
         manifest_dir / "csj_supervisions_core.jsonl.gz"
-    ) + SupervisionSet.from_file(
-        manifest_dir / "csj_supervisions_noncore.jsonl.gz"
-    )
+    ) + SupervisionSet.from_file(manifest_dir / "csj_supervisions_noncore.jsonl.gz")
 
     cut_set = CutSet.from_manifests(
         recordings=recording_set,
@@ -101,15 +92,12 @@ def make_cutset_blueprints(
     cut_set = cut_set.shuffle(Random(RNG_SEED))
 
     logging.info(
-        "Creating valid and train cuts from core and noncore,"
-        f"split at {split}."
+        "Creating valid and train cuts from core and noncore," f"split at {split}."
     )
     valid_set = CutSet.from_cuts(islice(cut_set, 0, split))
 
     train_set = CutSet.from_cuts(islice(cut_set, split, None))
-    train_set = (
-        train_set + train_set.perturb_speed(0.9) + train_set.perturb_speed(1.1)
-    )
+    train_set = train_set + train_set.perturb_speed(0.9) + train_set.perturb_speed(1.1)
 
     cut_sets.extend([("valid", valid_set), ("train", train_set)])
 
@@ -122,15 +110,9 @@ def get_args():
         formatter_class=argparse.ArgumentDefaultsHelpFormatter,
     )
 
-    parser.add_argument(
-        "--manifest-dir", type=Path, help="Path to save manifests"
-    )
-    parser.add_argument(
-        "--fbank-dir", type=Path, help="Path to save fbank features"
-    )
-    parser.add_argument(
-        "--split", type=int, default=4000, help="Split at this index"
-    )
+    parser.add_argument("--manifest-dir", type=Path, help="Path to save manifests")
+    parser.add_argument("--fbank-dir", type=Path, help="Path to save fbank features")
+    parser.add_argument("--split", type=int, default=4000, help="Split at this index")
 
     return parser.parse_args()
 
@@ -141,9 +123,7 @@ def main():
     extractor = Fbank(FbankConfig(num_mel_bins=80))
     num_jobs = min(16, os.cpu_count())
 
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/csj/ASR/local/compute_fbank_musan.py b/egs/csj/ASR/local/compute_fbank_musan.py
index 44a33c4eb..f60e62c85 100644
--- a/egs/csj/ASR/local/compute_fbank_musan.py
+++ b/egs/csj/ASR/local/compute_fbank_musan.py
@@ -26,7 +26,6 @@ from lhotse.recipes.utils import read_manifests_if_cached
 
 from icefall.utils import get_executor
 
-
 ARGPARSE_DESCRIPTION = """
 This file computes fbank features of the musan dataset.
 It looks for manifests in the directory data/manifests.
@@ -84,9 +83,7 @@ def compute_fbank_musan(manifest_dir: Path, fbank_dir: Path):
         # create chunks of Musan with duration 5 - 10 seconds
         musan_cuts = (
             CutSet.from_manifests(
-                recordings=combine(
-                    part["recordings"] for part in manifests.values()
-                )
+                recordings=combine(part["recordings"] for part in manifests.values())
             )
             .cut_into_windows(10.0)
             .filter(lambda c: c.duration > 5)
@@ -107,21 +104,15 @@ def get_args():
         formatter_class=argparse.ArgumentDefaultsHelpFormatter,
     )
 
-    parser.add_argument(
-        "--manifest-dir", type=Path, help="Path to save manifests"
-    )
-    parser.add_argument(
-        "--fbank-dir", type=Path, help="Path to save fbank features"
-    )
+    parser.add_argument("--manifest-dir", type=Path, help="Path to save manifests")
+    parser.add_argument("--fbank-dir", type=Path, help="Path to save fbank features")
 
     return parser.parse_args()
 
 
 if __name__ == "__main__":
     args = get_args()
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     compute_fbank_musan(args.manifest_dir, args.fbank_dir)
diff --git a/egs/csj/ASR/local/conf/disfluent.ini b/egs/csj/ASR/local/conf/disfluent.ini
index eb70673de..c987e72c5 100644
--- a/egs/csj/ASR/local/conf/disfluent.ini
+++ b/egs/csj/ASR/local/conf/disfluent.ini
@@ -1,17 +1,17 @@
 ; # This section is ignored if this file is not supplied as the first config file to
-; # lhotse prepare csj  
+; # lhotse prepare csj
 [SEGMENTS]
 ; # Allowed period of nonverbal noise. If exceeded, a new segment is created.
 gap = 0.5
 ; # Maximum length of segment (s).
 maxlen = 10
-; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently.  
+; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently.
 minlen = 0.02
-; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. 
-; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. 
-; # If you intend to use a multicharacter string for gap_sym, remember to register the 
-; # multicharacter string as part of userdef-string in prepare_lang_char.py. 
-gap_sym = 
+; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`.
+; # Pass an empty string to avoid adding any symbol. It was "" in kaldi.
+; # If you intend to use a multicharacter string for gap_sym, remember to register the
+; # multicharacter string as part of userdef-string in prepare_lang_char.py.
+gap_sym =
 
 [CONSTANTS]
 ; # Name of this mode
@@ -115,59 +115,59 @@ B^ = 0
 ; # 0 to remain, 1 to delete
 ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)'
 笑 = 0
-; # Example: 'コク(笑 サイ+(D オン))', 
+; # Example: 'コク(笑 サイ+(D オン))',
 笑^ = 0
 ; # 泣きながら発話
 ; # 0 to remain, 1 to delete
-; # Example: '(泣 ドンナニ)' 
+; # Example: '(泣 ドンナニ)'
 泣 = 0
 泣^ = 0
 ; # 咳をしながら発話
 ; # 0 to remain, 1 to delete
-; # Example: 'シャ(咳 リン) ノ' 
+; # Example: 'シャ(咳 リン) ノ'
 咳 = 0
 ; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)'
 咳^ = 0
 ; # ささやき声や独り言などの小さな声
 ; # 0 to remain, 1 to delete
-; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' 
+; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))'
 L = 0
 ; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト'
 L^ = 0
 
 [REPLACEMENTS]
 ; # ボーカルフライなどで母音が同定できない場合
- = 
+ =
 ; # 「うん/うーん/ふーん」の音の特定が困難な場合
- = 
+ =
 ; # 非語彙的な母音の引き延ばし
- = 
+ =
 ; # 非語彙的な子音の引き延ばし
- = 
+ =
 ; # 言語音と独立に講演者の笑いが生じている場合
-<笑> = 
+<笑> =
 ; # 言語音と独立に講演者の咳が生じている場合
-<咳> = 
+<咳> =
 ; # 言語音と独立に講演者の息が生じている場合
-<息> = 
+<息> =
 ; # 講演者の泣き声
-<泣> = 
+<泣> =
 ; # 聴衆(司会者なども含む)の発話
-<フロア発話> = 
+<フロア発話> =
 ; # 聴衆の笑い
-<フロア笑> = 
+<フロア笑> =
 ; # 聴衆の拍手
-<拍手> = 
+<拍手> =
 ; # 講演者が発表中に用いたデモンストレーションの音声
-<デモ> = 
+<デモ> =
 ; # 学会講演に発表時間を知らせるためにならすベルの音
-<ベル> = 
+<ベル> =
 ; # 転記単位全体が再度読み直された場合
-<朗読間違い> = 
+<朗読間違い> =
 ; # 上記以外の音で特に目立った音
-<雑音> = 
+<雑音> =
 ; # 0.2秒以上のポーズ
-

= +

= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -318,4 +318,3 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo - diff --git a/egs/csj/ASR/local/conf/fluent.ini b/egs/csj/ASR/local/conf/fluent.ini index 5d22f9eb8..f7f27f5bc 100644 --- a/egs/csj/ASR/local/conf/fluent.ini +++ b/egs/csj/ASR/local/conf/fluent.ini @@ -1,17 +1,17 @@ ; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj +; # lhotse prepare csj [SEGMENTS] ; # Allowed period of nonverbal noise. If exceeded, a new segment is created. gap = 0.5 ; # Maximum length of segment (s). maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. +; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = +; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. +; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. +; # If you intend to use a multicharacter string for gap_sym, remember to register the +; # multicharacter string as part of userdef-string in prepare_lang_char.py. +gap_sym = [CONSTANTS] ; # Name of this mode @@ -115,59 +115,59 @@ B^ = 0 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', +; # Example: 'コク(笑 サイ+(D オン))', 笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' +; # Example: '(泣 ドンナニ)' 泣 = 0 泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' +; # Example: 'シャ(咳 リン) ノ' 咳 = 0 ; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' 咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 ; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' L^ = 0 [REPLACEMENTS] ; # ボーカルフライなどで母音が同定できない場合 - = + = ; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = + = ; # 非語彙的な母音の引き延ばし - = + = ; # 非語彙的な子音の引き延ばし - = + = ; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = +<笑> = ; # 言語音と独立に講演者の咳が生じている場合 -<咳> = +<咳> = ; # 言語音と独立に講演者の息が生じている場合 -<息> = +<息> = ; # 講演者の泣き声 -<泣> = +<泣> = ; # 聴衆(司会者なども含む)の発話 -<フロア発話> = +<フロア発話> = ; # 聴衆の笑い -<フロア笑> = +<フロア笑> = ; # 聴衆の拍手 -<拍手> = +<拍手> = ; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = +<デモ> = ; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = +<ベル> = ; # 転記単位全体が再度読み直された場合 -<朗読間違い> = +<朗読間違い> = ; # 上記以外の音で特に目立った音 -<雑音> = +<雑音> = ; # 0.2秒以上のポーズ -

= +

= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -318,4 +318,3 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo - diff --git a/egs/csj/ASR/local/conf/number.ini b/egs/csj/ASR/local/conf/number.ini index 2613c3409..cf9038f62 100644 --- a/egs/csj/ASR/local/conf/number.ini +++ b/egs/csj/ASR/local/conf/number.ini @@ -1,17 +1,17 @@ ; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj +; # lhotse prepare csj [SEGMENTS] ; # Allowed period of nonverbal noise. If exceeded, a new segment is created. gap = 0.5 ; # Maximum length of segment (s). maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. +; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = +; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. +; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. +; # If you intend to use a multicharacter string for gap_sym, remember to register the +; # multicharacter string as part of userdef-string in prepare_lang_char.py. +gap_sym = [CONSTANTS] ; # Name of this mode @@ -115,59 +115,59 @@ B^ = 0 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', +; # Example: 'コク(笑 サイ+(D オン))', 笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' +; # Example: '(泣 ドンナニ)' 泣 = 0 泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' +; # Example: 'シャ(咳 リン) ノ' 咳 = 0 ; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' 咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 ; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' L^ = 0 [REPLACEMENTS] ; # ボーカルフライなどで母音が同定できない場合 - = + = ; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = + = ; # 非語彙的な母音の引き延ばし - = + = ; # 非語彙的な子音の引き延ばし - = + = ; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = +<笑> = ; # 言語音と独立に講演者の咳が生じている場合 -<咳> = +<咳> = ; # 言語音と独立に講演者の息が生じている場合 -<息> = +<息> = ; # 講演者の泣き声 -<泣> = +<泣> = ; # 聴衆(司会者なども含む)の発話 -<フロア発話> = +<フロア発話> = ; # 聴衆の笑い -<フロア笑> = +<フロア笑> = ; # 聴衆の拍手 -<拍手> = +<拍手> = ; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = +<デモ> = ; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = +<ベル> = ; # 転記単位全体が再度読み直された場合 -<朗読間違い> = +<朗読間違い> = ; # 上記以外の音で特に目立った音 -<雑音> = +<雑音> = ; # 0.2秒以上のポーズ -

= +

= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -318,4 +318,3 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo - diff --git a/egs/csj/ASR/local/conf/symbol.ini b/egs/csj/ASR/local/conf/symbol.ini index 8ba451dd5..f9801284b 100644 --- a/egs/csj/ASR/local/conf/symbol.ini +++ b/egs/csj/ASR/local/conf/symbol.ini @@ -1,17 +1,17 @@ ; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj +; # lhotse prepare csj [SEGMENTS] ; # Allowed period of nonverbal noise. If exceeded, a new segment is created. gap = 0.5 ; # Maximum length of segment (s). maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. +; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = +; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. +; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. +; # If you intend to use a multicharacter string for gap_sym, remember to register the +; # multicharacter string as part of userdef-string in prepare_lang_char.py. +gap_sym = [CONSTANTS] ; # Name of this mode @@ -116,59 +116,59 @@ B^ = 0 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', +; # Example: 'コク(笑 サイ+(D オン))', 笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' +; # Example: '(泣 ドンナニ)' 泣 = 0 泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' +; # Example: 'シャ(咳 リン) ノ' 咳 = 0 ; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' 咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 ; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' L^ = 0 [REPLACEMENTS] ; # ボーカルフライなどで母音が同定できない場合 - = + = ; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = + = ; # 非語彙的な母音の引き延ばし - = + = ; # 非語彙的な子音の引き延ばし - = + = ; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = +<笑> = ; # 言語音と独立に講演者の咳が生じている場合 -<咳> = +<咳> = ; # 言語音と独立に講演者の息が生じている場合 -<息> = +<息> = ; # 講演者の泣き声 -<泣> = +<泣> = ; # 聴衆(司会者なども含む)の発話 -<フロア発話> = +<フロア発話> = ; # 聴衆の笑い -<フロア笑> = +<フロア笑> = ; # 聴衆の拍手 -<拍手> = +<拍手> = ; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = +<デモ> = ; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = +<ベル> = ; # 転記単位全体が再度読み直された場合 -<朗読間違い> = +<朗読間違い> = ; # 上記以外の音で特に目立った音 -<雑音> = +<雑音> = ; # 0.2秒以上のポーズ -

= +

= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -319,4 +319,3 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo - diff --git a/egs/csj/ASR/local/display_manifest_statistics.py b/egs/csj/ASR/local/display_manifest_statistics.py index c9de21073..c043cf853 100644 --- a/egs/csj/ASR/local/display_manifest_statistics.py +++ b/egs/csj/ASR/local/display_manifest_statistics.py @@ -37,9 +37,7 @@ def get_parser(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument( - "--manifest-dir", type=Path, help="Path to cutset manifests" - ) + parser.add_argument("--manifest-dir", type=Path, help="Path to cutset manifests") return parser.parse_args() diff --git a/egs/csj/ASR/local/prepare_lang_char.py b/egs/csj/ASR/local/prepare_lang_char.py index e4d996871..ef91f6e43 100644 --- a/egs/csj/ASR/local/prepare_lang_char.py +++ b/egs/csj/ASR/local/prepare_lang_char.py @@ -87,9 +87,7 @@ def main(): args = get_args() logging.basicConfig( - format=( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] " "%(message)s" - ), + format=("%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] " "%(message)s"), level=logging.INFO, ) @@ -111,8 +109,7 @@ def main(): words = set() logging.info( - f"Creating vocabulary from {args.train_cut.name}" - f" at {args.trans_mode} mode." + f"Creating vocabulary from {args.train_cut.name}" f" at {args.trans_mode} mode." ) for cut in train_set: try: @@ -123,8 +120,7 @@ def main(): ) except KeyError: raise KeyError( - f"Could not find {args.trans_mode} in " - f"{cut.supervisions[0].custom}" + f"Could not find {args.trans_mode} in " f"{cut.supervisions[0].custom}" ) for t in text.split(): if t in args.userdef_string: @@ -143,9 +139,7 @@ def main(): (args.lang_dir / "words_len").write_text(f"{len(words)}") - (args.lang_dir / "userdef_string").write_text( - "\n".join(args.userdef_string) - ) + (args.lang_dir / "userdef_string").write_text("\n".join(args.userdef_string)) (args.lang_dir / "trans_mode").write_text(args.trans_mode) logging.info("Done.") diff --git a/egs/csj/ASR/local/validate_manifest.py b/egs/csj/ASR/local/validate_manifest.py index 0c4c6c1ea..7f67c64b6 100644 --- a/egs/csj/ASR/local/validate_manifest.py +++ b/egs/csj/ASR/local/validate_manifest.py @@ -89,9 +89,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py index d78e26240..72dcd772a 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py @@ -183,23 +183,18 @@ class GigaSpeechAsrDataModule: "--small-dev", type=str2bool, default=False, - help="Should we use only 1000 utterances for dev " - "(speeds up training)", + help="Should we use only 1000 utterances for dev " "(speeds up training)", ) def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -221,9 +216,7 @@ class GigaSpeechAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") input_transforms.append( SpecAugment( time_warp_factor=self.args.spec_aug_time_warp_factor, @@ -256,9 +249,7 @@ class GigaSpeechAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -304,9 +295,7 @@ class GigaSpeechAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: @@ -362,9 +351,7 @@ class GigaSpeechAsrDataModule: @lru_cache() def dev_cuts(self) -> CutSet: logging.info("About to get dev cuts") - cuts_valid = load_manifest_lazy( - self.args.manifest_dir / "cuts_DEV.jsonl.gz" - ) + cuts_valid = load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") if self.args.small_dev: return cuts_valid.subset(first=1000) else: diff --git a/egs/gigaspeech/ASR/conformer_ctc/conformer.py b/egs/gigaspeech/ASR/conformer_ctc/conformer.py index 6fac07f93..a1cfe6e75 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/conformer.py +++ b/egs/gigaspeech/ASR/conformer_ctc/conformer.py @@ -160,9 +160,7 @@ class ConformerEncoderLayer(nn.Module): use_conv_batchnorm: bool = False, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -182,18 +180,14 @@ class ConformerEncoderLayer(nn.Module): d_model, cnn_module_kernel, use_batchnorm=use_conv_batchnorm ) - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -227,9 +221,7 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) + src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -348,9 +340,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -366,9 +356,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -638,9 +626,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -708,31 +696,22 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -771,9 +750,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -785,9 +762,7 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -821,13 +796,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads diff --git a/egs/gigaspeech/ASR/conformer_ctc/decode.py b/egs/gigaspeech/ASR/conformer_ctc/decode.py index 9c1418baa..d7035a1f8 100755 --- a/egs/gigaspeech/ASR/conformer_ctc/decode.py +++ b/egs/gigaspeech/ASR/conformer_ctc/decode.py @@ -476,9 +476,7 @@ def decode_dataset( results[lm_scale].extend(this_batch) else: - assert ( - len(results) > 0 - ), "It should not decode to empty in the first batch!" + assert len(results) > 0, "It should not decode to empty in the first batch!" this_batch = [] hyp_words = [] for cut_id, ref_text in zip(cut_ids, texts): @@ -493,9 +491,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -528,9 +524,7 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -705,9 +699,7 @@ def main(): eos_id=eos_id, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py b/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py index cdc85ce9a..3b94f0c4b 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py +++ b/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py @@ -78,13 +78,10 @@ class LabelSmoothingLoss(torch.nn.Module): ignored = target == self.ignore_index target[ignored] = 0 - true_dist = torch.nn.functional.one_hot( - target, num_classes=num_classes - ).to(x) + true_dist = torch.nn.functional.one_hot(target, num_classes=num_classes).to(x) true_dist = ( - true_dist * (1 - self.label_smoothing) - + self.label_smoothing / num_classes + true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes ) # Set the value of ignored indexes to 0 true_dist[ignored] = 0 diff --git a/egs/gigaspeech/ASR/conformer_ctc/subsampling.py b/egs/gigaspeech/ASR/conformer_ctc/subsampling.py index 542fb0364..8e0f73d05 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/subsampling.py +++ b/egs/gigaspeech/ASR/conformer_ctc/subsampling.py @@ -42,13 +42,9 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), - nn.Conv2d( - in_channels=odim, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -132,17 +128,13 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d( - kernel_size=2, stride=2, padding=0, ceil_mode=True - ) + torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear( - block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim - ) + self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/gigaspeech/ASR/conformer_ctc/train.py b/egs/gigaspeech/ASR/conformer_ctc/train.py index 2965cde18..4883d04d8 100755 --- a/egs/gigaspeech/ASR/conformer_ctc/train.py +++ b/egs/gigaspeech/ASR/conformer_ctc/train.py @@ -386,9 +386,7 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids( - supervisions["text"] - ) + unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -521,9 +519,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -641,9 +637,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/gigaspeech/ASR/conformer_ctc/transformer.py b/egs/gigaspeech/ASR/conformer_ctc/transformer.py index 00ca027a7..0566cfc81 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/transformer.py +++ b/egs/gigaspeech/ASR/conformer_ctc/transformer.py @@ -151,9 +151,7 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear( - d_model, self.decoder_num_class - ) + self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) self.decoder_criterion = LabelSmoothingLoss() else: @@ -181,18 +179,13 @@ class Transformer(nn.Module): memory_key_padding_mask for the decoder. Its shape is (N, T). It is None if `supervision` is None. """ - if ( - isinstance(self.use_feat_batchnorm, bool) - and self.use_feat_batchnorm - ): + if isinstance(self.use_feat_batchnorm, bool) and self.use_feat_batchnorm: x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) if isinstance(self.use_feat_batchnorm, float): x *= self.use_feat_batchnorm - encoder_memory, memory_key_padding_mask = self.run_encoder( - x, supervision - ) + encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -273,23 +266,17 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -350,23 +337,17 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -639,9 +620,7 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class PositionalEncoding(nn.Module): @@ -843,9 +822,7 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -866,9 +843,7 @@ def encoder_padding_mask( return mask -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py index 8209ee3ec..07beeb1f0 100755 --- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py +++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py @@ -77,9 +77,7 @@ def compute_fbank_gigaspeech_dev_test(): def main(): - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_gigaspeech_dev_test() diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py index 6410249db..1c71be0f9 100755 --- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py +++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py @@ -134,9 +134,7 @@ def main(): date_time = now.strftime("%Y-%m-%d-%H-%M-%S") log_filename = "log-compute_fbank_gigaspeech_splits" - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" log_filename = f"{log_filename}-{date_time}" logging.basicConfig( diff --git a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py index 48d10a157..31abe7fff 100755 --- a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py +++ b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py @@ -98,19 +98,13 @@ def preprocess_giga_speech(): f"Speed perturb for {partition} with factors 0.9 and 1.1 " "(Perturbing may take 8 minutes and saving may take 20 minutes)" ) - cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) - ) + cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) logging.info(f"Saving to {raw_cuts_path}") cut_set.to_file(raw_cuts_path) def main(): - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) preprocess_giga_speech() diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index c87686e1e..7f114fba6 100644 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -195,8 +195,7 @@ class GigaSpeechAsrDataModule: "--small-dev", type=str2bool, default=False, - help="Should we use only 1000 utterances for dev " - "(speeds up training)", + help="Should we use only 1000 utterances for dev " "(speeds up training)", ) def train_dataloaders( @@ -216,13 +215,9 @@ class GigaSpeechAsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -244,9 +239,7 @@ class GigaSpeechAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -289,9 +282,7 @@ class GigaSpeechAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -347,9 +338,7 @@ class GigaSpeechAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: @@ -405,9 +394,7 @@ class GigaSpeechAsrDataModule: @lru_cache() def dev_cuts(self) -> CutSet: logging.info("About to get dev cuts") - cuts_valid = load_manifest_lazy( - self.args.manifest_dir / "cuts_DEV.jsonl.gz" - ) + cuts_valid = load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") if self.args.small_dev: return cuts_valid.subset(first=1000) else: diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py index 5849a3471..c0b17750e 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py @@ -77,11 +77,7 @@ from beam_search import ( from gigaspeech_scoring import asr_text_post_processing from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -188,8 +184,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -258,9 +253,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -275,10 +268,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -398,9 +388,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -434,8 +422,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -511,8 +498,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py index cff9c7377..3d1e7bc18 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py @@ -51,11 +51,7 @@ import sentencepiece as spm import torch from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import str2bool @@ -120,8 +116,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) return parser @@ -160,8 +155,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -209,9 +203,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py index 83ae25561..f51584120 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py @@ -77,9 +77,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def get_parser(): @@ -178,8 +176,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -202,8 +199,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -553,23 +549,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -732,9 +721,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") diff --git a/egs/librispeech/ASR/conformer_ctc/ali.py b/egs/librispeech/ASR/conformer_ctc/ali.py index 2828e309e..42e14abac 100755 --- a/egs/librispeech/ASR/conformer_ctc/ali.py +++ b/egs/librispeech/ASR/conformer_ctc/ali.py @@ -231,9 +231,7 @@ def compute_alignments( labels_ali = get_alignments(best_path, kind="labels") aux_labels_ali = get_alignments(best_path, kind="aux_labels") assert len(labels_ali) == len(aux_labels_ali) == len(cut_list) - for cut, labels, aux_labels in zip( - cut_list, labels_ali, aux_labels_ali - ): + for cut, labels, aux_labels in zip(cut_list, labels_ali, aux_labels_ali): cut.labels_alignment = labels_writer.store_array( key=cut.id, value=np.asarray(labels, dtype=np.int32), @@ -258,9 +256,7 @@ def compute_alignments( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return CutSet.from_cuts(cuts) @@ -289,9 +285,7 @@ def main(): out_labels_ali_filename = out_dir / f"labels_{params.dataset}.h5" out_aux_labels_ali_filename = out_dir / f"aux_labels_{params.dataset}.h5" - out_manifest_filename = ( - out_dir / f"librispeech_cuts_{params.dataset}.jsonl.gz" - ) + out_manifest_filename = out_dir / f"librispeech_cuts_{params.dataset}.jsonl.gz" for f in ( out_labels_ali_filename, diff --git a/egs/librispeech/ASR/conformer_ctc/conformer.py b/egs/librispeech/ASR/conformer_ctc/conformer.py index 6fac07f93..a1cfe6e75 100644 --- a/egs/librispeech/ASR/conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc/conformer.py @@ -160,9 +160,7 @@ class ConformerEncoderLayer(nn.Module): use_conv_batchnorm: bool = False, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -182,18 +180,14 @@ class ConformerEncoderLayer(nn.Module): d_model, cnn_module_kernel, use_batchnorm=use_conv_batchnorm ) - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -227,9 +221,7 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) + src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -348,9 +340,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -366,9 +356,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -638,9 +626,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -708,31 +696,22 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -771,9 +750,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -785,9 +762,7 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -821,13 +796,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 3f3b1acda..7e0bf5b7b 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -551,9 +551,7 @@ def decode_dataset( results[lm_scale].extend(this_batch) else: - assert ( - len(results) > 0 - ), "It should not decode to empty in the first batch!" + assert len(results) > 0, "It should not decode to empty in the first batch!" this_batch = [] hyp_words = [] for ref_text in texts: @@ -568,9 +566,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -602,9 +598,7 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -809,9 +803,7 @@ def main(): eos_id=eos_id, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/librispeech/ASR/conformer_ctc/export.py b/egs/librispeech/ASR/conformer_ctc/export.py index 28c28df01..fbcbd7b29 100755 --- a/egs/librispeech/ASR/conformer_ctc/export.py +++ b/egs/librispeech/ASR/conformer_ctc/export.py @@ -157,9 +157,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conformer_ctc/label_smoothing.py b/egs/librispeech/ASR/conformer_ctc/label_smoothing.py index 1f2f3b137..cb0d6e04d 100644 --- a/egs/librispeech/ASR/conformer_ctc/label_smoothing.py +++ b/egs/librispeech/ASR/conformer_ctc/label_smoothing.py @@ -82,13 +82,10 @@ class LabelSmoothingLoss(torch.nn.Module): # for why we don't use target[ignored] = 0 here target = torch.where(ignored, torch.zeros_like(target), target) - true_dist = torch.nn.functional.one_hot( - target, num_classes=num_classes - ).to(x) + true_dist = torch.nn.functional.one_hot(target, num_classes=num_classes).to(x) true_dist = ( - true_dist * (1 - self.label_smoothing) - + self.label_smoothing / num_classes + true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes ) # Set the value of ignored indexes to 0 diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index a2c0a5486..8200af866 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -237,8 +237,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -300,9 +299,7 @@ def main(): logging.info("Decoding started") features = fbank(waves) - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) # Note: We don't use key padding mask for attention during decoding with torch.no_grad(): @@ -427,9 +424,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 542fb0364..8e0f73d05 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -42,13 +42,9 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), - nn.Conv2d( - in_channels=odim, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -132,17 +128,13 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d( - kernel_size=2, stride=2, padding=0, ceil_mode=True - ) + torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear( - block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim - ) + self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 6419f6816..1449bc310 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -393,9 +393,7 @@ def compute_loss( # Works with a phone lexicon decoding_graph = graph_compiler.compile(texts) else: - raise ValueError( - f"Unsupported type of graph compiler: {type(graph_compiler)}" - ) + raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}") dense_fsa_vec = k2.DenseFsaVec( nnet_output, @@ -422,9 +420,7 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids( - supervisions["text"] - ) + unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -453,9 +449,7 @@ def compute_loss( info["utt_duration"] = supervisions["num_frames"].sum().item() # averaged padding proportion over utterances info["utt_pad_proportion"] = ( - ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)) - .sum() - .item() + ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)).sum().item() ) return loss, info @@ -568,9 +562,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -733,9 +725,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc/transformer.py index 00ca027a7..0566cfc81 100644 --- a/egs/librispeech/ASR/conformer_ctc/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/transformer.py @@ -151,9 +151,7 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear( - d_model, self.decoder_num_class - ) + self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) self.decoder_criterion = LabelSmoothingLoss() else: @@ -181,18 +179,13 @@ class Transformer(nn.Module): memory_key_padding_mask for the decoder. Its shape is (N, T). It is None if `supervision` is None. """ - if ( - isinstance(self.use_feat_batchnorm, bool) - and self.use_feat_batchnorm - ): + if isinstance(self.use_feat_batchnorm, bool) and self.use_feat_batchnorm: x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) if isinstance(self.use_feat_batchnorm, float): x *= self.use_feat_batchnorm - encoder_memory, memory_key_padding_mask = self.run_encoder( - x, supervision - ) + encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -273,23 +266,17 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -350,23 +337,17 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -639,9 +620,7 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class PositionalEncoding(nn.Module): @@ -843,9 +822,7 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -866,9 +843,7 @@ def encoder_padding_mask( return mask -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/librispeech/ASR/conformer_ctc2/attention.py b/egs/librispeech/ASR/conformer_ctc2/attention.py index 1375d7245..356d3f21b 100644 --- a/egs/librispeech/ASR/conformer_ctc2/attention.py +++ b/egs/librispeech/ASR/conformer_ctc2/attention.py @@ -18,11 +18,10 @@ from typing import Optional, Tuple import torch import torch.nn as nn +from scaling import ScaledLinear from torch import Tensor from torch.nn.init import xavier_normal_ -from scaling import ScaledLinear - class MultiheadAttention(nn.Module): r"""Allows the model to jointly attend to information @@ -76,9 +75,7 @@ class MultiheadAttention(nn.Module): self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim - self._qkv_same_embed_dim = ( - self.kdim == embed_dim and self.vdim == embed_dim - ) + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim self.num_heads = num_heads self.dropout = dropout @@ -94,9 +91,7 @@ class MultiheadAttention(nn.Module): self.v_proj_weight = ScaledLinear(self.vdim, embed_dim, bias=bias) self.register_parameter("in_proj_weight", None) else: - self.in_proj_weight = ScaledLinear( - embed_dim, 3 * embed_dim, bias=bias - ) + self.in_proj_weight = ScaledLinear(embed_dim, 3 * embed_dim, bias=bias) self.register_parameter("q_proj_weight", None) self.register_parameter("k_proj_weight", None) self.register_parameter("v_proj_weight", None) @@ -107,12 +102,8 @@ class MultiheadAttention(nn.Module): self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=bias) if add_bias_kv: - self.bias_k = nn.Parameter( - torch.empty((1, 1, embed_dim), **factory_kwargs) - ) - self.bias_v = nn.Parameter( - torch.empty((1, 1, embed_dim), **factory_kwargs) - ) + self.bias_k = nn.Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) + self.bias_v = nn.Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) else: self.bias_k = self.bias_v = None diff --git a/egs/librispeech/ASR/conformer_ctc2/conformer.py b/egs/librispeech/ASR/conformer_ctc2/conformer.py index b906d2650..09f1eb000 100644 --- a/egs/librispeech/ASR/conformer_ctc2/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc2/conformer.py @@ -29,9 +29,8 @@ from scaling import ( ScaledConv1d, ScaledLinear, ) -from torch import Tensor, nn from subsampling import Conv2dSubsampling - +from torch import Tensor, nn from transformer import Supervisions, Transformer, encoder_padding_mask @@ -182,9 +181,7 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -356,9 +353,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -373,9 +368,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vecotr and `j` means the @@ -650,9 +643,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -721,31 +714,22 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -784,9 +768,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -794,13 +776,9 @@ class RelPositionMultiheadAttention(nn.Module): ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -834,13 +812,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -863,9 +837,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/conformer_ctc2/decode.py b/egs/librispeech/ASR/conformer_ctc2/decode.py index 97f2f2d39..0b271a51c 100755 --- a/egs/librispeech/ASR/conformer_ctc2/decode.py +++ b/egs/librispeech/ASR/conformer_ctc2/decode.py @@ -658,9 +658,7 @@ def decode_dataset( results[lm_scale].extend(this_batch) else: - assert ( - len(results) > 0 - ), "It should not decode to empty in the first batch!" + assert len(results) > 0, "It should not decode to empty in the first batch!" this_batch = [] hyp_words = [] for ref_text in texts: @@ -675,9 +673,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -709,9 +705,7 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -852,9 +846,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -881,9 +875,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -985,9 +979,7 @@ def main(): eos_id=eos_id, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/librispeech/ASR/conformer_ctc2/export.py b/egs/librispeech/ASR/conformer_ctc2/export.py index 584b3c3fc..7892b03c6 100755 --- a/egs/librispeech/ASR/conformer_ctc2/export.py +++ b/egs/librispeech/ASR/conformer_ctc2/export.py @@ -47,6 +47,7 @@ import logging from pathlib import Path import torch +from conformer import Conformer from decode import get_params from icefall.checkpoint import ( @@ -55,10 +56,8 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from conformer import Conformer - -from icefall.utils import str2bool from icefall.lexicon import Lexicon +from icefall.utils import str2bool def get_parser(): @@ -177,9 +176,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -206,9 +205,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -273,9 +272,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conformer_ctc2/train.py b/egs/librispeech/ASR/conformer_ctc2/train.py index 18fa3e69f..ceea0c22c 100755 --- a/egs/librispeech/ASR/conformer_ctc2/train.py +++ b/egs/librispeech/ASR/conformer_ctc2/train.py @@ -69,8 +69,8 @@ from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter -from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall import diagnostics +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import ( @@ -89,9 +89,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def get_parser(): @@ -498,11 +496,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -531,9 +525,7 @@ def compute_loss( # Works with a phone lexicon decoding_graph = graph_compiler.compile(texts) else: - raise ValueError( - f"Unsupported type of graph compiler: {type(graph_compiler)}" - ) + raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}") dense_fsa_vec = k2.DenseFsaVec( nnet_output, @@ -560,9 +552,7 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids( - supervisions["text"] - ) + unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -580,9 +570,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() info["ctc_loss"] = ctc_loss.detach().cpu().item() if params.att_rate != 0.0: info["att_loss"] = att_loss.detach().cpu().item() @@ -776,9 +764,9 @@ def train_one_epoch( f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"lr: {cur_lr:.2e}" ) - if loss_info["ctc_loss"] == float("inf") or loss_info[ - "att_loss" - ] == float("inf"): + if loss_info["ctc_loss"] == float("inf") or loss_info["att_loss"] == float( + "inf" + ): logging.error( "Your loss contains inf, something goes wrong" f"failing batch names {batch_name}" @@ -791,9 +779,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") diff --git a/egs/librispeech/ASR/conformer_ctc2/transformer.py b/egs/librispeech/ASR/conformer_ctc2/transformer.py index 3ef7edc23..d3443dc94 100644 --- a/egs/librispeech/ASR/conformer_ctc2/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc2/transformer.py @@ -21,19 +21,17 @@ from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn -from label_smoothing import LabelSmoothingLoss -from subsampling import Conv2dSubsampling from attention import MultiheadAttention -from torch.nn.utils.rnn import pad_sequence - +from label_smoothing import LabelSmoothingLoss from scaling import ( ActivationBalancer, BasicNorm, DoubleSwish, - ScaledLinear, ScaledEmbedding, + ScaledLinear, ) - +from subsampling import Conv2dSubsampling +from torch.nn.utils.rnn import pad_sequence # Note: TorchScript requires Dict/List/etc. to be fully typed. Supervisions = Dict[str, torch.Tensor] @@ -210,9 +208,7 @@ class Transformer(nn.Module): x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) mask = encoder_padding_mask(x.size(0), supervisions) mask = mask.to(x.device) if mask is not None else None - x = self.encoder( - x, src_key_padding_mask=mask, warmup=warmup - ) # (T, N, C) + x = self.encoder(x, src_key_padding_mask=mask, warmup=warmup) # (T, N, C) return x, mask @@ -261,23 +257,17 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -338,23 +328,17 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -959,9 +943,7 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -982,9 +964,7 @@ def encoder_padding_mask( return mask -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/librispeech/ASR/conformer_mmi/conformer.py b/egs/librispeech/ASR/conformer_mmi/conformer.py index 97c8d83a2..53e48eb13 100644 --- a/egs/librispeech/ASR/conformer_mmi/conformer.py +++ b/egs/librispeech/ASR/conformer_mmi/conformer.py @@ -156,9 +156,7 @@ class ConformerEncoderLayer(nn.Module): normalize_before: bool = True, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -176,18 +174,14 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -221,9 +215,7 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) + src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -342,9 +334,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -360,9 +350,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -632,9 +620,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -702,31 +690,22 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -765,9 +744,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -779,9 +756,7 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -815,13 +790,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -844,9 +815,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/conformer_mmi/decode.py b/egs/librispeech/ASR/conformer_mmi/decode.py index fc9861489..e3c7b685f 100755 --- a/egs/librispeech/ASR/conformer_mmi/decode.py +++ b/egs/librispeech/ASR/conformer_mmi/decode.py @@ -478,9 +478,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -512,9 +510,7 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -653,9 +649,7 @@ def main(): if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save( - {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" - ) + torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") return model.to(device) @@ -687,9 +681,7 @@ def main(): eos_id=eos_id, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/librispeech/ASR/conformer_mmi/subsampling.py b/egs/librispeech/ASR/conformer_mmi/subsampling.py index 5c3e1222e..ad9415987 100644 --- a/egs/librispeech/ASR/conformer_mmi/subsampling.py +++ b/egs/librispeech/ASR/conformer_mmi/subsampling.py @@ -25,13 +25,9 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), - nn.Conv2d( - in_channels=odim, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -115,17 +111,13 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d( - kernel_size=2, stride=2, padding=0, ceil_mode=True - ) + torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear( - block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim - ) + self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/librispeech/ASR/conformer_mmi/test_subsampling.py b/egs/librispeech/ASR/conformer_mmi/test_subsampling.py index 937845d77..d0bb017dd 100755 --- a/egs/librispeech/ASR/conformer_mmi/test_subsampling.py +++ b/egs/librispeech/ASR/conformer_mmi/test_subsampling.py @@ -1,8 +1,7 @@ #!/usr/bin/env python3 -from subsampling import Conv2dSubsampling -from subsampling import VggSubsampling import torch +from subsampling import Conv2dSubsampling, VggSubsampling def test_conv2d_subsampling(): diff --git a/egs/librispeech/ASR/conformer_mmi/test_transformer.py b/egs/librispeech/ASR/conformer_mmi/test_transformer.py index 08e680607..25d18076d 100644 --- a/egs/librispeech/ASR/conformer_mmi/test_transformer.py +++ b/egs/librispeech/ASR/conformer_mmi/test_transformer.py @@ -1,17 +1,16 @@ #!/usr/bin/env python3 import torch +from torch.nn.utils.rnn import pad_sequence from transformer import ( Transformer, + add_eos, + add_sos, + decoder_padding_mask, encoder_padding_mask, generate_square_subsequent_mask, - decoder_padding_mask, - add_sos, - add_eos, ) -from torch.nn.utils.rnn import pad_sequence - def test_encoder_padding_mask(): supervisions = { diff --git a/egs/librispeech/ASR/conformer_mmi/train-with-attention.py b/egs/librispeech/ASR/conformer_mmi/train-with-attention.py index 011dadd73..f8c94cff9 100755 --- a/egs/librispeech/ASR/conformer_mmi/train-with-attention.py +++ b/egs/librispeech/ASR/conformer_mmi/train-with-attention.py @@ -36,23 +36,14 @@ from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter from transformer import Noam -from icefall.ali import ( - convert_alignments_to_tensor, - load_alignments, - lookup_alignments, -) +from icefall.ali import convert_alignments_to_tensor, load_alignments, lookup_alignments from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.lexicon import Lexicon from icefall.mmi import LFMMILoss from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler -from icefall.utils import ( - AttributeDict, - encode_supervisions, - setup_logger, - str2bool, -) +from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool def get_parser(): @@ -370,10 +361,7 @@ def compute_loss( nnet_output = nnet_output.clone() nnet_output[:, :min_len, :] += ali_scale * mask[:, :min_len, :] - if ( - params.batch_idx_train > params.use_ali_until - and params.beam_size < 8 - ): + if params.batch_idx_train > params.use_ali_until and params.beam_size < 8: # logging.info("Change beam size to 8") params.beam_size = 8 else: @@ -762,19 +750,14 @@ def run(rank, world_size, args): for epoch in range(params.start_epoch, params.num_epochs): train_dl.sampler.set_epoch(epoch) - if ( - params.batch_idx_train >= params.use_ali_until - and train_ali is not None - ): + if params.batch_idx_train >= params.use_ali_until and train_ali is not None: # Delete the alignments to save memory train_ali = None valid_ali = None cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/conformer_mmi/train.py b/egs/librispeech/ASR/conformer_mmi/train.py index 9a5bdcce2..5cfb2bfc7 100755 --- a/egs/librispeech/ASR/conformer_mmi/train.py +++ b/egs/librispeech/ASR/conformer_mmi/train.py @@ -36,23 +36,14 @@ from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter from transformer import Noam -from icefall.ali import ( - convert_alignments_to_tensor, - load_alignments, - lookup_alignments, -) +from icefall.ali import convert_alignments_to_tensor, load_alignments, lookup_alignments from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.lexicon import Lexicon from icefall.mmi import LFMMILoss from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler -from icefall.utils import ( - AttributeDict, - encode_supervisions, - setup_logger, - str2bool, -) +from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool def get_parser(): @@ -377,10 +368,7 @@ def compute_loss( nnet_output = nnet_output.clone() nnet_output[:, :min_len, :] += ali_scale * mask[:, :min_len, :] - if ( - params.batch_idx_train > params.use_ali_until - and params.beam_size < 8 - ): + if params.batch_idx_train > params.use_ali_until and params.beam_size < 8: logging.info("Change beam size to 8") params.beam_size = 8 else: @@ -770,19 +758,14 @@ def run(rank, world_size, args): for epoch in range(params.start_epoch, params.num_epochs): fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) - if ( - params.batch_idx_train >= params.use_ali_until - and train_ali is not None - ): + if params.batch_idx_train >= params.use_ali_until and train_ali is not None: # Delete the alignments to save memory train_ali = None valid_ali = None cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/conformer_mmi/transformer.py b/egs/librispeech/ASR/conformer_mmi/transformer.py index 68a4ff65c..2542d9abe 100644 --- a/egs/librispeech/ASR/conformer_mmi/transformer.py +++ b/egs/librispeech/ASR/conformer_mmi/transformer.py @@ -148,9 +148,7 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear( - d_model, self.decoder_num_class - ) + self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) self.decoder_criterion = LabelSmoothingLoss(self.decoder_num_class) else: @@ -182,9 +180,7 @@ class Transformer(nn.Module): x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) - encoder_memory, memory_key_padding_mask = self.run_encoder( - x, supervision - ) + encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -274,9 +270,7 @@ class Transformer(nn.Module): ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -341,9 +335,7 @@ class Transformer(nn.Module): ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -616,9 +608,7 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class PositionalEncoding(nn.Module): @@ -887,9 +877,7 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -910,9 +898,7 @@ def encoder_padding_mask( return mask -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py index 620d69a19..6854c82d8 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py @@ -215,8 +215,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -284,9 +283,7 @@ def decode_one_batch( value=LOG_EPS, ) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -301,10 +298,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -427,9 +421,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -462,8 +454,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -506,9 +497,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -540,9 +529,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -569,9 +558,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py index 8ca7d5568..1aaa3b9cb 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py @@ -35,7 +35,6 @@ from scaling import ( from icefall.utils import make_pad_mask - LOG_EPSILON = math.log(1e-10) @@ -127,9 +126,7 @@ def stack_states( for si, s in enumerate(layer): attn_caches[li][si].append(s) if b == batch_size - 1: - attn_caches[li][si] = torch.stack( - attn_caches[li][si], dim=1 - ) + attn_caches[li][si] = torch.stack(attn_caches[li][si], dim=1) conv_caches = [] for layer in state_list[0][1]: @@ -268,9 +265,7 @@ class ConvolutionModule(nn.Module): intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length ) - first = torch.arange( - self.chunk_length, self.chunk_length + self.cache_size - ) + first = torch.arange(self.chunk_length, self.chunk_length + self.cache_size) indexes = intervals.unsqueeze(1) + first.unsqueeze(0) indexes = torch.cat( [indexes, torch.arange(U_ - self.cache_size, U_).unsqueeze(0)] @@ -284,9 +279,7 @@ class ConvolutionModule(nn.Module): # (num_chunks * B, cache_size + right_context_length, D) return pad_right_context.permute(0, 2, 1) - def _merge_right_context( - self, right_context: torch.Tensor, B: int - ) -> torch.Tensor: + def _merge_right_context(self, right_context: torch.Tensor, B: int) -> torch.Tensor: """ Args: right_context: @@ -337,12 +330,8 @@ class ConvolutionModule(nn.Module): right_context = x[:, :, :R] # (B, D, R) # make causal convolution - cache = torch.zeros( - B, D, self.cache_size, device=x.device, dtype=x.dtype - ) - pad_utterance = torch.cat( - [cache, utterance], dim=2 - ) # (B, D, cache + U) + cache = torch.zeros(B, D, self.cache_size, device=x.device, dtype=x.dtype) + pad_utterance = torch.cat([cache, utterance], dim=2) # (B, D, cache + U) # depth-wise conv on utterance utterance = self.depthwise_conv(pad_utterance) # (B, D, U) @@ -355,9 +344,7 @@ class ConvolutionModule(nn.Module): right_context = self.depthwise_conv( pad_right_context ) # (num_segs * B, D, right_context_length) - right_context = self._merge_right_context( - right_context, B - ) # (B, D, R) + right_context = self._merge_right_context(right_context, B) # (B, D, R) x = torch.cat([right_context, utterance], dim=2) # (B, D, R + U) x = self.deriv_balancer2(x) @@ -458,8 +445,7 @@ class EmformerAttention(nn.Module): if embed_dim % nhead != 0: raise ValueError( - f"embed_dim ({embed_dim}) is not a multiple of" - f"nhead ({nhead})." + f"embed_dim ({embed_dim}) is not a multiple of" f"nhead ({nhead})." ) self.embed_dim = embed_dim @@ -469,9 +455,7 @@ class EmformerAttention(nn.Module): self.head_dim = embed_dim // nhead self.dropout = dropout - self.emb_to_key_value = ScaledLinear( - embed_dim, 2 * embed_dim, bias=True - ) + self.emb_to_key_value = ScaledLinear(embed_dim, 2 * embed_dim, bias=True) self.emb_to_query = ScaledLinear(embed_dim, embed_dim, bias=True) self.out_proj = ScaledLinear( embed_dim, embed_dim, bias=True, initial_scale=0.25 @@ -513,9 +497,7 @@ class EmformerAttention(nn.Module): if padding_mask is not None: Q = attention_weights.size(1) B = attention_weights.size(0) // self.nhead - attention_weights_float = attention_weights_float.view( - B, self.nhead, Q, -1 - ) + attention_weights_float = attention_weights_float.view(B, self.nhead, Q, -1) attention_weights_float = attention_weights_float.masked_fill( padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), self.negative_inf, @@ -551,9 +533,7 @@ class EmformerAttention(nn.Module): scaling = float(self.head_dim) ** -0.5 # compute query with [right_context, utterance, summary]. - query = self.emb_to_query( - torch.cat([right_context, utterance, summary]) - ) + query = self.emb_to_query(torch.cat([right_context, utterance, summary])) # compute key and value with [memory, right_context, utterance]. key, value = self.emb_to_key_value( torch.cat([memory, right_context, utterance]) @@ -564,16 +544,12 @@ class EmformerAttention(nn.Module): # [memory, right context, left context, uttrance] # this is used in inference mode key = torch.cat([key[: M + R], left_context_key, key[M + R :]]) - value = torch.cat( - [value[: M + R], left_context_val, value[M + R :]] - ) + value = torch.cat([value[: M + R], left_context_val, value[M + R :]]) Q = query.size(0) # KV = key.size(0) reshaped_query, reshaped_key, reshaped_value = [ - tensor.contiguous() - .view(-1, B * self.nhead, self.head_dim) - .transpose(0, 1) + tensor.contiguous().view(-1, B * self.nhead, self.head_dim).transpose(0, 1) for tensor in [query, key, value] ] # (B * nhead, Q or KV, head_dim) attention_weights = torch.bmm( @@ -588,9 +564,7 @@ class EmformerAttention(nn.Module): # compute attention outputs attention = torch.bmm(attention_probs, reshaped_value) assert attention.shape == (B * self.nhead, Q, self.head_dim) - attention = ( - attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) - ) + attention = attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) # apply output projection outputs = self.out_proj(attention) @@ -672,12 +646,7 @@ class EmformerAttention(nn.Module): - output of right context and utterance, with shape (R + U, B, D). - memory output, with shape (M, B, D), where M = S - 1 or M = 0. """ - ( - output_right_context_utterance, - output_memory, - _, - _, - ) = self._forward_impl( + (output_right_context_utterance, output_memory, _, _,) = self._forward_impl( utterance, right_context, summary, @@ -947,13 +916,9 @@ class EmformerEncoderLayer(nn.Module): right_context = right_context_utterance[:R] if self.use_memory: - summary = self.summary_op(utterance.permute(1, 2, 0)).permute( - 2, 0, 1 - ) + summary = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1) else: - summary = torch.empty(0).to( - dtype=utterance.dtype, device=utterance.device - ) + summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) output_right_context_utterance, output_memory = self.attention( utterance=utterance, right_context=right_context, @@ -992,14 +957,10 @@ class EmformerEncoderLayer(nn.Module): left_context_val = attn_cache[2] if self.use_memory: - summary = self.summary_op(utterance.permute(1, 2, 0)).permute( - 2, 0, 1 - ) + summary = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1) summary = summary[:1] else: - summary = torch.empty(0).to( - dtype=utterance.dtype, device=utterance.device - ) + summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) ( output_right_context_utterance, output_memory, @@ -1014,9 +975,7 @@ class EmformerEncoderLayer(nn.Module): left_context_val=left_context_val, padding_mask=padding_mask, ) - attn_cache = self._update_attn_cache( - next_key, next_val, memory, attn_cache - ) + attn_cache = self._update_attn_cache(next_key, next_val, memory, attn_cache) return output_right_context_utterance, output_memory, attn_cache def forward( @@ -1151,11 +1110,7 @@ class EmformerEncoderLayer(nn.Module): src = src + self.dropout(self.feed_forward_macaron(src)) # emformer attention module - ( - src_att, - output_memory, - attn_cache, - ) = self._apply_attention_module_infer( + (src_att, output_memory, attn_cache,) = self._apply_attention_module_infer( src, R, memory, attn_cache, padding_mask=padding_mask ) src = src + self.dropout(src_att) @@ -1295,9 +1250,7 @@ class EmformerEncoder(nn.Module): def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: """Hard copy each chunk's right context and concat them.""" T = x.shape[0] - num_chunks = math.ceil( - (T - self.right_context_length) / self.chunk_length - ) + num_chunks = math.ceil((T - self.right_context_length) / self.chunk_length) # first (num_chunks - 1) right context block intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length @@ -1316,9 +1269,7 @@ class EmformerEncoder(nn.Module): right_context_blocks = x[indexes.reshape(-1)] return right_context_blocks - def _gen_attention_mask_col_widths( - self, chunk_idx: int, U: int - ) -> List[int]: + def _gen_attention_mask_col_widths(self, chunk_idx: int, U: int) -> List[int]: """Calculate column widths (key, value) in attention mask for the chunk_idx chunk.""" num_chunks = math.ceil(U / self.chunk_length) @@ -1479,9 +1430,7 @@ class EmformerEncoder(nn.Module): output_lengths = torch.clamp(lengths - self.right_context_length, min=0) attention_mask = self._gen_attention_mask(utterance) memory = ( - self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ - :-1 - ] + self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[:-1] if self.use_memory else torch.empty(0).to(dtype=x.dtype, device=x.device) ) @@ -1643,12 +1592,8 @@ class EmformerEncoder(nn.Module): attn_caches = [ [ torch.zeros(self.memory_size, self.d_model, device=device), - torch.zeros( - self.left_context_length, self.d_model, device=device - ), - torch.zeros( - self.left_context_length, self.d_model, device=device - ), + torch.zeros(self.left_context_length, self.d_model, device=device), + torch.zeros(self.left_context_length, self.d_model, device=device), ] for _ in range(self.num_encoder_layers) ] @@ -1693,17 +1638,11 @@ class Emformer(EncoderInterface): raise NotImplementedError( "chunk_length must be a mutiple of subsampling_factor." ) - if ( - left_context_length != 0 - and left_context_length % subsampling_factor != 0 - ): + if left_context_length != 0 and left_context_length % subsampling_factor != 0: raise NotImplementedError( "left_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) - if ( - right_context_length != 0 - and right_context_length % subsampling_factor != 0 - ): + if right_context_length != 0 and right_context_length % subsampling_factor != 0: raise NotImplementedError( "right_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) @@ -1766,9 +1705,7 @@ class Emformer(EncoderInterface): x_lens = (((x_lens - 1) >> 1) - 1) >> 1 assert x.size(0) == x_lens.max().item() - output, output_lengths = self.encoder( - x, x_lens, warmup=warmup - ) # (T, N, C) + output, output_lengths = self.encoder(x, x_lens, warmup=warmup) # (T, N, C) output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py index 4930881ea..334682ad6 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py @@ -136,8 +136,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -181,9 +180,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -210,9 +209,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -279,9 +278,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py index 9494e1fc1..c211b215e 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py @@ -68,14 +68,12 @@ class Stream(object): elif params.decoding_method == "fast_beam_search": # feature_len is needed to get partial results. # The rnnt_decoding_stream for fast_beam_search. - self.rnnt_decoding_stream: k2.RnntDecodingStream = ( - k2.RnntDecodingStream(decoding_graph) + self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( + decoding_graph ) self.hyp: Optional[List[int]] = None else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") self.ground_truth: str = "" diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py index 61dbe8658..621eeb952 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py @@ -211,8 +211,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -371,9 +370,7 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner( - current_encoder_out, decoder_out, project_input=False - ) + logits = model.joiner(current_encoder_out, decoder_out, project_input=False) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -390,9 +387,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -551,14 +546,10 @@ def decode_one_chunk( feature_list, batch_first=True, padding_value=LOG_EPSILON ).to(device) feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor( - num_processed_frames_list, device=device - ) + num_processed_frames = torch.tensor(num_processed_frames_list, device=device) # Make sure it has at least 1 frame after subsampling, first-and-last-frame cutting, and right context cutting # noqa - tail_length = ( - 3 * params.subsampling_factor + params.right_context_length + 3 - ) + tail_length = 3 * params.subsampling_factor + params.right_context_length + 3 if features.size(1) < tail_length: pad_length = tail_length - features.size(1) feature_lens += pad_length @@ -605,9 +596,7 @@ def decode_one_chunk( max_states=params.max_states, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") # Update cached states of each stream state_list = unstack_states(states) @@ -782,8 +771,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -831,9 +819,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -867,9 +853,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -896,9 +882,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py index c07d8f76b..3d8d4a18a 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py @@ -95,9 +95,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -265,8 +263,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -289,8 +286,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -636,11 +632,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -668,23 +660,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -871,9 +856,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -981,7 +964,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py index 98b8290b5..d3c001942 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py @@ -215,8 +215,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -284,9 +283,7 @@ def decode_one_batch( value=LOG_EPS, ) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -301,10 +298,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -427,9 +421,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -462,8 +454,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -506,9 +497,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -540,9 +529,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -569,9 +558,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py index f16f5acc7..c3739566f 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py @@ -35,7 +35,6 @@ from scaling import ( from icefall.utils import make_pad_mask - LOG_EPSILON = math.log(1e-10) @@ -127,9 +126,7 @@ def stack_states( for si, s in enumerate(layer): attn_caches[li][si].append(s) if b == batch_size - 1: - attn_caches[li][si] = torch.stack( - attn_caches[li][si], dim=1 - ) + attn_caches[li][si] = torch.stack(attn_caches[li][si], dim=1) conv_caches = [] for layer in state_list[0][1]: @@ -268,9 +265,7 @@ class ConvolutionModule(nn.Module): intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length ) - first = torch.arange( - self.chunk_length, self.chunk_length + self.cache_size - ) + first = torch.arange(self.chunk_length, self.chunk_length + self.cache_size) indexes = intervals.unsqueeze(1) + first.unsqueeze(0) indexes = torch.cat( [indexes, torch.arange(U_ - self.cache_size, U_).unsqueeze(0)] @@ -284,9 +279,7 @@ class ConvolutionModule(nn.Module): # (num_chunks * B, cache_size + right_context_length, D) return pad_right_context.permute(0, 2, 1) - def _merge_right_context( - self, right_context: torch.Tensor, B: int - ) -> torch.Tensor: + def _merge_right_context(self, right_context: torch.Tensor, B: int) -> torch.Tensor: """ Args: right_context: @@ -337,12 +330,8 @@ class ConvolutionModule(nn.Module): right_context = x[:, :, :R] # (B, D, R) # make causal convolution - cache = torch.zeros( - B, D, self.cache_size, device=x.device, dtype=x.dtype - ) - pad_utterance = torch.cat( - [cache, utterance], dim=2 - ) # (B, D, cache + U) + cache = torch.zeros(B, D, self.cache_size, device=x.device, dtype=x.dtype) + pad_utterance = torch.cat([cache, utterance], dim=2) # (B, D, cache + U) # depth-wise conv on utterance utterance = self.depthwise_conv(pad_utterance) # (B, D, U) @@ -355,9 +344,7 @@ class ConvolutionModule(nn.Module): right_context = self.depthwise_conv( pad_right_context ) # (num_segs * B, D, right_context_length) - right_context = self._merge_right_context( - right_context, B - ) # (B, D, R) + right_context = self._merge_right_context(right_context, B) # (B, D, R) x = torch.cat([right_context, utterance], dim=2) # (B, D, R + U) x = self.deriv_balancer2(x) @@ -458,8 +445,7 @@ class EmformerAttention(nn.Module): if embed_dim % nhead != 0: raise ValueError( - f"embed_dim ({embed_dim}) is not a multiple of" - f"nhead ({nhead})." + f"embed_dim ({embed_dim}) is not a multiple of" f"nhead ({nhead})." ) self.embed_dim = embed_dim @@ -469,9 +455,7 @@ class EmformerAttention(nn.Module): self.head_dim = embed_dim // nhead self.dropout = dropout - self.emb_to_key_value = ScaledLinear( - embed_dim, 2 * embed_dim, bias=True - ) + self.emb_to_key_value = ScaledLinear(embed_dim, 2 * embed_dim, bias=True) self.emb_to_query = ScaledLinear(embed_dim, embed_dim, bias=True) self.out_proj = ScaledLinear( embed_dim, embed_dim, bias=True, initial_scale=0.25 @@ -513,9 +497,7 @@ class EmformerAttention(nn.Module): if padding_mask is not None: Q = attention_weights.size(1) B = attention_weights.size(0) // self.nhead - attention_weights_float = attention_weights_float.view( - B, self.nhead, Q, -1 - ) + attention_weights_float = attention_weights_float.view(B, self.nhead, Q, -1) attention_weights_float = attention_weights_float.masked_fill( padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), self.negative_inf, @@ -561,16 +543,12 @@ class EmformerAttention(nn.Module): # [memory, right context, left context, uttrance] # this is used in inference mode key = torch.cat([key[: M + R], left_context_key, key[M + R :]]) - value = torch.cat( - [value[: M + R], left_context_val, value[M + R :]] - ) + value = torch.cat([value[: M + R], left_context_val, value[M + R :]]) Q = query.size(0) # KV = key.size(0) reshaped_query, reshaped_key, reshaped_value = [ - tensor.contiguous() - .view(-1, B * self.nhead, self.head_dim) - .transpose(0, 1) + tensor.contiguous().view(-1, B * self.nhead, self.head_dim).transpose(0, 1) for tensor in [query, key, value] ] # (B * nhead, Q or KV, head_dim) attention_weights = torch.bmm( @@ -585,9 +563,7 @@ class EmformerAttention(nn.Module): # compute attention outputs attention = torch.bmm(attention_probs, reshaped_value) assert attention.shape == (B * self.nhead, Q, self.head_dim) - attention = ( - attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) - ) + attention = attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) # apply output projection output_right_context_utterance = self.out_proj(attention) @@ -905,13 +881,11 @@ class EmformerEncoderLayer(nn.Module): right_context = right_context_utterance[:R] if self.use_memory: - memory = self.summary_op(utterance.permute(1, 2, 0)).permute( - 2, 0, 1 - )[:-1, :, :] + memory = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ + :-1, :, : + ] else: - memory = torch.empty(0).to( - dtype=utterance.dtype, device=utterance.device - ) + memory = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) output_right_context_utterance = self.attention( utterance=utterance, right_context=right_context, @@ -948,18 +922,12 @@ class EmformerEncoderLayer(nn.Module): left_context_val = attn_cache[2] if self.use_memory: - memory = self.summary_op(utterance.permute(1, 2, 0)).permute( - 2, 0, 1 - )[:1, :, :] + memory = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ + :1, :, : + ] else: - memory = torch.empty(0).to( - dtype=utterance.dtype, device=utterance.device - ) - ( - output_right_context_utterance, - next_key, - next_val, - ) = self.attention.infer( + memory = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) + (output_right_context_utterance, next_key, next_val,) = self.attention.infer( utterance=utterance, right_context=right_context, memory=pre_memory, @@ -967,9 +935,7 @@ class EmformerEncoderLayer(nn.Module): left_context_val=left_context_val, padding_mask=padding_mask, ) - attn_cache = self._update_attn_cache( - next_key, next_val, memory, attn_cache - ) + attn_cache = self._update_attn_cache(next_key, next_val, memory, attn_cache) return output_right_context_utterance, attn_cache def forward( @@ -1226,9 +1192,7 @@ class EmformerEncoder(nn.Module): def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: """Hard copy each chunk's right context and concat them.""" T = x.shape[0] - num_chunks = math.ceil( - (T - self.right_context_length) / self.chunk_length - ) + num_chunks = math.ceil((T - self.right_context_length) / self.chunk_length) # first (num_chunks - 1) right context block intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length @@ -1247,9 +1211,7 @@ class EmformerEncoder(nn.Module): right_context_blocks = x[indexes.reshape(-1)] return right_context_blocks - def _gen_attention_mask_col_widths( - self, chunk_idx: int, U: int - ) -> List[int]: + def _gen_attention_mask_col_widths(self, chunk_idx: int, U: int) -> List[int]: """Calculate column widths (key, value) in attention mask for the chunk_idx chunk.""" num_chunks = math.ceil(U / self.chunk_length) @@ -1549,12 +1511,8 @@ class EmformerEncoder(nn.Module): attn_caches = [ [ torch.zeros(self.memory_size, self.d_model, device=device), - torch.zeros( - self.left_context_length, self.d_model, device=device - ), - torch.zeros( - self.left_context_length, self.d_model, device=device - ), + torch.zeros(self.left_context_length, self.d_model, device=device), + torch.zeros(self.left_context_length, self.d_model, device=device), ] for _ in range(self.num_encoder_layers) ] @@ -1599,17 +1557,11 @@ class Emformer(EncoderInterface): raise NotImplementedError( "chunk_length must be a mutiple of subsampling_factor." ) - if ( - left_context_length != 0 - and left_context_length % subsampling_factor != 0 - ): + if left_context_length != 0 and left_context_length % subsampling_factor != 0: raise NotImplementedError( "left_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) - if ( - right_context_length != 0 - and right_context_length % subsampling_factor != 0 - ): + if right_context_length != 0 and right_context_length % subsampling_factor != 0: raise NotImplementedError( "right_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) @@ -1672,9 +1624,7 @@ class Emformer(EncoderInterface): x_lens = (((x_lens - 1) >> 1) - 1) >> 1 assert x.size(0) == x_lens.max().item() - output, output_lengths = self.encoder( - x, x_lens, warmup=warmup - ) # (T, N, C) + output, output_lengths = self.encoder(x, x_lens, warmup=warmup) # (T, N, C) output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py index ab15e0241..998fb6e81 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py @@ -136,8 +136,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -181,9 +180,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -210,9 +209,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -279,9 +278,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py index 71150392d..618d8bb63 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py @@ -211,8 +211,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -371,9 +370,7 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner( - current_encoder_out, decoder_out, project_input=False - ) + logits = model.joiner(current_encoder_out, decoder_out, project_input=False) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -390,9 +387,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -551,14 +546,10 @@ def decode_one_chunk( feature_list, batch_first=True, padding_value=LOG_EPSILON ).to(device) feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor( - num_processed_frames_list, device=device - ) + num_processed_frames = torch.tensor(num_processed_frames_list, device=device) # Make sure it has at least 1 frame after subsampling, first-and-last-frame cutting, and right context cutting # noqa - tail_length = ( - 3 * params.subsampling_factor + params.right_context_length + 3 - ) + tail_length = 3 * params.subsampling_factor + params.right_context_length + 3 if features.size(1) < tail_length: pad_length = tail_length - features.size(1) feature_lens += pad_length @@ -605,9 +596,7 @@ def decode_one_chunk( max_states=params.max_states, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") # Update cached states of each stream state_list = unstack_states(states) @@ -782,8 +771,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -831,9 +819,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -867,9 +853,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -896,9 +882,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py index 2bbc45d78..542f524a9 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py @@ -95,9 +95,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -265,8 +263,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -289,8 +286,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -636,11 +632,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -668,23 +660,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -871,9 +856,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -981,7 +964,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/local/add_alignment_librispeech.py b/egs/librispeech/ASR/local/add_alignment_librispeech.py index fe6a26c51..cc34a72d8 100755 --- a/egs/librispeech/ASR/local/add_alignment_librispeech.py +++ b/egs/librispeech/ASR/local/add_alignment_librispeech.py @@ -157,9 +157,7 @@ def add_alignment( for ali_path in part_ali_dir.rglob("*.alignment.txt"): ali = parse_alignments(ali_path) alignments.update(ali) - logging.info( - f"{part} has {len(alignments.keys())} cuts with alignments." - ) + logging.info(f"{part} has {len(alignments.keys())} cuts with alignments.") # add alignment attribute and write out cuts_in = load_manifest_lazy(cuts_in_path) @@ -170,18 +168,14 @@ def add_alignment( if origin_id in alignments: ali = alignments[origin_id] else: - logging.info( - f"Warning: {origin_id} does not have alignment." - ) + logging.info(f"Warning: {origin_id} does not have alignment.") ali = [] subcut.alignment = {"word": ali} writer.write(cut, flush=True) def main(): - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) parser = get_parser() diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index c628dfd53..df6c609bb 100755 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -57,7 +57,7 @@ def get_args(): return parser.parse_args() -def compile_HLG(lang_dir: str, lm: str="G_3_gram") -> k2.Fsa: +def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa: """ Args: lang_dir: @@ -159,9 +159,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/compile_lg.py b/egs/librispeech/ASR/local/compile_lg.py index 45c4b7f5f..19bf3bff4 100755 --- a/egs/librispeech/ASR/local/compile_lg.py +++ b/egs/librispeech/ASR/local/compile_lg.py @@ -132,9 +132,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py index c0c7ef8c5..97750f3ea 100644 --- a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py +++ b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py @@ -80,9 +80,7 @@ def compute_fbank_gigaspeech_dev_test(): def main(): - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_gigaspeech_dev_test() diff --git a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py index 5587106e5..ce0ef24e7 100644 --- a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py +++ b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py @@ -144,9 +144,7 @@ def main(): date_time = now.strftime("%Y-%m-%d-%H-%M-%S") log_filename = "log-compute_fbank_gigaspeech_splits" - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" log_filename = f"{log_filename}-{date_time}" logging.basicConfig( diff --git a/egs/librispeech/ASR/local/compute_fbank_librispeech.py b/egs/librispeech/ASR/local/compute_fbank_librispeech.py index ce7d087f0..9f8503814 100755 --- a/egs/librispeech/ASR/local/compute_fbank_librispeech.py +++ b/egs/librispeech/ASR/local/compute_fbank_librispeech.py @@ -112,9 +112,7 @@ def compute_fbank_librispeech(bpe_model: Optional[str] = None): if "train" in partition: cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -128,9 +126,7 @@ def compute_fbank_librispeech(bpe_model: Optional[str] = None): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() diff --git a/egs/librispeech/ASR/local/compute_fbank_musan.py b/egs/librispeech/ASR/local/compute_fbank_musan.py index 056da29e5..4a4093ae4 100755 --- a/egs/librispeech/ASR/local/compute_fbank_musan.py +++ b/egs/librispeech/ASR/local/compute_fbank_musan.py @@ -83,9 +83,7 @@ def compute_fbank_musan(): # create chunks of Musan with duration 5 - 10 seconds musan_cuts = ( CutSet.from_manifests( - recordings=combine( - part["recordings"] for part in manifests.values() - ) + recordings=combine(part["recordings"] for part in manifests.values()) ) .cut_into_windows(10.0) .filter(lambda c: c.duration > 5) @@ -101,9 +99,7 @@ def compute_fbank_musan(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_musan() diff --git a/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py b/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py index 133499c8b..a8d5117c9 100755 --- a/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py +++ b/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py @@ -51,16 +51,12 @@ def get_args(): "lines. Each line consists of space separated words.", ) parser.add_argument("--lexicon", type=str, help="The input lexicon file.") - parser.add_argument( - "--oov", type=str, default="", help="The OOV word." - ) + parser.add_argument("--oov", type=str, default="", help="The OOV word.") return parser.parse_args() -def process_line( - lexicon: Dict[str, List[str]], line: str, oov_token: str -) -> None: +def process_line(lexicon: Dict[str, List[str]], line: str, oov_token: str) -> None: """ Args: lexicon: diff --git a/egs/librispeech/ASR/local/download_lm.py b/egs/librispeech/ASR/local/download_lm.py index 030122aa7..3518db524 100755 --- a/egs/librispeech/ASR/local/download_lm.py +++ b/egs/librispeech/ASR/local/download_lm.py @@ -87,9 +87,7 @@ def main(out_dir: str): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/filter_cuts.py b/egs/librispeech/ASR/local/filter_cuts.py index dff98a954..b3f0956c3 100644 --- a/egs/librispeech/ASR/local/filter_cuts.py +++ b/egs/librispeech/ASR/local/filter_cuts.py @@ -79,8 +79,7 @@ def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor): total += 1 if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}" ) removed += 1 return False @@ -125,8 +124,7 @@ def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor): ans = cut_set.filter(remove_short_and_long_utterances).to_eager() ratio = removed / total * 100 logging.info( - f"Removed {removed} cuts from {total} cuts. " - f"{ratio:.3f}% data is removed." + f"Removed {removed} cuts from {total} cuts. " f"{ratio:.3f}% data is removed." ) return ans @@ -155,9 +153,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/generate_unique_lexicon.py b/egs/librispeech/ASR/local/generate_unique_lexicon.py index 566c0743d..3459c2f5a 100755 --- a/egs/librispeech/ASR/local/generate_unique_lexicon.py +++ b/egs/librispeech/ASR/local/generate_unique_lexicon.py @@ -91,9 +91,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/prepare_lang_bpe.py b/egs/librispeech/ASR/local/prepare_lang_bpe.py index dec8a7442..e121aefa9 100755 --- a/egs/librispeech/ASR/local/prepare_lang_bpe.py +++ b/egs/librispeech/ASR/local/prepare_lang_bpe.py @@ -150,9 +150,7 @@ def generate_lexicon( words_pieces_ids: List[List[int]] = sp.encode(words, out_type=int) # Now convert word piece IDs back to word piece strings. - words_pieces: List[List[str]] = [ - sp.id_to_piece(ids) for ids in words_pieces_ids - ] + words_pieces: List[List[str]] = [sp.id_to_piece(ids) for ids in words_pieces_ids] lexicon = [] for word, pieces in zip(words, words_pieces): diff --git a/egs/librispeech/ASR/local/prepare_lm_training_data.py b/egs/librispeech/ASR/local/prepare_lm_training_data.py index 5070341f1..32ae8c580 100755 --- a/egs/librispeech/ASR/local/prepare_lm_training_data.py +++ b/egs/librispeech/ASR/local/prepare_lm_training_data.py @@ -137,8 +137,7 @@ def main(): for i in range(num_sentences): if step and i % step == 0: logging.info( - f"Processed number of lines: {i} " - f"({i/num_sentences*100: .3f}%)" + f"Processed number of lines: {i} " f"({i/num_sentences*100: .3f}%)" ) word_ids = sentences[i] @@ -154,18 +153,14 @@ def main(): sentence_lengths[i] = token_ids.numel() - output["sentence_lengths"] = torch.tensor( - sentence_lengths, dtype=torch.int32 - ) + output["sentence_lengths"] = torch.tensor(sentence_lengths, dtype=torch.int32) torch.save(output, args.lm_archive) logging.info(f"Saved to {args.lm_archive}") if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/preprocess_gigaspeech.py b/egs/librispeech/ASR/local/preprocess_gigaspeech.py index 077f23039..8aa5e461d 100644 --- a/egs/librispeech/ASR/local/preprocess_gigaspeech.py +++ b/egs/librispeech/ASR/local/preprocess_gigaspeech.py @@ -119,9 +119,7 @@ def preprocess_giga_speech(): def main(): - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) preprocess_giga_speech() diff --git a/egs/librispeech/ASR/local/test_prepare_lang.py b/egs/librispeech/ASR/local/test_prepare_lang.py index d4cf62bba..74e025ad7 100755 --- a/egs/librispeech/ASR/local/test_prepare_lang.py +++ b/egs/librispeech/ASR/local/test_prepare_lang.py @@ -88,9 +88,7 @@ def test_read_lexicon(filename: str): fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.draw("L.pdf", title="L") - fsa_disambig = lexicon_to_fst( - lexicon_disambig, phone2id=phone2id, word2id=word2id - ) + fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.draw("L_disambig.pdf", title="L_disambig") diff --git a/egs/librispeech/ASR/local/validate_manifest.py b/egs/librispeech/ASR/local/validate_manifest.py index 7c57d629a..f620b91ea 100755 --- a/egs/librispeech/ASR/local/validate_manifest.py +++ b/egs/librispeech/ASR/local/validate_manifest.py @@ -85,9 +85,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py index 27414d717..79b21fab1 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py @@ -272,8 +272,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -366,9 +365,7 @@ def decode_one_batch( ) feature_lens += num_tail_padded_frames - encoder_out, encoder_out_lens, _ = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens, _ = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -427,10 +424,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -561,9 +555,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -596,8 +588,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -648,9 +639,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -682,9 +671,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -711,9 +700,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -772,9 +761,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/export.py b/egs/librispeech/ASR/lstm_transducer_stateless/export.py index 13dac6009..45fa6d662 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/export.py @@ -172,8 +172,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) add_model_arguments(parser) @@ -281,9 +280,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -310,9 +309,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -380,9 +379,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py index 594c33e4f..51f4a2e8a 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py @@ -124,8 +124,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -314,9 +313,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py index c54a4c478..bbab16af7 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py @@ -672,9 +672,7 @@ class RandomCombine(nn.Module): self.stddev = stddev self.final_log_weight = ( - torch.tensor( - (final_weight / (1 - final_weight)) * (self.num_inputs - 1) - ) + torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)) .log() .item() ) @@ -771,16 +769,14 @@ class RandomCombine(nn.Module): # final contains self.num_inputs - 1 in all elements final = torch.full((num_frames,), self.num_inputs - 1, device=device) # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. # noqa - nonfinal = torch.randint( - self.num_inputs - 1, (num_frames,), device=device - ) + nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device) indexes = torch.where( torch.rand(num_frames, device=device) < final_prob, final, nonfinal ) - ans = torch.nn.functional.one_hot( - indexes, num_classes=self.num_inputs - ).to(dtype=dtype) + ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to( + dtype=dtype + ) return ans def _get_random_mixed_weights( diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/model.py b/egs/librispeech/ASR/lstm_transducer_stateless/model.py index d71132b4a..e7bad7ed8 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/model.py @@ -66,9 +66,7 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) def forward( @@ -151,9 +149,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py index 2a6e2adc6..9263b41b2 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py @@ -166,8 +166,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -199,8 +198,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -264,15 +262,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens, _ = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens, _ = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -344,9 +338,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/stream.py b/egs/librispeech/ASR/lstm_transducer_stateless/stream.py index 97d890c82..d8f7fd960 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/stream.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/stream.py @@ -70,14 +70,12 @@ class Stream(object): elif params.decoding_method == "fast_beam_search": # feature_len is needed to get partial results. # The rnnt_decoding_stream for fast_beam_search. - self.rnnt_decoding_stream: k2.RnntDecodingStream = ( - k2.RnntDecodingStream(decoding_graph) + self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( + decoding_graph ) self.hyp: Optional[List[int]] = None else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") self.ground_truth: str = "" diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py index d6376bdc0..4cc2aabb2 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py @@ -199,8 +199,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -359,9 +358,7 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner( - current_encoder_out, decoder_out, project_input=False - ) + logits = model.joiner(current_encoder_out, decoder_out, project_input=False) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -378,9 +375,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -539,9 +534,7 @@ def decode_one_chunk( feature_list, batch_first=True, padding_value=LOG_EPSILON ).to(device) feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor( - num_processed_frames_list, device=device - ) + num_processed_frames = torch.tensor(num_processed_frames_list, device=device) # Make sure it has at least 1 frame after subsampling tail_length = params.subsampling_factor + 5 @@ -583,8 +576,7 @@ def decode_one_chunk( with warnings.catch_warnings(): warnings.simplefilter("ignore") processed_lens = ( - num_processed_frames // params.subsampling_factor - + encoder_out_lens + num_processed_frames // params.subsampling_factor + encoder_out_lens ) fast_beam_search_one_best( model=model, @@ -596,9 +588,7 @@ def decode_one_chunk( max_states=params.max_states, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") # Update cached states of each stream state_list = unstack_states(states) @@ -773,8 +763,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -816,9 +805,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -852,9 +839,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -881,9 +868,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/train.py b/egs/librispeech/ASR/lstm_transducer_stateless/train.py index d30fc260a..b9a68753e 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/train.py @@ -87,9 +87,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -222,8 +220,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -246,8 +243,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -594,11 +590,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -638,9 +630,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -653,14 +643,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -671,9 +656,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -856,9 +839,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -989,8 +970,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index bad4e243e..41602d207 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -295,8 +295,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -474,9 +473,7 @@ def decode_one_batch( ) feature_lens += num_tail_padded_frames - encoder_out, encoder_out_lens, _ = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens, _ = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -535,10 +532,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -700,9 +694,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -735,8 +727,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -789,9 +780,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -826,9 +815,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -860,9 +849,9 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -961,9 +950,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py index 190673638..2a25cb46a 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py @@ -225,8 +225,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) add_model_arguments(parser) @@ -342,9 +341,7 @@ def export_encoder_model_onnx( x = torch.zeros(N, 9, 80, dtype=torch.float32) x_lens = torch.tensor([9], dtype=torch.int64) h = torch.rand(encoder_model.num_encoder_layers, N, encoder_model.d_model) - c = torch.rand( - encoder_model.num_encoder_layers, N, encoder_model.rnn_hidden_size - ) + c = torch.rand(encoder_model.num_encoder_layers, N, encoder_model.rnn_hidden_size) warmup = 1.0 torch.onnx.export( @@ -445,13 +442,9 @@ def export_joiner_model_onnx( - projected_decoder_out: a tensor of shape (N, joiner_dim) """ - encoder_proj_filename = str(joiner_filename).replace( - ".onnx", "_encoder_proj.onnx" - ) + encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") - decoder_proj_filename = str(joiner_filename).replace( - ".onnx", "_decoder_proj.onnx" - ) + decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] @@ -550,9 +543,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -585,9 +578,9 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -694,9 +687,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py index da184b76f..40f11018f 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py @@ -125,8 +125,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -315,9 +314,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py index fadeb4ac2..4957d14b1 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py @@ -84,9 +84,7 @@ class Transducer(nn.Module): self.decoder_giga = decoder_giga self.joiner_giga = joiner_giga - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) if decoder_giga is not None: @@ -190,9 +188,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py index 410de8d3d..ab2f17480 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py @@ -156,9 +156,7 @@ class Model: assert ret == 0, ret encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() - encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to( - torch.int32 - ) + encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to(torch.int32) hx = torch.from_numpy(ncnn_out2.numpy()).clone() cx = torch.from_numpy(ncnn_out3.numpy()).clone() return encoder_out, encoder_out_lens, hx, cx @@ -201,8 +199,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -286,9 +283,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py index bef0ad760..2983328bf 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py @@ -169,8 +169,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -202,8 +201,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -267,15 +265,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens, _ = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens, _ = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -347,9 +341,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py index e47a05a9e..a787a00e6 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py @@ -144,9 +144,7 @@ class Model: assert ret == 0, ret encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() - encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to( - torch.int32 - ) + encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to(torch.int32) hx = torch.from_numpy(ncnn_out2.numpy()).clone() cx = torch.from_numpy(ncnn_out3.numpy()).clone() return encoder_out, encoder_out_lens, hx, cx @@ -189,8 +187,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -229,9 +226,7 @@ def greedy_search( if decoder_out is None: assert hyp is None, hyp hyp = [blank_id] * context_size - decoder_input = torch.tensor( - hyp, dtype=torch.int32 - ) # (1, context_size) + decoder_input = torch.tensor(hyp, dtype=torch.int32) # (1, context_size) decoder_out = model.run_decoder(decoder_input).squeeze(0) else: assert decoder_out.ndim == 1 @@ -310,9 +305,7 @@ def main(): frames.append(online_fbank.get_frame(num_processed_frames + i)) num_processed_frames += offset frames = torch.cat(frames, dim=0) - encoder_out, encoder_out_lens, hx, cx = model.run_encoder( - frames, states - ) + encoder_out, encoder_out_lens, hx, cx = model.run_encoder(frames, states) states = (hx, cx) hyp, decoder_out = greedy_search( model, encoder_out.squeeze(0), decoder_out, hyp @@ -328,9 +321,7 @@ def main(): frames.append(online_fbank.get_frame(num_processed_frames + i)) num_processed_frames += offset frames = torch.cat(frames, dim=0) - encoder_out, encoder_out_lens, hx, cx = model.run_encoder( - frames, states - ) + encoder_out, encoder_out_lens, hx, cx = model.run_encoder(frames, states) states = (hx, cx) hyp, decoder_out = greedy_search( model, encoder_out.squeeze(0), decoder_out, hyp @@ -343,9 +334,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py index 232d3dd18..e896fd510 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py @@ -148,8 +148,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -199,9 +198,7 @@ class Model: sess_options=self.session_opts, ) - def run_encoder( - self, x, h0, c0 - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def run_encoder(self, x, h0, c0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: x: @@ -258,9 +255,7 @@ class Model: }, )[0] - return self.run_joiner_decoder_proj( - torch.from_numpy(decoder_out).squeeze(1) - ) + return self.run_joiner_decoder_proj(torch.from_numpy(decoder_out).squeeze(1)) def run_joiner( self, @@ -303,11 +298,7 @@ class Model: projected_encoder_out = self.joiner_encoder_proj.run( [self.joiner_encoder_proj.get_outputs()[0].name], - { - self.joiner_encoder_proj.get_inputs()[ - 0 - ].name: encoder_out.numpy() - }, + {self.joiner_encoder_proj.get_inputs()[0].name: encoder_out.numpy()}, )[0] return torch.from_numpy(projected_encoder_out) @@ -326,11 +317,7 @@ class Model: projected_decoder_out = self.joiner_decoder_proj.run( [self.joiner_decoder_proj.get_outputs()[0].name], - { - self.joiner_decoder_proj.get_inputs()[ - 0 - ].name: decoder_out.numpy() - }, + {self.joiner_decoder_proj.get_inputs()[0].name: decoder_out.numpy()}, )[0] return torch.from_numpy(projected_decoder_out) @@ -369,9 +356,7 @@ def greedy_search( if decoder_out is None: assert hyp is None, hyp hyp = [blank_id] * context_size - decoder_input = torch.tensor( - [hyp], dtype=torch.int64 - ) # (1, context_size) + decoder_input = torch.tensor([hyp], dtype=torch.int64) # (1, context_size) decoder_out = model.run_decoder(decoder_input) else: assert decoder_out.shape[0] == 1 @@ -474,9 +459,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py index 5eaaf321f..056285c64 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py @@ -95,9 +95,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -163,8 +161,7 @@ def get_parser(): "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. " - "Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. " "Otherwise, use 100h subset.", ) parser.add_argument( @@ -238,8 +235,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -262,8 +258,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -645,11 +640,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -692,9 +683,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -707,14 +696,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -725,9 +709,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -958,9 +940,7 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) libri_tot_loss.write_summary( tb_writer, "train/libri_tot_", params.batch_idx_train ) @@ -1006,8 +986,7 @@ def filter_short_and_long_utterances( # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}" ) return False @@ -1155,9 +1134,7 @@ def run(rank, world_size, args): train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: - cuts_musan = load_manifest( - Path(args.manifest_dir) / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") else: cuts_musan = None diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py index 9eee19379..cba1ac689 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py @@ -290,8 +290,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -386,9 +385,7 @@ def decode_one_batch( ) feature_lens += num_tail_padded_frames - encoder_out, encoder_out_lens, _ = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens, _ = model.encoder(x=feature, x_lens=feature_lens) if params.decoding_method == "fast_beam_search": res = fast_beam_search_one_best( @@ -441,10 +438,7 @@ def decode_one_batch( nbest_scale=params.nbest_scale, return_timestamps=True, ) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: res = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -522,9 +516,7 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[ - str, List[Tuple[str, List[str], List[str], List[float], List[float]]] -]: +) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]: """Decode dataset. Args: @@ -599,9 +591,7 @@ def decode_dataset( cut_ids, hyps, texts, timestamps_hyp, timestamps_ref ): ref_words = ref_text.split() - this_batch.append( - (cut_id, ref_words, hyp_words, time_ref, time_hyp) - ) + this_batch.append((cut_id, ref_words, hyp_words, time_ref, time_hyp)) results[name].extend(this_batch) @@ -610,9 +600,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -650,8 +638,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -678,9 +665,7 @@ def save_results( note = "" logging.info(s) - s = "\nFor {}, symbol-delay of different settings are:\n".format( - test_set_name - ) + s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name) note = "\tbest for {}".format(test_set_name) for key, val in test_set_delays: s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note) @@ -724,9 +709,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -758,9 +741,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -787,9 +770,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -848,9 +831,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/export.py b/egs/librispeech/ASR/lstm_transducer_stateless3/export.py index 212c7bad6..457bd472f 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/export.py @@ -172,8 +172,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) add_model_arguments(parser) @@ -281,9 +280,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -310,9 +309,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -380,9 +379,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py index a3443cf0a..71b37ac55 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py @@ -124,8 +124,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -314,9 +313,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py index 90bc351f4..6e51b85e4 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py @@ -661,9 +661,7 @@ class RandomCombine(nn.Module): self.stddev = stddev self.final_log_weight = ( - torch.tensor( - (final_weight / (1 - final_weight)) * (self.num_inputs - 1) - ) + torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)) .log() .item() ) @@ -760,16 +758,14 @@ class RandomCombine(nn.Module): # final contains self.num_inputs - 1 in all elements final = torch.full((num_frames,), self.num_inputs - 1, device=device) # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. # noqa - nonfinal = torch.randint( - self.num_inputs - 1, (num_frames,), device=device - ) + nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device) indexes = torch.where( torch.rand(num_frames, device=device) < final_prob, final, nonfinal ) - ans = torch.nn.functional.one_hot( - indexes, num_classes=self.num_inputs - ).to(dtype=dtype) + ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to( + dtype=dtype + ) return ans def _get_random_mixed_weights( diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py index 0e48fef04..e72f4ee42 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py @@ -166,8 +166,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -199,8 +198,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -264,15 +262,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens, _ = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens, _ = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -344,9 +338,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py index cfa918ed5..dad6b905f 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py @@ -199,8 +199,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -359,9 +358,7 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner( - current_encoder_out, decoder_out, project_input=False - ) + logits = model.joiner(current_encoder_out, decoder_out, project_input=False) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -378,9 +375,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -539,9 +534,7 @@ def decode_one_chunk( feature_list, batch_first=True, padding_value=LOG_EPSILON ).to(device) feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor( - num_processed_frames_list, device=device - ) + num_processed_frames = torch.tensor(num_processed_frames_list, device=device) # Make sure it has at least 1 frame after subsampling tail_length = params.subsampling_factor + 5 @@ -583,8 +576,7 @@ def decode_one_chunk( with warnings.catch_warnings(): warnings.simplefilter("ignore") processed_lens = ( - num_processed_frames // params.subsampling_factor - + encoder_out_lens + num_processed_frames // params.subsampling_factor + encoder_out_lens ) fast_beam_search_one_best( model=model, @@ -596,9 +588,7 @@ def decode_one_chunk( max_states=params.max_states, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") # Update cached states of each stream state_list = unstack_states(states) @@ -773,8 +763,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -816,9 +805,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -852,9 +839,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -881,9 +868,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py index 60a5a2be7..97ca4b94c 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py @@ -87,9 +87,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -232,8 +230,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -256,8 +253,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -606,11 +602,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -650,9 +642,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -665,14 +655,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -683,9 +668,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -852,10 +835,7 @@ def train_one_epoch( rank=rank, ) - if ( - batch_idx % params.log_interval == 0 - and not params.print_diagnostics - ): + if batch_idx % params.log_interval == 0 and not params.print_diagnostics: cur_lr = scheduler.get_last_lr()[0] logging.info( f"Epoch {params.cur_epoch}, " @@ -872,9 +852,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if ( batch_idx > 0 @@ -1009,8 +987,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py index 8dd1459ca..3dc9164f8 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py +++ b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py @@ -83,8 +83,7 @@ class LibriSpeechAsrDataModule: "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. " - "Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. " "Otherwise, use 100h subset.", ) group.add_argument( "--manifest-dir", @@ -208,13 +207,9 @@ class LibriSpeechAsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "cuts_musan.json.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.json.gz") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -236,9 +231,7 @@ class LibriSpeechAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -281,9 +274,7 @@ class LibriSpeechAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -340,9 +331,7 @@ class LibriSpeechAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: @@ -389,23 +378,17 @@ class LibriSpeechAsrDataModule: @lru_cache() def train_clean_100_cuts(self) -> CutSet: logging.info("About to get train-clean-100 cuts") - return load_manifest( - self.args.manifest_dir / "cuts_train-clean-100.json.gz" - ) + return load_manifest(self.args.manifest_dir / "cuts_train-clean-100.json.gz") @lru_cache() def train_clean_360_cuts(self) -> CutSet: logging.info("About to get train-clean-360 cuts") - return load_manifest( - self.args.manifest_dir / "cuts_train-clean-360.json.gz" - ) + return load_manifest(self.args.manifest_dir / "cuts_train-clean-360.json.gz") @lru_cache() def train_other_500_cuts(self) -> CutSet: logging.info("About to get train-other-500 cuts") - return load_manifest( - self.args.manifest_dir / "cuts_train-other-500.json.gz" - ) + return load_manifest(self.args.manifest_dir / "cuts_train-other-500.json.gz") @lru_cache() def dev_clean_cuts(self) -> CutSet: diff --git a/egs/librispeech/ASR/pruned2_knowledge/beam_search.py b/egs/librispeech/ASR/pruned2_knowledge/beam_search.py index 2e9bf3e0b..785a8f097 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/beam_search.py +++ b/egs/librispeech/ASR/pruned2_knowledge/beam_search.py @@ -172,9 +172,9 @@ def greedy_search( y = logits.argmax().item() if y != blank_id: hyp.append(y) - decoder_input = torch.tensor( - [hyp[-context_size:]], device=device - ).reshape(1, context_size) + decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( + 1, context_size + ) decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.joiner.decoder_proj(decoder_out) @@ -302,9 +302,7 @@ class HypothesisList(object): key = hyp.key if key in self: old_hyp = self._data[key] # shallow copy - torch.logaddexp( - old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob - ) + torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) else: self._data[key] = hyp @@ -320,9 +318,7 @@ class HypothesisList(object): Return the hypothesis that has the largest `log_prob`. """ if length_norm: - return max( - self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) - ) + return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -496,9 +492,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) diff --git a/egs/librispeech/ASR/pruned2_knowledge/conformer.py b/egs/librispeech/ASR/pruned2_knowledge/conformer.py index 295a35204..de367c234 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/conformer.py +++ b/egs/librispeech/ASR/pruned2_knowledge/conformer.py @@ -18,10 +18,10 @@ import math import warnings from typing import Optional, Tuple -from sampling import create_knowledge_base, KnowledgeBaseLookup import torch from encoder_interface import EncoderInterface +from sampling import KnowledgeBaseLookup, create_knowledge_base from scaling import ( ActivationBalancer, BasicNorm, @@ -73,9 +73,9 @@ class Conformer(EncoderInterface): if subsampling_factor != 4: raise NotImplementedError("Support only 'subsampling_factor=4'.") - - self.knowledge_base = create_knowledge_base(knowledge_M, knowledge_N, - knowledge_D) + self.knowledge_base = create_knowledge_base( + knowledge_M, knowledge_N, knowledge_D + ) # self.encoder_embed converts the input of shape (N, T, num_features) # to the shape (N, T//subsampling_factor, d_model). @@ -89,7 +89,7 @@ class Conformer(EncoderInterface): # Pass in a lambda that creates a new ConformerEncoderLayer with these # args. Don't use deepcopy because we need the knowledge_base # to be shared. - encoder_layer_fn = lambda: ConformerEncoderLayer( + encoder_layer_fn = lambda: ConformerEncoderLayer( # noqa: E731 self.knowledge_base, d_model, nhead, @@ -100,7 +100,7 @@ class Conformer(EncoderInterface): knowledge_M, knowledge_N, knowledge_D, - knowledge_K + knowledge_K, ) self.encoder = ConformerEncoder(encoder_layer_fn, num_encoder_layers) @@ -187,9 +187,7 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -209,10 +207,9 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.lookup = KnowledgeBaseLookup(knowledge_M, knowledge_N, - knowledge_D, knowledge_K, - d_model, - knowledge_base) + self.lookup = KnowledgeBaseLookup( + knowledge_M, knowledge_N, knowledge_D, knowledge_K, d_model, knowledge_base + ) self.norm_final = BasicNorm(d_model) @@ -311,9 +308,7 @@ class ConformerEncoder(nn.Module): def __init__(self, encoder_layer_fn, num_layers: int) -> None: super().__init__() - self.layers = nn.ModuleList( - [encoder_layer_fn() for i in range(num_layers)] - ) + self.layers = nn.ModuleList([encoder_layer_fn() for i in range(num_layers)]) self.num_layers = num_layers def forward( @@ -367,9 +362,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -384,9 +377,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vecotr and `j` means the @@ -661,9 +652,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -732,31 +723,22 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -795,9 +777,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -805,13 +785,9 @@ class RelPositionMultiheadAttention(nn.Module): ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -845,13 +821,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -874,9 +846,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/pruned2_knowledge/decode.py b/egs/librispeech/ASR/pruned2_knowledge/decode.py index b4a9af55a..c3e7b01ab 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/decode.py +++ b/egs/librispeech/ASR/pruned2_knowledge/decode.py @@ -76,11 +76,7 @@ from beam_search import ( ) from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -186,8 +182,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -245,9 +240,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -262,10 +255,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -385,9 +375,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -419,8 +407,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned2_knowledge/decoder.py b/egs/librispeech/ASR/pruned2_knowledge/decoder.py index b6d94aaf1..0b9c886c7 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/decoder.py +++ b/egs/librispeech/ASR/pruned2_knowledge/decoder.py @@ -90,9 +90,7 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned2_knowledge/decoder2.py b/egs/librispeech/ASR/pruned2_knowledge/decoder2.py index db51fb1cd..0c9cee431 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/decoder2.py +++ b/egs/librispeech/ASR/pruned2_knowledge/decoder2.py @@ -14,12 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + import torch import torch.nn as nn import torch.nn.functional as F -from torch import Tensor -from typing import Optional from subsampling import ScaledConv1d +from torch import Tensor class Decoder(nn.Module): @@ -90,9 +91,7 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output @@ -102,7 +101,6 @@ class Decoder(nn.Module): return embedding_out - class ScaledEmbedding(nn.Module): r"""A simple lookup table that stores embeddings of a fixed dictionary and size. @@ -171,8 +169,13 @@ class ScaledEmbedding(nn.Module): [ 0.0000, 0.0000, 0.0000], [-0.1655, 0.9897, 0.0635]]]) """ - __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', - 'scale_grad_by_freq', 'sparse'] + __constants__ = [ + "num_embeddings", + "embedding_dim", + "padding_idx", + "scale_grad_by_freq", + "sparse", + ] num_embeddings: int embedding_dim: int @@ -181,34 +184,41 @@ class ScaledEmbedding(nn.Module): weight: Tensor sparse: bool - def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, - scale_grad_by_freq: bool = False, - sparse: bool = False, - scale_speed: float = 5.0) -> None: + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + scale_grad_by_freq: bool = False, + sparse: bool = False, + scale_speed: float = 5.0, + ) -> None: super(ScaledEmbedding, self).__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim if padding_idx is not None: if padding_idx > 0: - assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert ( + padding_idx < self.num_embeddings + ), "Padding_idx must be within num_embeddings" elif padding_idx < 0: - assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert ( + padding_idx >= -self.num_embeddings + ), "Padding_idx must be within num_embeddings" padding_idx = self.num_embeddings + padding_idx self.padding_idx = padding_idx self.scale_grad_by_freq = scale_grad_by_freq self.scale_speed = scale_speed - self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() + self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() self.sparse = sparse self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) self.reset_parameters() - - def reset_parameters(self) -> None: nn.init.normal_(self.weight, std=0.05) - nn.init.constant_(self.scale, torch.tensor(1.0/0.05).log() / self.scale_speed) + nn.init.constant_(self.scale, torch.tensor(1.0 / 0.05).log() / self.scale_speed) if self.padding_idx is not None: with torch.no_grad(): @@ -217,22 +227,35 @@ class ScaledEmbedding(nn.Module): def forward(self, input: Tensor) -> Tensor: scale = (self.scale * self.scale_speed).exp() if input.numel() < self.num_embeddings: - return F.embedding( - input, self.weight, self.padding_idx, - None, 2.0, # None, 2.0 relate to normalization - self.scale_grad_by_freq, self.sparse) * scale + return ( + F.embedding( + input, + self.weight, + self.padding_idx, + None, + 2.0, # None, 2.0 relate to normalization + self.scale_grad_by_freq, + self.sparse, + ) + * scale + ) else: return F.embedding( - input, self.weight * scale, self.padding_idx, - None, 2.0, # None, 2.0 relates to normalization - self.scale_grad_by_freq, self.sparse) + input, + self.weight * scale, + self.padding_idx, + None, + 2.0, # None, 2.0 relates to normalization + self.scale_grad_by_freq, + self.sparse, + ) def extra_repr(self) -> str: - s = '{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}, scale={scale}' + s = "{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}, scale={scale}" if self.padding_idx is not None: - s += ', padding_idx={padding_idx}' + s += ", padding_idx={padding_idx}" if self.scale_grad_by_freq is not False: - s += ', scale_grad_by_freq={scale_grad_by_freq}' + s += ", scale_grad_by_freq={scale_grad_by_freq}" if self.sparse is not False: - s += ', sparse=True' + s += ", sparse=True" return s.format(**self.__dict__) diff --git a/egs/librispeech/ASR/pruned2_knowledge/export.py b/egs/librispeech/ASR/pruned2_knowledge/export.py index 96d1a30fb..ce5f162bf 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/export.py +++ b/egs/librispeech/ASR/pruned2_knowledge/export.py @@ -105,8 +105,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) return parser @@ -174,9 +173,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned2_knowledge/joiner.py b/egs/librispeech/ASR/pruned2_knowledge/joiner.py index 35f75ed2a..68c663b66 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/joiner.py +++ b/egs/librispeech/ASR/pruned2_knowledge/joiner.py @@ -56,9 +56,7 @@ class Joiner(nn.Module): assert encoder_out.shape[:-1] == decoder_out.shape[:-1] if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj( - decoder_out - ) + logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/pruned2_knowledge/model.py b/egs/librispeech/ASR/pruned2_knowledge/model.py index 599bf2506..ca8c28af1 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/model.py +++ b/egs/librispeech/ASR/pruned2_knowledge/model.py @@ -63,9 +63,7 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) def forward( @@ -136,9 +134,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned2_knowledge/optim.py b/egs/librispeech/ASR/pruned2_knowledge/optim.py index 432bf8220..76cd4e11e 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/optim.py +++ b/egs/librispeech/ASR/pruned2_knowledge/optim.py @@ -72,17 +72,11 @@ class Eve(Optimizer): if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: - raise ValueError( - "Invalid beta parameter at index 0: {}".format(betas[0]) - ) + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: - raise ValueError( - "Invalid beta parameter at index 1: {}".format(betas[1]) - ) + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) if not 0 <= weight_decay <= 0.1: - raise ValueError( - "Invalid weight_decay value: {}".format(weight_decay) - ) + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict( @@ -118,9 +112,7 @@ class Eve(Optimizer): # Perform optimization step grad = p.grad if grad.is_sparse: - raise RuntimeError( - "AdamW does not support sparse gradients" - ) + raise RuntimeError("AdamW does not support sparse gradients") state = self.state[p] @@ -147,7 +139,7 @@ class Eve(Optimizer): # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( + denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( group["eps"] ) @@ -158,9 +150,7 @@ class Eve(Optimizer): if p.numel() > 1: # avoid applying this weight-decay on "scaling factors" # (which are scalar). - is_above_target_rms = p.norm() > ( - target_rms * (p.numel() ** 0.5) - ) + is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) @@ -176,18 +166,14 @@ class LRScheduler(object): def __init__(self, optimizer: Optimizer, verbose: bool = False): # Attach optimizer if not isinstance(optimizer, Optimizer): - raise TypeError( - "{} is not an Optimizer".format(type(optimizer).__name__) - ) + raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) self.optimizer = optimizer self.verbose = verbose for group in optimizer.param_groups: group.setdefault("initial_lr", group["lr"]) - self.base_lrs = [ - group["initial_lr"] for group in optimizer.param_groups - ] + self.base_lrs = [group["initial_lr"] for group in optimizer.param_groups] self.epoch = 0 self.batch = 0 @@ -295,10 +281,9 @@ class Eden(LRScheduler): def get_lr(self): factor = ( - (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 + (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 ) ** -0.25 * ( - ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) - ** -0.25 + ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 ) return [x * factor for x in self.base_lrs] diff --git a/egs/librispeech/ASR/pruned2_knowledge/sampling.py b/egs/librispeech/ASR/pruned2_knowledge/sampling.py index 7b05e2f00..5b595c76c 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/sampling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/sampling.py @@ -3,32 +3,29 @@ # This was copied from /ceph-dan/torch-sampling/torch_sampling/sampling_ref.py, # its git history is there. -import timeit -import torch -from torch import Tensor -from torch import nn -from torch.cuda.amp import GradScaler, custom_fwd, custom_bwd -from typing import Tuple, Optional -from scaling import ScaledLinear import random +import timeit +from typing import Optional, Tuple + +import torch +from scaling import ScaledLinear +from torch import Tensor, nn +from torch.cuda.amp import GradScaler, custom_bwd, custom_fwd from torch_scheduled_sampling import sample_combined # The main exports of this file are the module KnowledgeBaseLookup and the # function create_knowledge_base. - - - - def create_knowledge_base(M: int, N: int, D: int) -> nn.Parameter: std = 0.1 - a = (3 ** 0.5) * std # this sqrt(3) thing is intended to get variance of - # 0.1 from uniform distribution - ans = nn.Parameter(torch.ones(M ** N, D)) + a = (3**0.5) * std # this sqrt(3) thing is intended to get variance of + # 0.1 from uniform distribution + ans = nn.Parameter(torch.ones(M**N, D)) nn.init.uniform_(ans, -a, a) return ans + def join_indexes(indexes: Tensor, M: int) -> Tensor: """ Combines N-tuples of indexes into single indexes that can be used for @@ -47,9 +44,9 @@ def join_indexes(indexes: Tensor, M: int) -> Tensor: # Note, we don't use this, we -def weighted_matrix_lookup(weights: Tensor, - indexes: Tensor, - knowledge_base: Tensor) -> Tensor: +def weighted_matrix_lookup( + weights: Tensor, indexes: Tensor, knowledge_base: Tensor +) -> Tensor: """ Weighted combination of specified rows of a matrix. weights: Tensor of shape (*, K), can contain any value but probably in [0..1]. @@ -65,9 +62,9 @@ def weighted_matrix_lookup(weights: Tensor, # simpler but less memory-efficient implementation lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) D = knowledge_base.shape[-1] - weights = weights.unsqueeze(-2) # (*, 1, K) - lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) - ans = torch.matmul(weights, lookup) # ans: (*, 1, D) + weights = weights.unsqueeze(-2) # (*, 1, K) + lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) + ans = torch.matmul(weights, lookup) # ans: (*, 1, D) ans = ans.squeeze(-2) assert list(ans.shape) == list(weights.shape[:-2]) + [D] return ans @@ -76,7 +73,9 @@ def weighted_matrix_lookup(weights: Tensor, class WeightedMatrixLookupFunction(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, weights: Tensor, indexes: Tensor, knowledge_base: Tensor) -> Tensor: + def forward( + ctx, weights: Tensor, indexes: Tensor, knowledge_base: Tensor + ) -> Tensor: """ Weighted combination of specified rows of a matrix. weights: Tensor of shape (*, K), can contain any value but probably in [0..1]. @@ -88,15 +87,16 @@ class WeightedMatrixLookupFunction(torch.autograd.Function): """ if random.random() < 0.001: print("dtype[1] = ", weights.dtype) - ctx.save_for_backward(weights.detach(), indexes.detach(), - knowledge_base.detach()) + ctx.save_for_backward( + weights.detach(), indexes.detach(), knowledge_base.detach() + ) with torch.no_grad(): lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) D = knowledge_base.shape[-1] - weights = weights.unsqueeze(-2) # (*, 1, K) - lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) - ans = torch.matmul(weights, lookup) # ans: (*, 1, D) - ans = ans.squeeze(-2) #(*, D) + weights = weights.unsqueeze(-2) # (*, 1, K) + lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) + ans = torch.matmul(weights, lookup) # ans: (*, 1, D) + ans = ans.squeeze(-2) # (*, D) return ans @staticmethod @@ -107,7 +107,7 @@ class WeightedMatrixLookupFunction(torch.autograd.Function): knowledge_base.requires_grad = True dtype = ans_grad.dtype ans_grad = ans_grad.to(weights.dtype) - assert weights.requires_grad == False + assert weights.requires_grad is False D = knowledge_base.shape[-1] with torch.enable_grad(): # we'll use torch's autograd to differentiate this operation, which @@ -115,16 +115,19 @@ class WeightedMatrixLookupFunction(torch.autograd.Function): # We don't save `lookup` because it's large, that is the reason # we override Torch autograd. lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) - lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) - weights = weights.unsqueeze(-1) # (*, K, 1) + lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) + weights = weights.unsqueeze(-1) # (*, K, 1) # forward pass: was: ## ans = torch.matmul(weights, lookup) ## ans: (*, 1, D) ## ans = ans.squeeze(-2) # ans, ans_grad: (*, D) - weights_grad = torch.matmul(lookup, # (*, K, D) - ans_grad.unsqueeze(-1)) # (*, D, 1) - weights_grad = weights_grad.squeeze(-1) # (*, K, 1) -> (*, K) - lookup_grad = weights * ans_grad.unsqueeze(-2) # (*, K, 1) * (*, 1, D) = (*, K, D) + weights_grad = torch.matmul( + lookup, ans_grad.unsqueeze(-1) # (*, K, D) + ) # (*, D, 1) + weights_grad = weights_grad.squeeze(-1) # (*, K, 1) -> (*, K) + lookup_grad = weights * ans_grad.unsqueeze( + -2 + ) # (*, K, 1) * (*, 1, D) = (*, K, D) lookup.backward(gradient=lookup_grad) return weights_grad.to(dtype), None, knowledge_base.grad.to(dtype) @@ -146,6 +149,7 @@ class PenalizeNegentropyFunction(torch.autograd.Function): Returns: logprobs """ + @staticmethod def forward(ctx, logprobs: Tensor, alpha: float): ctx.save_for_backward(logprobs.detach()) @@ -154,18 +158,23 @@ class PenalizeNegentropyFunction(torch.autograd.Function): @staticmethod def backward(ctx, logprobs_grad: Tensor) -> Tuple[Tensor, None]: - logprobs, = ctx.saved_tensors + (logprobs,) = ctx.saved_tensors with torch.enable_grad(): logprobs.requires_grad = True # `negentropy` is the negative entropy of the average distribution. # distributions. It will be <= 0. - l = logprobs.reshape(-1, logprobs.shape[-1]) + l = logprobs.reshape(-1, logprobs.shape[-1]) # noqa: E741 scale = ctx.alpha * l.shape[0] avg_dist = l.exp().mean(dim=0) negentropy = (avg_dist * (avg_dist + 1.0e-20).log()).sum() if random.random() < 0.0005: negentropy_individual = (l * l.exp()).sum(dim=-1).mean() - print("Negentropy[individual,combined] = ", negentropy_individual.item(), ", ", negentropy.item()) + print( + "Negentropy[individual,combined] = ", + negentropy_individual.item(), + ", ", + negentropy.item(), + ) loss = negentropy * scale loss.backward() return logprobs_grad + logprobs.grad, None @@ -183,18 +192,23 @@ class KnowledgeBaseLookup(nn.Module): embedding_dim: the dimension to project from and to, e.g. the d_model of the conformer. """ - def __init__(self, M: int, N: int, D: int, - K: int, embedding_dim: int, - knowledge_base: nn.Parameter, - negentropy_penalty: float = 0.001): + + def __init__( + self, + M: int, + N: int, + D: int, + K: int, + embedding_dim: int, + knowledge_base: nn.Parameter, + negentropy_penalty: float = 0.001, + ): super(KnowledgeBaseLookup, self).__init__() self.knowledge_base = knowledge_base # shared! - self.in_proj = ScaledLinear(embedding_dim, M * N, - initial_scale=1.0) + self.in_proj = ScaledLinear(embedding_dim, M * N, initial_scale=1.0) # initial_scale = 4.0 because the knowlege_base activations are # quite small -- if we use our optimizer they'll have stddev <= 0.1. - self.out_proj = ScaledLinear(D, embedding_dim, - initial_scale = 4.0) + self.out_proj = ScaledLinear(D, embedding_dim, initial_scale=4.0) self.M = M self.N = N self.K = K @@ -210,14 +224,14 @@ class KnowledgeBaseLookup(nn.Module): # TODO: later we can try multiplying by a projection of x or something like that. """ - x = self.in_proj(x) # now (*, M*N) - x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M) - x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M) + x = self.in_proj(x) # now (*, M*N) + x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M) + x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M) x = PenalizeNegentropyFunction.apply(x, self.negentropy_penalty) _, indexes, weights = sample_combined(x, self.K, input_is_log=True) - x = weighted_matrix_lookup(weights, indexes, self.knowledge_base) # now (*, D) - x = self.out_proj(x) # now (*, self.embedding_dim) + x = weighted_matrix_lookup(weights, indexes, self.knowledge_base) # now (*, D) + x = self.out_proj(x) # now (*, self.embedding_dim) return x @@ -237,38 +251,44 @@ def _test_knowledge_base_lookup(): x.requires_grad = True y = m(x) assert y.shape == x.shape - y.sum().backward() # make sure backward doesn't crash.. + y.sum().backward() # make sure backward doesn't crash.. print("y = ", y) print("x.grad = ", x.grad) print("knowlege_base.grad norm = ", knowledge_base.grad.norm()) dtype = torch.float32 - device = torch.device('cuda') - train_pairs = [ (torch.randn(B, T, E, device=device, dtype=dtype), torch.randn(B, T, E, device=device, dtype=dtype)) for _ in range(10) ] + device = torch.device("cuda") + train_pairs = [ + ( + torch.randn(B, T, E, device=device, dtype=dtype), + torch.randn(B, T, E, device=device, dtype=dtype), + ) + for _ in range(10) + ] from optim import Eve + optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04) m = m.to(device).to(dtype) - start = timeit.default_timer() -# Epoch 0, batch 0, loss 1.0109944343566895 -# Epoch 10, batch 0, loss 1.0146660804748535 -# Epoch 20, batch 0, loss 1.0119813680648804 -# Epoch 30, batch 0, loss 1.0105408430099487 -# Epoch 40, batch 0, loss 1.0077732801437378 -# Epoch 50, batch 0, loss 1.0050103664398193 -# Epoch 60, batch 0, loss 1.0033129453659058 -# Epoch 70, batch 0, loss 1.0014232397079468 -# Epoch 80, batch 0, loss 0.9977912306785583 -# Epoch 90, batch 0, loss 0.8274348974227905 -# Epoch 100, batch 0, loss 0.3368612825870514 -# Epoch 110, batch 0, loss 0.11323091387748718 -# Time taken: 17.591704960912466 + # Epoch 0, batch 0, loss 1.0109944343566895 + # Epoch 10, batch 0, loss 1.0146660804748535 + # Epoch 20, batch 0, loss 1.0119813680648804 + # Epoch 30, batch 0, loss 1.0105408430099487 + # Epoch 40, batch 0, loss 1.0077732801437378 + # Epoch 50, batch 0, loss 1.0050103664398193 + # Epoch 60, batch 0, loss 1.0033129453659058 + # Epoch 70, batch 0, loss 1.0014232397079468 + # Epoch 80, batch 0, loss 0.9977912306785583 + # Epoch 90, batch 0, loss 0.8274348974227905 + # Epoch 100, batch 0, loss 0.3368612825870514 + # Epoch 110, batch 0, loss 0.11323091387748718 + # Time taken: 17.591704960912466 for epoch in range(150): - for n, (x,y) in enumerate(train_pairs): + for n, (x, y) in enumerate(train_pairs): y_out = m(x) - loss = ((y_out - y)**2).mean() * 100.0 + loss = ((y_out - y) ** 2).mean() * 100.0 if n % 10 == 0 and epoch % 10 == 0: print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") loss.backward() @@ -276,7 +296,8 @@ def _test_knowledge_base_lookup(): optimizer.zero_grad() stop = timeit.default_timer() - print('Time taken: ', stop - start) + print("Time taken: ", stop - start) + def _test_knowledge_base_lookup_autocast(): K = 16 @@ -294,14 +315,18 @@ def _test_knowledge_base_lookup_autocast(): x.requires_grad = True y = m(x) assert y.shape == x.shape - y.sum().backward() # make sure backward doesn't crash.. + y.sum().backward() # make sure backward doesn't crash.. print("y = ", y) print("x.grad = ", x.grad) print("knowlege_base.grad norm = ", knowledge_base.grad.norm()) - device = torch.device('cuda') - train_pairs = [ (torch.randn(B, T, E, device=device), torch.randn(B, T, E, device=device)) for _ in range(10) ] + device = torch.device("cuda") + train_pairs = [ + (torch.randn(B, T, E, device=device), torch.randn(B, T, E, device=device)) + for _ in range(10) + ] from optim import Eve + optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04) m = m.to(device) @@ -309,12 +334,11 @@ def _test_knowledge_base_lookup_autocast(): start = timeit.default_timer() - for epoch in range(150): - for n, (x,y) in enumerate(train_pairs): + for n, (x, y) in enumerate(train_pairs): y_out = m(x) with torch.cuda.amp.autocast(enabled=True): - loss = ((y_out - y)**2).mean() * 100.0 + loss = ((y_out - y) ** 2).mean() * 100.0 if n % 10 == 0 and epoch % 10 == 0: print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") scaler.scale(loss).backward() @@ -323,10 +347,9 @@ def _test_knowledge_base_lookup_autocast(): optimizer.zero_grad() stop = timeit.default_timer() - print('Time taken: ', stop - start) + print("Time taken: ", stop - start) - -if __name__ == '__main__': +if __name__ == "__main__": _test_knowledge_base_lookup() _test_knowledge_base_lookup_autocast() diff --git a/egs/librispeech/ASR/pruned2_knowledge/scaling.py b/egs/librispeech/ASR/pruned2_knowledge/scaling.py index f726c2583..527c735eb 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/scaling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/scaling.py @@ -18,11 +18,11 @@ import collections from itertools import repeat from typing import Optional, Tuple -from torch.cuda.amp import custom_fwd, custom_bwd import torch import torch.nn as nn from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd def _ntuple(n): @@ -79,9 +79,7 @@ class ActivationBalancerFunction(torch.autograd.Function): below_threshold = mean_abs < min_abs above_threshold = mean_abs > max_abs - ctx.save_for_backward( - factor, xgt0, below_threshold, above_threshold - ) + ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) ctx.max_factor = max_factor ctx.sum_dims = sum_dims return x @@ -149,8 +147,7 @@ class BasicNorm(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels scales = ( - torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) - + self.eps.exp() + torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps.exp() ) ** -0.5 return x * scales @@ -182,11 +179,7 @@ class ScaledLinear(nn.Linear): """ def __init__( - self, - *args, - initial_scale: float = 1.0, - initial_speed: float = 1.0, - **kwargs + self, *args, initial_scale: float = 1.0, initial_speed: float = 1.0, **kwargs ): super(ScaledLinear, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() @@ -202,12 +195,12 @@ class ScaledLinear(nn.Linear): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -218,19 +211,13 @@ class ScaledLinear(nn.Linear): return None if self.bias is None else self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear( - input, self.get_weight(), self.get_bias() - ) + return torch.nn.functional.linear(input, self.get_weight(), self.get_bias()) class ScaledConv1d(nn.Conv1d): # See docs for ScaledLinear def __init__( - self, - *args, - initial_scale: float = 1.0, - initial_speed: float = 1.0, - **kwargs + self, *args, initial_scale: float = 1.0, initial_speed: float = 1.0, **kwargs ): super(ScaledConv1d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() @@ -245,12 +232,12 @@ class ScaledConv1d(nn.Conv1d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -290,11 +277,7 @@ class ScaledConv1d(nn.Conv1d): class ScaledConv2d(nn.Conv2d): # See docs for ScaledLinear def __init__( - self, - *args, - initial_scale: float = 1.0, - initial_speed: float = 1.0, - **kwargs + self, *args, initial_scale: float = 1.0, initial_speed: float = 1.0, **kwargs ): super(ScaledConv2d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() @@ -309,12 +292,12 @@ class ScaledConv2d(nn.Conv2d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -653,9 +636,7 @@ def _test_activation_balancer_sign(): def _test_activation_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( - -1 - ) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True m = ActivationBalancer( @@ -685,8 +666,8 @@ def _test_basic_norm(): y = m(x) assert y.shape == x.shape - x_rms = (x ** 2).mean().sqrt() - y_rms = (y ** 2).mean().sqrt() + x_rms = (x**2).mean().sqrt() + y_rms = (y**2).mean().sqrt() print("x rms = ", x_rms) print("y rms = ", y_rms) assert y_rms < x_rms diff --git a/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py b/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py index 6293e081a..3f21133a0 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py +++ b/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py @@ -15,21 +15,23 @@ # limitations under the License. +from typing import Optional, Tuple + import torch import torch.nn as nn from torch import Tensor -from typing import Tuple, Optional - -def _activation_balancer_loss(mean_pos: Tensor, - mean_neg: Tensor, - min_positive: float, # e.g. 0.05 - max_positive: float, # e.g. 0.95 - max_factor: float, # e.g. 0.01 - min_abs: float, # e.g. 0.2 - max_abs: float, # e.g. 100.0 - eps: float = 1.0e-10): +def _activation_balancer_loss( + mean_pos: Tensor, + mean_neg: Tensor, + min_positive: float, # e.g. 0.05 + max_positive: float, # e.g. 0.95 + max_factor: float, # e.g. 0.01 + min_abs: float, # e.g. 0.2 + max_abs: float, # e.g. 100.0 + eps: float = 1.0e-10, +): """ Returns a loss-function for the ActivationBalancer module. This loss function is not exposed to the user but is used internally, and eventually @@ -50,28 +52,32 @@ def _activation_balancer_loss(mean_pos: Tensor, """ loss_parts = [] - x_mean = mean_positive - mean_negative - x_mean_abs = (mean_positive + mean_negative + eps).detach() - x_rel_mean= x_mean / x_mean_abs + x_mean = mean_pos - mean_neg + x_mean_abs = (mean_pos + mean_neg + eps).detach() + x_rel_mean = x_mean / x_mean_abs if min_positive != 0.0: # e.g. x_mean_floor = -0.95 + 0.05 = -0.9 - x_rel_mean_floor = (-(1-min_positive) + min_positive) - min_positive_loss = (x_rel_mean_floor - x_rel_mean).relu().sum() * (1.0 / (2*min_positive)) + x_rel_mean_floor = -(1 - min_positive) + min_positive + min_positive_loss = (x_rel_mean_floor - x_rel_mean).relu().sum() * ( + 1.0 / (2 * min_positive) + ) # this part of the loss would be 1.0 * num_channels if all these constraints were # 100% violated. loss_parts.append(min_positive_loss) if max_positive != 1.0: # e.g. x_mean_floor = -0.05 + 0.95 = 0.8 - x_rel_mean_ceil = - (1.0-max_positive) + max_positive - max_positive_loss = (x_rel_mean - x_rel_mean_ceil).relu().sum() * (1.0 / (1 - x_rel_mean_ceil)) + x_rel_mean_ceil = -(1.0 - max_positive) + max_positive + max_positive_loss = (x_rel_mean - x_rel_mean_ceil).relu().sum() * ( + 1.0 / (1 - x_rel_mean_ceil) + ) # this part of the loss would be 1.0 * num_channels if all these constraints were # 100% violated. loss_parts.append(max_positive_loss) if min_abs != 0.0: - min_abs_loss = min_abs - x_mean_abs).relu().sum() / min_abs + min_abs_loss = (min_abs - x_mean_abs).relu().sum() / min_abs # this part of the loss would be 1.0 * num_channels if all these constraints were # 100% violated. loss_parts.append(min_abs_loss) @@ -82,43 +88,53 @@ def _activation_balancer_loss(mean_pos: Tensor, # 100% violated. loss_parts.append(max_abs_loss) - # the min_positive and 1 - max_positive are "ballast" added to the denom = mean_pos + mean_neg + (min_positive + (1 - max_positive)) - num + # num if min_positive != 0.0: - - + pass class ActivationBalancerFunction(torch.autograd.Function): @staticmethod - def forward(ctx, x: Tensor, - channel_dim: int, - min_positive: float, # e.g. 0.05 - max_positive: float, # e.g. 0.95 - max_factor: float, # e.g. 0.01 - min_abs: float, # e.g. 0.2 - max_abs: float, # e.g. 100.0 + def forward( + ctx, + x: Tensor, + channel_dim: int, + min_positive: float, # e.g. 0.05 + max_positive: float, # e.g. 0.95 + max_factor: float, # e.g. 0.01 + min_abs: float, # e.g. 0.2 + max_abs: float, # e.g. 100.0 ) -> Tensor: if x.requires_grad: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] xgt0 = x > 0 - proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True) - factor1 = ((min_positive - proportion_positive).relu() * (max_factor / min_positive) - if min_positive != 0.0 else 0.0) - factor2 = ((proportion_positive - max_positive).relu() * (max_factor / (max_positive - 1.0)) - if max_positive != 1.0 else 0.0) + proportion_positive = torch.mean( + xgt0.to(x.dtype), dim=sum_dims, keepdim=True + ) + factor1 = ( + (min_positive - proportion_positive).relu() + * (max_factor / min_positive) + if min_positive != 0.0 + else 0.0 + ) + factor2 = ( + (proportion_positive - max_positive).relu() + * (max_factor / (max_positive - 1.0)) + if max_positive != 1.0 + else 0.0 + ) factor = factor1 + factor2 if isinstance(factor, float): factor = torch.zeros_like(proportion_positive) mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True) - below_threshold = (mean_abs < min_abs) - above_threshold = (mean_abs > max_abs) + below_threshold = mean_abs < min_abs + above_threshold = mean_abs > max_abs ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) ctx.max_factor = max_factor @@ -126,11 +142,16 @@ class ActivationBalancerFunction(torch.autograd.Function): return x @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None, None]: + def backward( + ctx, x_grad: Tensor + ) -> Tuple[Tensor, None, None, None, None, None, None]: factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors dtype = x_grad.dtype - scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) * - (xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0)) + scale_factor = ( + (below_threshold.to(dtype) - above_threshold.to(dtype)) + * (xgt0.to(dtype) - 0.5) + * (ctx.max_factor * 2.0) + ) neg_delta_grad = x_grad.abs() * (factor + scale_factor) return x_grad - neg_delta_grad, None, None, None, None, None, None @@ -163,29 +184,30 @@ class BasicNorm(torch.nn.Module): learn_eps: if true, we learn epsilon; if false, we keep it at the initial value. """ - def __init__(self, - num_channels: int, - channel_dim: int = -1, # CAUTION: see documentation. - eps: float = 0.25, - learn_eps: bool = True) -> None: + + def __init__( + self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + eps: float = 0.25, + learn_eps: bool = True, + ) -> None: super(BasicNorm, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim if learn_eps: self.eps = nn.Parameter(torch.tensor(eps).log().detach()) else: - self.register_buffer('eps', torch.tensor(eps).log().detach()) - + self.register_buffer("eps", torch.tensor(eps).log().detach()) def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels - scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) + - self.eps.exp()) ** -0.5 + scales = ( + torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps.exp() + ) ** -0.5 return x * scales - - class ScaledLinear(nn.Linear): """ A modified version of nn.Linear where the parameters are scaled before @@ -207,27 +229,26 @@ class ScaledLinear(nn.Linear): inherited from nn.Linear. For modules with small fan-in, this may be larger than optimal. """ - def __init__(self, *args, - initial_scale: float = 1.0, - **kwargs): + + def __init__(self, *args, initial_scale: float = 1.0, **kwargs): super(ScaledLinear, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: - self.register_parameter('bias_scale', None) + self.register_parameter("bias_scale", None) self._reset_parameters() # Overrides the reset_parameters in nn.Linear def _reset_parameters(self): std = 0.01 - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() if self.bias is not None: @@ -237,56 +258,67 @@ class ScaledLinear(nn.Linear): return self.weight * self.weight_scale.exp() def get_bias(self): - return (None if self.bias is None else - self.bias * self.bias_scale.exp()) + return None if self.bias is None else self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear(input, self.get_weight(), - self.get_bias()) + return torch.nn.functional.linear(input, self.get_weight(), self.get_bias()) class ScaledConv1d(nn.Conv1d): - def __init__(self, *args, - initial_scale=1.0, **kwargs): + def __init__(self, *args, initial_scale=1.0, **kwargs): super(ScaledConv1d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: - self.register_parameter('bias_scale', None) + self.register_parameter("bias_scale", None) self._reset_parameters() # Overrides the reset_parameters in base class def _reset_parameters(self): std = 0.01 - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() if self.bias is not None: self.bias_scale += torch.tensor(scale / std).log() - def get_weight(self): return self.weight * self.weight_scale.exp() def get_bias(self): - return (None if self.bias is None else - self.bias * self.bias_scale.exp()) + return None if self.bias is None else self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: F = torch.nn.functional - if self.padding_mode != 'zeros': - return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), - self.get_weight(), self.get_bias(), self.stride, - _single(0), self.dilation, self.groups) - return F.conv1d(input, self.get_weight(), self.get_bias(), self.stride, - self.padding, self.dilation, self.groups) - + if self.padding_mode != "zeros": + return F.conv1d( + F.pad( + input, + self._reversed_padding_repeated_twice, + mode=self.padding_mode, + ), + self.get_weight(), + self.get_bias(), + self.stride, + _single(0), # noqa: F821 + self.dilation, + self.groups, + ) + return F.conv1d( + input, + self.get_weight(), + self.get_bias(), + self.stride, + self.padding, + self.dilation, + self.groups, + ) class ScaledConv2d(nn.Conv2d): @@ -297,45 +329,58 @@ class ScaledConv2d(nn.Conv2d): if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: - self.register_parameter('bias_scale', None) + self.register_parameter("bias_scale", None) self._reset_parameters() # Overrides the reset_parameters in base class def _reset_parameters(self): std = 0.01 - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() if self.bias is not None: self.bias_scale += torch.tensor(scale / std).log() - def get_weight(self): return self.weight * self.weight_scale.exp() def get_bias(self): - return (None if self.bias is None else - self.bias * self.bias_scale.exp()) + return None if self.bias is None else self.bias * self.bias_scale.exp() def _conv_forward(self, input, weight): F = torch.nn.functional - if self.padding_mode != 'zeros': - return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), - weight, self.get_bias(), self.stride, - _pair(0), self.dilation, self.groups) - return F.conv2d(input, weight, self.get_bias(), self.stride, - self.padding, self.dilation, self.groups) + if self.padding_mode != "zeros": + return F.conv2d( + F.pad( + input, + self._reversed_padding_repeated_twice, + mode=self.padding_mode, + ), + weight, + self.get_bias(), + self.stride, + _pair(0), # noqa: F821 + self.dilation, + self.groups, + ) + return F.conv2d( + input, + weight, + self.get_bias(), + self.stride, + self.padding, + self.dilation, + self.groups, + ) def forward(self, input: Tensor) -> Tensor: return self._conv_forward(input, self.get_weight()) - - class ActivationBalancer(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to encourage, for @@ -364,12 +409,16 @@ class ActivationBalancer(torch.nn.Module): we allow, before we start to modify the derivatives to prevent this. """ - def __init__(self, channel_dim: int, - min_positive: float = 0.05, - max_positive: float = 0.95, - max_factor: float = 0.01, - min_abs: float = 0.2, - max_abs: float = 100.0): + + def __init__( + self, + channel_dim: int, + min_positive: float = 0.05, + max_positive: float = 0.95, + max_factor: float = 0.01, + min_abs: float = 0.2, + max_abs: float = 100.0, + ): super(ActivationBalancer, self).__init__() self.channel_dim = channel_dim self.min_positive = min_positive @@ -379,10 +428,15 @@ class ActivationBalancer(torch.nn.Module): self.max_abs = max_abs def forward(self, x: Tensor) -> Tensor: - return ActivationBalancerFunction.apply(x, self.channel_dim, - self.min_positive, self.max_positive, - self.max_factor, self.min_abs, - self.max_abs) + return ActivationBalancerFunction.apply( + x, + self.channel_dim, + self.min_positive, + self.max_positive, + self.max_factor, + self.min_abs, + self.max_abs, + ) class DoubleSwishFunction(torch.autograd.Function): @@ -400,6 +454,7 @@ class DoubleSwishFunction(torch.autograd.Function): = double_swish(x) * (1-s(x)) + s(x) ... so we just need to remember s(x) but not x itself. """ + @staticmethod def forward(ctx, x: Tensor) -> Tensor: x = x.detach() @@ -411,18 +466,17 @@ class DoubleSwishFunction(torch.autograd.Function): @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: s, y = ctx.saved_tensors - return (y * (1-s) + s) * y_grad + return (y * (1 - s) + s) * y_grad + class DoubleSwish(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: """Return double-swish activation function which is an approximation to Swish(Swish(x)), - that we approximate closely with x * sigmoid(x-1). + that we approximate closely with x * sigmoid(x-1). """ return DoubleSwishFunction.apply(x) - - class ScaledEmbedding(nn.Module): r"""A simple lookup table that stores embeddings of a fixed dictionary and size. @@ -491,8 +545,13 @@ class ScaledEmbedding(nn.Module): [ 0.0000, 0.0000, 0.0000], [-0.1655, 0.9897, 0.0635]]]) """ - __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', - 'scale_grad_by_freq', 'sparse'] + __constants__ = [ + "num_embeddings", + "embedding_dim", + "padding_idx", + "scale_grad_by_freq", + "sparse", + ] num_embeddings: int embedding_dim: int @@ -501,33 +560,40 @@ class ScaledEmbedding(nn.Module): weight: Tensor sparse: bool - def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, - scale_grad_by_freq: bool = False, - sparse: bool = False) -> None: + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + scale_grad_by_freq: bool = False, + sparse: bool = False, + ) -> None: super(ScaledEmbedding, self).__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim if padding_idx is not None: if padding_idx > 0: - assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert ( + padding_idx < self.num_embeddings + ), "Padding_idx must be within num_embeddings" elif padding_idx < 0: - assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert ( + padding_idx >= -self.num_embeddings + ), "Padding_idx must be within num_embeddings" padding_idx = self.num_embeddings + padding_idx self.padding_idx = padding_idx self.scale_grad_by_freq = scale_grad_by_freq - self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() + self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() self.sparse = sparse self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) self.reset_parameters() - - def reset_parameters(self) -> None: std = 0.01 nn.init.normal_(self.weight, std=std) - nn.init.constant_(self.scale, torch.tensor(1.0/std).log()) + nn.init.constant_(self.scale, torch.tensor(1.0 / std).log()) if self.padding_idx is not None: with torch.no_grad(): @@ -537,24 +603,37 @@ class ScaledEmbedding(nn.Module): F = torch.nn.functional scale = self.scale.exp() if input.numel() < self.num_embeddings: - return F.embedding( - input, self.weight, self.padding_idx, - None, 2.0, # None, 2.0 relate to normalization - self.scale_grad_by_freq, self.sparse) * scale + return ( + F.embedding( + input, + self.weight, + self.padding_idx, + None, + 2.0, # None, 2.0 relate to normalization + self.scale_grad_by_freq, + self.sparse, + ) + * scale + ) else: return F.embedding( - input, self.weight * scale, self.padding_idx, - None, 2.0, # None, 2.0 relates to normalization - self.scale_grad_by_freq, self.sparse) + input, + self.weight * scale, + self.padding_idx, + None, + 2.0, # None, 2.0 relates to normalization + self.scale_grad_by_freq, + self.sparse, + ) def extra_repr(self) -> str: - s = '{num_embeddings}, {embedding_dim}, scale={scale}' + s = "{num_embeddings}, {embedding_dim}, scale={scale}" if self.padding_idx is not None: - s += ', padding_idx={padding_idx}' + s += ", padding_idx={padding_idx}" if self.scale_grad_by_freq is not False: - s += ', scale_grad_by_freq={scale_grad_by_freq}' + s += ", scale_grad_by_freq={scale_grad_by_freq}" if self.sparse is not False: - s += ', sparse=True' + s += ", sparse=True" return s.format(**self.__dict__) @@ -565,8 +644,13 @@ def _test_activation_balancer_sign(): x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) x = x.detach() x.requires_grad = True - m = ActivationBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95, - max_factor=0.2, min_abs=0.0) + m = ActivationBalancer( + channel_dim=0, + min_positive=0.05, + max_positive=0.95, + max_factor=0.2, + min_abs=0.0, + ) y_grad = torch.sign(torch.randn(probs.numel(), N)) @@ -576,17 +660,22 @@ def _test_activation_balancer_sign(): print("_test_activation_balancer_sign: y grad = ", y_grad) print("_test_activation_balancer_sign: x grad = ", x.grad) + def _test_activation_balancer_magnitude(): channel_dim = 0 magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True - m = ActivationBalancer(channel_dim=0, - min_positive=0.0, max_positive=1.0, - max_factor=0.2, - min_abs=0.2, max_abs=0.8) + m = ActivationBalancer( + channel_dim=0, + min_positive=0.0, + max_positive=1.0, + max_factor=0.2, + min_abs=0.2, + max_abs=0.8, + ) y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) @@ -621,7 +710,7 @@ def _test_double_swish_deriv(): torch.autograd.gradcheck(m, x) -if __name__ == '__main__': +if __name__ == "__main__": _test_activation_balancer_sign() _test_activation_balancer_magnitude() _test_basic_norm() diff --git a/egs/librispeech/ASR/pruned2_knowledge/train.py b/egs/librispeech/ASR/pruned2_knowledge/train.py index 2f6840166..c322abaf8 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/train.py +++ b/egs/librispeech/ASR/pruned2_knowledge/train.py @@ -78,9 +78,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def get_parser(): @@ -179,8 +177,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -203,8 +200,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -554,23 +550,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -733,9 +722,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -835,7 +822,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py index 2d5724d30..891719f3d 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py @@ -204,8 +204,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -272,9 +271,7 @@ def decode_one_batch( value=LOG_EPS, ) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -289,10 +286,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -415,9 +409,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -450,8 +442,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -494,9 +485,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -528,9 +517,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -557,9 +546,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py index 318cd5094..008f40fb1 100644 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py @@ -272,13 +272,9 @@ class Emformer(EncoderInterface): # Caution: We assume the subsampling factor is 4! x_lens = (((x_lens - 1) >> 1) - 1) >> 1 - emformer_out, emformer_out_lens, states = self.model.infer( - x, x_lens, states - ) + emformer_out, emformer_out_lens, states = self.model.infer(x, x_lens, states) - if x.size(1) != ( - self.model.segment_length + self.model.right_context_length - ): + if x.size(1) != (self.model.segment_length + self.model.right_context_length): raise ValueError( "Incorrect input shape." f"{x.size(1)} vs {self.model.segment_length} + " diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py index 2375f5001..047a1d476 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py @@ -133,8 +133,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) add_model_arguments(parser) @@ -170,9 +169,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -199,9 +198,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -273,9 +272,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py index 2f019bcdb..ed6848879 100644 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py @@ -122,9 +122,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py index fed814f19..69e74cc57 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py @@ -209,8 +209,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -233,8 +232,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -566,11 +564,7 @@ def compute_loss( function enables autograd during computation; when it is False, it disables autograd. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -599,9 +593,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -782,9 +774,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -908,8 +898,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py index 7af9cc3d7..830b37cfb 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py @@ -509,9 +509,9 @@ def greedy_search( y = logits.argmax().item() if y not in (blank_id, unk_id): hyp.append(y) - decoder_input = torch.tensor( - [hyp[-context_size:]], device=device - ).reshape(1, context_size) + decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( + 1, context_size + ) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -670,9 +670,7 @@ class HypothesisList(object): if use_max: old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob) else: - torch.logaddexp( - old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob - ) + torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) else: self._data[key] = hyp @@ -688,9 +686,7 @@ class HypothesisList(object): Return the hypothesis that has the largest `log_prob`. """ if length_norm: - return max( - self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) - ) + return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -892,9 +888,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1088,9 +1082,7 @@ def beam_search( t = 0 B = HypothesisList() - B.add( - Hypothesis(ys=[blank_id] * context_size, log_prob=0.0), use_max=use_max - ) + B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0), use_max=use_max) max_sym_per_utt = 20000 @@ -1130,9 +1122,7 @@ def beam_search( cached_key += f"-t-{t}" if cached_key not in joint_cache: - logits = model.joiner( - current_encoder_out, decoder_out.unsqueeze(1) - ) + logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1)) # TODO(fangjun): Scale the blank posterior diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 7b6338948..12bd7f9bb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -128,11 +128,7 @@ from beam_search import ( ) from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -269,8 +265,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -383,9 +378,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if ( @@ -450,10 +443,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -584,9 +574,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -619,8 +607,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -678,9 +665,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -718,8 +703,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -757,9 +741,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py index 386248554..e522943c0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py @@ -75,9 +75,7 @@ class DecodeStream(object): # encoder.streaming_forward self.done_frames: int = 0 - self.pad_length = ( - params.right_context + 2 - ) * params.subsampling_factor + 3 + self.pad_length = (params.right_context + 2) * params.subsampling_factor + 3 if params.decoding_method == "greedy_search": self.hyp = [params.blank_id] * params.context_size @@ -91,13 +89,11 @@ class DecodeStream(object): ) elif params.decoding_method == "fast_beam_search": # The rnnt_decoding_stream for fast_beam_search. - self.rnnt_decoding_stream: k2.RnntDecodingStream = ( - k2.RnntDecodingStream(decoding_graph) + self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( + decoding_graph ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") @property def done(self) -> bool: @@ -126,13 +122,10 @@ class DecodeStream(object): """Consume chunk_size frames of features""" chunk_length = chunk_size + self.pad_length - ret_length = min( - self.num_frames - self.num_processed_frames, chunk_length - ) + ret_length = min(self.num_frames - self.num_processed_frames, chunk_length) ret_features = self.features[ - self.num_processed_frames : self.num_processed_frames # noqa - + ret_length + self.num_processed_frames : self.num_processed_frames + ret_length # noqa ] self.num_processed_frames += chunk_size diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py index f4355e8a0..72593173c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py @@ -92,9 +92,7 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/export.py b/egs/librispeech/ASR/pruned_transducer_stateless/export.py index b5a151878..be45536d8 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/export.py @@ -105,8 +105,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -192,9 +191,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/model.py b/egs/librispeech/ASR/pruned_transducer_stateless/model.py index 73b651b3f..2cca7fa27 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/model.py @@ -130,9 +130,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py index eb95827af..6e91e0501 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py @@ -168,8 +168,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -222,8 +221,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -292,9 +290,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -381,9 +377,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py index dcf6dc42f..9e09200a1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py @@ -166,14 +166,10 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk( - num_active_paths - ) + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths) with warnings.catch_warnings(): warnings.simplefilter("ignore") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py index d2cae4f9f..ce8e2f348 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py @@ -51,11 +51,7 @@ from streaming_beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -162,8 +158,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -269,9 +264,7 @@ def decode_one_chunk( ) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -291,9 +284,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -349,9 +340,7 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device - ) + initial_states = model.encoder.get_init_state(params.left_context, device=device) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -422,9 +411,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -460,8 +447,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -533,8 +519,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index 399b11a29..7861df874 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -203,8 +203,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -227,8 +226,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -562,9 +560,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -584,9 +580,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -777,9 +771,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -897,8 +889,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}" ) return False @@ -956,9 +947,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index b7c2010f7..5e9428b60 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -580,9 +580,9 @@ def greedy_search( if y not in (blank_id, unk_id): hyp.append(y) timestamp.append(t) - decoder_input = torch.tensor( - [hyp[-context_size:]], device=device - ).reshape(1, context_size) + decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( + 1, context_size + ) decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.joiner.decoder_proj(decoder_out) @@ -775,9 +775,7 @@ class HypothesisList(object): key = hyp.key if key in self: old_hyp = self._data[key] # shallow copy - torch.logaddexp( - old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob - ) + torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) else: self._data[key] = hyp @@ -793,9 +791,7 @@ class HypothesisList(object): Return the hypothesis that has the largest `log_prob`. """ if length_norm: - return max( - self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) - ) + return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -990,9 +986,7 @@ def modified_beam_search( logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - log_probs = (logits / temperature).log_softmax( - dim=-1 - ) # (num_hyps, vocab_size) + log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) log_probs.add_(ys_log_probs) @@ -1004,9 +998,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1676,9 +1668,7 @@ def fast_beam_search_with_nbest_rnn_rescoring( for rnn_scale in rnn_lm_scale_list: key = f"ngram_lm_scale_{n_scale}_rnn_lm_scale_{rnn_scale}" tot_scores = ( - am_scores.values - + n_scale * ngram_lm_scores - + rnn_scale * rnn_lm_scores + am_scores.values + n_scale * ngram_lm_scores + rnn_scale * rnn_lm_scores ) ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) max_indexes = ragged_tot_scores.argmax() @@ -1804,9 +1794,7 @@ def modified_beam_search_ngram_rescoring( logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - log_probs = (logits / temperature).log_softmax( - dim=-1 - ) # (num_hyps, vocab_size) + log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) log_probs.add_(ys_log_probs) vocab_size = log_probs.size(-1) @@ -1816,9 +1804,7 @@ def modified_beam_search_ngram_rescoring( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1841,9 +1827,7 @@ def modified_beam_search_ngram_rescoring( state_cost = hyp.state_cost # We only keep AM scores in new_hyp.log_prob - new_log_prob = ( - topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale - ) + new_log_prob = topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale new_hyp = Hypothesis( ys=new_ys, log_prob=new_log_prob, state_cost=state_cost @@ -1995,9 +1979,7 @@ def modified_beam_search_rnnlm_shallow_fusion( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) """ for all hyps with a non-blank new token, score this token. It is a little confusing here because this for-loop @@ -2032,10 +2014,7 @@ def modified_beam_search_rnnlm_shallow_fusion( # forward RNNLM to get new states and scores if len(token_list) != 0: tokens_to_score = ( - torch.tensor(token_list) - .to(torch.int64) - .to(device) - .reshape(-1, 1) + torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) ) hs = torch.cat(hs, dim=1).to(device) @@ -2067,9 +2046,7 @@ def modified_beam_search_rnnlm_shallow_fusion( ys.append(new_token) new_timestamp.append(t) - hyp_log_prob += ( - lm_score[new_token] * lm_scale - ) # add the lm score + hyp_log_prob += lm_score[new_token] * lm_scale # add the lm score lm_score = scores[count] state = ( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index bc273d33b..f94ffef59 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -214,10 +214,7 @@ class Conformer(EncoderInterface): NOTE: the returned tensors are on the given device. """ - if ( - len(self._init_state) == 2 - and self._init_state[0].size(1) == left_context - ): + if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context: # Note: It is OK to share the init state as it is # not going to be modified by the model return self._init_state @@ -439,9 +436,7 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -459,9 +454,7 @@ class ConformerEncoderLayer(nn.Module): ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) - self.conv_module = ConvolutionModule( - d_model, cnn_module_kernel, causal=causal - ) + self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal) self.norm_final = BasicNorm(d_model) @@ -527,9 +520,7 @@ class ConformerEncoderLayer(nn.Module): src = src + self.dropout(src_att) # convolution module - conv, _ = self.conv_module( - src, src_key_padding_mask=src_key_padding_mask - ) + conv, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask) src = src + self.dropout(conv) # feed forward module @@ -785,9 +776,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() if is_jit_tracing(): @@ -811,9 +800,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x_size_1 * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -1127,9 +1114,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -1198,31 +1185,22 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -1264,23 +1242,15 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d - matrix_bd = torch.matmul( - q_with_bias_v, p - ) # (batch, head, time1, 2*time1-1) + matrix_bd = torch.matmul(q_with_bias_v, p) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd, left_context) - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) if not is_jit_tracing(): assert list(attn_output_weights.size()) == [ @@ -1322,21 +1292,17 @@ class RelPositionMultiheadAttention(nn.Module): ): if attn_mask.size(0) != 1: attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len) - combined_mask = attn_mask | key_padding_mask.unsqueeze( - 1 - ).unsqueeze(2) + combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) else: # attn_mask.shape == (1, tgt_len, src_len) - combined_mask = attn_mask.unsqueeze( - 0 - ) | key_padding_mask.unsqueeze(1).unsqueeze(2) + combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( + 1 + ).unsqueeze(2) attn_output_weights = attn_output_weights.view( bsz, num_heads, tgt_len, src_len ) - attn_output_weights = attn_output_weights.masked_fill( - combined_mask, 0.0 - ) + attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, src_len ) @@ -1355,13 +1321,9 @@ class RelPositionMultiheadAttention(nn.Module): ] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -1498,16 +1460,12 @@ class ConvolutionModule(nn.Module): # manualy padding self.lorder zeros to the left x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) else: - assert ( - not self.training - ), "Cache should be None in training time" + assert not self.training, "Cache should be None in training time" assert cache.size(0) == self.lorder x = torch.cat([cache.permute(1, 2, 0), x], dim=2) if right_context > 0: cache = x.permute(2, 0, 1)[ - -(self.lorder + right_context) : ( # noqa - -right_context - ), + -(self.lorder + right_context) : (-right_context), # noqa ..., ] else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 979a0e02e..92138a5ea 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -132,11 +132,7 @@ from beam_search import ( ) from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -275,8 +271,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -397,9 +392,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -465,10 +458,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -608,9 +598,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -643,8 +631,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -700,9 +687,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -740,8 +725,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -779,9 +763,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index ba91302ce..b59928103 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -107,15 +107,11 @@ class Decoder(nn.Module): # This is for exporting to PNNX via ONNX embedding_out = self.embedding(y) else: - embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze( - -1 - ) + embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py index f1a8ea589..4f1170bbc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py @@ -51,11 +51,7 @@ import sentencepiece as spm import torch from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import str2bool @@ -120,8 +116,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -173,8 +168,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -222,9 +216,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py index 6a9d08033..1954f4724 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -60,9 +60,7 @@ class Joiner(nn.Module): assert encoder_out.shape == decoder_out.shape if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj( - decoder_out - ) + logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 417c391d9..272d06c37 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -66,9 +66,7 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) def forward( @@ -152,9 +150,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index 041a81f45..2d7f557ad 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -72,17 +72,11 @@ class Eve(Optimizer): if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: - raise ValueError( - "Invalid beta parameter at index 0: {}".format(betas[0]) - ) + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: - raise ValueError( - "Invalid beta parameter at index 1: {}".format(betas[1]) - ) + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) if not 0 <= weight_decay <= 0.1: - raise ValueError( - "Invalid weight_decay value: {}".format(weight_decay) - ) + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict( @@ -118,9 +112,7 @@ class Eve(Optimizer): # Perform optimization step grad = p.grad if grad.is_sparse: - raise RuntimeError( - "AdamW does not support sparse gradients" - ) + raise RuntimeError("AdamW does not support sparse gradients") state = self.state[p] @@ -147,7 +139,7 @@ class Eve(Optimizer): # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( + denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( group["eps"] ) @@ -158,9 +150,7 @@ class Eve(Optimizer): if p.numel() > 1: # avoid applying this weight-decay on "scaling factors" # (which are scalar). - is_above_target_rms = p.norm() > ( - target_rms * (p.numel() ** 0.5) - ) + is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) @@ -180,18 +170,14 @@ class LRScheduler(object): def __init__(self, optimizer: Optimizer, verbose: bool = False): # Attach optimizer if not isinstance(optimizer, Optimizer): - raise TypeError( - "{} is not an Optimizer".format(type(optimizer).__name__) - ) + raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) self.optimizer = optimizer self.verbose = verbose for group in optimizer.param_groups: group.setdefault("initial_lr", group["lr"]) - self.base_lrs = [ - group["initial_lr"] for group in optimizer.param_groups - ] + self.base_lrs = [group["initial_lr"] for group in optimizer.param_groups] self.epoch = 0 self.batch = 0 @@ -299,10 +285,9 @@ class Eden(LRScheduler): def get_lr(self): factor = ( - (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 + (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 ) ** -0.25 * ( - ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) - ** -0.25 + ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 ) return [x * factor for x in self.base_lrs] diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py index f52cb22ab..e5b5aeba5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py @@ -168,8 +168,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -223,8 +222,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -293,9 +291,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -382,9 +378,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 8c572a9ef..c802ecf89 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -89,9 +89,7 @@ class ActivationBalancerFunction(torch.autograd.Function): below_threshold = mean_abs < min_abs above_threshold = mean_abs > max_abs - ctx.save_for_backward( - factor, xgt0, below_threshold, above_threshold - ) + ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) ctx.max_factor = max_factor ctx.sum_dims = sum_dims return x @@ -137,7 +135,7 @@ class GradientFilterFunction(torch.autograd.Function): eps = 1.0e-20 dim = ctx.batch_dim norm_dims = [d for d in range(x_grad.ndim) if d != dim] - norm_of_batch = (x_grad ** 2).mean(dim=norm_dims, keepdim=True).sqrt() + norm_of_batch = (x_grad**2).mean(dim=norm_dims, keepdim=True).sqrt() median_norm = norm_of_batch.median() cutoff = median_norm * ctx.threshold @@ -229,8 +227,7 @@ class BasicNorm(torch.nn.Module): if not is_jit_tracing(): assert x.shape[self.channel_dim] == self.num_channels scales = ( - torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) - + self.eps.exp() + torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps.exp() ) ** -0.5 return x * scales @@ -282,12 +279,12 @@ class ScaledLinear(nn.Linear): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -301,9 +298,7 @@ class ScaledLinear(nn.Linear): return self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear( - input, self.get_weight(), self.get_bias() - ) + return torch.nn.functional.linear(input, self.get_weight(), self.get_bias()) class ScaledConv1d(nn.Conv1d): @@ -331,12 +326,12 @@ class ScaledConv1d(nn.Conv1d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -400,12 +395,12 @@ class ScaledConv2d(nn.Conv2d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -476,9 +471,7 @@ class ScaledLSTM(nn.LSTM): setattr(self, scale_name, param) self._scales.append(param) - self.grad_filter = GradientFilter( - batch_dim=1, threshold=grad_norm_threshold - ) + self.grad_filter = GradientFilter(batch_dim=1, threshold=grad_norm_threshold) self._reset_parameters( initial_speed @@ -486,8 +479,8 @@ class ScaledLSTM(nn.LSTM): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std - scale = self.hidden_size ** -0.5 + a = (3**0.5) * std + scale = self.hidden_size**-0.5 v = scale / std for idx, name in enumerate(self._flat_weights_names): if "weight" in name: @@ -559,15 +552,11 @@ class ScaledLSTM(nn.LSTM): """Get scaled weights, and resets their data pointer.""" flat_weights = [] for idx in range(len(self._flat_weights_names)): - flat_weights.append( - self._flat_weights[idx] * self._scales[idx].exp() - ) + flat_weights.append(self._flat_weights[idx] * self._scales[idx].exp()) self._flatten_parameters(flat_weights) return flat_weights - def forward( - self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None - ): + def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None): # This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa # The change for calling `_VF.lstm()` is: # self._flat_weights -> self._get_flat_weights() @@ -915,9 +904,7 @@ def _test_activation_balancer_sign(): def _test_activation_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( - -1 - ) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True m = ActivationBalancer( @@ -947,8 +934,8 @@ def _test_basic_norm(): y = m(x) assert y.shape == x.shape - x_rms = (x ** 2).mean().sqrt() - y_rms = (y ** 2).mean().sqrt() + x_rms = (x**2).mean().sqrt() + y_rms = (y**2).mean().sqrt() print("x rms = ", x_rms) print("y rms = ", y_rms) assert y_rms < x_rms @@ -1007,11 +994,11 @@ def _test_grad_filter(): print( "_test_grad_filter: x_out_grad norm = ", - (x_out_grad ** 2).mean(dim=(0, 2)).sqrt(), + (x_out_grad**2).mean(dim=(0, 2)).sqrt(), ) print( "_test_grad_filter: x.grad norm = ", - (x.grad ** 2).mean(dim=(0, 2)).sqrt(), + (x.grad**2).mean(dim=(0, 2)).sqrt(), ) print("_test_grad_filter: w_out_grad = ", w_out_grad) print("_test_grad_filter: w.grad = ", w.grad) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py index 9bcd2f9f9..e6e0fb1c8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py @@ -153,9 +153,7 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner( - current_encoder_out, decoder_out, project_input=False - ) + logits = model.joiner(current_encoder_out, decoder_out, project_input=False) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -172,14 +170,10 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk( - num_active_paths - ) + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths) with warnings.catch_warnings(): warnings.simplefilter("ignore") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py index d76a03946..0eea3a782 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py @@ -51,11 +51,7 @@ from streaming_beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -162,8 +158,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -271,9 +266,7 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -293,9 +286,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -351,9 +342,7 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device - ) + initial_states = model.encoder.get_init_state(params.left_context, device=device) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -425,9 +414,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -462,8 +449,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -536,8 +522,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 1947834bf..f6702ef16 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -96,9 +96,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -210,8 +208,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need to " - "be changed.", + help="The initial learning rate. This value should not need to " "be changed.", ) parser.add_argument( @@ -234,8 +231,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -258,8 +254,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -634,9 +629,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -649,14 +642,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -667,9 +655,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -837,9 +823,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -963,8 +947,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py index 1df7f9ee5..b7735be85 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py @@ -27,10 +27,7 @@ from lhotse.dataset import ( K2SpeechRecognitionDataset, SpecAugment, ) -from lhotse.dataset.input_strategies import ( - OnTheFlyFeatures, - PrecomputedFeatures, -) +from lhotse.dataset.input_strategies import OnTheFlyFeatures, PrecomputedFeatures from torch.utils.data import DataLoader from icefall.utils import str2bool @@ -167,9 +164,7 @@ class AsrDataModule: if cuts_musan is not None: logging.info("Enable MUSAN") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -178,9 +173,7 @@ class AsrDataModule: if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") input_transforms.append( SpecAugment( time_warp_factor=self.args.spec_aug_time_warp_factor, @@ -250,9 +243,7 @@ class AsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py index 5784a78ba..df24d9585 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py @@ -79,11 +79,7 @@ from gigaspeech import GigaSpeech from gigaspeech_scoring import asr_text_post_processing from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -192,8 +188,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -280,9 +275,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -312,10 +305,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -446,9 +436,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -481,8 +469,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -532,9 +519,7 @@ def main(): params.suffix += f"-num-paths-{params.num_paths}" params.suffix += f"-nbest-scale-{params.nbest_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -567,8 +552,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index 8025d6be1..55585e08c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -120,11 +120,7 @@ from beam_search import ( from librispeech import LibriSpeech from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.rnn_lm.model import RnnLmModel from icefall.utils import ( @@ -265,8 +261,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -478,9 +473,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -550,10 +543,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -691,10 +681,7 @@ def decode_one_batch( return {key: hyps} else: return { - ( - f"beam_size_{params.beam_size}_" - f"temperature_{params.temperature}" - ): hyps + (f"beam_size_{params.beam_size}_" f"temperature_{params.temperature}"): hyps } @@ -779,9 +766,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -814,8 +799,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -939,9 +923,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" params.suffix += f"-temperature-{params.temperature}" else: params.suffix += f"-context-{params.context_size}" @@ -981,8 +963,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -1032,15 +1013,10 @@ def main(): word_table=word_table, device=device, ) - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) logging.info(f"G properties_str: {G.properties_str}") rnn_lm_model = None - if ( - params.decoding_method - == "fast_beam_search_with_nbest_rnn_rescoring" - ): + if params.decoding_method == "fast_beam_search_with_nbest_rnn_rescoring": rnn_lm_model = RnnLmModel( vocab_size=params.vocab_size, embedding_dim=params.rnn_lm_embedding_dim, @@ -1065,9 +1041,7 @@ def main(): rnn_lm_model.eval() else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) rnn_lm_model = None else: decoding_graph = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py index 47217ba05..2e444353c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py @@ -128,11 +128,7 @@ import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import str2bool @@ -235,8 +231,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -509,13 +504,9 @@ def export_joiner_model_onnx( - projected_decoder_out: a tensor of shape (N, joiner_dim) """ - encoder_proj_filename = str(joiner_filename).replace( - ".onnx", "_encoder_proj.onnx" - ) + encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") - decoder_proj_filename = str(joiner_filename).replace( - ".onnx", "_decoder_proj.onnx" - ) + decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] @@ -616,8 +607,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -715,9 +705,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py index 36f32c6b3..598434f54 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py @@ -52,18 +52,14 @@ class GigaSpeech: ) pattern = re.compile(r"gigaspeech_cuts_XL.([0-9]+).jsonl.gz") - idx_filenames = [ - (int(pattern.search(f).group(1)), f) for f in filenames - ] + idx_filenames = [(int(pattern.search(f).group(1)), f) for f in filenames] idx_filenames = sorted(idx_filenames, key=lambda x: x[0]) sorted_filenames = [f[1] for f in idx_filenames] logging.info(f"Loading {len(sorted_filenames)} splits") - return lhotse.combine( - lhotse.load_manifest_lazy(p) for p in sorted_filenames - ) + return lhotse.combine(lhotse.load_manifest_lazy(p) for p in sorted_filenames) def train_L_cuts(self) -> CutSet: f = self.manifest_dir / "gigaspeech_cuts_L_raw.jsonl.gz" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py index 162f8c7db..86cb45c09 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py @@ -143,8 +143,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -330,9 +329,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py index 7852f84e9..d45f6dadc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py @@ -84,9 +84,7 @@ class Transducer(nn.Module): self.decoder_giga = decoder_giga self.joiner_giga = joiner_giga - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) if decoder_giga is not None: @@ -190,9 +188,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = encoder_out_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py index d03d1d7ef..163d737e3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py @@ -203,9 +203,7 @@ def test_joiner( ) # Now test encoder_proj - joiner_encoder_proj_inputs = { - encoder_proj_input_name: encoder_out.numpy() - } + joiner_encoder_proj_inputs = {encoder_proj_input_name: encoder_out.numpy()} joiner_encoder_proj_out = joiner_encoder_proj_session.run( [encoder_proj_output_name], joiner_encoder_proj_inputs )[0] @@ -214,16 +212,10 @@ def test_joiner( torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out) assert torch.allclose( joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5 - ), ( - (joiner_encoder_proj_out - torch_joiner_encoder_proj_out) - .abs() - .max() - ) + ), ((joiner_encoder_proj_out - torch_joiner_encoder_proj_out).abs().max()) # Now test decoder_proj - joiner_decoder_proj_inputs = { - decoder_proj_input_name: decoder_out.numpy() - } + joiner_decoder_proj_inputs = {decoder_proj_input_name: decoder_out.numpy()} joiner_decoder_proj_out = joiner_decoder_proj_session.run( [decoder_proj_output_name], joiner_decoder_proj_inputs )[0] @@ -232,11 +224,7 @@ def test_joiner( torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out) assert torch.allclose( joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5 - ), ( - (joiner_decoder_proj_out - torch_joiner_decoder_proj_out) - .abs() - .max() - ) + ), ((joiner_decoder_proj_out - torch_joiner_decoder_proj_out).abs().max()) @torch.no_grad() @@ -288,9 +276,7 @@ def main(): if __name__ == "__main__": torch.manual_seed(20220727) - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py index ea5d4e674..825c6510b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py @@ -141,8 +141,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -191,11 +190,7 @@ def greedy_search( projected_encoder_out = joiner_encoder_proj.run( [joiner_encoder_proj.get_outputs()[0].name], - { - joiner_encoder_proj.get_inputs()[ - 0 - ].name: packed_encoder_out.data.numpy() - }, + {joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()}, )[0] blank_id = 0 # hard-code to 0 @@ -382,9 +377,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py index 19b636a23..77bd6d13d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py @@ -177,8 +177,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -232,8 +231,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -302,9 +300,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -391,9 +387,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py index 1e6022b57..b712eeda0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py @@ -234,9 +234,7 @@ def scaled_lstm_to_lstm(scaled_lstm: ScaledLSTM) -> nn.LSTM: assert lstm._flat_weights_names == scaled_lstm._flat_weights_names for idx in range(len(scaled_lstm._flat_weights_names)): - scaled_weight = ( - scaled_lstm._flat_weights[idx] * scaled_lstm._scales[idx].exp() - ) + scaled_weight = scaled_lstm._flat_weights[idx] * scaled_lstm._scales[idx].exp() lstm._flat_weights[idx].data.copy_(scaled_weight) return lstm diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py index 10bb44e00..e85d2060a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py @@ -52,11 +52,7 @@ from streaming_beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -163,8 +159,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -272,9 +267,7 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -294,9 +287,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -352,9 +343,7 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device - ) + initial_states = model.encoder.get_init_state(params.left_context, device=device) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -426,9 +415,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -461,8 +448,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -535,8 +521,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py index 66ffbd3ec..598fcf344 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py @@ -90,9 +90,7 @@ def test_conv2d_subsampling(): onnx_y = torch.from_numpy(onnx_y) torch_y = jit_model(x) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( - (onnx_y - torch_y).abs().max() - ) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() os.remove(filename) @@ -147,9 +145,7 @@ def test_rel_pos(): onnx_pos_emb = torch.from_numpy(onnx_pos_emb) torch_y, torch_pos_emb = jit_model(x) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( - (onnx_y - torch_y).abs().max() - ) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() assert torch.allclose(onnx_pos_emb, torch_pos_emb, atol=1e-05), ( (onnx_pos_emb - torch_pos_emb).abs().max() @@ -197,9 +193,7 @@ def test_conformer_encoder_layer(): encoder_layer.eval() encoder_layer = convert_scaled_to_non_scaled(encoder_layer, inplace=True) - jit_model = torch.jit.trace( - encoder_layer, (x, pos_emb, src_key_padding_mask) - ) + jit_model = torch.jit.trace(encoder_layer, (x, pos_emb, src_key_padding_mask)) torch.onnx.export( encoder_layer, @@ -236,9 +230,7 @@ def test_conformer_encoder_layer(): onnx_y = torch.from_numpy(onnx_y) torch_y = jit_model(x, pos_emb, src_key_padding_mask) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( - (onnx_y - torch_y).abs().max() - ) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) @@ -322,9 +314,7 @@ def test_conformer_encoder(): onnx_y = torch.from_numpy(onnx_y) torch_y = jit_model(x, pos_emb, src_key_padding_mask) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( - (onnx_y - torch_y).abs().max() - ) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) @@ -379,9 +369,7 @@ def test_conformer(): onnx_y_lens = torch.from_numpy(onnx_y_lens) torch_y, torch_y_lens = jit_model(x, x_lens) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( - (onnx_y - torch_y).abs().max() - ) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() assert torch.allclose(onnx_y_lens, torch_y_lens, atol=1e-05), ( (onnx_y_lens - torch_y_lens).abs().max() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index 44e96644a..e9ceb60de 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -92,9 +92,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -163,8 +161,7 @@ def get_parser(): "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. " - "Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. " "Otherwise, use 100h subset.", ) parser.add_argument( @@ -214,8 +211,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need " - "to be changed.", + help="The initial learning rate. This value should not need " "to be changed.", ) parser.add_argument( @@ -238,8 +234,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -262,8 +257,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -672,9 +666,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -687,14 +679,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -705,9 +692,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -919,9 +904,7 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) libri_tot_loss.write_summary( tb_writer, "train/libri_tot_", params.batch_idx_train ) @@ -967,8 +950,7 @@ def filter_short_and_long_utterances( # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}" ) return False @@ -1109,9 +1091,7 @@ def run(rank, world_size, args): train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: - cuts_musan = load_manifest( - Path(args.manifest_dir) / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") else: cuts_musan = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 4f043e5a6..2f9a60f13 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -306,8 +306,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -427,9 +426,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) if ( params.decoding_method == "fast_beam_search" @@ -485,10 +482,7 @@ def decode_one_batch( nbest_scale=params.nbest_scale, return_timestamps=True, ) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: res = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -566,9 +560,7 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[ - str, List[Tuple[str, List[str], List[str], List[float], List[float]]] -]: +) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]: """Decode dataset. Args: @@ -643,9 +635,7 @@ def decode_dataset( cut_ids, hyps, texts, timestamps_hyp, timestamps_ref ): ref_words = ref_text.split() - this_batch.append( - (cut_id, ref_words, hyp_words, time_ref, time_hyp) - ) + this_batch.append((cut_id, ref_words, hyp_words, time_ref, time_hyp)) results[name].extend(this_batch) @@ -654,9 +644,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -694,8 +682,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -722,9 +709,7 @@ def save_results( note = "" logging.info(s) - s = "\nFor {}, symbol-delay of different settings are:\n".format( - test_set_name - ) + s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name) note = "\tbest for {}".format(test_set_name) for key, val in test_set_delays: s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note) @@ -773,9 +758,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -812,9 +795,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -841,9 +824,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -902,9 +885,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py index ce7518ceb..64ef89733 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py @@ -133,8 +133,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -183,9 +182,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -212,9 +211,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -282,9 +281,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py index 7af9ea9b8..d74d1c89d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py @@ -175,8 +175,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -284,9 +283,7 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -306,9 +303,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -364,9 +359,7 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device - ) + initial_states = model.encoder.get_init_state(params.left_context, device=device) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -438,9 +431,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -473,8 +464,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -547,9 +537,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -576,9 +566,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index cf32e565b..97f3e56a9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -101,9 +101,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -239,8 +237,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -263,8 +260,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -621,11 +617,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -665,9 +657,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -680,14 +670,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -698,9 +683,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -879,9 +862,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -1013,8 +994,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index 427b06294..b3a7d71bc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -214,10 +214,7 @@ class Conformer(EncoderInterface): (num_encoder_layers, cnn_module_kernel - 1, encoder_dim). NOTE: the returned tensors are on the given device. """ - if ( - len(self._init_state) == 2 - and self._init_state[0].size(1) == left_context - ): + if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context: # Note: It is OK to share the init state as it is # not going to be modified by the model return self._init_state @@ -439,9 +436,7 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -459,9 +454,7 @@ class ConformerEncoderLayer(nn.Module): ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) - self.conv_module = ConvolutionModule( - d_model, cnn_module_kernel, causal=causal - ) + self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal) self.norm_final = BasicNorm(d_model) @@ -527,9 +520,7 @@ class ConformerEncoderLayer(nn.Module): src = src + self.dropout(src_att) # convolution module - conv, _ = self.conv_module( - src, src_key_padding_mask=src_key_padding_mask - ) + conv, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask) src = src + self.dropout(conv) # feed forward module @@ -802,9 +793,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -820,9 +809,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x_size_1 * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -848,9 +835,7 @@ class RelPositionalEncoding(torch.nn.Module): pe = torch.cat([pe_positive, pe_negative], dim=1) self.pe = pe.to(device=x.device, dtype=x.dtype) - def forward( - self, x: torch.Tensor, left_context: int = 0 - ) -> Tuple[Tensor, Tensor]: + def forward(self, x: torch.Tensor, left_context: int = 0) -> Tuple[Tensor, Tensor]: """Add positional encoding. Args: @@ -1118,9 +1103,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -1189,31 +1174,22 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -1253,23 +1229,15 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d - matrix_bd = torch.matmul( - q_with_bias_v, p - ) # (batch, head, time1, 2*time1-1) + matrix_bd = torch.matmul(q_with_bias_v, p) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd, left_context) - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -1310,21 +1278,17 @@ class RelPositionMultiheadAttention(nn.Module): ): if attn_mask.size(0) != 1: attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len) - combined_mask = attn_mask | key_padding_mask.unsqueeze( - 1 - ).unsqueeze(2) + combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) else: # attn_mask.shape == (1, tgt_len, src_len) - combined_mask = attn_mask.unsqueeze( - 0 - ) | key_padding_mask.unsqueeze(1).unsqueeze(2) + combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( + 1 + ).unsqueeze(2) attn_output_weights = attn_output_weights.view( bsz, num_heads, tgt_len, src_len ) - attn_output_weights = attn_output_weights.masked_fill( - combined_mask, 0.0 - ) + attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, src_len ) @@ -1336,13 +1300,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -1481,16 +1441,12 @@ class ConvolutionModule(nn.Module): # manualy padding self.lorder zeros to the left x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) else: - assert ( - not self.training - ), "Cache should be None in training time" + assert not self.training, "Cache should be None in training time" assert cache.size(0) == self.lorder x = torch.cat([cache.permute(1, 2, 0), x], dim=2) if right_context > 0: cache = x.permute(2, 0, 1)[ - -(self.lorder + right_context) : ( # noqa - -right_context - ), + -(self.lorder + right_context) : (-right_context), # noqa ..., ] else: @@ -1666,9 +1622,7 @@ class RandomCombine(nn.Module): self.stddev = stddev self.final_log_weight = ( - torch.tensor( - (final_weight / (1 - final_weight)) * (self.num_inputs - 1) - ) + torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)) .log() .item() ) @@ -1765,16 +1719,14 @@ class RandomCombine(nn.Module): # final contains self.num_inputs - 1 in all elements final = torch.full((num_frames,), self.num_inputs - 1, device=device) # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. - nonfinal = torch.randint( - self.num_inputs - 1, (num_frames,), device=device - ) + nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device) indexes = torch.where( torch.rand(num_frames, device=device) < final_prob, final, nonfinal ) - ans = torch.nn.functional.one_hot( - indexes, num_classes=self.num_inputs - ).to(dtype=dtype) + ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to( + dtype=dtype + ) return ans def _get_random_mixed_weights( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 22bcdd88e..5c76afde6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -303,8 +303,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -477,9 +476,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -545,10 +542,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -696,9 +690,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -731,8 +723,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -787,9 +778,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -828,9 +817,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -857,9 +846,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -937,9 +926,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py index b2e5b430e..f0bfd3b4c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py @@ -133,8 +133,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -181,9 +180,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -210,9 +209,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -280,9 +279,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py index 1e100fcbd..77ba0873b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py @@ -166,8 +166,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -199,8 +198,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -264,15 +262,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -344,9 +338,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py index 6fee9483e..e750f5554 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py @@ -175,8 +175,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -284,9 +283,7 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -306,9 +303,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -364,9 +359,7 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device - ) + initial_states = model.encoder.get_init_state(params.left_context, device=device) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -438,9 +431,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -473,8 +464,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -547,9 +537,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -576,9 +566,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index 179d9372e..a1a810d3e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -89,9 +89,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -248,8 +246,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need " - "to be changed.", + help="The initial learning rate. This value should not need " "to be changed.", ) parser.add_argument( @@ -272,8 +269,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -296,8 +292,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -645,11 +640,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -690,9 +681,7 @@ def compute_loss( # If the batch contains more than 10 utterances AND # if either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -705,14 +694,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -723,9 +707,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -908,9 +890,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -1023,7 +1003,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1045,8 +1025,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py index 53788b3f7..0667e7f61 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py @@ -90,10 +90,7 @@ class Conformer(EncoderInterface): output_layers = [] if middle_output_layer is not None: - assert ( - middle_output_layer >= 0 - and middle_output_layer < num_encoder_layers - ) + assert middle_output_layer >= 0 and middle_output_layer < num_encoder_layers output_layers.append(middle_output_layer) # The last layer is always needed. @@ -178,9 +175,7 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -362,9 +357,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -379,9 +372,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -656,9 +647,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -727,31 +718,22 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -790,9 +772,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -800,13 +780,9 @@ class RelPositionMultiheadAttention(nn.Module): ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -840,13 +816,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -869,9 +841,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py index 74df04006..3734564fe 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py @@ -208,8 +208,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -267,9 +266,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - layer_results, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + layer_results, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) encoder_out = layer_results[-1] hyps = [] @@ -285,10 +282,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -411,9 +405,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -446,8 +438,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -490,9 +481,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -524,9 +513,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -553,9 +542,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/export.py b/egs/librispeech/ASR/pruned_transducer_stateless6/export.py index cff9c7377..3d1e7bc18 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/export.py @@ -51,11 +51,7 @@ import sentencepiece as spm import torch from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import str2bool @@ -120,8 +116,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) return parser @@ -160,8 +155,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -209,9 +203,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py b/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py index 21409287c..86cf34877 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py @@ -21,9 +21,10 @@ import os from pathlib import Path import torch -from vq_utils import CodebookIndexExtractor from asr_datamodule import LibriSpeechAsrDataModule from hubert_xlarge import HubertXlargeFineTuned +from vq_utils import CodebookIndexExtractor + from icefall.utils import AttributeDict, str2bool diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py index 49b557814..b8440f90a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py @@ -23,7 +23,6 @@ from pathlib import Path from typing import Dict, List, Tuple import torch - from asr_datamodule import LibriSpeechAsrDataModule from hubert_xlarge import HubertXlargeFineTuned @@ -99,9 +98,7 @@ def decode_dataset( if batch_idx % 20 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -124,9 +121,7 @@ def save_results( ) test_set_wers[key] = wer - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.res_dir / f"wer-summary-{test_set_name}.txt" @@ -155,9 +150,7 @@ def main(): # reset some parameters needed by hubert. params.update(HubertXlargeFineTuned.get_params()) - params.res_dir = ( - params.exp_dir / f"ctc_greedy_search-{params.teacher_model_id}" - ) + params.res_dir = params.exp_dir / f"ctc_greedy_search-{params.teacher_model_id}" setup_logger(f"{params.res_dir}/log/log-ctc_greedy_search") logging.info("Decoding started") @@ -190,9 +183,7 @@ def main(): params=params, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py index 55ce7b00d..4f9417c9f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py @@ -22,11 +22,7 @@ from pathlib import Path from typing import Dict, List, Tuple import torch -from fairseq import ( - checkpoint_utils, - tasks, - utils, -) +from fairseq import checkpoint_utils, tasks, utils from fairseq.data.data_utils import post_process from omegaconf import OmegaConf @@ -51,9 +47,7 @@ def _load_hubert_model(params: AttributeDict): "data": str(params.hubert_model_dir), } ) - model_path = Path(params.hubert_model_dir) / ( - params.teacher_model_id + ".pt" - ) + model_path = Path(params.hubert_model_dir) / (params.teacher_model_id + ".pt") task = tasks.setup_task(cfg_task) processor = task.target_dictionary models, saved_cfg = checkpoint_utils.load_model_ensemble( @@ -151,9 +145,7 @@ class HubertXlargeFineTuned: supervisions = batch["supervisions"] num_samples = supervisions["num_samples"] B, T = features.shape - padding_mask = torch.arange(0, T).expand(B, T) > num_samples.reshape( - [-1, 1] - ) + padding_mask = torch.arange(0, T).expand(B, T) > num_samples.reshape([-1, 1]) padding_mask = padding_mask.to(self.params.device) features = features.to(self.params.device) @@ -163,9 +155,7 @@ class HubertXlargeFineTuned: features = features.transpose(1, 2) features = self.w2v_model.layer_norm(features) - padding_mask = self.w2v_model.forward_padding_mask( - features, padding_mask - ) + padding_mask = self.w2v_model.forward_padding_mask(features, padding_mask) if self.w2v_model.post_extract_proj is not None: features = self.w2v_model.post_extract_proj(features) @@ -212,9 +202,7 @@ class HubertXlargeFineTuned: toks = encoder_out.argmax(dim=-1) blank = 0 toks = [tok.unique_consecutive() for tok in toks] - hyps = [ - self.processor.string(tok[tok != blank].int().cpu()) for tok in toks - ] + hyps = [self.processor.string(tok[tok != blank].int().cpu()) for tok in toks] hyps = [post_process(hyp, "letter") for hyp in hyps] return hyps diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py index 7716d19cf..daadb70c9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py @@ -69,9 +69,7 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) from icefall import is_module_available @@ -180,9 +178,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens @@ -237,9 +233,7 @@ class Transducer(nn.Module): return (simple_loss, pruned_loss, codebook_loss) @staticmethod - def concat_successive_codebook_indexes( - middle_layer_output, codebook_indexes - ): + def concat_successive_codebook_indexes(middle_layer_output, codebook_indexes): # Output rate of hubert is 50 frames per second, # while that of current encoder is 25. # Following code handling two issues: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py index f717d85fb..a24becb14 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py @@ -101,9 +101,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def get_parser(): @@ -203,8 +201,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -227,8 +224,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -569,9 +565,7 @@ def save_checkpoint( def extract_codebook_indexes(batch): cuts = batch["supervisions"]["cut"] # -100 is identical to ignore_value in CE loss computation. - cuts_pre_mixed = [ - c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts - ] + cuts_pre_mixed = [c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts] codebook_indexes, codebook_indexes_lens = collate_custom_field( cuts_pre_mixed, "codebook_indexes", pad_value=-100 ) @@ -604,11 +598,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -655,9 +645,7 @@ def compute_loss( # If the batch contains more than 10 utterances AND # if either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -670,14 +658,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss if is_training and params.enable_distillation: assert codebook_loss is not None loss += params.codebook_loss_scale * codebook_loss @@ -690,9 +673,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -873,9 +854,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -1007,8 +986,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py index 47cf2b14b..97a83b974 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py @@ -68,9 +68,7 @@ class CodebookIndexExtractor: def init_dirs(self): # vq_dir is the root dir for quantization, containing: # training data, trained quantizer, and extracted codebook indexes - self.vq_dir = ( - self.params.exp_dir / f"vq/{self.params.teacher_model_id}/" - ) + self.vq_dir = self.params.exp_dir / f"vq/{self.params.teacher_model_id}/" self.vq_dir.mkdir(parents=True, exist_ok=True) # manifest_dir contains: @@ -208,9 +206,7 @@ class CodebookIndexExtractor: start = cur_offset % (data.shape[0] + 1 - B) end = start + B cur_offset += B - yield data[start:end, :].to(self.params.device).to( - dtype=torch.float - ) + yield data[start:end, :].to(self.params.device).to(dtype=torch.float) for x in minibatch_generator(train, repeat=True): trainer.step(x) @@ -227,9 +223,7 @@ class CodebookIndexExtractor: """ for subset in self.params.subsets: logging.info(f"About to split {subset}.") - ori_manifest = ( - f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" - ) + ori_manifest = f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" split_cmd = f"lhotse split {self.params.world_size} {ori_manifest} {self.manifest_dir}" os.system(f"{split_cmd}") @@ -240,16 +234,13 @@ class CodebookIndexExtractor: logging.info("Start to join manifest files.") for subset in self.params.subsets: vq_manifest_path = ( - self.dst_manifest_dir - / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" + self.dst_manifest_dir / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" ) ori_manifest_path = ( - self.ori_manifest_dir - / f"librispeech_cuts_train-{subset}.jsonl.gz" + self.ori_manifest_dir / f"librispeech_cuts_train-{subset}.jsonl.gz" ) dst_vq_manifest_path = ( - self.dst_manifest_dir - / f"librispeech_cuts_train-{subset}.jsonl.gz" + self.dst_manifest_dir / f"librispeech_cuts_train-{subset}.jsonl.gz" ) cuts_vq = load_manifest(vq_manifest_path) cuts_ori = load_manifest(ori_manifest_path) @@ -269,8 +260,7 @@ class CodebookIndexExtractor: for subset in self.params.subsets: vq_manifests = f"{self.manifest_dir}/with_codebook_indexes-librispeech-cuts_train-{subset}*.jsonl.gz" dst_vq_manifest = ( - self.dst_manifest_dir - / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" + self.dst_manifest_dir / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" ) if 1 == self.params.world_size: merge_cmd = f"cp {vq_manifests} {dst_vq_manifest}" @@ -330,9 +320,7 @@ class CodebookIndexExtractor: def load_ori_dl(self, subset): if self.params.world_size == 1: - ori_manifest_path = ( - f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" - ) + ori_manifest_path = f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" else: ori_manifest_path = ( self.manifest_dir diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py index 06c5863f1..162966df8 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py @@ -272,8 +272,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -393,9 +392,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -454,10 +451,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -588,9 +582,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -623,8 +615,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -679,9 +670,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -718,9 +707,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -747,9 +736,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -808,9 +797,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py index 712dc8ce1..5f90e6375 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py @@ -69,7 +69,7 @@ class Decoder(nn.Module): out_channels=decoder_dim, kernel_size=context_size, padding=0, - groups=decoder_dim//4, # group size == 4 + groups=decoder_dim // 4, # group size == 4 bias=False, ) @@ -91,9 +91,7 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py index 5744ea3ea..57af52fb1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py @@ -176,8 +176,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) add_model_arguments(parser) @@ -215,9 +214,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -244,9 +243,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -316,9 +315,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py index e2405d5ef..f469442ed 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py @@ -94,8 +94,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -267,9 +266,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py index 7d8de5afe..3ddac2cf2 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py @@ -56,9 +56,7 @@ class Joiner(nn.Module): assert encoder_out.shape[:-1] == decoder_out.shape[:-1] if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj( - decoder_out - ) + logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index 53cde6c6f..0e59b0f2f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -15,14 +15,15 @@ # limitations under the License. +import random + import k2 import torch import torch.nn as nn -import random from encoder_interface import EncoderInterface +from scaling import penalize_abs_values_gt from icefall.utils import add_sos -from scaling import penalize_abs_values_gt class Transducer(nn.Module): @@ -65,7 +66,8 @@ class Transducer(nn.Module): self.joiner = joiner self.simple_am_proj = nn.Linear( - encoder_dim, vocab_size, + encoder_dim, + vocab_size, ) self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size) @@ -133,18 +135,16 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - #if self.training and random.random() < 0.25: + # if self.training and random.random() < 0.25: # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) - #if self.training and random.random() < 0.25: + # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) with torch.cuda.amp.autocast(enabled=False): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index bb8b0a0e3..8b90c9a0d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -14,17 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import defaultdict -from typing import List, Optional, Union, Tuple, List -from lhotse.utils import fix_random_seed -import torch -from scaling import ActivationBalancer +import contextlib +import logging import random +from collections import defaultdict +from typing import List, Optional, Tuple, Union + +import torch +from lhotse.utils import fix_random_seed +from scaling import ActivationBalancer from torch import Tensor from torch.optim import Optimizer -import logging -import contextlib - class BatchedOptimizer(Optimizer): @@ -37,11 +37,10 @@ class BatchedOptimizer(Optimizer): Args: params: """ + def __init__(self, params, defaults): super(BatchedOptimizer, self).__init__(params, defaults) - - @contextlib.contextmanager def batched_params(self, param_group): """ @@ -73,7 +72,9 @@ class BatchedOptimizer(Optimizer): group: a parameter group, which is a list of parameters; should be one of self.groups. """ - batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter + batches = defaultdict( + list + ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter for p in param_group: key = (str(p.dtype), *p.shape) @@ -82,7 +83,7 @@ class BatchedOptimizer(Optimizer): stacked_params_dict = dict() # turn batches into a list, in deterministic order. - batches = [ batches[key] for key in sorted(batches.keys()) ] + batches = [batches[key] for key in sorted(batches.keys())] # pairs will contain pairs of (stacked_param, state), one for each batch # in `batches`. pairs = [] @@ -94,76 +95,77 @@ class BatchedOptimizer(Optimizer): # group. class Optimizer will take care of saving/loading state. state = self.state[p] p_stacked = torch.stack(batch) - grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch ]) + grad = torch.stack( + [torch.zeros_like(p) if p.grad is None else p.grad for p in batch] + ) p_stacked.grad = grad stacked_params_dict[key] = p_stacked pairs.append((p_stacked, state)) - yield pairs # <-- calling code will do the actual optimization here! + yield pairs # <-- calling code will do the actual optimization here! for ((stacked_params, _state), batch) in zip(pairs, batches): for i, p in enumerate(batch): # batch is list of Parameter p.copy_(stacked_params[i]) - class ScaledAdam(BatchedOptimizer): """ - Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update - proportional to the norm of that parameter; and also learn the scale of the parameter, - in log space, subject to upper and lower limits (as if we had factored each parameter as - param = underlying_param * log_scale.exp()) + Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update + proportional to the norm of that parameter; and also learn the scale of the parameter, + in log space, subject to upper and lower limits (as if we had factored each parameter as + param = underlying_param * log_scale.exp()) - Args: - params: The parameters or param_groups to optimize (like other Optimizer subclasses) - lr: The learning rate. We will typically use a learning rate schedule that starts - at 0.03 and decreases over time, i.e. much higher than other common - optimizers. - clipping_scale: (e.g. 2.0) - A scale for gradient-clipping: if specified, the normalized gradients - over the whole model will be clipped to have 2-norm equal to - `clipping_scale` times the median 2-norm over the most recent period - of `clipping_update_period` minibatches. By "normalized gradients", - we mean after multiplying by the rms parameter value for this tensor - [for non-scalars]; this is appropriate because our update is scaled - by this quantity. - betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. - Must satisfy 0 < beta <= beta2 < 1. - scalar_lr_scale: A scaling factor on the learning rate, that we use to update the - scale of each parameter tensor and scalar parameters of the mode.. - If each parameter were decomposed - as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale - would be a the scaling factor on the learning rate of p_scale. - eps: A general-purpose epsilon to prevent division by zero - param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be >= this value) - param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be <= this value) - scalar_max: Maximum absolute value for scalar parameters (applicable if your - model has any parameters with numel() == 1). - size_update_period: The periodicity, in steps, with which we update the size (scale) - of the parameter tensor. This is provided to save a little time - in the update. - clipping_update_period: if clipping_scale is specified, this is the period + Args: + params: The parameters or param_groups to optimize (like other Optimizer subclasses) + lr: The learning rate. We will typically use a learning rate schedule that starts + at 0.03 and decreases over time, i.e. much higher than other common + optimizers. + clipping_scale: (e.g. 2.0) + A scale for gradient-clipping: if specified, the normalized gradients + over the whole model will be clipped to have 2-norm equal to + `clipping_scale` times the median 2-norm over the most recent period + of `clipping_update_period` minibatches. By "normalized gradients", + we mean after multiplying by the rms parameter value for this tensor + [for non-scalars]; this is appropriate because our update is scaled + by this quantity. + betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. + Must satisfy 0 < beta <= beta2 < 1. + scalar_lr_scale: A scaling factor on the learning rate, that we use to update the + scale of each parameter tensor and scalar parameters of the mode.. + If each parameter were decomposed + as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale + would be a the scaling factor on the learning rate of p_scale. + eps: A general-purpose epsilon to prevent division by zero + param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be >= this value) + param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be <= this value) + scalar_max: Maximum absolute value for scalar parameters (applicable if your + model has any parameters with numel() == 1). + size_update_period: The periodicity, in steps, with which we update the size (scale) + of the parameter tensor. This is provided to save a little time + in the update. + clipping_update_period: if clipping_scale is specified, this is the period """ - def __init__( - self, - params, - lr=3e-02, - clipping_scale=None, - betas=(0.9, 0.98), - scalar_lr_scale=0.1, - eps=1.0e-08, - param_min_rms=1.0e-05, - param_max_rms=3.0, - scalar_max=10.0, - size_update_period=4, - clipping_update_period=100, - ): + def __init__( + self, + params, + lr=3e-02, + clipping_scale=None, + betas=(0.9, 0.98), + scalar_lr_scale=0.1, + eps=1.0e-08, + param_min_rms=1.0e-05, + param_max_rms=3.0, + scalar_max=10.0, + size_update_period=4, + clipping_update_period=100, + ): defaults = dict( lr=lr, @@ -183,7 +185,6 @@ class ScaledAdam(BatchedOptimizer): def __setstate__(self, state): super(ScaledAdam, self).__setstate__(state) - @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. @@ -206,7 +207,9 @@ class ScaledAdam(BatchedOptimizer): # a regular parameter, and will have a .grad, but the 1st dim corresponds to # a stacking dim, it is not a real dim. - if len(batches[0][1]) == 0: # if len(first state) == 0: not yet initialized + if ( + len(batches[0][1]) == 0 + ): # if len(first state) == 0: not yet initialized clipping_scale = 1 else: clipping_scale = self._get_clipping_scale(group, batches) @@ -225,13 +228,9 @@ class ScaledAdam(BatchedOptimizer): self._step_one_batch(group, p, state, clipping_scale) - return loss - def _init_state(self, - group: dict, - p: Tensor, - state: dict): + def _init_state(self, group: dict, p: Tensor, state: dict): """ Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p is actually the batch dimension, corresponding to batched-together @@ -247,7 +246,7 @@ class ScaledAdam(BatchedOptimizer): state["step"] = 0 - kwargs = {'device':p.device, 'dtype':p.dtype} + kwargs = {"device": p.device, "dtype": p.dtype} # 'delta' implements conventional momentum. There are # several different kinds of update going on, so rather than @@ -255,36 +254,30 @@ class ScaledAdam(BatchedOptimizer): # parameter-change "delta", which combines all forms of # update. this is equivalent to how it's done in Adam, # except for the first few steps. - state["delta"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) + state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format) batch_size = p.shape[0] numel = p.numel() // batch_size numel = p.numel() - if numel > 1: # "param_rms" just periodically records the scalar root-mean-square value of # the parameter tensor. # it has a shape like (batch_size, 1, 1, 1, 1) - param_rms = (p**2).mean(dim=list(range(1, p.ndim)), - keepdim=True).sqrt() + param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() state["param_rms"] = param_rms state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) - state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape, - **kwargs) - + state["scale_grads"] = torch.zeros( + size_update_period, *param_rms.shape, **kwargs + ) # exp_avg_sq is the weighted sum of scaled gradients. as in Adam. - state["exp_avg_sq"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) + state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) - def _get_clipping_scale(self, - group: dict, - pairs: List[Tuple[Tensor, dict]]) -> float: + def _get_clipping_scale( + self, group: dict, pairs: List[Tuple[Tensor, dict]] + ) -> float: """ Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients by this amount before applying the rest of the update. @@ -314,57 +307,65 @@ class ScaledAdam(BatchedOptimizer): if p.numel() == p.shape[0]: # a batch of scalars tot_sumsq += (grad**2).sum() # sum() to change shape [1] to [] else: - tot_sumsq += ((grad * state["param_rms"])**2).sum() + tot_sumsq += ((grad * state["param_rms"]) ** 2).sum() tot_norm = tot_sumsq.sqrt() - if not "model_norms" in first_state: - first_state["model_norms"] = torch.zeros(clipping_update_period, - device=p.device) + if "model_norms" not in first_state: + first_state["model_norms"] = torch.zeros( + clipping_update_period, device=p.device + ) first_state["model_norms"][step % clipping_update_period] = tot_norm if step % clipping_update_period == 0: # Print some stats. # We don't reach here if step == 0 because we would have returned # above. - sorted_norms = first_state["model_norms"].sort()[0].to('cpu') + sorted_norms = first_state["model_norms"].sort()[0].to("cpu") quartiles = [] for n in range(0, 5): - index = min(clipping_update_period - 1, - (clipping_update_period // 4) * n) + index = min( + clipping_update_period - 1, (clipping_update_period // 4) * n + ) quartiles.append(sorted_norms[index].item()) median = quartiles[2] threshold = clipping_scale * median first_state["model_norm_threshold"] = threshold - percent_clipped = (first_state["num_clipped"] * 100.0 / clipping_update_period - if "num_clipped" in first_state else 0.0) + percent_clipped = ( + first_state["num_clipped"] * 100.0 / clipping_update_period + if "num_clipped" in first_state + else 0.0 + ) first_state["num_clipped"] = 0 - quartiles = ' '.join([ '%.3e' % x for x in quartiles ]) - logging.info(f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " - f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}") + quartiles = " ".join(["%.3e" % x for x in quartiles]) + logging.info( + f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " + f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}" + ) if step < clipping_update_period: return 1.0 # We have not yet estimated a norm to clip to. else: try: model_norm_threshold = first_state["model_norm_threshold"] - except: - logging.info("Warning: model_norm_threshold not in state: possibly " - "you changed config when restarting, adding clipping_scale option?") + except KeyError: + logging.info( + "Warning: model_norm_threshold not in state: possibly " + "you changed config when restarting, adding clipping_scale option?" + ) return 1.0 - ans = min(1.0,(model_norm_threshold / (tot_norm + 1.0e-20)).item()) + ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item()) if ans < 1.0: first_state["num_clipped"] += 1 if ans < 0.1: - logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}") + logging.warn( + f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}" + ) return ans - - def _step_one_batch(self, - group: dict, - p: Tensor, - state: dict, - clipping_scale: float): + def _step_one_batch( + self, group: dict, p: Tensor, state: dict, clipping_scale: float + ): """ Do the step for one parameter, which is actually going to be a batch of `real` parameters, with dim 0 as the batch dim. @@ -391,17 +392,18 @@ class ScaledAdam(BatchedOptimizer): # Update the size/scale of p, and set param_rms scale_grads = state["scale_grads"] scale_grads[step % size_update_period] = (p * grad).sum( - dim=list(range(1, p.ndim)), keepdim=True) + dim=list(range(1, p.ndim)), keepdim=True + ) if step % size_update_period == size_update_period - 1: param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..) - param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), - keepdim=True).sqrt()) + param_rms.copy_( + (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() + ) if step > 0: # self._size_update() learns the overall scale on the # parameter, by shrinking or expanding it. self._size_update(group, scale_grads, p, state) - if numel == 1: # For parameters with 1 element we just use regular Adam. # Updates delta. @@ -411,24 +413,21 @@ class ScaledAdam(BatchedOptimizer): state["step"] = step + 1 - - def _size_update(self, - group: dict, - scale_grads: Tensor, - p: Tensor, - state: dict) -> None: + def _size_update( + self, group: dict, scale_grads: Tensor, p: Tensor, state: dict + ) -> None: """ - Called only where p.numel() > 1, this updates the scale of the parameter. - If we imagine: p = underlying_param * scale.exp(), and we are doing - gradient descent on underlying param and on scale, this function does the update - on `scale`. + Called only where p.numel() > 1, this updates the scale of the parameter. + If we imagine: p = underlying_param * scale.exp(), and we are doing + gradient descent on underlying param and on scale, this function does the update + on `scale`. - Args: - group: dict to look up configuration values - scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing - grads w.r.t. the scales. - p: The parameter to update - state: The state-dict of p + Args: + group: dict to look up configuration values + scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing + grads w.r.t. the scales. + p: The parameter to update + state: The state-dict of p """ param_rms = state["param_rms"] @@ -443,25 +442,28 @@ class ScaledAdam(BatchedOptimizer): size_update_period = scale_grads.shape[0] # correct beta2 for the size update period: we will have # faster decay at this level. - beta2_corr = beta2 ** size_update_period + beta2_corr = beta2**size_update_period scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..) scale_exp_avg_sq.mul_(beta2_corr).add_( - (scale_grads ** 2).mean(dim=0), # mean over dim `size_update_period` - alpha=1-beta2_corr) # shape is (batch_size, 1, 1, ...) + (scale_grads**2).mean(dim=0), # mean over dim `size_update_period` + alpha=1 - beta2_corr, + ) # shape is (batch_size, 1, 1, ...) # The 1st time we reach here is when size_step == 1. size_step = (step + 1) // size_update_period - bias_correction2 = 1 - beta2_corr ** size_step + bias_correction2 = 1 - beta2_corr**size_step # we don't bother with bias_correction1; this will help prevent divergence # at the start of training. denom = scale_exp_avg_sq.sqrt() + eps - scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum(dim=0) / denom + scale_step = ( + -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom + ) - is_too_small = (param_rms < param_min_rms) - is_too_large = (param_rms > param_max_rms) + is_too_small = param_rms < param_min_rms + is_too_large = param_rms > param_max_rms # when the param gets too small, just don't shrink it any further. scale_step.masked_fill_(is_too_small, 0.0) @@ -469,13 +471,9 @@ class ScaledAdam(BatchedOptimizer): scale_step.masked_fill_(is_too_large, -size_lr * size_update_period) delta = state["delta"] # the factor of (1-beta1) relates to momentum. - delta.add_(p * scale_step, alpha=(1-beta1)) + delta.add_(p * scale_step, alpha=(1 - beta1)) - - def _step(self, - group: dict, - p: Tensor, - state: dict): + def _step(self, group: dict, p: Tensor, state: dict): """ This function does the core update of self.step(), in the case where the members of the batch have more than 1 element. @@ -496,8 +494,7 @@ class ScaledAdam(BatchedOptimizer): step = state["step"] exp_avg_sq = state["exp_avg_sq"] - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, - value=(1-beta2)) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0) bias_correction2 = 1 - beta2 ** (this_step + 1) @@ -509,17 +506,13 @@ class ScaledAdam(BatchedOptimizer): denom += eps grad = grad / denom - alpha = -lr * (1-beta1) * state["param_rms"].clamp(min=param_min_rms) + alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms) delta = state["delta"] delta.add_(grad * alpha) p.add_(delta) - - def _step_scalar(self, - group: dict, - p: Tensor, - state: dict): + def _step_scalar(self, group: dict, p: Tensor, state: dict): """ A simplified form of the core update for scalar tensors, where we cannot get a good estimate of the parameter rms. @@ -531,8 +524,7 @@ class ScaledAdam(BatchedOptimizer): grad = p.grad exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, - value=1-beta2) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # bias_correction2 is like in Adam. Don't bother with bias_correction1; # slower update at the start will help stability anyway. @@ -540,12 +532,11 @@ class ScaledAdam(BatchedOptimizer): denom = (exp_avg_sq / bias_correction2).sqrt() + eps delta = state["delta"] - delta.add_(grad / denom, alpha=-lr*(1-beta1)) + delta.add_(grad / denom, alpha=-lr * (1 - beta1)) p.clamp_(min=-scalar_max, max=scalar_max) p.add_(delta) - class LRScheduler(object): """ Base-class for learning rate schedulers where the learning-rate depends on both the @@ -555,18 +546,14 @@ class LRScheduler(object): def __init__(self, optimizer: Optimizer, verbose: bool = False): # Attach optimizer if not isinstance(optimizer, Optimizer): - raise TypeError( - "{} is not an Optimizer".format(type(optimizer).__name__) - ) + raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) self.optimizer = optimizer self.verbose = verbose for group in optimizer.param_groups: group.setdefault("base_lr", group["lr"]) - self.base_lrs = [ - group["base_lr"] for group in optimizer.param_groups - ] + self.base_lrs = [group["base_lr"] for group in optimizer.param_groups] self.epoch = 0 self.batch = 0 @@ -680,13 +667,15 @@ class Eden(LRScheduler): def get_lr(self): factor = ( - (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 + (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 ) ** -0.25 * ( - ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) - ** -0.25 + ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 + ) + warmup_factor = ( + 1.0 + if self.batch >= self.warmup_batches + else 0.5 + 0.5 * (self.batch / self.warmup_batches) ) - warmup_factor = (1.0 if self.batch >= self.warmup_batches - else 0.5 + 0.5 * (self.batch / self.warmup_batches)) return [x * factor * warmup_factor for x in self.base_lrs] @@ -745,13 +734,14 @@ class Eve(Optimizer): parameters, if they fall below this we will stop applying weight decay. - .. _Adam\: A Method for Stochastic Optimization: + .. _Adam: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 .. _Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101 .. _On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ """ + def __init__( self, params, @@ -766,17 +756,11 @@ class Eve(Optimizer): if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: - raise ValueError( - "Invalid beta parameter at index 0: {}".format(betas[0]) - ) + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: - raise ValueError( - "Invalid beta parameter at index 1: {}".format(betas[1]) - ) + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) if not 0 <= weight_decay <= 0.1: - raise ValueError( - "Invalid weight_decay value: {}".format(weight_decay) - ) + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict( @@ -812,9 +796,7 @@ class Eve(Optimizer): # Perform optimization step grad = p.grad if grad.is_sparse: - raise RuntimeError( - "AdamW does not support sparse gradients" - ) + raise RuntimeError("AdamW does not support sparse gradients") state = self.state[p] @@ -841,7 +823,7 @@ class Eve(Optimizer): # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( + denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( group["eps"] ) @@ -852,30 +834,31 @@ class Eve(Optimizer): if p.numel() > 1: # avoid applying this weight-decay on "scaling factors" # (which are scalar). - is_above_target_rms = p.norm() > ( - target_rms * (p.numel() ** 0.5) - ) + is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) if random.random() < 0.0005: - step = (exp_avg/denom) * step_size - logging.info(f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}") - + step = (exp_avg / denom) * step_size + logging.info( + f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}" + ) return loss def _test_scaled_adam(hidden_dim: int): import timeit + from scaling import ScaledLinear + E = 100 B = 4 T = 2 logging.info("in test_eve_cain") - #device = torch.device('cuda') - device = torch.device('cpu') + # device = torch.device('cuda') + device = torch.device("cpu") dtype = torch.float32 fix_random_seed(42) @@ -889,79 +872,92 @@ def _test_scaled_adam(hidden_dim: int): fix_random_seed(42) Linear = torch.nn.Linear if iter == 0 else ScaledLinear - m = torch.nn.Sequential(Linear(E, hidden_dim), - torch.nn.PReLU(), - Linear(hidden_dim, hidden_dim), - torch.nn.PReLU(), - Linear(hidden_dim, E), - ).to(device) + m = torch.nn.Sequential( + Linear(E, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, E), + ).to(device) - train_pairs = [ (100.0 * torch.randn(B, T, E, device=device, dtype=dtype) * input_magnitudes, - torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes) for _ in range(20) ] + train_pairs = [ + ( + 100.0 + * torch.randn(B, T, E, device=device, dtype=dtype) + * input_magnitudes, + torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes, + ) + for _ in range(20) + ] - if iter == 0: optim = Eve(m.parameters(), lr=0.003) - elif iter == 1: optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) + if iter == 0: + optim = Eve(m.parameters(), lr=0.003) + elif iter == 1: + optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) - start = timeit.default_timer() avg_loss = 0.0 for epoch in range(180): scheduler.step_epoch() - #if epoch == 100 and iter in [2,3]: + # if epoch == 100 and iter in [2,3]: # optim.reset_speedup() # check it doesn't crash. - #if epoch == 130: + # if epoch == 130: # opts = diagnostics.TensorDiagnosticOptions( # 2 ** 22 # ) # allow 4 megabytes per sub-module # diagnostic = diagnostics.attach_diagnostics(m, opts) - - for n, (x,y) in enumerate(train_pairs): + for n, (x, y) in enumerate(train_pairs): y_out = m(x) - loss = ((y_out - y)**2).mean() * 100.0 + loss = ((y_out - y) ** 2).mean() * 100.0 if epoch == 0 and n == 0: avg_loss = loss.item() else: avg_loss = 0.98 * avg_loss + 0.02 * loss.item() if n == 0 and epoch % 5 == 0: - #norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() - #norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() - #norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() - #norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() - #scale1 = '%.2e' % (m[0].weight_scale.exp().item()) - #scale1b = '%.2e' % (m[0].bias_scale.exp().item()) - #scale2 = '%.2e' % (m[2].weight_scale.exp().item()) - #scale2b = '%.2e' % (m[2].bias_scale.exp().item()) + # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() + # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() + # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() + # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() + # scale1 = '%.2e' % (m[0].weight_scale.exp().item()) + # scale1b = '%.2e' % (m[0].bias_scale.exp().item()) + # scale2 = '%.2e' % (m[2].weight_scale.exp().item()) + # scale2b = '%.2e' % (m[2].bias_scale.exp().item()) lr = scheduler.get_last_lr()[0] - logging.info(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}") #, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} + logging.info( + f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}" + ) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} loss.log().backward() optim.step() optim.zero_grad() scheduler.step_batch() - #diagnostic.print_diagnostics() + # diagnostic.print_diagnostics() stop = timeit.default_timer() logging.info(f"Iter={iter}, Time taken: {stop - start}") logging.info(f"last lr = {scheduler.get_last_lr()}") - #logging.info("state dict = ", scheduler.state_dict()) - #logging.info("optim state_dict = ", optim.state_dict()) + # logging.info("state dict = ", scheduler.state_dict()) + # logging.info("optim state_dict = ", optim.state_dict()) logging.info(f"input_magnitudes = {input_magnitudes}") logging.info(f"output_magnitudes = {output_magnitudes}") - if __name__ == "__main__": torch.set_num_threads(1) torch.set_num_interop_threads(1) logging.getLogger().setLevel(logging.INFO) import subprocess - s = subprocess.check_output("git status -uno .; git log -1; git diff HEAD .", shell=True) + + s = subprocess.check_output( + "git status -uno .; git log -1; git diff HEAD .", shell=True + ) logging.info(s) import sys + if len(sys.argv) > 1: hidden_dim = int(sys.argv[1]) else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py index 7fe1e681a..758e0c036 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py @@ -177,8 +177,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -210,8 +209,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -275,15 +273,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -355,9 +349,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 50cedba56..6f63e0629 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -16,12 +16,12 @@ import collections +import logging +import random +from functools import reduce from itertools import repeat from typing import Optional, Tuple, Union -from functools import reduce -import logging -import random import torch import torch.nn as nn import torch.nn.functional as F @@ -32,27 +32,24 @@ from torch.nn import Embedding as ScaledEmbedding class ActivationBalancerFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - scale_factor: Tensor, - sign_factor: Optional[Tensor], - channel_dim: int, + ctx, + x: Tensor, + scale_factor: Tensor, + sign_factor: Optional[Tensor], + channel_dim: int, ) -> Tensor: if channel_dim < 0: channel_dim += x.ndim ctx.channel_dim = channel_dim - xgt0 = (x > 0) + xgt0 = x > 0 if sign_factor is None: ctx.save_for_backward(xgt0, scale_factor) else: ctx.save_for_backward(xgt0, scale_factor, sign_factor) return x - @staticmethod - def backward( - ctx, x_grad: Tensor - ) -> Tuple[Tensor, None, None, None]: + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: if len(ctx.saved_tensors) == 3: xgt0, scale_factor, sign_factor = ctx.saved_tensors for _ in range(ctx.channel_dim, x_grad.ndim - 1): @@ -65,14 +62,22 @@ class ActivationBalancerFunction(torch.autograd.Function): scale_factor = scale_factor.unsqueeze(-1) factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5) neg_delta_grad = x_grad.abs() * factor - return x_grad - neg_delta_grad, None, None, None, + return ( + x_grad - neg_delta_grad, + None, + None, + None, + ) -def _compute_scale_factor(x: Tensor, - channel_dim: int, - min_abs: float, - max_abs: float, - gain_factor: float, - max_factor: float) -> Tensor: + +def _compute_scale_factor( + x: Tensor, + channel_dim: int, + min_abs: float, + max_abs: float, + gain_factor: float, + max_factor: float, +) -> Tensor: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] @@ -83,71 +88,76 @@ def _compute_scale_factor(x: Tensor, else: # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if # x_abs)_mean , min_abs. - below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(min=0, max=max_factor) + below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp( + min=0, max=max_factor + ) - above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(min=0, max=max_factor) + above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp( + min=0, max=max_factor + ) return below_threshold - above_threshold -def _compute_sign_factor(x: Tensor, - channel_dim: int, - min_positive: float, - max_positive: float, - gain_factor: float, - max_factor: float) -> Tensor: + +def _compute_sign_factor( + x: Tensor, + channel_dim: int, + min_positive: float, + max_positive: float, + gain_factor: float, + max_factor: float, +) -> Tensor: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] - proportion_positive = torch.mean((x > 0).to(torch.float32), - dim=sum_dims) + proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims) if min_positive == 0.0: factor1 = 0.0 else: # 0 if proportion_positive >= min_positive, else can be # as large as max_factor. - factor1 = ((min_positive - proportion_positive) * - (gain_factor / min_positive)).clamp_(min=0, max=max_factor) + factor1 = ( + (min_positive - proportion_positive) * (gain_factor / min_positive) + ).clamp_(min=0, max=max_factor) if max_positive == 1.0: factor2 = 0.0 else: # 0 if self.proportion_positive <= max_positive, else can be # as large as -max_factor. - factor2 = ((proportion_positive - max_positive) * - (gain_factor / (1.0 - max_positive))).clamp_(min=0, max=max_factor) + factor2 = ( + (proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive)) + ).clamp_(min=0, max=max_factor) sign_factor = factor1 - factor2 # require min_positive != 0 or max_positive != 1: assert not isinstance(sign_factor, float) return sign_factor - class ActivationScaleBalancerFunction(torch.autograd.Function): """ This object is used in class ActivationBalancer when the user specified min_positive=0, max_positive=1, so there are no constraints on the signs of the activations and only the absolute value has a constraint. """ + @staticmethod def forward( - ctx, - x: Tensor, - sign_factor: Tensor, - scale_factor: Tensor, - channel_dim: int, + ctx, + x: Tensor, + sign_factor: Tensor, + scale_factor: Tensor, + channel_dim: int, ) -> Tensor: if channel_dim < 0: channel_dim += x.ndim ctx.channel_dim = channel_dim - xgt0 = (x > 0) + xgt0 = x > 0 ctx.save_for_backward(xgt0, sign_factor, scale_factor) return x - @staticmethod - def backward( - ctx, x_grad: Tensor - ) -> Tuple[Tensor, None, None, None]: + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: xgt0, sign_factor, scale_factor = ctx.saved_tensors for _ in range(ctx.channel_dim, x_grad.ndim - 1): sign_factor = sign_factor.unsqueeze(-1) @@ -155,18 +165,24 @@ class ActivationScaleBalancerFunction(torch.autograd.Function): factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) neg_delta_grad = x_grad.abs() * factor - return x_grad - neg_delta_grad, None, None, None, + return ( + x_grad - neg_delta_grad, + None, + None, + None, + ) class RandomClampFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - min: Optional[float], - max: Optional[float], - prob: float, - reflect: float) -> Tensor: + ctx, + x: Tensor, + min: Optional[float], + max: Optional[float], + prob: float, + reflect: float, + ) -> Tensor: x_clamped = torch.clamp(x, min=min, max=max) mask = torch.rand_like(x) < prob ans = torch.where(mask, x_clamped, x) @@ -179,30 +195,32 @@ class RandomClampFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None, None]: - is_same, = ctx.saved_tensors + (is_same,) = ctx.saved_tensors x_grad = ans_grad * is_same.to(ans_grad.dtype) reflect = ctx.reflect - if reflect != 0.0: + if reflect != 0.0: x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect) return x_grad, None, None, None, None -def random_clamp(x: Tensor, - min: Optional[float] = None, - max: Optional[float] = None, - prob: float = 0.5, - reflect: float = 0.0): + +def random_clamp( + x: Tensor, + min: Optional[float] = None, + max: Optional[float] = None, + prob: float = 0.5, + reflect: float = 0.0, +): return RandomClampFunction.apply(x, min, max, prob, reflect) -def random_cast_to_half(x: Tensor, - min_abs: float = 5.0e-06) -> Tensor: +def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor: """ A randomized way of casting a floating point value to half precision. """ if x.dtype == torch.float16: return x x_abs = x.abs() - is_too_small = (x_abs < min_abs) + is_too_small = x_abs < min_abs # for elements where is_too_small is true, random_val will contain +-min_abs with # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, # for those elements]. @@ -215,6 +233,7 @@ class RandomGradFunction(torch.autograd.Function): Does nothing in forward pass; in backward pass, gets rid of very small grads using randomized approach that preserves expectations (intended to reduce roundoff). """ + @staticmethod def forward(ctx, x: Tensor, min_abs: float) -> Tensor: ctx.min_abs = min_abs @@ -223,35 +242,37 @@ class RandomGradFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]: if ans_grad.dtype == torch.float16: - return random_cast_to_half(ans_grad.to(torch.float32), - min_abs=ctx.min_abs), None + return ( + random_cast_to_half(ans_grad.to(torch.float32), min_abs=ctx.min_abs), + None, + ) else: return ans_grad, None + class RandomGrad(torch.nn.Module): """ Gets rid of very small gradients using an expectation-preserving method, intended to increase accuracy of training when using amp (automatic mixed precision) """ - def __init__(self, - min_abs: float = 5.0e-06): + + def __init__(self, min_abs: float = 5.0e-06): super(RandomGrad, self).__init__() self.min_abs = min_abs - def forward(self, - x: Tensor): + def forward(self, x: Tensor): if torch.jit.is_scripting() or not self.training: return x else: return RandomGradFunction.apply(x, self.min_abs) - class SoftmaxFunction(torch.autograd.Function): """ Tries to handle half-precision derivatives in a randomized way that should be more accurate for training than the default behavior. """ + @staticmethod def forward(ctx, x: Tensor, dim: int): ans = x.softmax(dim=dim) @@ -267,7 +288,7 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): - ans, = ctx.saved_tensors + (ans,) = ctx.saved_tensors with torch.cuda.amp.autocast(enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) @@ -276,9 +297,7 @@ class SoftmaxFunction(torch.autograd.Function): return x_grad, None - -def softmax(x: Tensor, - dim: int): +def softmax(x: Tensor, dim: int): if torch.jit.is_scripting(): return x.softmax(dim) @@ -288,20 +307,18 @@ def softmax(x: Tensor, class MaxEigLimiterFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - coeffs: Tensor, - direction: Tensor, - channel_dim: int, - grad_scale: float) -> Tensor: + ctx, + x: Tensor, + coeffs: Tensor, + direction: Tensor, + channel_dim: int, + grad_scale: float, + ) -> Tensor: ctx.channel_dim = channel_dim ctx.grad_scale = grad_scale - ctx.save_for_backward(x.detach(), - coeffs.detach(), - direction.detach()) + ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach()) return x - @staticmethod def backward(ctx, x_grad, *args): with torch.enable_grad(): @@ -311,15 +328,20 @@ class MaxEigLimiterFunction(torch.autograd.Function): x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) new_direction.requires_grad = False x = x - x.mean(dim=0) - x_var = (x ** 2).mean() + x_var = (x**2).mean() x_residual = x - coeffs * new_direction - x_residual_var = (x_residual ** 2).mean() + x_residual_var = (x_residual**2).mean() # `variance_proportion` is the proportion of the variance accounted for # by the top eigen-direction. This is to be minimized. variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) variance_proportion.backward() x_orig_grad = x_orig.grad - x_extra_grad = x_orig.grad * ctx.grad_scale * x_grad.norm() / (x_orig_grad.norm() + 1.0e-20) + x_extra_grad = ( + x_orig.grad + * ctx.grad_scale + * x_grad.norm() + / (x_orig_grad.norm() + 1.0e-20) + ) return x_grad + x_extra_grad.detach(), None, None, None, None @@ -385,15 +407,12 @@ class BasicNorm(torch.nn.Module): # region if it happens to exit it. eps = eps.clamp(min=self.eps_min, max=self.eps_max) scales = ( - torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp() + torch.mean(x**2, dim=self.channel_dim, keepdim=True) + eps.exp() ) ** -0.5 return x * scales - -def ScaledLinear(*args, - initial_scale: float = 1.0, - **kwargs ) -> nn.Linear: +def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: """ Behaves like a constructor of a modified version of nn.Linear that gives an easy way to set the default initial parameter scale. @@ -412,16 +431,11 @@ def ScaledLinear(*args, with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) return ans - -def ScaledConv1d(*args, - initial_scale: float = 1.0, - **kwargs ) -> nn.Conv1d: +def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d: """ Behaves like a constructor of a modified version of nn.Conv1d that gives an easy way to set the default initial parameter scale. @@ -440,13 +454,10 @@ def ScaledConv1d(*args, with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) return ans - class ActivationBalancer(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to encourage, for @@ -486,18 +497,19 @@ class ActivationBalancer(torch.nn.Module): from doing it at the same time. Early in training we may use higher probabilities than this; it will decay to this value. """ + def __init__( - self, - num_channels: int, - channel_dim: int, - min_positive: float = 0.05, - max_positive: float = 0.95, - max_factor: float = 0.04, - sign_gain_factor: float = 0.01, - scale_gain_factor: float = 0.02, - min_abs: float = 0.2, - max_abs: float = 100.0, - min_prob: float = 0.1, + self, + num_channels: int, + channel_dim: int, + min_positive: float = 0.05, + max_positive: float = 0.95, + max_factor: float = 0.04, + sign_gain_factor: float = 0.01, + scale_gain_factor: float = 0.02, + min_abs: float = 0.2, + max_abs: float = 100.0, + min_prob: float = 0.1, ): super(ActivationBalancer, self).__init__() self.num_channels = num_channels @@ -515,9 +527,7 @@ class ActivationBalancer(torch.nn.Module): # We occasionally sync this to a tensor called `count`, that exists to # make sure it is synced to disk when we load and save the model. self.cpu_count = 0 - self.register_buffer('count', torch.tensor(0, dtype=torch.int64)) - - + self.register_buffer("count", torch.tensor(0, dtype=torch.int64)) def forward(self, x: Tensor) -> Tensor: if torch.jit.is_scripting() or not x.requires_grad: @@ -535,26 +545,35 @@ class ActivationBalancer(torch.nn.Module): # the prob of doing some work exponentially decreases from 0.5 till it hits # a floor at min_prob (==0.1, by default) - prob = max(self.min_prob, 0.5 ** (1 + (count/4000.0))) + prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0))) if random.random() < prob: sign_gain_factor = 0.5 if self.min_positive != 0.0 or self.max_positive != 1.0: - sign_factor = _compute_sign_factor(x, self.channel_dim, - self.min_positive, self.max_positive, - gain_factor=self.sign_gain_factor / prob, - max_factor=self.max_factor) + sign_factor = _compute_sign_factor( + x, + self.channel_dim, + self.min_positive, + self.max_positive, + gain_factor=self.sign_gain_factor / prob, + max_factor=self.max_factor, + ) else: sign_factor = None - - scale_factor = _compute_scale_factor(x, self.channel_dim, - min_abs=self.min_abs, - max_abs=self.max_abs, - gain_factor=self.scale_gain_factor / prob, - max_factor=self.max_factor) + scale_factor = _compute_scale_factor( + x, + self.channel_dim, + min_abs=self.min_abs, + max_abs=self.max_abs, + gain_factor=self.scale_gain_factor / prob, + max_factor=self.max_factor, + ) return ActivationBalancerFunction.apply( - x, scale_factor, sign_factor, self.channel_dim, + x, + scale_factor, + sign_factor, + self.channel_dim, ) else: return _no_op(x) @@ -594,13 +613,12 @@ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. else: (batch, dim, dim) = x.shape x = x.reshape(batch, dim * dim) - x = x[:, ::dim+1] + x = x[:, :: dim + 1] assert x.shape == (batch, dim) return x -def _whitening_metric(x: Tensor, - num_groups: int): +def _whitening_metric(x: Tensor, num_groups: int): """ Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of of the centered feature covariance are the same within each group's covariance matrix @@ -630,19 +648,17 @@ def _whitening_metric(x: Tensor, # the following expression is what we'd get if we took the matrix product # of each covariance and measured the mean of its trace, i.e. # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). - x_covarsq_mean_diag = (x_covar ** 2).sum() / (num_groups * channels_per_group) + x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group) # this metric will be >= 1.0; the larger it is, the less 'white' the data was. - metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20) + metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20) return metric class WhiteningPenaltyFunction(torch.autograd.Function): @staticmethod - def forward(ctx, - x: Tensor, - num_groups: int, - whitening_limit: float, - grad_scale: float) -> Tensor: + def forward( + ctx, x: Tensor, num_groups: int, whitening_limit: float, grad_scale: float + ) -> Tensor: ctx.save_for_backward(x) ctx.num_groups = num_groups ctx.whitening_limit = whitening_limit @@ -650,9 +666,8 @@ class WhiteningPenaltyFunction(torch.autograd.Function): return x @staticmethod - def backward(ctx, - x_grad: Tensor): - x_orig, = ctx.saved_tensors + def backward(ctx, x_grad: Tensor): + (x_orig,) = ctx.saved_tensors with torch.enable_grad(): with torch.cuda.amp.autocast(enabled=False): x_detached = x_orig.to(torch.float32).detach() @@ -661,25 +676,28 @@ class WhiteningPenaltyFunction(torch.autograd.Function): metric = _whitening_metric(x_detached, ctx.num_groups) if random.random() < 0.005 or __name__ == "__main__": - logging.info(f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, " - f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}") + logging.info( + f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, " + f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}" + ) (metric - ctx.whitening_limit).relu().backward() penalty_grad = x_detached.grad - scale = ctx.grad_scale * (x_grad.to(torch.float32).norm() / - (penalty_grad.norm() + 1.0e-20)) + scale = ctx.grad_scale * ( + x_grad.to(torch.float32).norm() / (penalty_grad.norm() + 1.0e-20) + ) penalty_grad = penalty_grad * scale return x_grad + penalty_grad.to(x_grad.dtype), None, None, None - class Whiten(nn.Module): def __init__( - self, - num_groups: int, - whitening_limit: float, - prob: Union[float, Tuple[float,float]], - grad_scale: float): + self, + num_groups: int, + whitening_limit: float, + prob: Union[float, Tuple[float, float]], + grad_scale: float, + ): """ Args: num_groups: the number of groups to divide the channel dim into before @@ -714,8 +732,7 @@ class Whiten(nn.Module): self.grad_scale = grad_scale - def forward(self, - x: Tensor) -> Tensor: + def forward(self, x: Tensor) -> Tensor: """ In the forward pass, this function just returns the input unmodified. In the backward pass, it will modify the gradients to ensure that the @@ -735,19 +752,21 @@ class Whiten(nn.Module): if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0: return _no_op(x) else: - if hasattr(self, 'min_prob') and random.random() < 0.25: + if hasattr(self, "min_prob") and random.random() < 0.25: # occasionally switch between min_prob and max_prob, based on whether # we are above or below the threshold. - if _whitening_metric(x.to(torch.float32), self.num_groups) > self.whitening_limit: + if ( + _whitening_metric(x.to(torch.float32), self.num_groups) + > self.whitening_limit + ): # there would be a change to the grad. self.prob = self.max_prob else: self.prob = self.min_prob - return WhiteningPenaltyFunction.apply(x, - self.num_groups, - self.whitening_limit, - self.grad_scale) + return WhiteningPenaltyFunction.apply( + x, self.num_groups, self.whitening_limit, self.grad_scale + ) class WithLoss(torch.autograd.Function): @@ -755,11 +774,14 @@ class WithLoss(torch.autograd.Function): def forward(ctx, x: Tensor, y: Tensor): ctx.y_shape = y.shape return x + @staticmethod def backward(ctx, ans_grad: Tensor): - return ans_grad, torch.ones(ctx.y_shape, - dtype=ans_grad.dtype, - device=ans_grad.device) + return ans_grad, torch.ones( + ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device + ) + + def with_loss(x, y): if torch.jit.is_scripting(): return x @@ -768,7 +790,7 @@ def with_loss(x, y): def _no_op(x: Tensor) -> Tensor: - if (torch.jit.is_scripting()): + if torch.jit.is_scripting(): return x else: # a no-op function that will have a node in the autograd graph, @@ -783,6 +805,7 @@ class Identity(torch.nn.Module): def forward(self, x): return _no_op(x) + class MaxEig(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to discourage @@ -803,13 +826,14 @@ class MaxEig(torch.nn.Module): scale: determines the scale with which we modify the gradients, relative to the existing / unmodified gradients """ + def __init__( - self, - num_channels: int, - channel_dim: int, - max_var_per_eig: float = 0.2, - min_prob: float = 0.01, - scale: float = 0.01, + self, + num_channels: int, + channel_dim: int, + max_var_per_eig: float = 0.2, + min_prob: float = 0.01, + scale: float = 0.01, ): super(MaxEig, self).__init__() self.num_channels = num_channels @@ -825,7 +849,7 @@ class MaxEig(torch.nn.Module): # random parameters unchanged for comparison direction = torch.arange(num_channels).to(torch.float) direction = direction / direction.norm() - self.register_buffer('max_eig_direction', direction) + self.register_buffer("max_eig_direction", direction) self.min_prob = min_prob # cur_prob is the current probability we'll use to apply the ActivationBalancer. @@ -833,12 +857,12 @@ class MaxEig(torch.nn.Module): # active. self.cur_prob = 1.0 - - def forward(self, x: Tensor) -> Tensor: - if (torch.jit.is_scripting() or - self.max_var_per_eig <= 0 or - random.random() > self.cur_prob): + if ( + torch.jit.is_scripting() + or self.max_var_per_eig <= 0 + or random.random() > self.cur_prob + ): return _no_op(x) with torch.cuda.amp.autocast(enabled=False): @@ -848,7 +872,9 @@ class MaxEig(torch.nn.Module): with torch.no_grad(): x = x.transpose(self.channel_dim, -1).reshape(-1, self.num_channels) x = x - x.mean(dim=0) - new_direction, coeffs = self._find_direction_coeffs(x, self.max_eig_direction) + new_direction, coeffs = self._find_direction_coeffs( + x, self.max_eig_direction + ) x_var = (x**2).mean() x_residual = x - coeffs * new_direction x_residual_var = (x_residual**2).mean() @@ -861,7 +887,9 @@ class MaxEig(torch.nn.Module): self._set_direction(0.1 * self.max_eig_direction + new_direction) if random.random() < 0.01 or __name__ == "__main__": - logging.info(f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}") + logging.info( + f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}" + ) if variance_proportion >= self.max_var_per_eig: # The constraint is active. Note, we should quite rarely @@ -869,17 +897,16 @@ class MaxEig(torch.nn.Module): # starting to diverge, should this constraint be active. cur_prob = self.cur_prob self.cur_prob = 1.0 # next time, do the update with probability 1.0. - return MaxEigLimiterFunction.apply(orig_x, coeffs, new_direction, - self.channel_dim, self.scale) + return MaxEigLimiterFunction.apply( + orig_x, coeffs, new_direction, self.channel_dim, self.scale + ) else: # let self.cur_prob exponentially approach self.min_prob, as # long as the constraint is inactive. self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob return orig_x - - def _set_direction(self, - direction: Tensor): + def _set_direction(self, direction: Tensor): """ Sets self.max_eig_direction to a normalized version of `direction` """ @@ -889,40 +916,39 @@ class MaxEig(torch.nn.Module): if direction_sum - direction_sum == 0: # no inf/nan self.max_eig_direction[:] = direction else: - logging.info(f"Warning: sum of direction in MaxEig is {direction_sum}, " - "num_channels={self.num_channels}, channel_dim={self.channel_dim}") + logging.info( + f"Warning: sum of direction in MaxEig is {direction_sum}, " + "num_channels={self.num_channels}, channel_dim={self.channel_dim}" + ) - - def _find_direction_coeffs(self, - x: Tensor, - prev_direction: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + def _find_direction_coeffs( + self, x: Tensor, prev_direction: Tensor + ) -> Tuple[Tensor, Tensor, Tensor]: """ - Figure out (an approximation to) the proportion of the variance of a set of - feature vectors that can be attributed to the top eigen-direction. - Args: - x: a Tensor of shape (num_frames, num_channels), with num_frames > 1. - prev_direction: a Tensor of shape (num_channels,), that is our previous estimate - of the top eigen-direction, or a random direction if this is the first - iteration. Does not have to be normalized, but should be nonzero. + Figure out (an approximation to) the proportion of the variance of a set of + feature vectors that can be attributed to the top eigen-direction. + Args: + x: a Tensor of shape (num_frames, num_channels), with num_frames > 1. + prev_direction: a Tensor of shape (num_channels,), that is our previous estimate + of the top eigen-direction, or a random direction if this is the first + iteration. Does not have to be normalized, but should be nonzero. - Returns: (cur_direction, coeffs), where: - cur_direction: a Tensor of shape (num_channels,) that is the current - estimate of the top eigen-direction. - coeffs: a Tensor of shape (num_frames, 1) that minimizes, or - approximately minimizes, (x - coeffs * cur_direction).norm() - """ + Returns: (cur_direction, coeffs), where: + cur_direction: a Tensor of shape (num_channels,) that is the current + estimate of the top eigen-direction. + coeffs: a Tensor of shape (num_frames, 1) that minimizes, or + approximately minimizes, (x - coeffs * cur_direction).norm() + """ (num_frames, num_channels) = x.shape assert num_channels > 1 and num_frames > 1 assert prev_direction.shape == (num_channels,) # `coeffs` are the coefficients of `prev_direction` in x. # actually represent the coeffs up to a constant positive factor. coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10 - cur_direction = (x * coeffs).sum(dim=0) / ((coeffs ** 2).sum() + 1.0e-20) + cur_direction = (x * coeffs).sum(dim=0) / ((coeffs**2).sum() + 1.0e-20) return cur_direction, coeffs - - class DoubleSwishFunction(torch.autograd.Function): """ double_swish(x) = x * torch.sigmoid(x-1) @@ -950,7 +976,7 @@ class DoubleSwishFunction(torch.autograd.Function): y = x * s if requires_grad: - deriv = (y * (1 - s) + s) + deriv = y * (1 - s) + s # notes on derivative of x * sigmoid(x - 1): # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund @@ -959,7 +985,9 @@ class DoubleSwishFunction(torch.autograd.Function): # floors), should be expectation-preserving. floor = -0.043637 ceil = 1.2 - d_scaled = ((deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv)) + d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + deriv + ) if __name__ == "__main__": # for self-testing only. assert d_scaled.min() >= 0.0 @@ -972,12 +1000,12 @@ class DoubleSwishFunction(torch.autograd.Function): @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: - d, = ctx.saved_tensors + (d,) = ctx.saved_tensors # the same constants as used in forward pass. floor = -0.043637 ceil = 1.2 - d = (d * ((ceil - floor) / 255.0) + floor) - return (y_grad * d) + d = d * ((ceil - floor) / 255.0) + floor + return y_grad * d class DoubleSwish(torch.nn.Module): @@ -990,7 +1018,6 @@ class DoubleSwish(torch.nn.Module): return DoubleSwishFunction.apply(x) - def _test_max_eig(): for proportion in [0.1, 0.5, 10.0]: logging.info(f"proportion = {proportion}") @@ -1002,11 +1029,9 @@ def _test_max_eig(): x.requires_grad = True num_channels = 128 - m = MaxEig(num_channels, - 1, # channel_dim - 0.5, # max_var_per_eig - scale=0.1) # grad_scale - + m = MaxEig( + num_channels, 1, 0.5, scale=0.1 # channel_dim # max_var_per_eig + ) # grad_scale for _ in range(4): y = m(x) @@ -1031,11 +1056,9 @@ def _test_whiten(): x.requires_grad = True num_channels = 128 - m = Whiten(1, # num_groups - 5.0, # whitening_limit, - prob=1.0, - grad_scale=0.1) # grad_scale - + m = Whiten( + 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit, + ) # grad_scale for _ in range(4): y = m(x) @@ -1049,7 +1072,6 @@ def _test_whiten(): assert not torch.allclose(x.grad, y_grad) - def _test_activation_balancer_sign(): probs = torch.arange(0, 1, 0.01) N = 1000 @@ -1077,9 +1099,7 @@ def _test_activation_balancer_sign(): def _test_activation_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( - -1 - ) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True m = ActivationBalancer( @@ -1111,8 +1131,8 @@ def _test_basic_norm(): y = m(x) assert y.shape == x.shape - x_rms = (x ** 2).mean().sqrt() - y_rms = (y ** 2).mean().sqrt() + x_rms = (x**2).mean().sqrt() + y_rms = (y**2).mean().sqrt() print("x rms = ", x_rms) print("y rms = ", y_rms) assert y_rms < x_rms @@ -1124,30 +1144,27 @@ def _test_double_swish_deriv(): x.requires_grad = True m = DoubleSwish() - tol = ((1.2-(-0.043637))/255.0) + tol = (1.2 - (-0.043637)) / 255.0 torch.autograd.gradcheck(m, x, atol=tol) - # for self-test. x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 x.requires_grad = True y = m(x) - def _test_softmax(): a = torch.randn(2, 10, dtype=torch.float64) b = a.clone() a.requires_grad = True b.requires_grad = True - a.softmax(dim=1)[:,0].sum().backward() + a.softmax(dim=1)[:, 0].sum().backward() print("a grad = ", a.grad) - softmax(b, dim=1)[:,0].sum().backward() + softmax(b, dim=1)[:, 0].sum().backward() print("b grad = ", b.grad) assert torch.allclose(a.grad, b.grad) - if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py index 8d357b15f..56165d1f9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py @@ -26,11 +26,7 @@ from typing import List import torch import torch.nn as nn -from scaling import ( - ActivationBalancer, - BasicNorm, - Whiten, -) +from scaling import ActivationBalancer, BasicNorm, Whiten class NonScaledNorm(nn.Module): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 3f27736b3..7160fc54a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -84,9 +84,7 @@ from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: @@ -269,8 +267,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -293,8 +290,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -646,11 +642,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -697,9 +689,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -870,9 +860,7 @@ def train_one_epoch( # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or ( - cur_grad_scale < 8.0 and batch_idx % 400 == 0 - ): + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") @@ -890,11 +878,7 @@ def train_one_epoch( f"batch {batch_idx}, loss[{loss_info}], " f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"lr: {cur_lr:.2e}, " - + ( - f"grad_scale: {scaler._scale.item()}" - if params.use_fp16 - else "" - ) + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") ) if tb_writer is not None: @@ -905,9 +889,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if params.use_fp16: tb_writer.add_scalar( "train/grad_scale", @@ -915,10 +897,7 @@ def train_one_epoch( params.batch_idx_train, ) - if ( - batch_idx % params.valid_interval == 0 - and not params.print_diagnostics - ): + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, @@ -1009,9 +988,7 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = ScaledAdam( - model.parameters(), lr=params.base_lr, clipping_scale=2.0 - ) + optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) @@ -1029,7 +1006,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1054,8 +1031,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 023dec97d..b007a7308 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -16,32 +16,35 @@ # limitations under the License. import copy -import math -import warnings import itertools -from typing import List, Optional, Tuple, Union import logging -import torch +import math import random +import warnings +from typing import List, Optional, Tuple, Union + +import torch from encoder_interface import EncoderInterface +from scaling import ( + ScaledLinear, # not as in other dirs.. just scales down initial parameter values. +) from scaling import ( ActivationBalancer, BasicNorm, - MaxEig, DoubleSwish, - ScaledConv1d, - ScaledLinear, # not as in other dirs.. just scales down initial parameter values. - Whiten, Identity, + MaxEig, + ScaledConv1d, + Whiten, _diag, - random_clamp, penalize_abs_values_gt, + random_clamp, softmax, ) from torch import Tensor, nn -from icefall.utils import make_pad_mask from icefall.dist import get_rank +from icefall.utils import make_pad_mask class Zipformer(EncoderInterface): @@ -89,7 +92,7 @@ class Zipformer(EncoderInterface): self.batch_count = 0 self.warmup_end = warmup_batches - for u,d in zip(encoder_unmasked_dims, encoder_dims): + for u, d in zip(encoder_unmasked_dims, encoder_dims): assert u <= d, (u, d) # self.encoder_embed converts the input of shape (N, T, num_features) @@ -97,9 +100,9 @@ class Zipformer(EncoderInterface): # That is, it does two things simultaneously: # (1) subsampling: T -> (T - 7)//2 # (2) embedding: num_features -> encoder_dims - self.encoder_embed = Conv2dSubsampling(num_features, encoder_dims[0], - dropout=dropout) - + self.encoder_embed = Conv2dSubsampling( + num_features, encoder_dims[0], dropout=dropout + ) # each one will be ZipformerEncoder or DownsampledZipformerEncoder encoders = [] @@ -123,13 +126,13 @@ class Zipformer(EncoderInterface): num_encoder_layers[i], dropout, warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), - warmup_end=warmup_batches * (i + 2) / (num_encoders + 1) + warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), ) if zipformer_downsampling_factors[i] != 1: encoder = DownsampledZipformerEncoder( encoder, - input_dim=encoder_dims[i-1] if i > 0 else encoder_dims[0], + input_dim=encoder_dims[i - 1] if i > 0 else encoder_dims[0], output_dim=encoder_dims[i], downsample=zipformer_downsampling_factors[i], ) @@ -139,10 +142,9 @@ class Zipformer(EncoderInterface): # initializes self.skip_layers and self.skip_modules self._init_skip_modules() - self.downsample_output = AttentionDownsample(encoder_dims[-1], - encoder_dims[-1], - downsample=output_downsampling_factor) - + self.downsample_output = AttentionDownsample( + encoder_dims[-1], encoder_dims[-1], downsample=output_downsampling_factor + ) def _get_layer_skip_dropout_prob(self): if not self.training: @@ -166,27 +168,31 @@ class Zipformer(EncoderInterface): skip_modules = [] z = self.zipformer_downsampling_factors for i in range(len(z)): - if i <= 1 or z[i-1] <= z[i]: + if i <= 1 or z[i - 1] <= z[i]: skip_layers.append(None) skip_modules.append(SimpleCombinerIdentity()) else: # TEMP - for j in range(i-2, -1, -1): + for j in range(i - 2, -1, -1): if z[j] <= z[i] or j == 0: # TEMP logging statement. - logging.info(f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will " - f"combine the outputs of layers {j} and {i-1}, with downsampling_factors={z[j]} and {z[i-1]}.") + logging.info( + f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will " + f"combine the outputs of layers {j} and {i-1}, with downsampling_factors={z[j]} and {z[i-1]}." + ) skip_layers.append(j) - skip_modules.append(SimpleCombiner(self.encoder_dims[j], - self.encoder_dims[i-1], - min_weight=(0.0,0.25))) + skip_modules.append( + SimpleCombiner( + self.encoder_dims[j], + self.encoder_dims[i - 1], + min_weight=(0.0, 0.25), + ) + ) break self.skip_layers = skip_layers self.skip_modules = nn.ModuleList(skip_modules) - def get_feature_masks( - self, - x: torch.Tensor) -> List[float]: + def get_feature_masks(self, x: torch.Tensor) -> List[float]: # Note: The actual return type is Union[List[float], List[Tensor]], # but to make torch.jit.script() work, we use List[float] """ @@ -206,46 +212,56 @@ class Zipformer(EncoderInterface): """ num_encoders = len(self.encoder_dims) if torch.jit.is_scripting() or not self.training: - return [ 1.0 ] * num_encoders + return [1.0] * num_encoders (num_frames0, batch_size, _encoder_dims0) = x.shape - - assert self.encoder_dims[0] == _encoder_dims0, (self.encoder_dims, _encoder_dims0) + assert self.encoder_dims[0] == _encoder_dims0, ( + self.encoder_dims, + _encoder_dims0, + ) max_downsampling_factor = max(self.zipformer_downsampling_factors) - num_frames_max = (num_frames0 + max_downsampling_factor - 1) - + num_frames_max = num_frames0 + max_downsampling_factor - 1 feature_mask_dropout_prob = 0.15 # frame_mask_max shape: (num_frames_max, batch_size, 1) - frame_mask_max = (torch.rand(num_frames_max, batch_size, 1, - device=x.device) > - feature_mask_dropout_prob).to(x.dtype) + frame_mask_max = ( + torch.rand(num_frames_max, batch_size, 1, device=x.device) + > feature_mask_dropout_prob + ).to(x.dtype) feature_masks = [] for i in range(num_encoders): ds = self.zipformer_downsampling_factors[i] - upsample_factor = (max_downsampling_factor // ds) + upsample_factor = max_downsampling_factor // ds - frame_mask = (frame_mask_max.unsqueeze(1).expand(num_frames_max, upsample_factor, - batch_size, 1) - .reshape(num_frames_max * upsample_factor, batch_size, 1)) + frame_mask = ( + frame_mask_max.unsqueeze(1) + .expand(num_frames_max, upsample_factor, batch_size, 1) + .reshape(num_frames_max * upsample_factor, batch_size, 1) + ) num_frames = (num_frames0 + ds - 1) // ds frame_mask = frame_mask[:num_frames] - feature_mask = torch.ones(num_frames, batch_size, self.encoder_dims[i], - dtype=x.dtype, device=x.device) + feature_mask = torch.ones( + num_frames, + batch_size, + self.encoder_dims[i], + dtype=x.dtype, + device=x.device, + ) u = self.encoder_unmasked_dims[i] feature_mask[:, :, u:] *= frame_mask feature_masks.append(feature_mask) return feature_masks - def forward( - self, x: torch.Tensor, x_lens: torch.Tensor, + self, + x: torch.Tensor, + x_lens: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -271,7 +287,9 @@ class Zipformer(EncoderInterface): outputs = [] feature_masks = self.get_feature_masks(x) - for i, (module, skip_module) in enumerate(zip(self.encoders, self.skip_modules)): + for i, (module, skip_module) in enumerate( + zip(self.encoders, self.skip_modules) + ): ds = self.zipformer_downsampling_factors[i] k = self.skip_layers[i] if isinstance(k, int): @@ -280,9 +298,11 @@ class Zipformer(EncoderInterface): x = skip_module(outputs[k], x) elif (not self.training) or random.random() > layer_skip_dropout_prob: x = skip_module(outputs[k], x) - x = module(x, - feature_mask=feature_masks[i], - src_key_padding_mask=None if mask is None else mask[...,::ds]) + x = module( + x, + feature_mask=feature_masks[i], + src_key_padding_mask=None if mask is None else mask[..., ::ds], + ) outputs.append(x) x = self.downsample_output(x) @@ -312,15 +332,16 @@ class ZipformerEncoderLayer(nn.Module): >>> pos_emb = torch.rand(32, 19, 512) >>> out = encoder_layer(src, pos_emb) """ + def __init__( - self, - d_model: int, - attention_dim: int, - nhead: int, - feedforward_dim: int = 2048, - dropout: float = 0.1, - cnn_module_kernel: int = 31, - pos_dim: int = 4, + self, + d_model: int, + attention_dim: int, + nhead: int, + feedforward_dim: int = 2048, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + pos_dim: int = 4, ) -> None: super(ZipformerEncoderLayer, self).__init__() @@ -330,29 +351,24 @@ class ZipformerEncoderLayer(nn.Module): self.batch_count = 0 self.self_attn = RelPositionMultiheadAttention( - d_model, attention_dim, nhead, pos_dim, dropout=0.0, + d_model, + attention_dim, + nhead, + pos_dim, + dropout=0.0, ) self.pooling = PoolingModule(d_model) - self.feed_forward1 = FeedforwardModule(d_model, - feedforward_dim, - dropout) + self.feed_forward1 = FeedforwardModule(d_model, feedforward_dim, dropout) - self.feed_forward2 = FeedforwardModule(d_model, - feedforward_dim, - dropout) + self.feed_forward2 = FeedforwardModule(d_model, feedforward_dim, dropout) - self.feed_forward3 = FeedforwardModule(d_model, - feedforward_dim, - dropout) + self.feed_forward3 = FeedforwardModule(d_model, feedforward_dim, dropout) + self.conv_module1 = ConvolutionModule(d_model, cnn_module_kernel) - self.conv_module1 = ConvolutionModule(d_model, - cnn_module_kernel) - - self.conv_module2 = ConvolutionModule(d_model, - cnn_module_kernel) + self.conv_module2 = ConvolutionModule(d_model, cnn_module_kernel) self.norm_final = BasicNorm(d_model) @@ -360,14 +376,15 @@ class ZipformerEncoderLayer(nn.Module): # try to ensure the output is close to zero-mean (or at least, zero-median). self.balancer = ActivationBalancer( - d_model, channel_dim=-1, - min_positive=0.45, max_positive=0.55, + d_model, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, max_abs=6.0, ) - self.whiten = Whiten(num_groups=1, - whitening_limit=5.0, - prob=(0.025, 0.25), - grad_scale=0.01) + self.whiten = Whiten( + num_groups=1, whitening_limit=5.0, prob=(0.025, 0.25), grad_scale=0.01 + ) def get_bypass_scale(self): if torch.jit.is_scripting() or not self.training: @@ -382,8 +399,9 @@ class ZipformerEncoderLayer(nn.Module): if self.batch_count > warmup_period: clamp_min = final_clamp_min else: - clamp_min = (initial_clamp_min - - (self.batch_count / warmup_period) * (initial_clamp_min - final_clamp_min)) + clamp_min = initial_clamp_min - (self.batch_count / warmup_period) * ( + initial_clamp_min - final_clamp_min + ) return self.bypass_scale.clamp(min=clamp_min, max=1.0) def get_dynamic_dropout_rate(self): @@ -398,8 +416,9 @@ class ZipformerEncoderLayer(nn.Module): if self.batch_count > warmup_period: return final_dropout_rate else: - return (initial_dropout_rate - - (initial_dropout_rate * final_dropout_rate) * (self.batch_count / warmup_period)) + return initial_dropout_rate - ( + initial_dropout_rate * final_dropout_rate + ) * (self.batch_count / warmup_period) def forward( self, @@ -508,13 +527,14 @@ class ZipformerEncoder(nn.Module): >>> src = torch.rand(10, 32, 512) >>> out = zipformer_encoder(src) """ + def __init__( - self, - encoder_layer: nn.Module, - num_layers: int, - dropout: float, - warmup_begin: float, - warmup_end: float + self, + encoder_layer: nn.Module, + num_layers: int, + dropout: float, + warmup_begin: float, + warmup_end: float, ) -> None: super().__init__() # will be written to, see set_batch_count() Note: in inference time this @@ -528,8 +548,7 @@ class ZipformerEncoder(nn.Module): # so that we can keep this consistent across worker tasks (for efficiency). self.module_seed = torch.randint(0, 1000, ()).item() - self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model, - dropout) + self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model, dropout) self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] @@ -538,15 +557,13 @@ class ZipformerEncoder(nn.Module): assert 0 <= warmup_begin <= warmup_end, (warmup_begin, warmup_end) - - delta = (1. / num_layers) * (warmup_end - warmup_begin) + delta = (1.0 / num_layers) * (warmup_end - warmup_begin) cur_begin = warmup_begin for i in range(num_layers): self.layers[i].warmup_begin = cur_begin cur_begin += delta self.layers[i].warmup_end = cur_begin - def get_layers_to_drop(self, rnd_seed: int): ans = set() if not self.training: @@ -579,12 +596,14 @@ class ZipformerEncoder(nn.Module): # linearly interpolate t = (batch_count - layer_warmup_begin) / layer_warmup_end assert 0.0 <= t < 1.001, t - return initial_layerdrop_prob + t * (final_layerdrop_prob - initial_layerdrop_prob) + return initial_layerdrop_prob + t * ( + final_layerdrop_prob - initial_layerdrop_prob + ) shared_rng = random.Random(batch_count + self.module_seed) independent_rng = random.Random(rnd_seed) - layerdrop_probs = [ get_layerdrop_prob(i) for i in range(num_layers) ] + layerdrop_probs = [get_layerdrop_prob(i) for i in range(num_layers)] tot = sum(layerdrop_probs) # Instead of drawing the samples independently, we first randomly decide # how many layers to drop out, using the same random number generator between @@ -604,11 +623,12 @@ class ZipformerEncoder(nn.Module): if len(ans) == num_to_drop: break if shared_rng.random() < 0.005 or __name__ == "__main__": - logging.info(f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, " - f"batch_count={batch_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}") + logging.info( + f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, " + f"batch_count={batch_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}" + ) return ans - def forward( self, src: Tensor, @@ -639,7 +659,6 @@ class ZipformerEncoder(nn.Module): pos_emb = self.encoder_pos(src) output = src - if torch.jit.is_scripting(): layers_to_drop = [] else: @@ -670,28 +689,27 @@ class DownsampledZipformerEncoder(nn.Module): after convolutional downsampling, and then upsampled again at the output, and combined with the origin input, so that the output has the same shape as the input. """ - def __init__(self, - encoder: nn.Module, - input_dim: int, - output_dim: int, - downsample: int): + + def __init__( + self, encoder: nn.Module, input_dim: int, output_dim: int, downsample: int + ): super(DownsampledZipformerEncoder, self).__init__() self.downsample_factor = downsample self.downsample = AttentionDownsample(input_dim, output_dim, downsample) self.encoder = encoder self.upsample = SimpleUpsample(output_dim, downsample) - self.out_combiner = SimpleCombiner(input_dim, - output_dim, - min_weight=(0.0, 0.25)) + self.out_combiner = SimpleCombiner( + input_dim, output_dim, min_weight=(0.0, 0.25) + ) - - def forward(self, - src: Tensor, - # Note: the type of feature_mask should be Unino[float, Tensor], - # but to make torch.jit.script() happ, we use float here - feature_mask: float = 1.0, - mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, + def forward( + self, + src: Tensor, + # Note: the type of feature_mask should be Unino[float, Tensor], + # but to make torch.jit.script() happ, we use float here + feature_mask: float = 1.0, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: r"""Downsample, go through encoder, upsample. @@ -718,42 +736,43 @@ class DownsampledZipformerEncoder(nn.Module): src = self.downsample(src) ds = self.downsample_factor if mask is not None: - mask = mask[::ds,::ds] + mask = mask[::ds, ::ds] src = self.encoder( - src, feature_mask=feature_mask, mask=mask, src_key_padding_mask=mask, + src, + feature_mask=feature_mask, + mask=mask, + src_key_padding_mask=mask, ) src = self.upsample(src) # remove any extra frames that are not a multiple of downsample_factor - src = src[:src_orig.shape[0]] + src = src[: src_orig.shape[0]] return self.out_combiner(src_orig, src) + class AttentionDownsample(torch.nn.Module): """ Does downsampling with attention, by weighted sum, and a projection.. """ - def __init__(self, - in_channels: int, - out_channels: int, - downsample: int): + + def __init__(self, in_channels: int, out_channels: int, downsample: int): """ Require out_channels > in_channels. """ super(AttentionDownsample, self).__init__() - self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5)) + self.query = nn.Parameter(torch.randn(in_channels) * (in_channels**-0.5)) # fill in the extra dimensions with a projection of the input if out_channels > in_channels: - self.extra_proj = nn.Linear(in_channels * downsample, - out_channels - in_channels, - bias=False) + self.extra_proj = nn.Linear( + in_channels * downsample, out_channels - in_channels, bias=False + ) else: self.extra_proj = None self.downsample = downsample - def forward(self, - src: Tensor) -> Tensor: + def forward(self, src: Tensor) -> Tensor: """ x: (seq_len, batch_size, in_channels) Returns a tensor of shape @@ -767,16 +786,14 @@ class AttentionDownsample(torch.nn.Module): if seq_len != d_seq_len * ds: # right-pad src, repeating the last element. pad = d_seq_len * ds - seq_len - src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2]) + src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) src = torch.cat((src, src_extra), dim=0) assert src.shape[0] == d_seq_len * ds, (src.shape[0], d_seq_len, ds) src = src.reshape(d_seq_len, ds, batch_size, in_channels) scores = (src * self.query).sum(dim=-1, keepdim=True) - scores = penalize_abs_values_gt(scores, - limit=10.0, - penalty=1.0e-04) + scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) weights = scores.softmax(dim=1) @@ -795,14 +812,12 @@ class SimpleUpsample(torch.nn.Module): A very simple form of upsampling that mostly just repeats the input, but also adds a position-specific bias. """ - def __init__(self, - num_channels: int, - upsample: int): + + def __init__(self, num_channels: int, upsample: int): super(SimpleUpsample, self).__init__() self.bias = nn.Parameter(torch.randn(upsample, num_channels) * 0.01) - def forward(self, - src: Tensor) -> Tensor: + def forward(self, src: Tensor) -> Tensor: """ x: (seq_len, batch_size, num_channels) Returns a tensor of shape @@ -815,6 +830,7 @@ class SimpleUpsample(torch.nn.Module): src = src.reshape(seq_len * upsample, batch_size, num_channels) return src + class SimpleCombinerIdentity(nn.Module): def __init__(self, *args, **kwargs): super().__init__() @@ -822,6 +838,7 @@ class SimpleCombinerIdentity(nn.Module): def forward(self, src1: Tensor, src2: Tensor) -> Tensor: return src1 + class SimpleCombiner(torch.nn.Module): """ A very simple way of combining 2 vectors of 2 different dims, via a @@ -831,18 +848,14 @@ class SimpleCombiner(torch.nn.Module): dim2: the dimension of the second input, e.g. 384. The output will have the same dimension as dim2. """ - def __init__(self, - dim1: int, - dim2: int, - min_weight: Tuple[float] = (0., 0.)): + + def __init__(self, dim1: int, dim2: int, min_weight: Tuple[float] = (0.0, 0.0)): super(SimpleCombiner, self).__init__() assert dim2 >= dim1, (dim2, dim1) self.weight1 = nn.Parameter(torch.zeros(())) self.min_weight = min_weight - def forward(self, - src1: Tensor, - src2: Tensor) -> Tensor: + def forward(self, src1: Tensor, src2: Tensor) -> Tensor: """ src1: (*, dim1) src2: (*, dim2) @@ -853,10 +866,14 @@ class SimpleCombiner(torch.nn.Module): weight1 = self.weight1 if not torch.jit.is_scripting(): - if self.training and random.random() < 0.25 and self.min_weight != (0., 0.): - weight1 = weight1.clamp(min=self.min_weight[0], - max=1.0-self.min_weight[1]) - + if ( + self.training + and random.random() < 0.25 + and self.min_weight != (0.0, 0.0) + ): + weight1 = weight1.clamp( + min=self.min_weight[0], max=1.0 - self.min_weight[1] + ) src1 = src1 * weight1 src2 = src2 * (1.0 - weight1) @@ -869,12 +886,9 @@ class SimpleCombiner(torch.nn.Module): else: src1 = src1[:src2_dim] - return src1 + src2 - - class RelPositionalEncoding(torch.nn.Module): """Relative positional encoding module. @@ -888,9 +902,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct a PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -905,9 +917,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(0) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vecotr and `j` means the @@ -955,7 +965,6 @@ class RelPositionalEncoding(torch.nn.Module): return self.dropout(pos_emb) - class RelPositionMultiheadAttention(nn.Module): r"""Multi-Head Attention layer with relative position encoding @@ -992,34 +1001,43 @@ class RelPositionMultiheadAttention(nn.Module): self.head_dim = attention_dim // num_heads self.pos_dim = pos_dim assert self.head_dim % 2 == 0, self.head_dim - assert ( - self.head_dim * num_heads == attention_dim - ), (self.head_dim, num_heads, attention_dim) + assert self.head_dim * num_heads == attention_dim, ( + self.head_dim, + num_heads, + attention_dim, + ) # the initial_scale is supposed to take over the "scaling" factor of # head_dim ** -0.5, dividing it between the query and key. - in_proj_dim = (2 * attention_dim + # query, key - attention_dim // 2 + # value - pos_dim * num_heads) # positional encoding query + in_proj_dim = ( + 2 * attention_dim + + attention_dim // 2 # query, key + + pos_dim * num_heads # value + ) # positional encoding query - self.in_proj = ScaledLinear(embed_dim, in_proj_dim, bias=True, - initial_scale=self.head_dim**-0.25) + self.in_proj = ScaledLinear( + embed_dim, in_proj_dim, bias=True, initial_scale=self.head_dim**-0.25 + ) # self.whiten_values is applied on the values in forward(); # it just copies the keys but prevents low-rank distribution by modifying grads. - self.whiten_values = Whiten(num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025) - self.whiten_keys = Whiten(num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025) - + self.whiten_values = Whiten( + num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025, + ) + self.whiten_keys = Whiten( + num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025, + ) # linear transformation for positional encoding. - self.linear_pos = ScaledLinear(embed_dim, num_heads * pos_dim, bias=False, - initial_scale=0.05) + self.linear_pos = ScaledLinear( + embed_dim, num_heads * pos_dim, bias=False, initial_scale=0.05 + ) # the following are for diagnosics only, see --print-diagnostics option. # they only copy their inputs. @@ -1031,14 +1049,16 @@ class RelPositionMultiheadAttention(nn.Module): ) self.in_proj2 = nn.Linear(embed_dim, attention_dim // 2, bias=False) - self.out_proj2 = ScaledLinear(attention_dim // 2, embed_dim, bias=True, - initial_scale=0.05) + self.out_proj2 = ScaledLinear( + attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 + ) # self.whiten_values2 is applied on the values in forward2() - self.whiten_values2 = Whiten(num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025) - + self.whiten_values2 = Whiten( + num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025, + ) def forward( self, @@ -1098,7 +1118,6 @@ class RelPositionMultiheadAttention(nn.Module): ) return x, weights - def multi_head_attention_forward( self, x_proj: Tensor, @@ -1158,24 +1177,21 @@ class RelPositionMultiheadAttention(nn.Module): pos_dim = self.pos_dim # positional-encoding dim per head assert ( head_dim * num_heads == attention_dim - ), f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}, {attention_dim}" - + ), f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}, {attention_dim}" # self-attention - q = x_proj[...,0:attention_dim] - k = x_proj[...,attention_dim:2*attention_dim] + q = x_proj[..., 0:attention_dim] + k = x_proj[..., attention_dim : 2 * attention_dim] value_dim = attention_dim // 2 - v = x_proj[...,2*attention_dim:2*attention_dim+value_dim] + v = x_proj[..., 2 * attention_dim : 2 * attention_dim + value_dim] # p is the position-encoding query, its dimension is num_heads*pos_dim.. - p = x_proj[...,2*attention_dim+value_dim:] - + p = x_proj[..., 2 * attention_dim + value_dim :] k = self.whiten_keys(k) # does nothing in the forward pass. v = self.whiten_values(v) # does nothing in the forward pass. q = self.copy_query(q) # for diagnostics only, does nothing. p = self.copy_pos_query(p) # for diagnostics only, does nothing. - if attn_mask is not None: assert ( attn_mask.dtype == torch.float32 @@ -1195,31 +1211,22 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, seq_len, seq_len]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, seq_len, seq_len, ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -1230,7 +1237,6 @@ class RelPositionMultiheadAttention(nn.Module): k = k.reshape(seq_len, bsz, num_heads, head_dim) v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) - if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz, "{} == {}".format( key_padding_mask.size(0), bsz @@ -1239,13 +1245,10 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask.size(1), seq_len ) - - q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim) p = p.permute(1, 2, 0, 3) # (batch, head, time1, pos_dim) k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - seq_len2 = 2 * seq_len - 1 pos = pos.reshape(1, seq_len2, num_heads, pos_dim).permute(0, 2, 3, 1) # pos shape now: (batch, head, pos_dim, seq_len2) @@ -1256,13 +1259,16 @@ class RelPositionMultiheadAttention(nn.Module): # the following .as_strided() expression converts the last axis of pos_weights from relative # to absolute position. I don't know whether I might have got the time-offsets backwards or # not, but let this code define which way round it is supposed to be. - pos_weights = pos_weights.as_strided((bsz, num_heads, seq_len, seq_len), - (pos_weights.stride(0), - pos_weights.stride(1), - pos_weights.stride(2)-pos_weights.stride(3), - pos_weights.stride(3)), - storage_offset=pos_weights.stride(3) * (seq_len - 1)) - + pos_weights = pos_weights.as_strided( + (bsz, num_heads, seq_len, seq_len), + ( + pos_weights.stride(0), + pos_weights.stride(1), + pos_weights.stride(2) - pos_weights.stride(3), + pos_weights.stride(3), + ), + storage_offset=pos_weights.stride(3) * (seq_len - 1), + ) # caution: they are really scores at this point. attn_output_weights = torch.matmul(q, k) + pos_weights @@ -1275,10 +1281,9 @@ class RelPositionMultiheadAttention(nn.Module): # this mechanism instead of, say, a limit on entropy, because once the entropy # gets very small gradients through the softmax can become very small, and # some mechanisms like that become ineffective. - attn_output_weights = penalize_abs_values_gt(attn_output_weights, - limit=25.0, - penalty=1.0e-04) - + attn_output_weights = penalize_abs_values_gt( + attn_output_weights, limit=25.0, penalty=1.0e-04 + ) # attn_output_weights: (batch, head, time1, time2) attn_output_weights = attn_output_weights.view( @@ -1320,20 +1325,16 @@ class RelPositionMultiheadAttention(nn.Module): ) attn_output = torch.bmm(attn_output_weights, v) - assert list(attn_output.size()) == [bsz * num_heads, seq_len, - head_dim // 2] + assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim // 2] attn_output = ( attn_output.transpose(0, 1) .contiguous() .view(seq_len, bsz, attention_dim // 2) ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias - ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) return attn_output, attn_output_weights - def forward2( self, x: Tensor, @@ -1372,11 +1373,7 @@ class RelPositionMultiheadAttention(nn.Module): # returned value is of shape (seq_len, bsz, embed_dim), like x. return self.out_proj2(attn_output) - - def _print_attn_stats( - self, - attn_weights: Tensor, - attn_output: Tensor): + def _print_attn_stats(self, attn_weights: Tensor, attn_output: Tensor): # attn_weights: (batch_size * num_heads, seq_len, seq_len) # attn_output: (bsz * num_heads, seq_len, head_dim) (n, seq_len, head_dim) = attn_output.shape @@ -1387,39 +1384,48 @@ class RelPositionMultiheadAttention(nn.Module): with torch.cuda.amp.autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_output = attn_output.to(torch.float32) - attn_weights_entropy = -((attn_weights + 1.0e-20).log() * attn_weights).sum( - dim=-1).reshape(bsz, num_heads, seq_len).mean(dim=(0,2)) + attn_weights_entropy = ( + -((attn_weights + 1.0e-20).log() * attn_weights) + .sum(dim=-1) + .reshape(bsz, num_heads, seq_len) + .mean(dim=(0, 2)) + ) attn_output = attn_output.reshape(bsz, num_heads, seq_len, head_dim) - attn_output = attn_output.permute(1, 0, 2, 3).reshape(num_heads, bsz * seq_len, head_dim) + attn_output = attn_output.permute(1, 0, 2, 3).reshape( + num_heads, bsz * seq_len, head_dim + ) attn_output_mean = attn_output.mean(dim=1, keepdim=True) attn_output = attn_output - attn_output_mean - attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / (bsz * seq_len) + attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / ( + bsz * seq_len + ) # attn_covar: (num_heads, head_dim, head_dim) - #eigs, _ = torch.symeig(attn_covar) - #logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") + # eigs, _ = torch.symeig(attn_covar) + # logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") attn_covar = _diag(attn_covar).mean(dim=1) # (num_heads,) embed_dim = self.in_proj2.weight.shape[1] - in_proj_covar = (self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2).mean(dim=(1,2)) - out_proj_covar = (self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2).mean(dim=(0,2)) - logging.info(f"attn_weights_entropy = {attn_weights_entropy}, covar={attn_covar}, in_proj_covar={in_proj_covar}, out_proj_covar={out_proj_covar}") - - + in_proj_covar = ( + self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2 + ).mean(dim=(1, 2)) + out_proj_covar = ( + self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2 + ).mean(dim=(0, 2)) + logging.info( + f"attn_weights_entropy = {attn_weights_entropy}, covar={attn_covar}, in_proj_covar={in_proj_covar}, out_proj_covar={out_proj_covar}" + ) class PoolingModule(nn.Module): """ Averages the input over the time dimension and project with a square matrix. """ - def __init__(self, - d_model: int): - super().__init__() - self.proj = ScaledLinear(d_model, d_model, - initial_scale=0.1, bias=False) - def forward(self, - x: Tensor, - key_padding_mask: Optional[Tensor] = None): + def __init__(self, d_model: int): + super().__init__() + self.proj = ScaledLinear(d_model, d_model, initial_scale=0.1, bias=False) + + def forward(self, x: Tensor, key_padding_mask: Optional[Tensor] = None): """ Args: x: a Tensor of shape (T, N, C) @@ -1430,7 +1436,7 @@ class PoolingModule(nn.Module): """ if key_padding_mask is not None: pooling_mask = key_padding_mask.logical_not().to(x.dtype) # (N, T) - pooling_mask = (pooling_mask / pooling_mask.sum(dim=1, keepdim=True)) + pooling_mask = pooling_mask / pooling_mask.sum(dim=1, keepdim=True) pooling_mask = pooling_mask.transpose(0, 1).contiguous().unsqueeze(-1) # now pooling_mask: (T, N, 1) x = (x * pooling_mask).sum(dim=0, keepdim=True) @@ -1444,24 +1450,19 @@ class PoolingModule(nn.Module): class FeedforwardModule(nn.Module): - """Feedforward module in Zipformer model. - """ - def __init__(self, - d_model: int, - feedforward_dim: int, - dropout: float): + """Feedforward module in Zipformer model.""" + + def __init__(self, d_model: int, feedforward_dim: int, dropout: float): super(FeedforwardModule, self).__init__() self.in_proj = nn.Linear(d_model, feedforward_dim) - self.balancer = ActivationBalancer(feedforward_dim, - channel_dim=-1, max_abs=10.0, - min_prob=0.25) + self.balancer = ActivationBalancer( + feedforward_dim, channel_dim=-1, max_abs=10.0, min_prob=0.25 + ) self.activation = DoubleSwish() self.dropout = nn.Dropout(dropout) - self.out_proj = ScaledLinear(feedforward_dim, d_model, - initial_scale=0.01) + self.out_proj = ScaledLinear(feedforward_dim, d_model, initial_scale=0.01) - def forward(self, - x: Tensor): + def forward(self, x: Tensor): x = self.in_proj(x) x = self.balancer(x) x = self.activation(x) @@ -1481,9 +1482,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding @@ -1513,7 +1512,10 @@ class ConvolutionModule(nn.Module): # the correct range. self.deriv_balancer1 = ActivationBalancer( 2 * channels, - channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0 + channel_dim=1, + max_abs=10.0, + min_positive=0.05, + max_positive=1.0, ) self.depthwise_conv = nn.Conv1d( @@ -1527,8 +1529,10 @@ class ConvolutionModule(nn.Module): ) self.deriv_balancer2 = ActivationBalancer( - channels, channel_dim=1, - min_positive=0.05, max_positive=1.0, + channels, + channel_dim=1, + min_positive=0.05, + max_positive=1.0, max_abs=20.0, ) @@ -1544,9 +1548,10 @@ class ConvolutionModule(nn.Module): initial_scale=0.05, ) - def forward(self, - x: Tensor, - src_key_padding_mask: Optional[Tensor] = None, + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: """Compute convolution module. @@ -1626,8 +1631,7 @@ class Conv2dSubsampling(nn.Module): kernel_size=3, padding=(0, 1), # (time, freq) ), - ActivationBalancer(layer1_channels, - channel_dim=1), + ActivationBalancer(layer1_channels, channel_dim=1), DoubleSwish(), nn.Conv2d( in_channels=layer1_channels, @@ -1636,24 +1640,21 @@ class Conv2dSubsampling(nn.Module): stride=2, padding=0, ), - ActivationBalancer(layer2_channels, - channel_dim=1), + ActivationBalancer(layer2_channels, channel_dim=1), DoubleSwish(), nn.Conv2d( in_channels=layer2_channels, out_channels=layer3_channels, kernel_size=3, - stride=(1, 2), # (time, freq) + stride=(1, 2), # (time, freq) ), - ActivationBalancer(layer3_channels, - channel_dim=1), + ActivationBalancer(layer3_channels, channel_dim=1), DoubleSwish(), ) out_height = (((in_channels - 1) // 2) - 1) // 2 self.out = ScaledLinear(out_height * layer3_channels, out_channels) self.dropout = nn.Dropout(dropout) - def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. @@ -1674,6 +1675,7 @@ class Conv2dSubsampling(nn.Module): x = self.dropout(x) return x + class AttentionCombine(nn.Module): """ This module combines a list of Tensors, all with the same shape, to @@ -1717,15 +1719,12 @@ class AttentionCombine(nn.Module): self.random_prob = random_prob self.single_prob = single_prob - self.weight = torch.nn.Parameter(torch.zeros(num_channels, - num_inputs)) + self.weight = torch.nn.Parameter(torch.zeros(num_channels, num_inputs)) self.bias = torch.nn.Parameter(torch.zeros(num_inputs)) assert 0 <= random_prob <= 1, random_prob assert 0 <= single_prob <= 1, single_prob - - def forward(self, inputs: List[Tensor]) -> Tensor: """Forward function. Args: @@ -1756,28 +1755,35 @@ class AttentionCombine(nn.Module): if self.training: # random masking.. - mask_start = torch.randint(low=1, high=int(num_inputs / self.random_prob), - size=(num_frames,), device=scores.device).unsqueeze(1) + mask_start = torch.randint( + low=1, + high=int(num_inputs / self.random_prob), + size=(num_frames,), + device=scores.device, + ).unsqueeze(1) # mask will have rows like: [ False, False, False, True, True, .. ] - arange = torch.arange(num_inputs, device=scores.device).unsqueeze(0).expand( - num_frames, num_inputs) + arange = ( + torch.arange(num_inputs, device=scores.device) + .unsqueeze(0) + .expand(num_frames, num_inputs) + ) mask = arange >= mask_start - apply_single_prob = torch.logical_and(torch.rand(size=(num_frames, 1), - device=scores.device) < self.single_prob, - mask_start < num_inputs) - single_prob_mask = torch.logical_and(apply_single_prob, - arange < mask_start - 1) + apply_single_prob = torch.logical_and( + torch.rand(size=(num_frames, 1), device=scores.device) + < self.single_prob, + mask_start < num_inputs, + ) + single_prob_mask = torch.logical_and( + apply_single_prob, arange < mask_start - 1 + ) - mask = torch.logical_or(mask, - single_prob_mask) + mask = torch.logical_or(mask, single_prob_mask) - scores = scores.masked_fill(mask, float('-inf')) + scores = scores.masked_fill(mask, float("-inf")) if self.training and random.random() < 0.1: - scores = penalize_abs_values_gt(scores, - limit=10.0, - penalty=1.0e-04) + scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) weights = scores.softmax(dim=1) @@ -1792,7 +1798,6 @@ class AttentionCombine(nn.Module): return ans - def _test_random_combine(): print("_test_random_combine()") num_inputs = 3 @@ -1801,8 +1806,8 @@ def _test_random_combine(): num_channels=num_channels, num_inputs=num_inputs, random_prob=0.5, - single_prob=0.0) - + single_prob=0.0, + ) x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)] @@ -1819,7 +1824,10 @@ def _test_zipformer_main(): # Just make sure the forward pass runs. c = Zipformer( - num_features=feature_dim, encoder_dims=(64,96), encoder_unmasked_dims=(48,64), nhead=(4,4) + num_features=feature_dim, + encoder_dims=(64, 96), + encoder_unmasked_dims=(48, 64), + nhead=(4, 4), ) batch_size = 5 seq_len = 20 @@ -1837,19 +1845,18 @@ def _test_zipformer_main(): ) f # to remove flake8 warnings + def _test_conv2d_subsampling(): num_features = 80 encoder_dims = 384 dropout = 0.1 - encoder_embed = Conv2dSubsampling(num_features, encoder_dims, - dropout=dropout) + encoder_embed = Conv2dSubsampling(num_features, encoder_dims, dropout=dropout) for i in range(20, 40): x = torch.rand(2, i, num_features) y = encoder_embed(x) assert (x.shape[1] - 7) // 2 == y.shape[1], (x.shape[1], y.shape[1]) - if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py index 9d7335e77..3d89ae00a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py @@ -273,8 +273,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -394,9 +393,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -455,10 +452,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -589,9 +583,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -624,8 +616,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -680,9 +671,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -719,9 +708,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -753,9 +742,9 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -816,9 +805,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/export.py b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py index 49f469e29..0a962149d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py @@ -176,8 +176,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) add_model_arguments(parser) @@ -217,9 +216,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -252,9 +251,9 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -326,9 +325,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py index e79a3a3aa..c458ee5a9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py @@ -94,8 +94,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -267,9 +266,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/model.py b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py index 497b89136..39a360796 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py @@ -160,9 +160,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py index 373a48fc1..f1f0771ef 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py @@ -177,8 +177,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -210,8 +209,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -275,15 +273,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -355,9 +349,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py index 2603bb854..ba8ed3ea8 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py @@ -92,9 +92,7 @@ from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: @@ -214,8 +212,7 @@ def get_parser(): "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. " - "Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. " "Otherwise, use 100h subset.", ) parser.add_argument( @@ -285,8 +282,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -309,8 +305,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -691,11 +686,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -744,9 +735,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -952,9 +941,7 @@ def train_one_epoch( # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or ( - cur_grad_scale < 8.0 and batch_idx % 400 == 0 - ): + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") @@ -975,11 +962,7 @@ def train_one_epoch( f"giga_tot_loss[{giga_tot_loss}], " f"batch size: {batch_size}, " f"lr: {cur_lr:.2e}, " - + ( - f"grad_scale: {scaler._scale.item()}" - if params.use_fp16 - else "" - ) + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") ) if tb_writer is not None: @@ -992,12 +975,8 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) libri_tot_loss.write_summary( tb_writer, "train/libri_tot_", params.batch_idx_train ) @@ -1011,10 +990,7 @@ def train_one_epoch( params.batch_idx_train, ) - if ( - batch_idx % params.valid_interval == 0 - and not params.print_diagnostics - ): + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, @@ -1054,8 +1030,7 @@ def filter_short_and_long_utterances( # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}" ) return False @@ -1152,9 +1127,7 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = ScaledAdam( - model.parameters(), lr=params.base_lr, clipping_scale=2.0 - ) + optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) @@ -1172,7 +1145,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1207,9 +1180,7 @@ def run(rank, world_size, args): train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: - cuts_musan = load_manifest( - Path(args.manifest_dir) / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") else: cuts_musan = None diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/README.md b/egs/librispeech/ASR/streaming_conformer_ctc/README.md index 01be7090b..53f383c99 100644 --- a/egs/librispeech/ASR/streaming_conformer_ctc/README.md +++ b/egs/librispeech/ASR/streaming_conformer_ctc/README.md @@ -1,20 +1,20 @@ ## Train and Decode -Commands of data preparation/train/decode steps are almost the same with +Commands of data preparation/train/decode steps are almost the same with ../conformer_ctc experiment except some options. Please read the code and understand following new added options before running this experiment: For data preparation: - + Nothing new. For streaming_conformer_ctc/train.py: - + --dynamic-chunk-training --short-chunk-proportion For streaming_conformer_ctc/streaming_decode.py: - + --chunk-size --tailing-num-frames --simulate-streaming @@ -57,10 +57,10 @@ And check md5sum values again. Finally, following files will be downloaded:

-streaming_models/  
-|-- lang_bpe  
-|   |-- L.pt  
-|   |-- Linv.pt  
+streaming_models/
+|-- lang_bpe
+|   |-- L.pt
+|   |-- Linv.pt
 |   |-- bpe.model
 |   |-- tokens.txt
 |   `-- words.txt
diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py
index ff4c91446..5fe92172e 100644
--- a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py
+++ b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py
@@ -309,36 +309,26 @@ class Conformer(Transformer):
 
                 # start chunk_by_chunk decoding
                 offset = 0
-                for cur in range(
-                    0, num_frames - embed_left_context + 1, stride
-                ):
+                for cur in range(0, num_frames - embed_left_context + 1, stride):
                     end = min(cur + decoding_window, num_frames)
                     cur_feature = feature[:, cur:end, :]
                     cur_feature = self.encoder_embed(cur_feature)
-                    cur_embed, cur_pos_emb = self.encoder_pos(
-                        cur_feature, offset
-                    )
-                    cur_embed = cur_embed.permute(
-                        1, 0, 2
-                    )  # (B, T, F) -> (T, B, F)
+                    cur_embed, cur_pos_emb = self.encoder_pos(cur_feature, offset)
+                    cur_embed = cur_embed.permute(1, 0, 2)  # (B, T, F) -> (T, B, F)
 
                     cur_T = cur_feature.size(1)
                     if cur == 0:
                         # for first chunk extract the central pos embedding
-                        pos_emb_central = cur_pos_emb[
-                            0, (chunk_size - 1), :
-                        ].view(1, 1, -1)
+                        pos_emb_central = cur_pos_emb[0, (chunk_size - 1), :].view(
+                            1, 1, -1
+                        )
                         cur_T -= 1
                     pos_emb_positive.append(cur_pos_emb[0, :cur_T].flip(0))
                     pos_emb_negative.append(cur_pos_emb[0, -cur_T:])
                     assert pos_emb_positive[-1].size(0) == cur_T
 
-                    pos_emb_pos = torch.cat(pos_emb_positive, dim=0).unsqueeze(
-                        0
-                    )
-                    pos_emb_neg = torch.cat(pos_emb_negative, dim=0).unsqueeze(
-                        0
-                    )
+                    pos_emb_pos = torch.cat(pos_emb_positive, dim=0).unsqueeze(0)
+                    pos_emb_neg = torch.cat(pos_emb_negative, dim=0).unsqueeze(0)
                     cur_pos_emb = torch.cat(
                         [pos_emb_pos.flip(1), pos_emb_central, pos_emb_neg],
                         dim=1,
@@ -413,9 +403,7 @@ class ConformerEncoderLayer(nn.Module):
         causal: bool = False,
     ) -> None:
         super(ConformerEncoderLayer, self).__init__()
-        self.self_attn = RelPositionMultiheadAttention(
-            d_model, nhead, dropout=0.0
-        )
+        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
 
         self.feed_forward = nn.Sequential(
             nn.Linear(d_model, dim_feedforward),
@@ -431,22 +419,16 @@ class ConformerEncoderLayer(nn.Module):
             nn.Linear(dim_feedforward, d_model),
         )
 
-        self.conv_module = ConvolutionModule(
-            d_model, cnn_module_kernel, causal=causal
-        )
+        self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal)
 
-        self.norm_ff_macaron = nn.LayerNorm(
-            d_model
-        )  # for the macaron style FNN module
+        self.norm_ff_macaron = nn.LayerNorm(d_model)  # for the macaron style FNN module
         self.norm_ff = nn.LayerNorm(d_model)  # for the FNN module
         self.norm_mha = nn.LayerNorm(d_model)  # for the MHA module
 
         self.ff_scale = 0.5
 
         self.norm_conv = nn.LayerNorm(d_model)  # for the CNN module
-        self.norm_final = nn.LayerNorm(
-            d_model
-        )  # for the final output of the block
+        self.norm_final = nn.LayerNorm(d_model)  # for the final output of the block
 
         self.dropout = nn.Dropout(dropout)
 
@@ -480,9 +462,7 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(
-            self.feed_forward_macaron(src)
-        )
+        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -554,9 +534,7 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(
-            self.feed_forward_macaron(src)
-        )
+        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -736,9 +714,7 @@ class RelPositionalEncoding(torch.nn.Module):
 
     """
 
-    def __init__(
-        self, d_model: int, dropout_rate: float, max_len: int = 5000
-    ) -> None:
+    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
         """Construct an PositionalEncoding object."""
         super(RelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -755,9 +731,7 @@ class RelPositionalEncoding(torch.nn.Module):
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(1) >= x_size_1 * 2 - 1:
                 # Note: TorchScript doesn't implement operator== for torch.Device
-                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
-                    x.device
-                ):
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
                     self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                 return
         # Suppose `i` means to the position of query vector and `j` means the
@@ -783,9 +757,7 @@ class RelPositionalEncoding(torch.nn.Module):
         pe = torch.cat([pe_positive, pe_negative], dim=1)
         self.pe = pe.to(device=x.device, dtype=x.dtype)
 
-    def forward(
-        self, x: torch.Tensor, offset: int = 0
-    ) -> Tuple[Tensor, Tensor]:
+    def forward(self, x: torch.Tensor, offset: int = 0) -> Tuple[Tensor, Tensor]:
         """Add positional encoding.
 
         Args:
@@ -813,9 +785,7 @@ class RelPositionalEncoding(torch.nn.Module):
             pos_emb = torch.cat(
                 [
                     pos_emb[:, : (x_T - 1)],
-                    self.pe[0, self.pe.size(1) // 2].view(
-                        1, 1, self.pe.size(-1)
-                    ),
+                    self.pe[0, self.pe.size(1) // 2].view(1, 1, self.pe.size(-1)),
                     pos_emb[:, -(x_T - 1) :],  # noqa: E203
                 ],
                 dim=1,
@@ -1050,9 +1020,9 @@ class RelPositionMultiheadAttention(nn.Module):
 
         if torch.equal(query, key) and torch.equal(key, value):
             # self-attention
-            q, k, v = nn.functional.linear(
-                query, in_proj_weight, in_proj_bias
-            ).chunk(3, dim=-1)
+            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
+                3, dim=-1
+            )
 
         elif torch.equal(key, value):
             # encoder-decoder attention
@@ -1120,31 +1090,22 @@ class RelPositionMultiheadAttention(nn.Module):
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0)
                 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
-                    raise RuntimeError(
-                        "The size of the 2D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
             elif attn_mask.dim() == 3:
                 if list(attn_mask.size()) != [
                     bsz * num_heads,
                     query.size(0),
                     key.size(0),
                 ]:
-                    raise RuntimeError(
-                        "The size of the 3D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
             else:
                 raise RuntimeError(
-                    "attn_mask's dimension {} is not supported".format(
-                        attn_mask.dim()
-                    )
+                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
                 )
             # attn_mask's dim is 3 now.
 
         # convert ByteTensor key_padding_mask to bool
-        if (
-            key_padding_mask is not None
-            and key_padding_mask.dtype == torch.uint8
-        ):
+        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
             warnings.warn(
                 "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
             )
@@ -1185,24 +1146,16 @@ class RelPositionMultiheadAttention(nn.Module):
         # first compute matrix a and matrix c
         # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
         k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
-        matrix_ac = torch.matmul(
-            q_with_bias_u, k
-        )  # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
 
         # compute matrix b and matrix d
-        matrix_bd = torch.matmul(
-            q_with_bias_v, p
-        )  # (batch, head, time1, 2*time1-1)
-        matrix_bd = self.rel_shift(
-            matrix_bd, offset=offset
-        )  # [B, head, time1, time2]
+        matrix_bd = torch.matmul(q_with_bias_v, p)  # (batch, head, time1, 2*time1-1)
+        matrix_bd = self.rel_shift(matrix_bd, offset=offset)  # [B, head, time1, time2]
         attn_output_weights = (
             matrix_ac + matrix_bd
         ) * scaling  # (batch, head, time1, time2)
 
-        attn_output_weights = attn_output_weights.view(
-            bsz * num_heads, tgt_len, -1
-        )
+        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
 
         assert list(attn_output_weights.size()) == [
             bsz * num_heads,
@@ -1236,13 +1189,9 @@ class RelPositionMultiheadAttention(nn.Module):
         attn_output = torch.bmm(attn_output_weights, v)
         assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
         attn_output = (
-            attn_output.transpose(0, 1)
-            .contiguous()
-            .view(tgt_len, bsz, embed_dim)
-        )
-        attn_output = nn.functional.linear(
-            attn_output, out_proj_weight, out_proj_bias
+            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
         )
+        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
 
         if need_weights:
             # average attention weights over heads
diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py b/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py
index a74c51836..3965bd5c3 100755
--- a/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py
+++ b/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py
@@ -28,6 +28,7 @@ import torch
 import torch.nn as nn
 from asr_datamodule import LibriSpeechAsrDataModule
 from conformer import Conformer
+
 from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
 from icefall.checkpoint import average_checkpoints, load_checkpoint
 from icefall.lexicon import Lexicon
@@ -86,8 +87,7 @@ def get_parser():
         "--tailing-num-frames",
         type=int,
         default=20,
-        help="tailing dummy frames padded to the right,"
-        "only used during decoding",
+        help="tailing dummy frames padded to the right," "only used during decoding",
     )
 
     parser.add_argument(
@@ -248,13 +248,9 @@ def decode_one_batch(
     maxlen = nnet_output.size(1)
     topk_prob, topk_index = nnet_output.topk(1, dim=2)  # (B, maxlen, 1)
     topk_index = topk_index.view(batch_size, maxlen)  # (B, maxlen)
-    topk_index = topk_index.masked_fill_(
-        memory_key_padding_mask, 0
-    )  # (B, maxlen)
+    topk_index = topk_index.masked_fill_(memory_key_padding_mask, 0)  # (B, maxlen)
     token_ids = [token_id.tolist() for token_id in topk_index]
-    token_ids = [
-        remove_duplicates_and_blank(token_id) for token_id in token_ids
-    ]
+    token_ids = [remove_duplicates_and_blank(token_id) for token_id in token_ids]
     hyps = bpe_model.decode(token_ids)
     hyps = [s.split() for s in hyps]
     return {key: hyps}
@@ -337,9 +333,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
 
     return results
 
@@ -364,8 +358,7 @@ def save_results(
         -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-"
     for key, results in results_dict.items():
         recog_path = (
-            params.exp_dir
-            / f"{result_file_prefix}recogs-{test_set_name}-{key}.txt"
+            params.exp_dir / f"{result_file_prefix}recogs-{test_set_name}-{key}.txt"
         )
         store_transcripts(filename=recog_path, texts=results)
         if enable_log:
@@ -374,8 +367,7 @@ def save_results(
         # The following prints out WERs, per-word error statistics and aligned
         # ref/hyp pairs.
         errs_filename = (
-            params.exp_dir
-            / f"{result_file_prefix}-errs-{test_set_name}-{key}.txt"
+            params.exp_dir / f"{result_file_prefix}-errs-{test_set_name}-{key}.txt"
         )
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
@@ -384,9 +376,7 @@ def save_results(
             test_set_wers[key] = wer
 
         if enable_log:
-            logging.info(
-                "Wrote detailed error stats to {}".format(errs_filename)
-            )
+            logging.info("Wrote detailed error stats to {}".format(errs_filename))
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
@@ -474,9 +464,7 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save(
-            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
-        )
+        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
         return
 
     model.to(device)
@@ -507,9 +495,7 @@ def main():
             simulate_streaming=params.simulate_streaming,
         )
 
-        save_results(
-            params=params, test_set_name=test_set, results_dict=results_dict
-        )
+        save_results(params=params, test_set_name=test_set, results_dict=results_dict)
 
     logging.info("Done!")
 
diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/train.py b/egs/librispeech/ASR/streaming_conformer_ctc/train.py
index e41b7ea78..553b7d092 100755
--- a/egs/librispeech/ASR/streaming_conformer_ctc/train.py
+++ b/egs/librispeech/ASR/streaming_conformer_ctc/train.py
@@ -405,9 +405,7 @@ def compute_loss(
             #
             # See https://github.com/k2-fsa/icefall/issues/97
             # for more details
-            unsorted_token_ids = graph_compiler.texts_to_ids(
-                supervisions["text"]
-            )
+            unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"])
             att_loss = mmodel.decoder_forward(
                 encoder_memory,
                 memory_mask,
@@ -436,9 +434,7 @@ def compute_loss(
     info["utt_duration"] = supervisions["num_frames"].sum().item()
     # averaged padding proportion over utterances
     info["utt_pad_proportion"] = (
-        ((feature.size(1) - supervisions["num_frames"]) / feature.size(1))
-        .sum()
-        .item()
+        ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)).sum().item()
     )
 
     return loss, info
@@ -551,9 +547,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -668,9 +662,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py b/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py
index bc78e4a41..0c87fdf1b 100644
--- a/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py
+++ b/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py
@@ -149,9 +149,7 @@ class Transformer(nn.Module):
                 norm=decoder_norm,
             )
 
-            self.decoder_output_layer = torch.nn.Linear(
-                d_model, self.decoder_num_class
-            )
+            self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class)
 
             self.decoder_criterion = LabelSmoothingLoss()
         else:
@@ -286,23 +284,17 @@ class Transformer(nn.Module):
         """
         ys_in = add_sos(token_ids, sos_id=sos_id)
         ys_in = [torch.tensor(y) for y in ys_in]
-        ys_in_pad = pad_sequence(
-            ys_in, batch_first=True, padding_value=float(eos_id)
-        )
+        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
 
         ys_out = add_eos(token_ids, eos_id=eos_id)
         ys_out = [torch.tensor(y) for y in ys_out]
-        ys_out_pad = pad_sequence(
-            ys_out, batch_first=True, padding_value=float(-1)
-        )
+        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
 
         device = memory.device
         ys_in_pad = ys_in_pad.to(device)
         ys_out_pad = ys_out_pad.to(device)
 
-        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
-            device
-        )
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
 
         tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
         # TODO: Use length information to create the decoder padding mask
@@ -363,23 +355,17 @@ class Transformer(nn.Module):
 
         ys_in = add_sos(token_ids, sos_id=sos_id)
         ys_in = [torch.tensor(y) for y in ys_in]
-        ys_in_pad = pad_sequence(
-            ys_in, batch_first=True, padding_value=float(eos_id)
-        )
+        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
 
         ys_out = add_eos(token_ids, eos_id=eos_id)
         ys_out = [torch.tensor(y) for y in ys_out]
-        ys_out_pad = pad_sequence(
-            ys_out, batch_first=True, padding_value=float(-1)
-        )
+        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
 
         device = memory.device
         ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
         ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
 
-        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
-            device
-        )
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
 
         tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
         # TODO: Use length information to create the decoder padding mask
@@ -652,9 +638,7 @@ def _get_activation_fn(activation: str):
     elif activation == "gelu":
         return nn.functional.gelu
 
-    raise RuntimeError(
-        "activation should be relu/gelu, not {}".format(activation)
-    )
+    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
 
 
 class PositionalEncoding(nn.Module):
@@ -856,9 +840,7 @@ def encoder_padding_mask(
         1,
     ).to(torch.int32)
 
-    lengths = [
-        0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
-    ]
+    lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)]
     for idx in range(supervision_segments.size(0)):
         # Note: TorchScript doesn't allow to unpack tensors as tuples
         sequence_idx = supervision_segments[idx, 0].item()
@@ -879,9 +861,7 @@ def encoder_padding_mask(
     return mask
 
 
-def decoder_padding_mask(
-    ys_pad: torch.Tensor, ignore_id: int = -1
-) -> torch.Tensor:
+def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor:
     """Generate a length mask for input.
 
     The masked position are filled with True,
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
index 355ccc99a..993a7cab5 100644
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -86,8 +86,7 @@ class LibriSpeechAsrDataModule:
             "--full-libri",
             type=str2bool,
             default=True,
-            help="When enabled, use 960h LibriSpeech. "
-            "Otherwise, use 100h subset.",
+            help="When enabled, use 960h LibriSpeech. " "Otherwise, use 100h subset.",
         )
         group.add_argument(
             "--manifest-dir",
@@ -224,13 +223,9 @@ class LibriSpeechAsrDataModule:
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             logging.info("About to get Musan cuts")
-            cuts_musan = load_manifest(
-                self.args.manifest_dir / "musan_cuts.jsonl.gz"
-            )
+            cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
@@ -252,9 +247,7 @@ class LibriSpeechAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -298,9 +291,7 @@ class LibriSpeechAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -356,9 +347,7 @@ class LibriSpeechAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
index 7d0cd0bf3..92529e06c 100755
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
@@ -336,9 +336,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -400,9 +398,7 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(
-        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
-    )
+    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -467,9 +463,7 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save(
-            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
-        )
+        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
         return
 
     model.to(device)
@@ -498,9 +492,7 @@ def main():
             G=G,
         )
 
-        save_results(
-            params=params, test_set_name=test_set, results_dict=results_dict
-        )
+        save_results(params=params, test_set_name=test_set, results_dict=results_dict)
 
     logging.info("Done!")
 
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/model.py b/egs/librispeech/ASR/tdnn_lstm_ctc/model.py
index 5e04c11b4..1731e1ebe 100644
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/model.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/model.py
@@ -66,10 +66,7 @@ class TdnnLstm(nn.Module):
             nn.BatchNorm1d(num_features=500, affine=False),
         )
         self.lstms = nn.ModuleList(
-            [
-                nn.LSTM(input_size=500, hidden_size=500, num_layers=1)
-                for _ in range(5)
-            ]
+            [nn.LSTM(input_size=500, hidden_size=500, num_layers=1) for _ in range(5)]
         )
         self.lstm_bnorms = nn.ModuleList(
             [nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)]
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
index 2baeb6bba..addadbe4e 100755
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
@@ -29,11 +29,7 @@ import torchaudio
 from model import TdnnLstm
 from torch.nn.utils.rnn import pad_sequence
 
-from icefall.decode import (
-    get_lattice,
-    one_best_decoding,
-    rescore_with_whole_lattice,
-)
+from icefall.decode import get_lattice, one_best_decoding, rescore_with_whole_lattice
 from icefall.utils import AttributeDict, get_texts
 
 
@@ -58,9 +54,7 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument(
-        "--HLG", type=str, required=True, help="Path to HLG.pt."
-    )
+    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
 
     parser.add_argument(
         "--method",
@@ -145,8 +139,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -215,9 +208,7 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
     features = features.permute(0, 2, 1)  # now features is (N, C, T)
 
     with torch.no_grad():
@@ -269,9 +260,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py
index 6b37d5c23..071ac792b 100755
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py
@@ -355,9 +355,7 @@ def compute_loss(
     info["utt_duration"] = supervisions["num_frames"].sum().item()
     # averaged padding proportion over utterances
     info["utt_pad_proportion"] = (
-        ((feature.size(2) - supervisions["num_frames"]) / feature.size(2))
-        .sum()
-        .item()
+        ((feature.size(2) - supervisions["num_frames"]) / feature.size(2)).sum().item()
     )
 
     return loss, info
@@ -470,9 +468,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/egs/librispeech/ASR/transducer/beam_search.py b/egs/librispeech/ASR/transducer/beam_search.py
index 11032f31a..b45b6a9d8 100644
--- a/egs/librispeech/ASR/transducer/beam_search.py
+++ b/egs/librispeech/ASR/transducer/beam_search.py
@@ -38,9 +38,7 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
     blank_id = model.decoder.blank_id
     device = model.device
 
-    sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(
-        1, 1
-    )
+    sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(1, 1)
     decoder_out, (h, c) = model.decoder(sos)
     T = encoder_out.size(1)
     t = 0
@@ -123,9 +121,7 @@ def beam_search(
     max_u = 20000  # terminate after this number of steps
     u = 0
 
-    cache: Dict[
-        str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
-    ] = {}
+    cache: Dict[str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = {}
 
     while t < T and u < max_u:
         # fmt: off
@@ -157,9 +153,9 @@ def beam_search(
             cached_key = "_".join(map(str, y_star.ys))
 
             if cached_key not in cache:
-                decoder_input = torch.tensor(
-                    [y_star.ys[-1]], device=device
-                ).reshape(1, 1)
+                decoder_input = torch.tensor([y_star.ys[-1]], device=device).reshape(
+                    1, 1
+                )
 
                 decoder_out, decoder_state = model.decoder(
                     decoder_input,
diff --git a/egs/librispeech/ASR/transducer/decode.py b/egs/librispeech/ASR/transducer/decode.py
index 5f233df87..804713a20 100755
--- a/egs/librispeech/ASR/transducer/decode.py
+++ b/egs/librispeech/ASR/transducer/decode.py
@@ -228,9 +228,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
     batch_size = encoder_out.size(0)
 
@@ -245,9 +243,7 @@ def decode_one_batch(
                 model=model, encoder_out=encoder_out_i, beam=params.beam_size
             )
         else:
-            raise ValueError(
-                f"Unsupported decoding method: {params.decoding_method}"
-            )
+            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
         hyps.append(sp.decode(hyp).split())
 
     if params.decoding_method == "greedy_search":
@@ -318,9 +314,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -353,8 +347,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
diff --git a/egs/librispeech/ASR/transducer/export.py b/egs/librispeech/ASR/transducer/export.py
index 5a5db30c4..6db0272f0 100755
--- a/egs/librispeech/ASR/transducer/export.py
+++ b/egs/librispeech/ASR/transducer/export.py
@@ -238,9 +238,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer/pretrained.py b/egs/librispeech/ASR/transducer/pretrained.py
index 1db2df648..b1ff7b2b1 100755
--- a/egs/librispeech/ASR/transducer/pretrained.py
+++ b/egs/librispeech/ASR/transducer/pretrained.py
@@ -189,8 +189,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -249,9 +248,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -287,9 +284,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer/rnn.py b/egs/librispeech/ASR/transducer/rnn.py
index 2a165b0c1..fe8732301 100644
--- a/egs/librispeech/ASR/transducer/rnn.py
+++ b/egs/librispeech/ASR/transducer/rnn.py
@@ -117,12 +117,8 @@ class LayerNormLSTMCell(nn.Module):
         )
 
         if bias:
-            self.bias_ih = nn.Parameter(
-                torch.empty(4 * hidden_size, **factory_kwargs)
-            )
-            self.bias_hh = nn.Parameter(
-                torch.empty(4 * hidden_size, **factory_kwargs)
-            )
+            self.bias_ih = nn.Parameter(torch.empty(4 * hidden_size, **factory_kwargs))
+            self.bias_hh = nn.Parameter(torch.empty(4 * hidden_size, **factory_kwargs))
         else:
             self.register_parameter("bias_ih", None)
             self.register_parameter("bias_hh", None)
@@ -348,9 +344,7 @@ class LayerNormLSTM(nn.Module):
             device=device,
             dtype=dtype,
         )
-        first_layer = LayerNormLSTMLayer(
-            input_size=input_size, **factory_kwargs
-        )
+        first_layer = LayerNormLSTMLayer(input_size=input_size, **factory_kwargs)
         layers = [first_layer]
         for i in range(1, num_layers):
             layers.append(
@@ -385,9 +379,7 @@ class LayerNormLSTM(nn.Module):
             - List[(next_h, next_c)] containing the hidden states for all layers
 
         """
-        output_states = torch.jit.annotate(
-            List[Tuple[torch.Tensor, torch.Tensor]], []
-        )
+        output_states = torch.jit.annotate(List[Tuple[torch.Tensor, torch.Tensor]], [])
         output = input
         for i, rnn_layer in enumerate(self.layers):
             state = states[i]
@@ -456,12 +448,8 @@ class LayerNormGRUCell(nn.Module):
         )
 
         if bias:
-            self.bias_ih = nn.Parameter(
-                torch.empty(3 * hidden_size, **factory_kwargs)
-            )
-            self.bias_hh = nn.Parameter(
-                torch.empty(3 * hidden_size, **factory_kwargs)
-            )
+            self.bias_ih = nn.Parameter(torch.empty(3 * hidden_size, **factory_kwargs))
+            self.bias_hh = nn.Parameter(torch.empty(3 * hidden_size, **factory_kwargs))
         else:
             self.register_parameter("bias_ih", None)
             self.register_parameter("bias_hh", None)
diff --git a/egs/librispeech/ASR/transducer/test_rnn.py b/egs/librispeech/ASR/transducer/test_rnn.py
index 8591e2d8a..74c94cc70 100755
--- a/egs/librispeech/ASR/transducer/test_rnn.py
+++ b/egs/librispeech/ASR/transducer/test_rnn.py
@@ -254,9 +254,7 @@ def test_layernorm_lstm_layer_with_projection_forward(device="cpu"):
         for name, self_param in self_layer.cell.named_parameters():
             getattr(torch_layer, f"{name}_l0").copy_(self_param)
 
-    torch_y, (torch_h, torch_c) = torch_layer(
-        x_clone, (h.unsqueeze(0), c.unsqueeze(0))
-    )
+    torch_y, (torch_h, torch_c) = torch_layer(x_clone, (h.unsqueeze(0), c.unsqueeze(0)))
     assert_allclose(self_y, torch_y)
     assert_allclose(self_h, torch_h)
     assert_allclose(self_c, torch_c)
@@ -303,9 +301,7 @@ def test_layernorm_lstm_layer_forward(device="cpu"):
         for name, self_param in self_layer.cell.named_parameters():
             getattr(torch_layer, f"{name}_l0").copy_(self_param)
 
-    torch_y, (torch_h, torch_c) = torch_layer(
-        x_clone, (h.unsqueeze(0), c.unsqueeze(0))
-    )
+    torch_y, (torch_h, torch_c) = torch_layer(x_clone, (h.unsqueeze(0), c.unsqueeze(0)))
     assert_allclose(self_y, torch_y)
     assert_allclose(self_h, torch_h)
     assert_allclose(self_c, torch_c)
@@ -594,9 +590,7 @@ def test_layernorm_gru_cell_forward(device="cpu"):
 
     assert_allclose(self_h, torch_h, atol=1e-5)
 
-    (
-        self_h.reshape(-1) * torch.arange(self_h.numel(), device=device)
-    ).sum().backward()
+    (self_h.reshape(-1) * torch.arange(self_h.numel(), device=device)).sum().backward()
     (
         torch_h.reshape(-1) * torch.arange(torch_h.numel(), device=device)
     ).sum().backward()
@@ -718,9 +712,7 @@ def test_layernorm_gru_forward(device="cpu"):
     T = torch.randint(low=2, high=100, size=(1,))
 
     x = torch.rand(N, T, input_size, device=device).requires_grad_()
-    states = [
-        torch.rand(N, hidden_size, device=device) for _ in range(num_layers)
-    ]
+    states = [torch.rand(N, hidden_size, device=device) for _ in range(num_layers)]
 
     x_clone = x.detach().clone().requires_grad_()
 
diff --git a/egs/librispeech/ASR/transducer/train.py b/egs/librispeech/ASR/transducer/train.py
index 1dd65eddb..674ea10a6 100755
--- a/egs/librispeech/ASR/transducer/train.py
+++ b/egs/librispeech/ASR/transducer/train.py
@@ -396,9 +396,7 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -520,9 +518,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -659,9 +655,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/transducer_lstm/beam_search.py b/egs/librispeech/ASR/transducer_lstm/beam_search.py
index 3531a9633..5342c3e8c 100644
--- a/egs/librispeech/ASR/transducer_lstm/beam_search.py
+++ b/egs/librispeech/ASR/transducer_lstm/beam_search.py
@@ -38,9 +38,7 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
     blank_id = model.decoder.blank_id
     device = model.device
 
-    sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(
-        1, 1
-    )
+    sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(1, 1)
     decoder_out, (h, c) = model.decoder(sos)
     T = encoder_out.size(1)
     t = 0
@@ -124,9 +122,7 @@ def beam_search(
     max_u = 20000  # terminate after this number of steps
     u = 0
 
-    cache: Dict[
-        str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
-    ] = {}
+    cache: Dict[str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = {}
 
     while t < T and u < max_u:
         # fmt: off
@@ -158,9 +154,9 @@ def beam_search(
             cached_key = "_".join(map(str, y_star.ys))
 
             if cached_key not in cache:
-                decoder_input = torch.tensor(
-                    [y_star.ys[-1]], device=device
-                ).reshape(1, 1)
+                decoder_input = torch.tensor([y_star.ys[-1]], device=device).reshape(
+                    1, 1
+                )
 
                 decoder_out, decoder_state = model.decoder(
                     decoder_input,
diff --git a/egs/librispeech/ASR/transducer_lstm/decode.py b/egs/librispeech/ASR/transducer_lstm/decode.py
index 604235e2a..9511ca6d7 100755
--- a/egs/librispeech/ASR/transducer_lstm/decode.py
+++ b/egs/librispeech/ASR/transducer_lstm/decode.py
@@ -225,9 +225,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
     batch_size = encoder_out.size(0)
 
@@ -242,9 +240,7 @@ def decode_one_batch(
                 model=model, encoder_out=encoder_out_i, beam=params.beam_size
             )
         else:
-            raise ValueError(
-                f"Unsupported decoding method: {params.decoding_method}"
-            )
+            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
         hyps.append(sp.decode(hyp).split())
 
     if params.decoding_method == "greedy_search":
@@ -315,9 +311,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -350,8 +344,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
diff --git a/egs/librispeech/ASR/transducer_lstm/encoder.py b/egs/librispeech/ASR/transducer_lstm/encoder.py
index 3dc992dd2..038d80077 100644
--- a/egs/librispeech/ASR/transducer_lstm/encoder.py
+++ b/egs/librispeech/ASR/transducer_lstm/encoder.py
@@ -48,9 +48,7 @@ class LstmEncoder(EncoderInterface):
         if vgg_frontend:
             self.encoder_embed = VggSubsampling(num_features, real_hidden_size)
         else:
-            self.encoder_embed = Conv2dSubsampling(
-                num_features, real_hidden_size
-            )
+            self.encoder_embed = Conv2dSubsampling(num_features, real_hidden_size)
 
         self.rnn = nn.LSTM(
             input_size=hidden_size,
diff --git a/egs/librispeech/ASR/transducer_lstm/train.py b/egs/librispeech/ASR/transducer_lstm/train.py
index cdb801e79..57bda63fd 100755
--- a/egs/librispeech/ASR/transducer_lstm/train.py
+++ b/egs/librispeech/ASR/transducer_lstm/train.py
@@ -400,9 +400,7 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -524,9 +522,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -665,9 +661,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/transducer_stateless/alignment.py b/egs/librispeech/ASR/transducer_stateless/alignment.py
index f143611ea..65f2c58d8 100644
--- a/egs/librispeech/ASR/transducer_stateless/alignment.py
+++ b/egs/librispeech/ASR/transducer_stateless/alignment.py
@@ -193,9 +193,7 @@ def force_alignment(
         decoder_out = model.decoder(decoder_input, need_pad=False)
         # decoder_output is of shape (num_active_items, 1, decoder_output_dim)
 
-        current_encoder_out = current_encoder_out.expand(
-            decoder_out.size(0), 1, -1
-        )
+        current_encoder_out = current_encoder_out.expand(decoder_out.size(0), 1, -1)
 
         logits = model.joiner(
             current_encoder_out,
diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py
index ea985f30d..1d79eef9d 100644
--- a/egs/librispeech/ASR/transducer_stateless/beam_search.py
+++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py
@@ -316,9 +316,9 @@ def greedy_search(
         y = logits.argmax().item()
         if y != blank_id:
             hyp.append(y)
-            decoder_input = torch.tensor(
-                [hyp[-context_size:]], device=device
-            ).reshape(1, context_size)
+            decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape(
+                1, context_size
+            )
 
             decoder_out = model.decoder(decoder_input, need_pad=False)
 
@@ -478,9 +478,7 @@ class HypothesisList(object):
         key = hyp.key
         if key in self:
             old_hyp = self._data[key]  # shallow copy
-            torch.logaddexp(
-                old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
-            )
+            torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob)
         else:
             self._data[key] = hyp
 
@@ -496,9 +494,7 @@ class HypothesisList(object):
           Return the hypothesis that has the largest `log_prob`.
         """
         if length_norm:
-            return max(
-                self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
-            )
+            return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
         else:
             return max(self._data.values(), key=lambda hyp: hyp.log_prob)
 
@@ -786,9 +782,7 @@ def modified_beam_search(
         log_probs_shape = k2.ragged.create_ragged_shape2(
             row_splits=row_splits, cached_tot_size=log_probs.numel()
         )
-        ragged_log_probs = k2.RaggedTensor(
-            shape=log_probs_shape, value=log_probs
-        )
+        ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
 
         for i in range(batch_size):
             topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
@@ -887,9 +881,7 @@ def _deprecated_modified_beam_search(
         decoder_out = model.decoder(decoder_input, need_pad=False)
         # decoder_output is of shape (num_hyps, 1, decoder_output_dim)
 
-        current_encoder_out = current_encoder_out.expand(
-            decoder_out.size(0), 1, -1
-        )
+        current_encoder_out = current_encoder_out.expand(decoder_out.size(0), 1, -1)
 
         logits = model.joiner(
             current_encoder_out,
@@ -959,9 +951,9 @@ def beam_search(
 
     device = model.device
 
-    decoder_input = torch.tensor(
-        [blank_id] * context_size, device=device
-    ).reshape(1, context_size)
+    decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
+        1, context_size
+    )
 
     decoder_out = model.decoder(decoder_input, need_pad=False)
 
diff --git a/egs/librispeech/ASR/transducer_stateless/compute_ali.py b/egs/librispeech/ASR/transducer_stateless/compute_ali.py
index 48769e9d1..c91198bb9 100755
--- a/egs/librispeech/ASR/transducer_stateless/compute_ali.py
+++ b/egs/librispeech/ASR/transducer_stateless/compute_ali.py
@@ -124,8 +124,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     return parser
@@ -162,9 +161,7 @@ def compute_alignments(
 
         feature_lens = supervisions["num_frames"].to(device)
 
-        encoder_out, encoder_out_lens = model.encoder(
-            x=feature, x_lens=feature_lens
-        )
+        encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
         batch_size = encoder_out.size(0)
 
@@ -204,9 +201,7 @@ def compute_alignments(
         if batch_idx % 2 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
 
     return CutSet.from_cuts(cuts)
 
diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py
index cde52c9fc..01e8c5b21 100644
--- a/egs/librispeech/ASR/transducer_stateless/conformer.py
+++ b/egs/librispeech/ASR/transducer_stateless/conformer.py
@@ -209,10 +209,7 @@ class Conformer(Transformer):
 
           NOTE: the returned tensors are on the given device.
         """
-        if (
-            len(self._init_state) == 2
-            and self._init_state[0].size(1) == left_context
-        ):
+        if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context:
             # Note: It is OK to share the init state as it is
             # not going to be modified by the model
             return self._init_state
@@ -421,9 +418,7 @@ class ConformerEncoderLayer(nn.Module):
         causal: bool = False,
     ) -> None:
         super(ConformerEncoderLayer, self).__init__()
-        self.self_attn = RelPositionMultiheadAttention(
-            d_model, nhead, dropout=0.0
-        )
+        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
 
         self.feed_forward = nn.Sequential(
             nn.Linear(d_model, dim_feedforward),
@@ -439,22 +434,16 @@ class ConformerEncoderLayer(nn.Module):
             nn.Linear(dim_feedforward, d_model),
         )
 
-        self.conv_module = ConvolutionModule(
-            d_model, cnn_module_kernel, causal=causal
-        )
+        self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal)
 
-        self.norm_ff_macaron = nn.LayerNorm(
-            d_model
-        )  # for the macaron style FNN module
+        self.norm_ff_macaron = nn.LayerNorm(d_model)  # for the macaron style FNN module
         self.norm_ff = nn.LayerNorm(d_model)  # for the FNN module
         self.norm_mha = nn.LayerNorm(d_model)  # for the MHA module
 
         self.ff_scale = 0.5
 
         self.norm_conv = nn.LayerNorm(d_model)  # for the CNN module
-        self.norm_final = nn.LayerNorm(
-            d_model
-        )  # for the final output of the block
+        self.norm_final = nn.LayerNorm(d_model)  # for the final output of the block
 
         self.dropout = nn.Dropout(dropout)
 
@@ -486,9 +475,7 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(
-            self.feed_forward_macaron(src)
-        )
+        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -514,9 +501,7 @@ class ConformerEncoderLayer(nn.Module):
         if self.normalize_before:
             src = self.norm_conv(src)
 
-        src, _ = self.conv_module(
-            src, src_key_padding_mask=src_key_padding_mask
-        )
+        src, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
         src = residual + self.dropout(src)
 
         if not self.normalize_before:
@@ -581,9 +566,7 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(
-            self.feed_forward_macaron(src)
-        )
+        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -625,9 +608,7 @@ class ConformerEncoderLayer(nn.Module):
         if self.normalize_before:
             src = self.norm_conv(src)
 
-        src, conv_cache = self.conv_module(
-            src, states[1], right_context=right_context
-        )
+        src, conv_cache = self.conv_module(src, states[1], right_context=right_context)
         states[1] = conv_cache
         src = residual + self.dropout(src)
 
@@ -779,9 +760,7 @@ class RelPositionalEncoding(torch.nn.Module):
 
     """
 
-    def __init__(
-        self, d_model: int, dropout_rate: float, max_len: int = 5000
-    ) -> None:
+    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
         """Construct an PositionalEncoding object."""
         super(RelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -798,9 +777,7 @@ class RelPositionalEncoding(torch.nn.Module):
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(1) >= x_size_1 * 2 - 1:
                 # Note: TorchScript doesn't implement operator== for torch.Device
-                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
-                    x.device
-                ):
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
                     self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                 return
         # Suppose `i` means to the position of query vector and `j` means the
@@ -826,9 +803,7 @@ class RelPositionalEncoding(torch.nn.Module):
         pe = torch.cat([pe_positive, pe_negative], dim=1)
         self.pe = pe.to(device=x.device, dtype=x.dtype)
 
-    def forward(
-        self, x: torch.Tensor, left_context: int = 0
-    ) -> Tuple[Tensor, Tensor]:
+    def forward(self, x: torch.Tensor, left_context: int = 0) -> Tuple[Tensor, Tensor]:
         """Add positional encoding.
 
         Args:
@@ -1092,9 +1067,9 @@ class RelPositionMultiheadAttention(nn.Module):
 
         if torch.equal(query, key) and torch.equal(key, value):
             # self-attention
-            q, k, v = nn.functional.linear(
-                query, in_proj_weight, in_proj_bias
-            ).chunk(3, dim=-1)
+            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
+                3, dim=-1
+            )
 
         elif torch.equal(key, value):
             # encoder-decoder attention
@@ -1163,31 +1138,22 @@ class RelPositionMultiheadAttention(nn.Module):
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0)
                 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
-                    raise RuntimeError(
-                        "The size of the 2D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
             elif attn_mask.dim() == 3:
                 if list(attn_mask.size()) != [
                     bsz * num_heads,
                     query.size(0),
                     key.size(0),
                 ]:
-                    raise RuntimeError(
-                        "The size of the 3D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
             else:
                 raise RuntimeError(
-                    "attn_mask's dimension {} is not supported".format(
-                        attn_mask.dim()
-                    )
+                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
                 )
             # attn_mask's dim is 3 now.
 
         # convert ByteTensor key_padding_mask to bool
-        if (
-            key_padding_mask is not None
-            and key_padding_mask.dtype == torch.uint8
-        ):
+        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
             warnings.warn(
                 "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
             )
@@ -1228,14 +1194,10 @@ class RelPositionMultiheadAttention(nn.Module):
         # first compute matrix a and matrix c
         # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
         k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
-        matrix_ac = torch.matmul(
-            q_with_bias_u, k
-        )  # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
 
         # compute matrix b and matrix d
-        matrix_bd = torch.matmul(
-            q_with_bias_v, p
-        )  # (batch, head, time1, 2*time1-1)
+        matrix_bd = torch.matmul(q_with_bias_v, p)  # (batch, head, time1, 2*time1-1)
 
         matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
 
@@ -1243,9 +1205,7 @@ class RelPositionMultiheadAttention(nn.Module):
             matrix_ac + matrix_bd
         ) * scaling  # (batch, head, time1, time2)
 
-        attn_output_weights = attn_output_weights.view(
-            bsz * num_heads, tgt_len, -1
-        )
+        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
 
         assert list(attn_output_weights.size()) == [
             bsz * num_heads,
@@ -1290,9 +1250,7 @@ class RelPositionMultiheadAttention(nn.Module):
             attn_output_weights = attn_output_weights.view(
                 bsz, num_heads, tgt_len, src_len
             )
-            attn_output_weights = attn_output_weights.masked_fill(
-                combined_mask, 0.0
-            )
+            attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0)
             attn_output_weights = attn_output_weights.view(
                 bsz * num_heads, tgt_len, src_len
             )
@@ -1304,13 +1262,9 @@ class RelPositionMultiheadAttention(nn.Module):
         attn_output = torch.bmm(attn_output_weights, v)
         assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
         attn_output = (
-            attn_output.transpose(0, 1)
-            .contiguous()
-            .view(tgt_len, bsz, embed_dim)
-        )
-        attn_output = nn.functional.linear(
-            attn_output, out_proj_weight, out_proj_bias
+            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
         )
+        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
 
         if need_weights:
             # average attention weights over heads
@@ -1418,16 +1372,12 @@ class ConvolutionModule(nn.Module):
                 # manualy padding self.lorder zeros to the left
                 x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
             else:
-                assert (
-                    not self.training
-                ), "Cache should be None in training time"
+                assert not self.training, "Cache should be None in training time"
                 assert cache.size(0) == self.lorder
                 x = torch.cat([cache.permute(1, 2, 0), x], dim=2)
                 if right_context > 0:
                     cache = x.permute(2, 0, 1)[
-                        -(self.lorder + right_context) : (  # noqa
-                            -right_context
-                        ),
+                        -(self.lorder + right_context) : (-right_context),  # noqa
                         ...,
                     ]
                 else:
diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py
index 74bba9cad..688e214c8 100755
--- a/egs/librispeech/ASR/transducer_stateless/decode.py
+++ b/egs/librispeech/ASR/transducer_stateless/decode.py
@@ -171,8 +171,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -230,9 +229,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
     hyps = []
 
@@ -248,10 +245,7 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -374,9 +368,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -409,8 +401,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -450,9 +441,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
diff --git a/egs/librispeech/ASR/transducer_stateless/decoder.py b/egs/librispeech/ASR/transducer_stateless/decoder.py
index fbc2373a9..a182d91e2 100644
--- a/egs/librispeech/ASR/transducer_stateless/decoder.py
+++ b/egs/librispeech/ASR/transducer_stateless/decoder.py
@@ -87,9 +87,7 @@ class Decoder(nn.Module):
         if self.context_size > 1:
             embedding_out = embedding_out.permute(0, 2, 1)
             if need_pad is True:
-                embedding_out = F.pad(
-                    embedding_out, pad=(self.context_size - 1, 0)
-                )
+                embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
             else:
                 # During inference time, there is no need to do extra padding
                 # as we only need one output
diff --git a/egs/librispeech/ASR/transducer_stateless/export.py b/egs/librispeech/ASR/transducer_stateless/export.py
index 8bd0bdea1..c617e6c4c 100755
--- a/egs/librispeech/ASR/transducer_stateless/export.py
+++ b/egs/librispeech/ASR/transducer_stateless/export.py
@@ -109,8 +109,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     return parser
@@ -244,9 +243,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless/joiner.py b/egs/librispeech/ASR/transducer_stateless/joiner.py
index 93cccbd8c..e1625992d 100644
--- a/egs/librispeech/ASR/transducer_stateless/joiner.py
+++ b/egs/librispeech/ASR/transducer_stateless/joiner.py
@@ -60,13 +60,9 @@ class Joiner(nn.Module):
         encoder_out_len: List[int] = encoder_out_len.tolist()
         decoder_out_len: List[int] = decoder_out_len.tolist()
 
-        encoder_out_list = [
-            encoder_out[i, : encoder_out_len[i], :] for i in range(N)
-        ]
+        encoder_out_list = [encoder_out[i, : encoder_out_len[i], :] for i in range(N)]
 
-        decoder_out_list = [
-            decoder_out[i, : decoder_out_len[i], :] for i in range(N)
-        ]
+        decoder_out_list = [decoder_out[i, : decoder_out_len[i], :] for i in range(N)]
 
         x = [
             e.unsqueeze(1) + d.unsqueeze(0)
diff --git a/egs/librispeech/ASR/transducer_stateless/pretrained.py b/egs/librispeech/ASR/transducer_stateless/pretrained.py
index b64521801..c393974e6 100755
--- a/egs/librispeech/ASR/transducer_stateless/pretrained.py
+++ b/egs/librispeech/ASR/transducer_stateless/pretrained.py
@@ -167,8 +167,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -198,8 +197,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -259,9 +257,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -334,9 +330,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py b/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py
index b00fc34f1..9af46846a 100755
--- a/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py
+++ b/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py
@@ -140,16 +140,13 @@ def main():
                 token_alignment[i, : token_alignment_length[i]].tolist(), sp=sp
             )
             word_starting_time = [
-                "{:.2f}".format(i * frame_shift_in_second)
-                for i in word_starting_frames
+                "{:.2f}".format(i * frame_shift_in_second) for i in word_starting_frames
             ]
 
             words = supervisions["text"][i].split()
 
             assert len(word_starting_frames) == len(words)
-            word_starting_time_dict[cuts[i].id] = list(
-                zip(words, word_starting_time)
-            )
+            word_starting_time_dict[cuts[i].id] = list(zip(words, word_starting_time))
 
         # This is a demo script and we exit here after processing
         # one batch.
@@ -160,9 +157,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless/test_conformer.py b/egs/librispeech/ASR/transducer_stateless/test_conformer.py
index d1350c8ab..65b08d425 100755
--- a/egs/librispeech/ASR/transducer_stateless/test_conformer.py
+++ b/egs/librispeech/ASR/transducer_stateless/test_conformer.py
@@ -29,9 +29,7 @@ from conformer import Conformer
 
 def test_conformer():
     feature_dim = 50
-    c = Conformer(
-        num_features=feature_dim, output_dim=256, d_model=128, nhead=4
-    )
+    c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4)
     batch_size = 5
     seq_len = 20
     # Just make sure the forward pass runs.
diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py
index ae93f3348..c86125f44 100755
--- a/egs/librispeech/ASR/transducer_stateless/train.py
+++ b/egs/librispeech/ASR/transducer_stateless/train.py
@@ -136,8 +136,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -422,9 +421,7 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -545,9 +542,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -664,13 +659,9 @@ def run(rank, world_size, args):
         num_removed = num_in_total - num_left
         removed_percent = num_removed / num_in_total * 100
 
-        logging.info(
-            f"Before removing short and long utterances: {num_in_total}"
-        )
+        logging.info(f"Before removing short and long utterances: {num_in_total}")
         logging.info(f"After removing short and long utterances: {num_left}")
-        logging.info(
-            f"Removed {num_removed} utterances ({removed_percent:.5f}%)"
-        )
+        logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
     except TypeError as e:
         # You can ignore this error as previous versions of Lhotse work fine
         # for the above code. In recent versions of Lhotse, it uses
@@ -698,9 +689,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/transducer_stateless/transformer.py b/egs/librispeech/ASR/transducer_stateless/transformer.py
index e851dcc32..b3ff153c1 100644
--- a/egs/librispeech/ASR/transducer_stateless/transformer.py
+++ b/egs/librispeech/ASR/transducer_stateless/transformer.py
@@ -250,9 +250,7 @@ def _get_activation_fn(activation: str):
     elif activation == "gelu":
         return nn.functional.gelu
 
-    raise RuntimeError(
-        "activation should be relu/gelu, not {}".format(activation)
-    )
+    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
 
 
 class PositionalEncoding(nn.Module):
diff --git a/egs/librispeech/ASR/transducer_stateless2/decode.py b/egs/librispeech/ASR/transducer_stateless2/decode.py
index ac2807241..c642b16bd 100755
--- a/egs/librispeech/ASR/transducer_stateless2/decode.py
+++ b/egs/librispeech/ASR/transducer_stateless2/decode.py
@@ -171,8 +171,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -230,9 +229,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
     hyps = []
 
@@ -248,10 +245,7 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -374,9 +368,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -409,8 +401,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -450,9 +441,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
diff --git a/egs/librispeech/ASR/transducer_stateless2/export.py b/egs/librispeech/ASR/transducer_stateless2/export.py
index 57c1a6094..229c514b9 100755
--- a/egs/librispeech/ASR/transducer_stateless2/export.py
+++ b/egs/librispeech/ASR/transducer_stateless2/export.py
@@ -104,8 +104,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     return parser
@@ -176,9 +175,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless2/pretrained.py b/egs/librispeech/ASR/transducer_stateless2/pretrained.py
index 292f77f03..9053bc6e0 100755
--- a/egs/librispeech/ASR/transducer_stateless2/pretrained.py
+++ b/egs/librispeech/ASR/transducer_stateless2/pretrained.py
@@ -167,8 +167,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -198,8 +197,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -259,9 +257,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -334,9 +330,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless2/train.py b/egs/librispeech/ASR/transducer_stateless2/train.py
index ea15c9040..71c9c5df7 100755
--- a/egs/librispeech/ASR/transducer_stateless2/train.py
+++ b/egs/librispeech/ASR/transducer_stateless2/train.py
@@ -136,8 +136,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -410,9 +409,7 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -533,9 +530,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -652,13 +647,9 @@ def run(rank, world_size, args):
         num_removed = num_in_total - num_left
         removed_percent = num_removed / num_in_total * 100
 
-        logging.info(
-            f"Before removing short and long utterances: {num_in_total}"
-        )
+        logging.info(f"Before removing short and long utterances: {num_in_total}")
         logging.info(f"After removing short and long utterances: {num_left}")
-        logging.info(
-            f"Removed {num_removed} utterances ({removed_percent:.5f}%)"
-        )
+        logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
     except TypeError as e:
         # You can ignore this error as previous versions of Lhotse work fine
         # for the above code. In recent versions of Lhotse, it uses
@@ -686,9 +677,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py
index d596e05cb..253821028 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py
@@ -172,8 +172,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -231,9 +230,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
     hyps = []
 
@@ -249,10 +246,7 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -375,9 +369,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -410,8 +402,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -451,9 +442,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py
index b6b69d932..97b0eea4a 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py
@@ -110,8 +110,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     return parser
@@ -247,9 +246,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py
index f297fa2b2..c698a35b0 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py
@@ -167,8 +167,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -198,8 +197,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -259,9 +257,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -334,9 +330,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py
index ef51a7811..1e1188ca6 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py
@@ -41,9 +41,7 @@ def test_dataset():
     print(args)
 
     if args.enable_musan:
-        cuts_musan = load_manifest(
-            Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
     else:
         cuts_musan = None
 
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
index 27912738c..e5b7dc390 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
@@ -114,8 +114,7 @@ def get_parser():
         "--full-libri",
         type=str2bool,
         default=True,
-        help="When enabled, use 960h LibriSpeech. "
-        "Otherwise, use 100h subset.",
+        help="When enabled, use 960h LibriSpeech. " "Otherwise, use 100h subset.",
     )
 
     parser.add_argument(
@@ -170,8 +169,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -469,9 +467,7 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -635,9 +631,7 @@ def train_one_epoch(
                     f"train/current_{prefix}_",
                     params.batch_idx_train,
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
                 libri_tot_loss.write_summary(
                     tb_writer, "train/libri_tot_", params.batch_idx_train
                 )
@@ -784,9 +778,7 @@ def run(rank, world_size, args):
     train_giga_cuts = train_giga_cuts.repeat(times=None)
 
     if args.enable_musan:
-        cuts_musan = load_manifest(
-            Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
     else:
         cuts_musan = None
 
@@ -825,9 +817,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/ptb/LM/local/sort_lm_training_data.py b/egs/ptb/LM/local/sort_lm_training_data.py
index af54dbd07..bed3856e4 100755
--- a/egs/ptb/LM/local/sort_lm_training_data.py
+++ b/egs/ptb/LM/local/sort_lm_training_data.py
@@ -135,9 +135,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/ptb/LM/local/test_prepare_lm_training_data.py b/egs/ptb/LM/local/test_prepare_lm_training_data.py
index 877720e7b..3790045fa 100755
--- a/egs/ptb/LM/local/test_prepare_lm_training_data.py
+++ b/egs/ptb/LM/local/test_prepare_lm_training_data.py
@@ -54,9 +54,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/spgispeech/ASR/local/compute_fbank_musan.py b/egs/spgispeech/ASR/local/compute_fbank_musan.py
index 6cb8b65ae..9bea28a41 100755
--- a/egs/spgispeech/ASR/local/compute_fbank_musan.py
+++ b/egs/spgispeech/ASR/local/compute_fbank_musan.py
@@ -87,9 +87,7 @@ def compute_fbank_musan():
     # create chunks of Musan with duration 5 - 10 seconds
     musan_cuts = (
         CutSet.from_manifests(
-            recordings=combine(
-                part["recordings"] for part in manifests.values()
-            )
+            recordings=combine(part["recordings"] for part in manifests.values())
         )
         .cut_into_windows(10.0)
         .filter(lambda c: c.duration > 5)
@@ -108,8 +106,6 @@ def compute_fbank_musan():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
     logging.basicConfig(format=formatter, level=logging.INFO)
     compute_fbank_musan()
diff --git a/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py b/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py
index 8116e7605..20ff6d7ab 100755
--- a/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py
+++ b/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py
@@ -103,11 +103,7 @@ def compute_fbank_spgispeech(args):
             chunk_size=chunk_size,
         )
         start = args.start
-        stop = (
-            min(args.stop, args.num_splits)
-            if args.stop > 0
-            else args.num_splits
-        )
+        stop = min(args.stop, args.num_splits) if args.stop > 0 else args.num_splits
         num_digits = len(str(args.num_splits))
         for i in range(start, stop):
             idx = f"{i + 1}".zfill(num_digits)
@@ -129,9 +125,7 @@ def compute_fbank_spgispeech(args):
                 logging.info(f"{partition} already exists - skipping.")
                 continue
             logging.info(f"Processing {partition}")
-            cut_set = load_manifest_lazy(
-                src_dir / f"cuts_{partition}_raw.jsonl.gz"
-            )
+            cut_set = load_manifest_lazy(src_dir / f"cuts_{partition}_raw.jsonl.gz")
             cut_set = cut_set.compute_and_store_features_batch(
                 extractor=extractor,
                 storage_path=output_dir / f"feats_{partition}",
@@ -144,9 +138,7 @@ def compute_fbank_spgispeech(args):
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
     logging.basicConfig(format=formatter, level=logging.INFO)
 
     args = get_args()
diff --git a/egs/spgispeech/ASR/local/prepare_splits.py b/egs/spgispeech/ASR/local/prepare_splits.py
index 8c8f1c133..508d4acd8 100755
--- a/egs/spgispeech/ASR/local/prepare_splits.py
+++ b/egs/spgispeech/ASR/local/prepare_splits.py
@@ -55,9 +55,7 @@ def split_spgispeech_train():
 
     # Add speed perturbation
     train_cuts = (
-        train_cuts
-        + train_cuts.perturb_speed(0.9)
-        + train_cuts.perturb_speed(1.1)
+        train_cuts + train_cuts.perturb_speed(0.9) + train_cuts.perturb_speed(1.1)
     )
 
     # Write the manifests to disk.
@@ -73,9 +71,7 @@ def split_spgispeech_train():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
     logging.basicConfig(format=formatter, level=logging.INFO)
 
     split_spgispeech_train()
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
index f165f6e60..d94a92503 100644
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -176,17 +176,13 @@ class SPGISpeechAsrDataModule:
             The state dict for the training sampler.
         """
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(
-            self.args.manifest_dir / "cuts_musan.jsonl.gz"
-        )
+        cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.jsonl.gz")
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
@@ -208,9 +204,7 @@ class SPGISpeechAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
             input_transforms.append(
                 SpecAugment(
                     time_warp_factor=self.args.spec_aug_time_warp_factor,
@@ -227,9 +221,7 @@ class SPGISpeechAsrDataModule:
         if self.args.on_the_fly_feats:
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
             )
         else:
@@ -282,9 +274,7 @@ class SPGISpeechAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
             )
         else:
             validate = K2SpeechRecognitionDataset(
@@ -328,9 +318,7 @@ class SPGISpeechAsrDataModule:
     @lru_cache()
     def train_cuts(self) -> CutSet:
         logging.info("About to get SPGISpeech train cuts")
-        return load_manifest_lazy(
-            self.args.manifest_dir / "cuts_train_shuf.jsonl.gz"
-        )
+        return load_manifest_lazy(self.args.manifest_dir / "cuts_train_shuf.jsonl.gz")
 
     @lru_cache()
     def dev_cuts(self) -> CutSet:
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py
index c39bd0530..098da3ff0 100755
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py
@@ -76,11 +76,7 @@ from beam_search import (
 )
 from train import get_params, get_transducer_model
 
-from icefall.checkpoint import (
-    average_checkpoints,
-    find_checkpoints,
-    load_checkpoint,
-)
+from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
 from icefall.utils import (
     AttributeDict,
     setup_logger,
@@ -187,8 +183,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -246,9 +241,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -263,10 +256,7 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -389,9 +379,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -424,9 +412,7 @@ def save_results(
         # we also compute CER for spgispeech dataset.
         results_char = []
         for res in results:
-            results_char.append(
-                (res[0], list("".join(res[1])), list("".join(res[2])))
-            )
+            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
         cers_filename = (
             params.res_dir / f"cers-{test_set_name}-{key}-{params.suffix}.txt"
         )
@@ -438,32 +424,23 @@ def save_results(
 
         logging.info("Wrote detailed error stats to {}".format(wers_filename))
 
-    test_set_wers = {
-        k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])
-    }
-    test_set_cers = {
-        k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])
-    }
+    test_set_wers = {k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])}
+    test_set_cers = {k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])}
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER\tCER", file=f)
         for key in test_set_wers:
             print(
-                "{}\t{}\t{}".format(
-                    key, test_set_wers[key], test_set_cers[key]
-                ),
+                "{}\t{}\t{}".format(key, test_set_wers[key], test_set_cers[key]),
                 file=f,
             )
 
     s = "\nFor {}, WER/CER of different settings are:\n".format(test_set_name)
     note = "\tbest for {}".format(test_set_name)
     for key in test_set_wers:
-        s += "{}\t{}\t{}{}\n".format(
-            key, test_set_wers[key], test_set_cers[key], note
-        )
+        s += "{}\t{}\t{}{}\n".format(key, test_set_wers[key], test_set_cers[key], note)
         note = ""
     logging.info(s)
 
@@ -496,9 +473,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@@ -530,8 +505,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for"
-                f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py
index 77faa3c0e..e79cb300d 100755
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py
@@ -50,11 +50,7 @@ import sentencepiece as spm
 import torch
 from train import get_params, get_transducer_model
 
-from icefall.checkpoint import (
-    average_checkpoints,
-    find_checkpoints,
-    load_checkpoint,
-)
+from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
 from icefall.utils import str2bool
 
 
@@ -119,8 +115,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     return parser
@@ -196,9 +191,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py
index dda29b3e5..213635894 100755
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py
@@ -77,9 +77,7 @@ from icefall.dist import cleanup_dist, setup_dist
 from icefall.env import get_env_info
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[
-    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
 
 
 def get_parser():
@@ -155,8 +153,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need to be "
-        "changed.",
+        help="The initial learning rate.  This value should not need to be " "changed.",
     )
 
     parser.add_argument(
@@ -179,8 +176,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -203,8 +199,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network)" "part.",
     )
 
     parser.add_argument(
@@ -554,23 +549,16 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0
-            if warmup < 1.0
-            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
-        )
-        loss = (
-            params.simple_loss_scale * simple_loss
-            + pruned_loss_scale * pruned_loss
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
         )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
 
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -733,9 +721,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
diff --git a/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py b/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py
index 4582609ac..602e50d29 100755
--- a/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py
+++ b/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py
@@ -84,9 +84,7 @@ def compute_fbank_tal_csasr(num_mel_bins: int = 80):
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set
-                    + cut_set.perturb_speed(0.9)
-                    + cut_set.perturb_speed(1.1)
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -112,9 +110,7 @@ def get_args():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/tal_csasr/ASR/local/prepare_char.py b/egs/tal_csasr/ASR/local/prepare_char.py
index 2c5b8b8b3..1262baf63 100755
--- a/egs/tal_csasr/ASR/local/prepare_char.py
+++ b/egs/tal_csasr/ASR/local/prepare_char.py
@@ -87,9 +87,7 @@ def lexicon_to_fst_no_sil(
         cur_state = loop_state
 
         word = word2id[word]
-        pieces = [
-            token2id[i] if i in token2id else token2id[""] for i in pieces
-        ]
+        pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces]
 
         for i in range(len(pieces) - 1):
             w = word if i == 0 else eps
diff --git a/egs/tal_csasr/ASR/local/prepare_lang.py b/egs/tal_csasr/ASR/local/prepare_lang.py
index e5ae89ec4..c8cf9b881 100755
--- a/egs/tal_csasr/ASR/local/prepare_lang.py
+++ b/egs/tal_csasr/ASR/local/prepare_lang.py
@@ -317,9 +317,7 @@ def lexicon_to_fst(
 
 def get_args():
     parser = argparse.ArgumentParser()
-    parser.add_argument(
-        "--lang-dir", type=str, help="The lang dir, data/lang_phone"
-    )
+    parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
     return parser.parse_args()
 
 
diff --git a/egs/tal_csasr/ASR/local/test_prepare_lang.py b/egs/tal_csasr/ASR/local/test_prepare_lang.py
index d4cf62bba..74e025ad7 100755
--- a/egs/tal_csasr/ASR/local/test_prepare_lang.py
+++ b/egs/tal_csasr/ASR/local/test_prepare_lang.py
@@ -88,9 +88,7 @@ def test_read_lexicon(filename: str):
     fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa.draw("L.pdf", title="L")
 
-    fsa_disambig = lexicon_to_fst(
-        lexicon_disambig, phone2id=phone2id, word2id=word2id
-    )
+    fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
     fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
     fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
diff --git a/egs/tal_csasr/ASR/local/text2token.py b/egs/tal_csasr/ASR/local/text2token.py
index 71be2a613..85047c367 100755
--- a/egs/tal_csasr/ASR/local/text2token.py
+++ b/egs/tal_csasr/ASR/local/text2token.py
@@ -56,9 +56,7 @@ def get_parser():
     parser.add_argument(
         "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
     )
-    parser.add_argument(
-        "--space", default="", type=str, help="space symbol"
-    )
+    parser.add_argument("--space", default="", type=str, help="space symbol")
     parser.add_argument(
         "--non-lang-syms",
         "-l",
@@ -66,9 +64,7 @@ def get_parser():
         type=str,
         help="list of non-linguistic symobles, e.g.,  etc.",
     )
-    parser.add_argument(
-        "text", type=str, default=False, nargs="?", help="input text"
-    )
+    parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
     parser.add_argument(
         "--trans_type",
         "-t",
@@ -108,8 +104,7 @@ def token2id(
             if token_type == "lazy_pinyin":
                 text = lazy_pinyin(chars_list)
                 sub_ids = [
-                    token_table[txt] if txt in token_table else oov_id
-                    for txt in text
+                    token_table[txt] if txt in token_table else oov_id for txt in text
                 ]
                 ids.append(sub_ids)
             else:  # token_type = "pinyin"
@@ -135,9 +130,7 @@ def main():
     if args.text:
         f = codecs.open(args.text, encoding="utf-8")
     else:
-        f = codecs.getreader("utf-8")(
-            sys.stdin if is_python2 else sys.stdin.buffer
-        )
+        f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
 
     sys.stdout = codecs.getwriter("utf-8")(
         sys.stdout if is_python2 else sys.stdout.buffer
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py
index 49bfb148b..2240c1c1d 100644
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py
@@ -222,17 +222,13 @@ class TAL_CSASRAsrDataModule:
             The state dict for the training sampler.
         """
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(
-            self.args.manifest_dir / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
@@ -254,9 +250,7 @@ class TAL_CSASRAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -300,9 +294,7 @@ class TAL_CSASRAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -360,9 +352,7 @@ class TAL_CSASRAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py
index b624913f5..82e1a9437 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py
@@ -208,8 +208,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -268,9 +267,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
     zh_hyps = []
     en_hyps = []
@@ -303,10 +300,7 @@ def decode_one_batch(
             hyps.append(chars_new)
             zh_hyps.append(zh_text)
             en_hyps.append(en_text)
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -375,9 +369,7 @@ def decode_one_batch(
                     f"Unsupported decoding method: {params.decoding_method}"
                 )
             for i in range(encoder_out.size(0)):
-                hyp = sp.decode(
-                    [lexicon.token_table[idx] for idx in hyp_tokens[i]]
-                )
+                hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]])
                 chars = pattern.split(hyp.upper())
                 chars_new = []
                 zh_text = []
@@ -506,9 +498,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results, zh_results, en_results
 
 
@@ -541,8 +531,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -585,9 +574,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@@ -619,9 +606,9 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
@@ -648,9 +635,9 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg + 1]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
index 8f900208a..d0875c5f5 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
@@ -139,8 +139,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     add_model_arguments(parser)
@@ -176,9 +175,9 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
@@ -205,9 +204,9 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg + 1]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
@@ -277,9 +276,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py
index dbe213b24..da4e3bc2f 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py
@@ -165,8 +165,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -198,8 +197,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -263,15 +261,11 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=features, x_lens=feature_lengths
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
 
     num_waves = encoder_out.size(0)
     hyps = []
@@ -367,9 +361,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py
index ca35eba45..97d434157 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py
@@ -86,9 +86,7 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[
-    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
 
 
 def add_model_arguments(parser: argparse.ArgumentParser):
@@ -214,8 +212,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need "
-        "to be changed.",
+        help="The initial learning rate.  This value should not need " "to be changed.",
     )
 
     parser.add_argument(
@@ -238,8 +235,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -262,8 +258,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network)" "part.",
     )
 
     parser.add_argument(
@@ -600,11 +595,7 @@ def compute_loss(
      warmup: a floating point value which increases throughout training;
         values >= 1.0 are fully warmed up and have all modules present.
     """
-    device = (
-        model.device
-        if isinstance(model, DDP)
-        else next(model.parameters()).device
-    )
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
     feature = batch["inputs"]
     # at entry, feature is (N, T, C)
     assert feature.ndim == 3
@@ -634,22 +625,15 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0
-            if warmup < 1.0
-            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
-        )
-        loss = (
-            params.simple_loss_scale * simple_loss
-            + pruned_loss_scale * pruned_loss
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
         )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -828,9 +812,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -944,7 +926,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2 ** 22
+            2**22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
diff --git a/egs/tedlium3/ASR/local/compute_fbank_tedlium.py b/egs/tedlium3/ASR/local/compute_fbank_tedlium.py
index 327962a79..733ebf235 100755
--- a/egs/tedlium3/ASR/local/compute_fbank_tedlium.py
+++ b/egs/tedlium3/ASR/local/compute_fbank_tedlium.py
@@ -83,9 +83,7 @@ def compute_fbank_tedlium():
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set
-                    + cut_set.perturb_speed(0.9)
-                    + cut_set.perturb_speed(1.1)
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
                 )
             cur_num_jobs = num_jobs if ex is None else 80
             cur_num_jobs = min(cur_num_jobs, len(cut_set))
@@ -104,9 +102,7 @@ def compute_fbank_tedlium():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py b/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py
index 49544ccb3..9dbcc9d9e 100644
--- a/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py
+++ b/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py
@@ -25,9 +25,7 @@ import sentencepiece as spm
 
 def get_args():
     parser = argparse.ArgumentParser()
-    parser.add_argument(
-        "--texts", type=List[str], help="The input transcripts list."
-    )
+    parser.add_argument("--texts", type=List[str], help="The input transcripts list.")
     parser.add_argument(
         "--bpe-model",
         type=str,
diff --git a/egs/tedlium3/ASR/local/prepare_lexicon.py b/egs/tedlium3/ASR/local/prepare_lexicon.py
index 35dd332e8..b9160b6d4 100755
--- a/egs/tedlium3/ASR/local/prepare_lexicon.py
+++ b/egs/tedlium3/ASR/local/prepare_lexicon.py
@@ -23,11 +23,12 @@ consisting of supervisions_train.json and does the following:
 1. Generate lexicon_words.txt.
 
 """
-import lhotse
 import argparse
 import logging
 from pathlib import Path
 
+import lhotse
+
 
 def get_args():
     parser = argparse.ArgumentParser()
@@ -61,9 +62,7 @@ def prepare_lexicon(manifests_dir: str, lang_dir: str):
     words = set()
 
     lexicon = Path(lang_dir) / "lexicon_words.txt"
-    sups = lhotse.load_manifest(
-        f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz"
-    )
+    sups = lhotse.load_manifest(f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz")
     for s in sups:
         # list the words units and filter the empty item
         words_list = list(filter(None, s.text.split()))
@@ -88,9 +87,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/tedlium3/ASR/local/prepare_transcripts.py b/egs/tedlium3/ASR/local/prepare_transcripts.py
index 1039ac5bb..7ea4e89a4 100755
--- a/egs/tedlium3/ASR/local/prepare_transcripts.py
+++ b/egs/tedlium3/ASR/local/prepare_transcripts.py
@@ -23,11 +23,12 @@ consisting of supervisions_train.json and does the following:
 1. Generate train.text.
 
 """
-import lhotse
 import argparse
 import logging
 from pathlib import Path
 
+import lhotse
+
 
 def get_args():
     parser = argparse.ArgumentParser()
@@ -61,9 +62,7 @@ def prepare_transcripts(manifests_dir: str, lang_dir: str):
     texts = []
 
     train_text = Path(lang_dir) / "train.text"
-    sups = lhotse.load_manifest(
-        f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz"
-    )
+    sups = lhotse.load_manifest(f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz")
     for s in sups:
         texts.append(s.text)
 
@@ -83,9 +82,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py
index 2b294e601..8ca875c24 100755
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py
@@ -172,8 +172,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -231,9 +230,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -248,10 +245,7 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -374,9 +368,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -409,8 +401,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/export.py b/egs/tedlium3/ASR/pruned_transducer_stateless/export.py
index a1c3bcea3..71a9e2d71 100644
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/export.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/export.py
@@ -106,8 +106,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     return parser
@@ -179,9 +178,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py
index 8480ac029..e8a453c80 100644
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py
@@ -165,8 +165,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -204,8 +203,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -271,9 +269,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -298,10 +294,7 @@ def main():
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -353,9 +346,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/train.py b/egs/tedlium3/ASR/pruned_transducer_stateless/train.py
index 8d5cdf683..59d80a0d8 100755
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/train.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/train.py
@@ -133,8 +133,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -157,8 +156,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network)" "part.",
     )
 
     parser.add_argument(
@@ -556,9 +554,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -678,9 +674,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py
index 94784c4c4..c647392f0 100644
--- a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py
+++ b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py
@@ -18,7 +18,6 @@
 
 import argparse
 import logging
-
 from functools import lru_cache
 from pathlib import Path
 from typing import Any, Dict, Optional
@@ -171,9 +170,7 @@ class TedLiumAsrDataModule:
         )
 
     def train_dataloaders(
-        self,
-        cuts_train: CutSet,
-        sampler_state_dict: Optional[Dict[str, Any]] = None
+        self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None
     ) -> DataLoader:
         """
         Args:
@@ -186,9 +183,7 @@ class TedLiumAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
 
             input_transforms.append(
                 SpecAugment(
@@ -208,13 +203,9 @@ class TedLiumAsrDataModule:
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
-            cuts_musan = load_manifest(
-                self.args.manifest_dir / "musan_cuts.jsonl.gz"
-            )
+            cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
@@ -247,9 +238,7 @@ class TedLiumAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -306,9 +295,7 @@ class TedLiumAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -339,9 +326,7 @@ class TedLiumAsrDataModule:
         logging.debug("About to create test dataset")
         if self.args.on_the_fly_feats:
             test = K2SpeechRecognitionDataset(
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -375,13 +360,9 @@ class TedLiumAsrDataModule:
     @lru_cache()
     def dev_cuts(self) -> CutSet:
         logging.info("About to get dev cuts")
-        return load_manifest_lazy(
-            self.args.manifest_dir / "tedlium_cuts_dev.jsonl.gz"
-        )
+        return load_manifest_lazy(self.args.manifest_dir / "tedlium_cuts_dev.jsonl.gz")
 
     @lru_cache()
     def test_cuts(self) -> CutSet:
         logging.info("About to get test cuts")
-        return load_manifest_lazy(
-            self.args.manifest_dir / "tedlium_cuts_test.jsonl.gz"
-        )
+        return load_manifest_lazy(self.args.manifest_dir / "tedlium_cuts_test.jsonl.gz")
diff --git a/egs/tedlium3/ASR/transducer_stateless/beam_search.py b/egs/tedlium3/ASR/transducer_stateless/beam_search.py
index 77caf6460..1f99edaf3 100644
--- a/egs/tedlium3/ASR/transducer_stateless/beam_search.py
+++ b/egs/tedlium3/ASR/transducer_stateless/beam_search.py
@@ -87,9 +87,9 @@ def greedy_search(
         y = logits.argmax().item()
         if y != blank_id and y != unk_id:
             hyp.append(y)
-            decoder_input = torch.tensor(
-                [hyp[-context_size:]], device=device
-            ).reshape(1, context_size)
+            decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape(
+                1, context_size
+            )
 
             decoder_out = model.decoder(decoder_input, need_pad=False)
 
@@ -148,9 +148,7 @@ class HypothesisList(object):
         key = hyp.key
         if key in self:
             old_hyp = self._data[key]  # shallow copy
-            torch.logaddexp(
-                old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
-            )
+            torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob)
         else:
             self._data[key] = hyp
 
@@ -166,9 +164,7 @@ class HypothesisList(object):
           Return the hypothesis that has the largest `log_prob`.
         """
         if length_norm:
-            return max(
-                self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
-            )
+            return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
         else:
             return max(self._data.values(), key=lambda hyp: hyp.log_prob)
 
@@ -344,9 +340,9 @@ def modified_beam_search(
 
     device = model.device
 
-    decoder_input = torch.tensor(
-        [blank_id] * context_size, device=device
-    ).reshape(1, context_size)
+    decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
+        1, context_size
+    )
 
     decoder_out = model.decoder(decoder_input, need_pad=False)
 
@@ -383,9 +379,7 @@ def modified_beam_search(
         decoder_out = model.decoder(decoder_input, need_pad=False)
         # decoder_output is of shape (num_hyps, 1, decoder_output_dim)
 
-        current_encoder_out = current_encoder_out.expand(
-            decoder_out.size(0), 1, -1
-        )
+        current_encoder_out = current_encoder_out.expand(decoder_out.size(0), 1, -1)
 
         logits = model.joiner(
             current_encoder_out,
@@ -454,9 +448,9 @@ def beam_search(
 
     device = model.device
 
-    decoder_input = torch.tensor(
-        [blank_id] * context_size, device=device
-    ).reshape(1, context_size)
+    decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
+        1, context_size
+    )
 
     decoder_out = model.decoder(decoder_input, need_pad=False)
 
diff --git a/egs/tedlium3/ASR/transducer_stateless/decode.py b/egs/tedlium3/ASR/transducer_stateless/decode.py
index d3e9e55e7..e5ab2c107 100755
--- a/egs/tedlium3/ASR/transducer_stateless/decode.py
+++ b/egs/tedlium3/ASR/transducer_stateless/decode.py
@@ -130,8 +130,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -250,9 +249,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
     batch_size = encoder_out.size(0)
 
@@ -275,9 +272,7 @@ def decode_one_batch(
                 model=model, encoder_out=encoder_out_i, beam=params.beam_size
             )
         else:
-            raise ValueError(
-                f"Unsupported decoding method: {params.decoding_method}"
-            )
+            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
         hyps.append(sp.decode(hyp).split())
 
     if params.decoding_method == "greedy_search":
@@ -348,9 +343,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -383,8 +376,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
diff --git a/egs/tedlium3/ASR/transducer_stateless/decoder.py b/egs/tedlium3/ASR/transducer_stateless/decoder.py
index f0c6f32b6..f9a3814c6 100644
--- a/egs/tedlium3/ASR/transducer_stateless/decoder.py
+++ b/egs/tedlium3/ASR/transducer_stateless/decoder.py
@@ -90,9 +90,7 @@ class Decoder(nn.Module):
         if self.context_size > 1:
             embedding_out = embedding_out.permute(0, 2, 1)
             if need_pad is True:
-                embedding_out = F.pad(
-                    embedding_out, pad=(self.context_size - 1, 0)
-                )
+                embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
             else:
                 # During inference time, there is no need to do extra padding
                 # as we only need one output
diff --git a/egs/tedlium3/ASR/transducer_stateless/export.py b/egs/tedlium3/ASR/transducer_stateless/export.py
index c32b1d002..c2ec7a590 100644
--- a/egs/tedlium3/ASR/transducer_stateless/export.py
+++ b/egs/tedlium3/ASR/transducer_stateless/export.py
@@ -110,8 +110,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     return parser
@@ -247,9 +246,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tedlium3/ASR/transducer_stateless/pretrained.py b/egs/tedlium3/ASR/transducer_stateless/pretrained.py
index c0e3bb844..070b070a7 100644
--- a/egs/tedlium3/ASR/transducer_stateless/pretrained.py
+++ b/egs/tedlium3/ASR/transducer_stateless/pretrained.py
@@ -127,8 +127,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -223,8 +222,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -285,9 +283,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -335,9 +331,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tedlium3/ASR/transducer_stateless/train.py b/egs/tedlium3/ASR/transducer_stateless/train.py
index 09cbf4a00..4fc13b1da 100755
--- a/egs/tedlium3/ASR/transducer_stateless/train.py
+++ b/egs/tedlium3/ASR/transducer_stateless/train.py
@@ -133,8 +133,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -525,9 +524,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -647,9 +644,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/timit/ASR/RESULTS.md b/egs/timit/ASR/RESULTS.md
index b78c16b88..d8ceb82b6 100644
--- a/egs/timit/ASR/RESULTS.md
+++ b/egs/timit/ASR/RESULTS.md
@@ -71,4 +71,4 @@ python tdnn_ligru_ctc/decode.py --epoch 25 \
                                --avg 17 \
                                --max-duration 20 \
                                --lang-dir data/lang_phone
-```
\ No newline at end of file
+```
diff --git a/egs/timit/ASR/local/compile_hlg.py b/egs/timit/ASR/local/compile_hlg.py
index 58cab4cf2..32c248d7e 100644
--- a/egs/timit/ASR/local/compile_hlg.py
+++ b/egs/timit/ASR/local/compile_hlg.py
@@ -146,9 +146,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/timit/ASR/local/compute_fbank_timit.py b/egs/timit/ASR/local/compute_fbank_timit.py
index f25786a0c..ecdf10ba9 100644
--- a/egs/timit/ASR/local/compute_fbank_timit.py
+++ b/egs/timit/ASR/local/compute_fbank_timit.py
@@ -85,9 +85,7 @@ def compute_fbank_timit():
             )
             if partition == "TRAIN":
                 cut_set = (
-                    cut_set
-                    + cut_set.perturb_speed(0.9)
-                    + cut_set.perturb_speed(1.1)
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -101,9 +99,7 @@ def compute_fbank_timit():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/timit/ASR/local/prepare_lexicon.py b/egs/timit/ASR/local/prepare_lexicon.py
index 04023a9ab..0cf0f0deb 100644
--- a/egs/timit/ASR/local/prepare_lexicon.py
+++ b/egs/timit/ASR/local/prepare_lexicon.py
@@ -62,9 +62,7 @@ def prepare_lexicon(manifests_dir: str, lang_dir: str):
 
     phones = set()
 
-    supervisions_train = (
-        Path(manifests_dir) / "timit_supervisions_TRAIN.jsonl.gz"
-    )
+    supervisions_train = Path(manifests_dir) / "timit_supervisions_TRAIN.jsonl.gz"
     lexicon = Path(lang_dir) / "lexicon.txt"
 
     logging.info(f"Loading {supervisions_train}!")
@@ -97,9 +95,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/timit/ASR/prepare.sh b/egs/timit/ASR/prepare.sh
index ae1b96a68..d11cd3a05 100644
--- a/egs/timit/ASR/prepare.sh
+++ b/egs/timit/ASR/prepare.sh
@@ -20,9 +20,9 @@ stop_stage=100
 #  - $dl_dir/lm
 #      This directory contains the language model(LM) downloaded from
 #      https://huggingface.co/luomingshuang/timit_lm, and the LM is based
-#	     on 39 phones. About how to get these LM files, you can know it 
+#	     on 39 phones. About how to get these LM files, you can know it
 #      from https://github.com/luomingshuang/Train_LM_with_kaldilm.
-#	
+#
 #	    - lm_3_gram.arpa
 #     - lm_4_gram.arpa
 #
diff --git a/egs/timit/ASR/tdnn_ligru_ctc/decode.py b/egs/timit/ASR/tdnn_ligru_ctc/decode.py
index 4f2aa2340..4beeed18c 100644
--- a/egs/timit/ASR/tdnn_ligru_ctc/decode.py
+++ b/egs/timit/ASR/tdnn_ligru_ctc/decode.py
@@ -336,9 +336,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -400,9 +398,7 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(
-        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
-    )
+    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -462,9 +458,7 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save(
-            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
-        )
+        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
         return
 
     model.to(device)
@@ -485,9 +479,7 @@ def main():
         G=G,
     )
 
-    save_results(
-        params=params, test_set_name=test_set, results_dict=results_dict
-    )
+    save_results(params=params, test_set_name=test_set, results_dict=results_dict)
 
     logging.info("Done!")
 
diff --git a/egs/timit/ASR/tdnn_ligru_ctc/model.py b/egs/timit/ASR/tdnn_ligru_ctc/model.py
index 4d2199ace..9a594a969 100644
--- a/egs/timit/ASR/tdnn_ligru_ctc/model.py
+++ b/egs/timit/ASR/tdnn_ligru_ctc/model.py
@@ -16,11 +16,11 @@
 # limitations under the License.
 
 
+from typing import Optional
+
 import torch
 import torch.nn as nn
-
 from torch import Tensor
-from typing import Optional
 
 
 class TdnnLiGRU(nn.Module):
@@ -261,9 +261,7 @@ class LiGRU(torch.nn.Module):
         h = []
         if hx is not None:
             if self.bidirectional:
-                hx = hx.reshape(
-                    self.num_layers, self.batch_size * 2, self.hidden_size
-                )
+                hx = hx.reshape(self.num_layers, self.batch_size * 2, self.hidden_size)
         # Processing the different layers
         for i, ligru_lay in enumerate(self.rnn):
             if hx is not None:
@@ -445,9 +443,7 @@ class LiGRU_Layer(torch.nn.Module):
             if self.drop_mask_cnt + self.batch_size > self.N_drop_masks:
                 self.drop_mask_cnt = 0
                 self.drop_masks = self.drop(
-                    torch.ones(
-                        self.N_drop_masks, self.hidden_size, device=w.device
-                    )
+                    torch.ones(self.N_drop_masks, self.hidden_size, device=w.device)
                 ).data
 
             # Sampling the mask
diff --git a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py b/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py
index 7da285944..4ef134412 100644
--- a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py
+++ b/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py
@@ -29,11 +29,7 @@ import torchaudio
 from model import TdnnLiGRU
 from torch.nn.utils.rnn import pad_sequence
 
-from icefall.decode import (
-    get_lattice,
-    one_best_decoding,
-    rescore_with_whole_lattice,
-)
+from icefall.decode import get_lattice, one_best_decoding, rescore_with_whole_lattice
 from icefall.utils import AttributeDict, get_texts
 
 
@@ -58,9 +54,7 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument(
-        "--HLG", type=str, required=True, help="Path to HLG.pt."
-    )
+    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
 
     parser.add_argument(
         "--method",
@@ -145,8 +139,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -215,9 +208,7 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
     features = features.permute(0, 2, 1)  # now features is (N, C, T)
 
     with torch.no_grad():
@@ -269,9 +260,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/timit/ASR/tdnn_ligru_ctc/train.py b/egs/timit/ASR/tdnn_ligru_ctc/train.py
index 452c2a7cb..48b7feda0 100644
--- a/egs/timit/ASR/tdnn_ligru_ctc/train.py
+++ b/egs/timit/ASR/tdnn_ligru_ctc/train.py
@@ -449,9 +449,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py
index 1554e987f..51ca4cc6e 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -154,9 +154,7 @@ class TimitAsrDataModule(DataModule):
         cuts_train = self.train_cuts()
 
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(
-            self.args.feature_dir / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(self.args.feature_dir / "musan_cuts.jsonl.gz")
 
         logging.info("About to create train dataset")
         transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))]
@@ -178,9 +176,9 @@ class TimitAsrDataModule(DataModule):
         # In different Lhotse's versions, the default of num_frame_masks is
         # different.
         num_frame_masks = 10
-        num_frame_masks_parameter = inspect.signature(
-            SpecAugment.__init__
-        ).parameters["num_frame_masks"]
+        num_frame_masks_parameter = inspect.signature(SpecAugment.__init__).parameters[
+            "num_frame_masks"
+        ]
         if num_frame_masks_parameter.default == 1:
             num_frame_masks = 2
         logging.info(f"Num frame mask: {num_frame_masks}")
@@ -212,9 +210,7 @@ class TimitAsrDataModule(DataModule):
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -263,9 +259,7 @@ class TimitAsrDataModule(DataModule):
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -299,20 +293,14 @@ class TimitAsrDataModule(DataModule):
         for cuts_test in cuts:
             logging.debug("About to create test dataset")
             test = K2SpeechRecognitionDataset(
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                )
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
                 if self.args.on_the_fly_feats
                 else PrecomputedFeatures(),
                 return_cuts=self.args.return_cuts,
             )
-            sampler = SingleCutSampler(
-                cuts_test, max_duration=self.args.max_duration
-            )
+            sampler = SingleCutSampler(cuts_test, max_duration=self.args.max_duration)
             logging.debug("About to create test dataloader")
-            test_dl = DataLoader(
-                test, batch_size=None, sampler=sampler, num_workers=1
-            )
+            test_dl = DataLoader(test, batch_size=None, sampler=sampler, num_workers=1)
             test_loaders.append(test_dl)
 
         if is_list:
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/decode.py b/egs/timit/ASR/tdnn_lstm_ctc/decode.py
index 5e7300cf2..502a48def 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/decode.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/decode.py
@@ -335,9 +335,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -399,9 +397,7 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(
-        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
-    )
+    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -461,9 +457,7 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save(
-            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
-        )
+        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
         return
 
     model.to(device)
@@ -483,9 +477,7 @@ def main():
         G=G,
     )
 
-    save_results(
-        params=params, test_set_name=test_set, results_dict=results_dict
-    )
+    save_results(params=params, test_set_name=test_set, results_dict=results_dict)
 
     logging.info("Done!")
 
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/model.py b/egs/timit/ASR/tdnn_lstm_ctc/model.py
index 51edb97e2..e211ad80d 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/model.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/model.py
@@ -74,10 +74,7 @@ class TdnnLstm(nn.Module):
             nn.BatchNorm1d(num_features=512, affine=False),
         )
         self.lstms = nn.ModuleList(
-            [
-                nn.LSTM(input_size=512, hidden_size=512, num_layers=1)
-                for _ in range(4)
-            ]
+            [nn.LSTM(input_size=512, hidden_size=512, num_layers=1) for _ in range(4)]
         )
         self.lstm_bnorms = nn.ModuleList(
             [nn.BatchNorm1d(num_features=512, affine=False) for _ in range(5)]
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py b/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py
index 5f478da1c..3f143912e 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py
@@ -29,11 +29,7 @@ import torchaudio
 from model import TdnnLstm
 from torch.nn.utils.rnn import pad_sequence
 
-from icefall.decode import (
-    get_lattice,
-    one_best_decoding,
-    rescore_with_whole_lattice,
-)
+from icefall.decode import get_lattice, one_best_decoding, rescore_with_whole_lattice
 from icefall.utils import AttributeDict, get_texts
 
 
@@ -58,9 +54,7 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument(
-        "--HLG", type=str, required=True, help="Path to HLG.pt."
-    )
+    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
 
     parser.add_argument(
         "--method",
@@ -145,8 +139,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -215,9 +208,7 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
     features = features.permute(0, 2, 1)  # now features is (N, C, T)
 
     with torch.no_grad():
@@ -269,9 +260,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/train.py b/egs/timit/ASR/tdnn_lstm_ctc/train.py
index 849256b98..be1ecffaa 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/train.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/train.py
@@ -449,9 +449,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py
index 8a9f6ed30..bd73e520e 100755
--- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py
+++ b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py
@@ -20,12 +20,7 @@ import logging
 from pathlib import Path
 
 import torch
-from lhotse import (
-    CutSet,
-    KaldifeatFbank,
-    KaldifeatFbankConfig,
-    LilcomHdf5Writer,
-)
+from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomHdf5Writer
 
 # Torch's multithreaded behavior needs to be disabled or
 # it wastes a lot of CPU and slow things down.
@@ -83,9 +78,7 @@ def compute_fbank_wenetspeech_dev_test():
 
 
 def main():
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
     logging.basicConfig(format=formatter, level=logging.INFO)
 
     compute_fbank_wenetspeech_dev_test()
diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py
index a882b6113..1b257fb70 100755
--- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py
+++ b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py
@@ -152,9 +152,7 @@ def main():
     date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
 
     log_filename = "log-compute_fbank_wenetspeech_splits"
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
     log_filename = f"{log_filename}-{date_time}"
 
     logging.basicConfig(
diff --git a/egs/wenetspeech/ASR/local/prepare_char.py b/egs/wenetspeech/ASR/local/prepare_char.py
index 8bc073c75..d8622842f 100755
--- a/egs/wenetspeech/ASR/local/prepare_char.py
+++ b/egs/wenetspeech/ASR/local/prepare_char.py
@@ -83,9 +83,7 @@ def lexicon_to_fst_no_sil(
         cur_state = loop_state
 
         word = word2id[word]
-        pieces = [
-            token2id[i] if i in token2id else token2id[""] for i in pieces
-        ]
+        pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces]
 
         for i in range(len(pieces) - 1):
             w = word if i == 0 else eps
@@ -138,9 +136,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
     return False
 
 
-def generate_lexicon(
-    token_sym_table: Dict[str, int], words: List[str]
-) -> Lexicon:
+def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
     """Generate a lexicon from a word list and token_sym_table.
     Args:
       token_sym_table:
diff --git a/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py b/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py
index 817969c47..93ce750f8 100755
--- a/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py
+++ b/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py
@@ -115,11 +115,7 @@ def preprocess_wenet_speech():
                 f"Speed perturb for {partition} with factors 0.9 and 1.1 "
                 "(Perturbing may take 8 minutes and saving may take 20 minutes)"
             )
-            cut_set = (
-                cut_set
-                + cut_set.perturb_speed(0.9)
-                + cut_set.perturb_speed(1.1)
-            )
+            cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
         logging.info(f"Saving to {raw_cuts_path}")
         cut_set.to_file(raw_cuts_path)
 
diff --git a/egs/wenetspeech/ASR/local/text2token.py b/egs/wenetspeech/ASR/local/text2token.py
index 1c463cf1c..d1d237a68 100755
--- a/egs/wenetspeech/ASR/local/text2token.py
+++ b/egs/wenetspeech/ASR/local/text2token.py
@@ -56,9 +56,7 @@ def get_parser():
     parser.add_argument(
         "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
     )
-    parser.add_argument(
-        "--space", default="", type=str, help="space symbol"
-    )
+    parser.add_argument("--space", default="", type=str, help="space symbol")
     parser.add_argument(
         "--non-lang-syms",
         "-l",
@@ -66,9 +64,7 @@ def get_parser():
         type=str,
         help="list of non-linguistic symobles, e.g.,  etc.",
     )
-    parser.add_argument(
-        "text", type=str, default=False, nargs="?", help="input text"
-    )
+    parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
     parser.add_argument(
         "--trans_type",
         "-t",
@@ -108,8 +104,7 @@ def token2id(
             if token_type == "lazy_pinyin":
                 text = lazy_pinyin(chars_list)
                 sub_ids = [
-                    token_table[txt] if txt in token_table else oov_id
-                    for txt in text
+                    token_table[txt] if txt in token_table else oov_id for txt in text
                 ]
                 ids.append(sub_ids)
             else:  # token_type = "pinyin"
@@ -135,9 +130,7 @@ def main():
     if args.text:
         f = codecs.open(args.text, encoding="utf-8")
     else:
-        f = codecs.getreader("utf-8")(
-            sys.stdin if is_python2 else sys.stdin.buffer
-        )
+        f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
 
     sys.stdout = codecs.getwriter("utf-8")(
         sys.stdout if is_python2 else sys.stdout.buffer
diff --git a/egs/wenetspeech/ASR/prepare.sh b/egs/wenetspeech/ASR/prepare.sh
index 755fbb2d7..da7d7e061 100755
--- a/egs/wenetspeech/ASR/prepare.sh
+++ b/egs/wenetspeech/ASR/prepare.sh
@@ -190,7 +190,7 @@ if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then
   mkdir -p $lang_char_dir
 
   if ! which jq; then
-      echo "This script is intended to be used with jq but you have not installed jq 
+      echo "This script is intended to be used with jq but you have not installed jq
       Note: in Linux, you can install jq with the following command:
       1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
       2. chmod +x ./jq
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
index 10c953e3b..9c07263a2 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -212,17 +212,13 @@ class WenetSpeechAsrDataModule:
             The state dict for the training sampler.
         """
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(
-            self.args.manifest_dir / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
@@ -244,9 +240,7 @@ class WenetSpeechAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -289,9 +283,7 @@ class WenetSpeechAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -348,9 +340,7 @@ class WenetSpeechAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -414,8 +404,7 @@ class WenetSpeechAsrDataModule:
     def train_cuts(self) -> CutSet:
         logging.info("About to get train cuts")
         cuts_train = load_manifest_lazy(
-            self.args.manifest_dir
-            / f"cuts_{self.args.training_subset}.jsonl.gz"
+            self.args.manifest_dir / f"cuts_{self.args.training_subset}.jsonl.gz"
         )
         return cuts_train
 
@@ -427,13 +416,9 @@ class WenetSpeechAsrDataModule:
     @lru_cache()
     def test_net_cuts(self) -> List[CutSet]:
         logging.info("About to get TEST_NET cuts")
-        return load_manifest_lazy(
-            self.args.manifest_dir / "cuts_TEST_NET.jsonl.gz"
-        )
+        return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_NET.jsonl.gz")
 
     @lru_cache()
     def test_meeting_cuts(self) -> List[CutSet]:
         logging.info("About to get TEST_MEETING cuts")
-        return load_manifest_lazy(
-            self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz"
-        )
+        return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz")
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py
index f0c9bebec..cd9ed57b9 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py
@@ -114,11 +114,7 @@ from beam_search import (
 from train import get_params, get_transducer_model
 
 from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
-from icefall.checkpoint import (
-    average_checkpoints,
-    find_checkpoints,
-    load_checkpoint,
-)
+from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
 from icefall.lexicon import Lexicon
 from icefall.utils import (
     AttributeDict,
@@ -252,8 +248,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -328,9 +323,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -389,10 +382,7 @@ def decode_one_batch(
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -515,9 +505,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -550,8 +538,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -663,9 +650,7 @@ def main():
             )
             decoding_graph.scores *= params.ngram_lm_scale
         else:
-            decoding_graph = k2.trivial_graph(
-                params.vocab_size - 1, device=device
-            )
+            decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
     else:
         decoding_graph = None
 
@@ -716,8 +701,7 @@ def main():
         )
 
     dev_shards = [
-        str(path)
-        for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
+        str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
     ]
     cuts_dev_webdataset = CutSet.from_webdataset(
         dev_shards,
@@ -727,8 +711,7 @@ def main():
     )
 
     test_net_shards = [
-        str(path)
-        for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
+        str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
     ]
     cuts_test_net_webdataset = CutSet.from_webdataset(
         test_net_shards,
@@ -739,9 +722,7 @@ def main():
 
     test_meeting_shards = [
         str(path)
-        for path in sorted(
-            glob.glob(os.path.join(test_meeting, "shared-*.tar"))
-        )
+        for path in sorted(glob.glob(os.path.join(test_meeting, "shared-*.tar")))
     ]
     cuts_test_meeting_webdataset = CutSet.from_webdataset(
         test_meeting_shards,
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
index 933642a0f..df2fc5df5 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
@@ -205,8 +205,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     return parser
@@ -468,13 +467,9 @@ def export_joiner_model_onnx(
 
         - projected_decoder_out: a tensor of shape (N, joiner_dim)
     """
-    encoder_proj_filename = str(joiner_filename).replace(
-        ".onnx", "_encoder_proj.onnx"
-    )
+    encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx")
 
-    decoder_proj_filename = str(joiner_filename).replace(
-        ".onnx", "_decoder_proj.onnx"
-    )
+    decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx")
 
     encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
     decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
@@ -645,9 +640,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
index e5cc47bfe..42ffbcfb8 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
@@ -146,8 +146,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -331,9 +330,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py
index c396c50ef..a46ff5a07 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py
@@ -219,9 +219,7 @@ def test_joiner(
         )
 
         # Now test encoder_proj
-        joiner_encoder_proj_inputs = {
-            encoder_proj_input_name: encoder_out.numpy()
-        }
+        joiner_encoder_proj_inputs = {encoder_proj_input_name: encoder_out.numpy()}
         joiner_encoder_proj_out = joiner_encoder_proj_session.run(
             [encoder_proj_output_name], joiner_encoder_proj_inputs
         )[0]
@@ -230,16 +228,10 @@ def test_joiner(
         torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out)
         assert torch.allclose(
             joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5
-        ), (
-            (joiner_encoder_proj_out - torch_joiner_encoder_proj_out)
-            .abs()
-            .max()
-        )
+        ), ((joiner_encoder_proj_out - torch_joiner_encoder_proj_out).abs().max())
 
         # Now test decoder_proj
-        joiner_decoder_proj_inputs = {
-            decoder_proj_input_name: decoder_out.numpy()
-        }
+        joiner_decoder_proj_inputs = {decoder_proj_input_name: decoder_out.numpy()}
         joiner_decoder_proj_out = joiner_decoder_proj_session.run(
             [decoder_proj_output_name], joiner_decoder_proj_inputs
         )[0]
@@ -248,11 +240,7 @@ def test_joiner(
         torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out)
         assert torch.allclose(
             joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5
-        ), (
-            (joiner_decoder_proj_out - torch_joiner_decoder_proj_out)
-            .abs()
-            .max()
-        )
+        ), ((joiner_decoder_proj_out - torch_joiner_decoder_proj_out).abs().max())
 
 
 @torch.no_grad()
@@ -304,9 +292,7 @@ def main():
 
 if __name__ == "__main__":
     torch.manual_seed(20220727)
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py
index 3770fbbb4..ca1e408fa 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py
@@ -150,8 +150,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -200,11 +199,7 @@ def greedy_search(
 
     projected_encoder_out = joiner_encoder_proj.run(
         [joiner_encoder_proj.get_outputs()[0].name],
-        {
-            joiner_encoder_proj.get_inputs()[
-                0
-            ].name: packed_encoder_out.data.numpy()
-        },
+        {joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()},
     )[0]
 
     blank_id = 0  # hard-code to 0
@@ -389,9 +384,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
index 9a549efd9..aaf7ac874 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
@@ -158,8 +158,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -190,8 +189,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -253,9 +251,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -280,10 +276,7 @@ def main():
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -335,9 +328,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
index d3cc7c9c9..7aba0711d 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
@@ -115,9 +115,7 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[
-    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
 
 
 def get_parser():
@@ -219,8 +217,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -243,8 +240,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network)" "part.",
     )
 
     parser.add_argument(
@@ -590,22 +586,15 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0
-            if warmup < 1.0
-            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
-        )
-        loss = (
-            params.simple_loss_scale * simple_loss
-            + pruned_loss_scale * pruned_loss
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
         )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -762,9 +751,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -864,7 +851,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2 ** 22
+            2**22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py
index dd27c17f0..9bb55d07a 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py
@@ -210,10 +210,7 @@ class Conformer(EncoderInterface):
           (num_encoder_layers, cnn_module_kernel - 1, encoder_dim).
           NOTE: the returned tensors are on the given device.
         """
-        if (
-            len(self._init_state) == 2
-            and self._init_state[0].size(1) == left_context
-        ):
+        if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context:
             # Note: It is OK to share the init state as it is
             # not going to be modified by the model
             return self._init_state
@@ -433,9 +430,7 @@ class ConformerEncoderLayer(nn.Module):
 
         self.d_model = d_model
 
-        self.self_attn = RelPositionMultiheadAttention(
-            d_model, nhead, dropout=0.0
-        )
+        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
 
         self.feed_forward = nn.Sequential(
             ScaledLinear(d_model, dim_feedforward),
@@ -453,9 +448,7 @@ class ConformerEncoderLayer(nn.Module):
             ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
         )
 
-        self.conv_module = ConvolutionModule(
-            d_model, cnn_module_kernel, causal=causal
-        )
+        self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal)
 
         self.norm_final = BasicNorm(d_model)
 
@@ -520,9 +513,7 @@ class ConformerEncoderLayer(nn.Module):
         src = src + self.dropout(src_att)
 
         # convolution module
-        conv, _ = self.conv_module(
-            src, src_key_padding_mask=src_key_padding_mask
-        )
+        conv, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
         src = src + self.dropout(conv)
 
         # feed forward module
@@ -766,9 +757,7 @@ class RelPositionalEncoding(torch.nn.Module):
         max_len: Maximum input length.
     """
 
-    def __init__(
-        self, d_model: int, dropout_rate: float, max_len: int = 5000
-    ) -> None:
+    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
         """Construct an PositionalEncoding object."""
         super(RelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -784,9 +773,7 @@ class RelPositionalEncoding(torch.nn.Module):
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(1) >= x_size_1 * 2 - 1:
                 # Note: TorchScript doesn't implement operator== for torch.Device
-                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
-                    x.device
-                ):
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
                     self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                 return
         # Suppose `i` means to the position of query vector and `j` means the
@@ -1073,9 +1060,9 @@ class RelPositionMultiheadAttention(nn.Module):
 
         if torch.equal(query, key) and torch.equal(key, value):
             # self-attention
-            q, k, v = nn.functional.linear(
-                query, in_proj_weight, in_proj_bias
-            ).chunk(3, dim=-1)
+            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
+                3, dim=-1
+            )
 
         elif torch.equal(key, value):
             # encoder-decoder attention
@@ -1144,31 +1131,22 @@ class RelPositionMultiheadAttention(nn.Module):
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0)
                 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
-                    raise RuntimeError(
-                        "The size of the 2D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
             elif attn_mask.dim() == 3:
                 if list(attn_mask.size()) != [
                     bsz * num_heads,
                     query.size(0),
                     key.size(0),
                 ]:
-                    raise RuntimeError(
-                        "The size of the 3D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
             else:
                 raise RuntimeError(
-                    "attn_mask's dimension {} is not supported".format(
-                        attn_mask.dim()
-                    )
+                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
                 )
             # attn_mask's dim is 3 now.
 
         # convert ByteTensor key_padding_mask to bool
-        if (
-            key_padding_mask is not None
-            and key_padding_mask.dtype == torch.uint8
-        ):
+        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
             warnings.warn(
                 "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
             )
@@ -1208,23 +1186,15 @@ class RelPositionMultiheadAttention(nn.Module):
         # first compute matrix a and matrix c
         # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
         k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
-        matrix_ac = torch.matmul(
-            q_with_bias_u, k
-        )  # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
 
         # compute matrix b and matrix d
-        matrix_bd = torch.matmul(
-            q_with_bias_v, p
-        )  # (batch, head, time1, 2*time1-1)
+        matrix_bd = torch.matmul(q_with_bias_v, p)  # (batch, head, time1, 2*time1-1)
         matrix_bd = self.rel_shift(matrix_bd, left_context)
 
-        attn_output_weights = (
-            matrix_ac + matrix_bd
-        )  # (batch, head, time1, time2)
+        attn_output_weights = matrix_ac + matrix_bd  # (batch, head, time1, time2)
 
-        attn_output_weights = attn_output_weights.view(
-            bsz * num_heads, tgt_len, -1
-        )
+        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
 
         assert list(attn_output_weights.size()) == [
             bsz * num_heads,
@@ -1265,21 +1235,17 @@ class RelPositionMultiheadAttention(nn.Module):
         ):
             if attn_mask.size(0) != 1:
                 attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len)
-                combined_mask = attn_mask | key_padding_mask.unsqueeze(
-                    1
-                ).unsqueeze(2)
+                combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2)
             else:
                 # attn_mask.shape == (1, tgt_len, src_len)
-                combined_mask = attn_mask.unsqueeze(
-                    0
-                ) | key_padding_mask.unsqueeze(1).unsqueeze(2)
+                combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze(
+                    1
+                ).unsqueeze(2)
 
             attn_output_weights = attn_output_weights.view(
                 bsz, num_heads, tgt_len, src_len
             )
-            attn_output_weights = attn_output_weights.masked_fill(
-                combined_mask, 0.0
-            )
+            attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0)
             attn_output_weights = attn_output_weights.view(
                 bsz * num_heads, tgt_len, src_len
             )
@@ -1291,13 +1257,9 @@ class RelPositionMultiheadAttention(nn.Module):
         attn_output = torch.bmm(attn_output_weights, v)
         assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
         attn_output = (
-            attn_output.transpose(0, 1)
-            .contiguous()
-            .view(tgt_len, bsz, embed_dim)
-        )
-        attn_output = nn.functional.linear(
-            attn_output, out_proj_weight, out_proj_bias
+            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
         )
+        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
 
         if need_weights:
             # average attention weights over heads
@@ -1430,16 +1392,12 @@ class ConvolutionModule(nn.Module):
                 # manualy padding self.lorder zeros to the left
                 x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
             else:
-                assert (
-                    not self.training
-                ), "Cache should be None in training time"
+                assert not self.training, "Cache should be None in training time"
                 assert cache.size(0) == self.lorder
                 x = torch.cat([cache.permute(1, 2, 0), x], dim=2)
                 if right_context > 0:
                     cache = x.permute(2, 0, 1)[
-                        -(self.lorder + right_context) : (  # noqa
-                            -right_context
-                        ),
+                        -(self.lorder + right_context) : (-right_context),  # noqa
                         ...,
                     ]
                 else:
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
index 344e31283..166497c31 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
@@ -244,8 +244,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -342,9 +341,7 @@ def decode_one_batch(
             simulate_streaming=True,
         )
     else:
-        encoder_out, encoder_out_lens = model.encoder(
-            x=feature, x_lens=feature_lens
-        )
+        encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
     hyps = []
 
@@ -360,10 +357,7 @@ def decode_one_batch(
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -484,9 +478,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -519,8 +511,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -589,9 +580,9 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
@@ -618,9 +609,9 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg + 1]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
@@ -720,8 +711,7 @@ def main():
         )
 
     dev_shards = [
-        str(path)
-        for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
+        str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
     ]
     cuts_dev_webdataset = CutSet.from_webdataset(
         dev_shards,
@@ -731,8 +721,7 @@ def main():
     )
 
     test_net_shards = [
-        str(path)
-        for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
+        str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
     ]
     cuts_test_net_webdataset = CutSet.from_webdataset(
         test_net_shards,
@@ -743,9 +732,7 @@ def main():
 
     test_meeting_shards = [
         str(path)
-        for path in sorted(
-            glob.glob(os.path.join(test_meeting, "shared-*.tar"))
-        )
+        for path in sorted(glob.glob(os.path.join(test_meeting, "shared-*.tar")))
     ]
     cuts_test_meeting_webdataset = CutSet.from_webdataset(
         test_meeting_shards,
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py
index 386248554..e522943c0 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py
@@ -75,9 +75,7 @@ class DecodeStream(object):
         # encoder.streaming_forward
         self.done_frames: int = 0
 
-        self.pad_length = (
-            params.right_context + 2
-        ) * params.subsampling_factor + 3
+        self.pad_length = (params.right_context + 2) * params.subsampling_factor + 3
 
         if params.decoding_method == "greedy_search":
             self.hyp = [params.blank_id] * params.context_size
@@ -91,13 +89,11 @@ class DecodeStream(object):
             )
         elif params.decoding_method == "fast_beam_search":
             # The rnnt_decoding_stream for fast_beam_search.
-            self.rnnt_decoding_stream: k2.RnntDecodingStream = (
-                k2.RnntDecodingStream(decoding_graph)
+            self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream(
+                decoding_graph
             )
         else:
-            raise ValueError(
-                f"Unsupported decoding method: {params.decoding_method}"
-            )
+            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
 
     @property
     def done(self) -> bool:
@@ -126,13 +122,10 @@ class DecodeStream(object):
         """Consume chunk_size frames of features"""
         chunk_length = chunk_size + self.pad_length
 
-        ret_length = min(
-            self.num_frames - self.num_processed_frames, chunk_length
-        )
+        ret_length = min(self.num_frames - self.num_processed_frames, chunk_length)
 
         ret_features = self.features[
-            self.num_processed_frames : self.num_processed_frames  # noqa
-            + ret_length
+            self.num_processed_frames : self.num_processed_frames + ret_length  # noqa
         ]
 
         self.num_processed_frames += chunk_size
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
index d0a7fd69f..ff2c4db38 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
@@ -131,8 +131,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     add_model_arguments(parser)
 
@@ -201,9 +200,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py
index 1b064c874..7e4829a60 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py
@@ -157,8 +157,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -190,8 +189,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -253,9 +251,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -280,10 +276,7 @@ def main():
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -335,9 +328,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py
index 651aff6c9..810d94135 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py
@@ -173,14 +173,10 @@ def modified_beam_search(
         log_probs_shape = k2.ragged.create_ragged_shape2(
             row_splits=row_splits, cached_tot_size=log_probs.numel()
         )
-        ragged_log_probs = k2.RaggedTensor(
-            shape=log_probs_shape, value=log_probs
-        )
+        ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
 
         for i in range(batch_size):
-            topk_log_probs, topk_indexes = ragged_log_probs[i].topk(
-                num_active_paths
-            )
+            topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths)
 
             with warnings.catch_warnings():
                 warnings.simplefilter("ignore")
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py
index ff96c6487..6909f40be 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py
@@ -201,8 +201,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -311,9 +310,7 @@ def decode_one_chunk(
     encoder_out = model.joiner.encoder_proj(encoder_out)
 
     if params.decoding_method == "greedy_search":
-        greedy_search(
-            model=model, encoder_out=encoder_out, streams=decode_streams
-        )
+        greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
     elif params.decoding_method == "fast_beam_search":
         processed_lens = processed_lens + encoder_out_lens
         fast_beam_search_one_best(
@@ -333,9 +330,7 @@ def decode_one_chunk(
             num_active_paths=params.num_active_paths,
         )
     else:
-        raise ValueError(
-            f"Unsupported decoding method: {params.decoding_method}"
-        )
+        raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
 
     states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)]
 
@@ -389,9 +384,7 @@ def decode_dataset(
     decode_results = []
     # Contain decode streams currently running.
     decode_streams = []
-    initial_states = model.encoder.get_init_state(
-        params.left_context, device=device
-    )
+    initial_states = model.encoder.get_init_state(params.left_context, device=device)
     for num, cut in enumerate(cuts):
         # each utterance has a DecodeStream.
         decode_stream = DecodeStream(
@@ -461,9 +454,7 @@ def decode_dataset(
     elif params.decoding_method == "modified_beam_search":
         key = f"num_active_paths_{params.num_active_paths}"
     else:
-        raise ValueError(
-            f"Unsupported decoding method: {params.decoding_method}"
-        )
+        raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
 
     return {key: decode_results}
 
@@ -499,8 +490,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -565,9 +555,9 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
@@ -594,9 +584,9 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg + 1]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py
index 2052e9da7..5f614e77c 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py
@@ -98,9 +98,7 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[
-    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
 
 
 def add_model_arguments(parser: argparse.ArgumentParser):
@@ -260,8 +258,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need "
-        "to be changed.",
+        help="The initial learning rate.  This value should not need " "to be changed.",
     )
 
     parser.add_argument(
@@ -284,8 +281,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -308,8 +304,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network)" "part.",
     )
 
     parser.add_argument(
@@ -665,11 +660,7 @@ def compute_loss(
      warmup: a floating point value which increases throughout training;
         values >= 1.0 are fully warmed up and have all modules present.
     """
-    device = (
-        model.device
-        if isinstance(model, DDP)
-        else next(model.parameters()).device
-    )
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
     feature = batch["inputs"]
     # at entry, feature is (N, T, C)
     assert feature.ndim == 3
@@ -701,23 +692,16 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0
-            if warmup < 1.0
-            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
-        )
-        loss = (
-            params.simple_loss_scale * simple_loss
-            + pruned_loss_scale * pruned_loss
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
         )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
 
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -841,9 +825,7 @@ def train_one_epoch(
             scaler.update()
             optimizer.zero_grad()
         except:  # noqa
-            display_and_save_batch(
-                batch, params=params, graph_compiler=graph_compiler
-            )
+            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
             raise
 
         if params.print_diagnostics and batch_idx == 5:
@@ -901,9 +883,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -1016,7 +996,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2 ** 22
+            2**22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
@@ -1184,9 +1164,7 @@ def scan_pessimistic_batches_for_oom(
                     f"Failing criterion: {criterion} "
                     f"(={crit_values[criterion]}) ..."
                 )
-            display_and_save_batch(
-                batch, params=params, graph_compiler=graph_compiler
-            )
+            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
             raise
 
 
diff --git a/egs/yesno/ASR/local/compile_hlg.py b/egs/yesno/ASR/local/compile_hlg.py
index f83be05cf..7234ca929 100755
--- a/egs/yesno/ASR/local/compile_hlg.py
+++ b/egs/yesno/ASR/local/compile_hlg.py
@@ -128,9 +128,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/yesno/ASR/local/compute_fbank_yesno.py b/egs/yesno/ASR/local/compute_fbank_yesno.py
index 9a4e8a36f..75d95df68 100755
--- a/egs/yesno/ASR/local/compute_fbank_yesno.py
+++ b/egs/yesno/ASR/local/compute_fbank_yesno.py
@@ -54,9 +54,7 @@ def compute_fbank_yesno():
         dataset_parts,
     )
 
-    extractor = Fbank(
-        FbankConfig(sampling_rate=8000, num_mel_bins=num_mel_bins)
-    )
+    extractor = Fbank(FbankConfig(sampling_rate=8000, num_mel_bins=num_mel_bins))
 
     with get_executor() as ex:  # Initialize the executor only once.
         for partition, m in manifests.items():
@@ -71,9 +69,7 @@ def compute_fbank_yesno():
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set
-                    + cut_set.perturb_speed(0.9)
-                    + cut_set.perturb_speed(1.1)
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -87,9 +83,7 @@ def compute_fbank_yesno():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/yesno/ASR/tdnn/decode.py b/egs/yesno/ASR/tdnn/decode.py
index 9d4ab4b61..d5efb41df 100755
--- a/egs/yesno/ASR/tdnn/decode.py
+++ b/egs/yesno/ASR/tdnn/decode.py
@@ -201,9 +201,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -274,9 +272,7 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(
-        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
-    )
+    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -297,9 +293,7 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save(
-            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
-        )
+        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
         return
 
     model.to(device)
@@ -317,9 +311,7 @@ def main():
         word_table=lexicon.word_table,
     )
 
-    save_results(
-        exp_dir=params.exp_dir, test_set_name="test_set", results=results
-    )
+    save_results(exp_dir=params.exp_dir, test_set_name="test_set", results=results)
 
     logging.info("Done!")
 
diff --git a/egs/yesno/ASR/tdnn/pretrained.py b/egs/yesno/ASR/tdnn/pretrained.py
index 14220be19..88d5eca5d 100755
--- a/egs/yesno/ASR/tdnn/pretrained.py
+++ b/egs/yesno/ASR/tdnn/pretrained.py
@@ -53,9 +53,7 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument(
-        "--HLG", type=str, required=True, help="Path to HLG.pt."
-    )
+    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
 
     parser.add_argument(
         "sound_files",
@@ -102,8 +100,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -159,9 +156,7 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     # Note: We don't use key padding mask for attention during decoding
     with torch.no_grad():
@@ -201,9 +196,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/yesno/ASR/tdnn/train.py b/egs/yesno/ASR/tdnn/train.py
index f32a27f35..335493491 100755
--- a/egs/yesno/ASR/tdnn/train.py
+++ b/egs/yesno/ASR/tdnn/train.py
@@ -430,9 +430,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/egs/yesno/ASR/transducer/decode.py b/egs/yesno/ASR/transducer/decode.py
index 6714180db..7f13e417a 100755
--- a/egs/yesno/ASR/transducer/decode.py
+++ b/egs/yesno/ASR/transducer/decode.py
@@ -116,9 +116,7 @@ def decode_one_batch(
     # at entry, feature is (N, T, C)
     feature_lens = batch["supervisions"]["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
     hyps = []
     batch_size = encoder_out.size(0)
@@ -186,9 +184,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -303,9 +299,7 @@ def main():
         model=model,
     )
 
-    save_results(
-        exp_dir=params.exp_dir, test_set_name="test_set", results=results
-    )
+    save_results(exp_dir=params.exp_dir, test_set_name="test_set", results=results)
 
     logging.info("Done!")
 
diff --git a/egs/yesno/ASR/transducer/train.py b/egs/yesno/ASR/transducer/train.py
index deb92107d..88866ae81 100755
--- a/egs/yesno/ASR/transducer/train.py
+++ b/egs/yesno/ASR/transducer/train.py
@@ -430,9 +430,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/icefall/char_graph_compiler.py b/icefall/char_graph_compiler.py
index 235160e14..c31db6e4c 100644
--- a/icefall/char_graph_compiler.py
+++ b/icefall/char_graph_compiler.py
@@ -71,9 +71,7 @@ class CharCtcTrainingGraphCompiler(object):
         for text in texts:
             text = re.sub(whitespace, "", text)
             sub_ids = [
-                self.token_table[txt]
-                if txt in self.token_table
-                else self.oov_id
+                self.token_table[txt] if txt in self.token_table else self.oov_id
                 for txt in text
             ]
             ids.append(sub_ids)
@@ -96,9 +94,7 @@ class CharCtcTrainingGraphCompiler(object):
         for text in texts:
             text = text.split("/")
             sub_ids = [
-                self.token_table[txt]
-                if txt in self.token_table
-                else self.oov_id
+                self.token_table[txt] if txt in self.token_table else self.oov_id
                 for txt in text
             ]
             ids.append(sub_ids)
diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py
index 5069b78e8..8aa0a8eeb 100644
--- a/icefall/checkpoint.py
+++ b/icefall/checkpoint.py
@@ -292,15 +292,11 @@ def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
     """
     checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt"))
     pattern = re.compile(r"checkpoint-([0-9]+).pt")
-    iter_checkpoints = [
-        (int(pattern.search(c).group(1)), c) for c in checkpoints
-    ]
+    iter_checkpoints = [(int(pattern.search(c).group(1)), c) for c in checkpoints]
     # iter_checkpoints is a list of tuples. Each tuple contains
     # two elements: (iteration_number, checkpoint-iteration_number.pt)
 
-    iter_checkpoints = sorted(
-        iter_checkpoints, reverse=True, key=lambda x: x[0]
-    )
+    iter_checkpoints = sorted(iter_checkpoints, reverse=True, key=lambda x: x[0])
     if iteration >= 0:
         ans = [ic[1] for ic in iter_checkpoints if ic[0] >= iteration]
     else:
@@ -469,7 +465,5 @@ def average_state_dict(
         v = state_dict_1[k]
         if torch.is_floating_point(v):
             v *= weight_1
-            v += (
-                state_dict_2[k].to(device=state_dict_1[k].device) * weight_2
-            )
+            v += state_dict_2[k].to(device=state_dict_1[k].device) * weight_2
             v *= scaling_factor
diff --git a/icefall/decode.py b/icefall/decode.py
index 099e2d171..e4c614c4e 100644
--- a/icefall/decode.py
+++ b/icefall/decode.py
@@ -334,13 +334,9 @@ class Nbest(object):
         if hasattr(lattice, "aux_labels"):
             # delete token IDs as it is not needed
             del word_fsa.aux_labels
-            word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(
-                word_fsa
-            )
+            word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa)
         else:
-            word_fsa_with_epsilon_loops = k2.linear_fst_with_self_loops(
-                word_fsa
-            )
+            word_fsa_with_epsilon_loops = k2.linear_fst_with_self_loops(word_fsa)
 
         path_to_utt_map = self.shape.row_ids(1)
 
@@ -370,9 +366,7 @@ class Nbest(object):
         # path_lattice has word IDs as labels and token IDs as aux_labels
         path_lattice = k2.top_sort(k2.connect(path_lattice))
 
-        one_best = k2.shortest_path(
-            path_lattice, use_double_scores=use_double_scores
-        )
+        one_best = k2.shortest_path(path_lattice, use_double_scores=use_double_scores)
 
         one_best = k2.invert(one_best)
         # Now one_best has token IDs as labels and word IDs as aux_labels
@@ -442,9 +436,7 @@ class Nbest(object):
         scores_shape = self.fsa.arcs.shape().remove_axis(1)
         # scores_shape has axes [path][arc]
 
-        ragged_scores = k2.RaggedTensor(
-            scores_shape, self.fsa.scores.contiguous()
-        )
+        ragged_scores = k2.RaggedTensor(scores_shape, self.fsa.scores.contiguous())
 
         tot_scores = ragged_scores.sum()
 
@@ -483,9 +475,7 @@ def one_best_decoding(
             am_scores = saved_am_scores / lm_scale
             lattice.scores = am_scores + lattice.lm_scores
 
-            best_path = k2.shortest_path(
-                lattice, use_double_scores=use_double_scores
-            )
+            best_path = k2.shortest_path(lattice, use_double_scores=use_double_scores)
             key = f"lm_scale_{lm_scale}"
             ans[key] = best_path
         return ans
@@ -696,9 +686,7 @@ def rescore_with_n_best_list(
             logging.info(f"num_paths before decreasing: {num_paths}")
             num_paths = int(num_paths / 2)
             if loop_count >= max_loop_count or num_paths <= 0:
-                logging.info(
-                    "Return None as the resulting lattice is too large."
-                )
+                logging.info("Return None as the resulting lattice is too large.")
                 return None
             logging.info(
                 "This OOM is not an error. You can ignore it. "
@@ -805,13 +793,9 @@ def rescore_with_whole_lattice(
         except RuntimeError as e:
             logging.info(f"Caught exception:\n{e}\n")
             if loop_count >= max_loop_count:
-                logging.info(
-                    "Return None as the resulting lattice is too large."
-                )
+                logging.info("Return None as the resulting lattice is too large.")
                 return None
-            logging.info(
-                f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}"
-            )
+            logging.info(f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}")
             logging.info(
                 "This OOM is not an error. You can ignore it. "
                 "If your model does not converge well, or --max-duration "
@@ -823,9 +807,7 @@ def rescore_with_whole_lattice(
                 prune_th_list[loop_count],
                 True,
             )
-            logging.info(
-                f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}"
-            )
+            logging.info(f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}")
         loop_count += 1
 
     # lat has token IDs as labels
@@ -912,9 +894,7 @@ def rescore_with_attention_decoder(
             logging.info(f"num_paths before decreasing: {num_paths}")
             num_paths = int(num_paths / 2)
             if loop_count >= max_loop_count or num_paths <= 0:
-                logging.info(
-                    "Return None as the resulting lattice is too large."
-                )
+                logging.info("Return None as the resulting lattice is too large.")
                 return None
             logging.info(
                 "This OOM is not an error. You can ignore it. "
diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py
index b075aceac..207c12bf1 100644
--- a/icefall/diagnostics.py
+++ b/icefall/diagnostics.py
@@ -19,7 +19,7 @@
 
 import random
 from dataclasses import dataclass
-from typing import Optional, Tuple, List
+from typing import List, Optional, Tuple
 
 import torch
 from torch import Tensor, nn
@@ -78,11 +78,11 @@ def get_tensor_stats(
     elif stats_type == "abs":
         x = x.abs()
     elif stats_type == "rms":
-        x = x ** 2
+        x = x**2
     elif stats_type == "positive":
         x = (x > 0).to(dtype=torch.float)
     else:
-        assert stats_type in [ "value", "max", "min" ]
+        assert stats_type in ["value", "max", "min"]
 
     sum_dims = [d for d in range(x.ndim) if d != dim]
     if len(sum_dims) > 0:
@@ -121,7 +121,9 @@ class TensorDiagnostic(object):
         self.name = name
         self.class_name = None  # will assign in accumulate()
 
-        self.stats = None  # we'll later assign a list to this data member.  It's a list of dict.
+        self.stats = (
+            None  # we'll later assign a list to this data member.  It's a list of dict.
+        )
 
         # the keys into self.stats[dim] are strings, whose values can be
         # "abs", "max", "min" ,"value", "positive", "rms", "value".
@@ -133,7 +135,6 @@ class TensorDiagnostic(object):
         # only adding a new element to the list if there was a different dim.
         # if the string in the key is "eigs", if we detect a length mismatch we put None as the value.
 
-
     def accumulate(self, x, class_name: Optional[str] = None):
         """
         Accumulate tensors.
@@ -185,17 +186,12 @@ class TensorDiagnostic(object):
                         done = True
                         break
                 if not done:
-                    if (
-                        this_dim_stats[stats_type] != []
-                        and stats_type == "eigs"
-                    ):
+                    if this_dim_stats[stats_type] != [] and stats_type == "eigs":
                         # >1 size encountered on this dim, e.g. it's a batch or time dimension,
                         # don't accumulat "eigs" stats type, it uses too much memory
                         this_dim_stats[stats_type] = None
                     else:
-                        this_dim_stats[stats_type].append(
-                            TensorAndCount(stats, count)
-                        )
+                        this_dim_stats[stats_type].append(TensorAndCount(stats, count))
 
     def print_diagnostics(self):
         """Print diagnostics for each dimension of the tensor."""
@@ -211,7 +207,6 @@ class TensorDiagnostic(object):
                     assert stats_type == "eigs"
                     continue
 
-
                 def get_count(count):
                     return 1 if stats_type in ["max", "min"] else count
 
@@ -229,9 +224,7 @@ class TensorDiagnostic(object):
                         eigs, _ = torch.symeig(stats)
                         stats = eigs.abs().sqrt()
                     except:  # noqa
-                        print(
-                            "Error getting eigenvalues, trying another method."
-                        )
+                        print("Error getting eigenvalues, trying another method.")
                         eigs, _ = torch.eig(stats)
                         stats = eigs.abs().sqrt()
                         # sqrt so it reflects data magnitude, like stddev- not variance
@@ -242,9 +235,9 @@ class TensorDiagnostic(object):
 
                 # if `summarize` we print percentiles of the stats; else,
                 # we print out individual elements.
-                summarize = (
-                    len(stats_list) > 1
-                ) or self.opts.dim_is_summarized(stats.numel())
+                summarize = (len(stats_list) > 1) or self.opts.dim_is_summarized(
+                    stats.numel()
+                )
                 if summarize:  # usually `summarize` will be true
                     # print out percentiles.
                     stats = stats.sort()[0]
@@ -261,15 +254,15 @@ class TensorDiagnostic(object):
                     ans = stats.tolist()
                     ans = ["%.2g" % x for x in ans]
                     ans = "[" + " ".join(ans) + "]"
-                if stats_type in [ "value", "rms", "eigs" ]:
+                if stats_type in ["value", "rms", "eigs"]:
                     # This norm is useful because it is strictly less than the largest
                     # sqrt(eigenvalue) of the variance, which we print out, and shows,
                     # speaking in an approximate way, how much of that largest eigenvalue
                     # can be attributed to the mean of the distribution.
-                    norm = (stats ** 2).sum().sqrt().item()
+                    norm = (stats**2).sum().sqrt().item()
                     ans += f", norm={norm:.2g}"
                 mean = stats.mean().item()
-                rms = (stats ** 2).mean().sqrt().item()
+                rms = (stats**2).mean().sqrt().item()
                 ans += f", mean={mean:.3g}, rms={rms:.3g}"
 
                 # OK, "ans" contains the actual stats, e.g.
@@ -277,17 +270,16 @@ class TensorDiagnostic(object):
 
                 sizes = [x.tensor.shape[0] for x in stats_list]
                 size_str = (
-                    f"{sizes[0]}"
-                    if len(sizes) == 1
-                    else f"{min(sizes)}..{max(sizes)}"
+                    f"{sizes[0]}" if len(sizes) == 1 else f"{min(sizes)}..{max(sizes)}"
+                )
+                maybe_class_name = (
+                    f" type={self.class_name}," if self.class_name is not None else ""
                 )
-                maybe_class_name = f" type={self.class_name}," if self.class_name is not None else ""
                 print(
                     f"module={self.name},{maybe_class_name} dim={dim}, size={size_str}, {stats_type} {ans}"
                 )
 
 
-
 class ModelDiagnostic(object):
     """This class stores diagnostics for all tensors in the torch.nn.Module.
 
@@ -345,32 +337,32 @@ def attach_diagnostics(
         # (matters for name, since the variable gets overwritten).
         # These closures don't really capture by value, only by
         # "the final value the variable got in the function" :-(
-        def forward_hook(
-            _module, _input, _output, _model_diagnostic=ans, _name=name
-        ):
+        def forward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
             if isinstance(_output, tuple) and len(_output) == 1:
                 _output = _output[0]
 
             if isinstance(_output, Tensor):
-                _model_diagnostic[f"{_name}.output"].accumulate(_output,
-                                                                class_name=type(_module).__name__)
+                _model_diagnostic[f"{_name}.output"].accumulate(
+                    _output, class_name=type(_module).__name__
+                )
             elif isinstance(_output, tuple):
                 for i, o in enumerate(_output):
-                    _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o,
-                                                                         class_name=type(_module).__name__)
+                    _model_diagnostic[f"{_name}.output[{i}]"].accumulate(
+                        o, class_name=type(_module).__name__
+                    )
 
-        def backward_hook(
-            _module, _input, _output, _model_diagnostic=ans, _name=name
-        ):
+        def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
             if isinstance(_output, tuple) and len(_output) == 1:
                 _output = _output[0]
             if isinstance(_output, Tensor):
-                _model_diagnostic[f"{_name}.grad"].accumulate(_output,
-                                                              class_name=type(_module).__name__)
+                _model_diagnostic[f"{_name}.grad"].accumulate(
+                    _output, class_name=type(_module).__name__
+                )
             elif isinstance(_output, tuple):
                 for i, o in enumerate(_output):
-                    _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o,
-                                                                       class_name=type(_module).__name__)
+                    _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(
+                        o, class_name=type(_module).__name__
+                    )
 
         module.register_forward_hook(forward_hook)
         module.register_backward_hook(backward_hook)
diff --git a/icefall/dist.py b/icefall/dist.py
index 7016beafb..9df1c5bd1 100644
--- a/icefall/dist.py
+++ b/icefall/dist.py
@@ -29,9 +29,7 @@ def setup_dist(rank, world_size, master_port=None, use_ddp_launch=False):
         os.environ["MASTER_ADDR"] = "localhost"
 
     if "MASTER_PORT" not in os.environ:
-        os.environ["MASTER_PORT"] = (
-            "12354" if master_port is None else str(master_port)
-        )
+        os.environ["MASTER_PORT"] = "12354" if master_port is None else str(master_port)
 
     if use_ddp_launch is False:
         dist.init_process_group("nccl", rank=rank, world_size=world_size)
diff --git a/icefall/env.py b/icefall/env.py
index 8aeda6be2..373e9a9ff 100644
--- a/icefall/env.py
+++ b/icefall/env.py
@@ -53,9 +53,7 @@ def get_git_sha1():
             )
             > 0
         )
-        git_commit = (
-            git_commit + "-dirty" if dirty_commit else git_commit + "-clean"
-        )
+        git_commit = git_commit + "-dirty" if dirty_commit else git_commit + "-clean"
     except:  # noqa
         return None
 
diff --git a/icefall/graph_compiler.py b/icefall/graph_compiler.py
index 570ed7d7a..e2ff03f61 100644
--- a/icefall/graph_compiler.py
+++ b/icefall/graph_compiler.py
@@ -75,9 +75,7 @@ class CtcTrainingGraphCompiler(object):
 
         # NOTE: k2.compose runs on CUDA only when treat_epsilons_specially
         # is False, so we add epsilon self-loops here
-        fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(
-            transcript_fsa
-        )
+        fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(transcript_fsa)
 
         fsa_with_self_loops = k2.arc_sort(fsa_with_self_loops)
 
diff --git a/icefall/hooks.py b/icefall/hooks.py
index fbcf5e148..398a5f689 100644
--- a/icefall/hooks.py
+++ b/icefall/hooks.py
@@ -14,10 +14,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import logging
 import random
+
 import torch
 from torch import Tensor, nn
-import logging
 
 
 def register_inf_check_hooks(model: nn.Module) -> None:
@@ -56,7 +57,7 @@ def register_inf_check_hooks(model: nn.Module) -> None:
             if isinstance(_output, Tensor):
                 if not torch.isfinite(_output.to(torch.float32).sum()):
                     logging.warning(
-                        f"The sum of {_name}.grad is not finite" # ": {_output}"
+                        f"The sum of {_name}.grad is not finite"  # ": {_output}"
                     )
             elif isinstance(_output, tuple):
                 for i, o in enumerate(_output):
@@ -65,28 +66,20 @@ def register_inf_check_hooks(model: nn.Module) -> None:
                     if not isinstance(o, Tensor):
                         continue
                     if not torch.isfinite(o.to(torch.float32).sum()):
-                        logging.warning(
-                            f"The sum of {_name}.grad[{i}] is not finite"
-                        )
+                        logging.warning(f"The sum of {_name}.grad[{i}] is not finite")
 
         module.register_forward_hook(forward_hook)
         module.register_backward_hook(backward_hook)
 
-
     for name, parameter in model.named_parameters():
 
-        def param_backward_hook(
-                grad, _name=name
-        ):
+        def param_backward_hook(grad, _name=name):
             if not torch.isfinite(grad.to(torch.float32).sum()):
-                logging.warning(
-                    f"The sum of {_name}.param_grad is not finite"
-                )
+                logging.warning(f"The sum of {_name}.param_grad is not finite")
 
         parameter.register_hook(param_backward_hook)
 
 
-
 def _test_inf_check_hooks():
     model = nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 80))
 
diff --git a/icefall/lexicon.py b/icefall/lexicon.py
index 80bd7c1ee..22e1b78bb 100644
--- a/icefall/lexicon.py
+++ b/icefall/lexicon.py
@@ -49,18 +49,12 @@ def read_lexicon(filename: str) -> List[Tuple[str, List[str]]]:
                 continue
 
             if len(a) < 2:
-                logging.info(
-                    f"Found bad line {line} in lexicon file {filename}"
-                )
-                logging.info(
-                    "Every line is expected to contain at least 2 fields"
-                )
+                logging.info(f"Found bad line {line} in lexicon file {filename}")
+                logging.info("Every line is expected to contain at least 2 fields")
                 sys.exit(1)
             word = a[0]
             if word == "":
-                logging.info(
-                    f"Found bad line {line} in lexicon file {filename}"
-                )
+                logging.info(f"Found bad line {line} in lexicon file {filename}")
                 logging.info(" should not be a valid word")
                 sys.exit(1)
 
@@ -119,9 +113,7 @@ def convert_lexicon_to_ragged(
     lexicon_tmp = read_lexicon(filename)
     lexicon = dict(lexicon_tmp)
     if len(lexicon_tmp) != len(lexicon):
-        raise RuntimeError(
-            "It's assumed that each word has a unique pronunciation"
-        )
+        raise RuntimeError("It's assumed that each word has a unique pronunciation")
 
     for i in range(disambig_id):
         w = word_table[i]
diff --git a/icefall/mmi.py b/icefall/mmi.py
index 2c479fc2c..16ed6e032 100644
--- a/icefall/mmi.py
+++ b/icefall/mmi.py
@@ -63,10 +63,7 @@ def _compute_mmi_loss_exact_optimized(
 
     # [0, num_fsas, 1, num_fsas, 2, num_fsas, ... ]
     num_den_graphs_indexes = (
-        torch.stack([num_graphs_indexes, den_graphs_indexes])
-        .t()
-        .reshape(-1)
-        .to(device)
+        torch.stack([num_graphs_indexes, den_graphs_indexes]).t().reshape(-1).to(device)
     )
 
     num_den_reordered_graphs = k2.index(num_den_graphs, num_den_graphs_indexes)
@@ -115,20 +112,12 @@ def _compute_mmi_loss_exact_non_optimized(
     num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True)
 
     # TODO: pass output_beam as function argument
-    num_lats = k2.intersect_dense(
-        num_graphs, dense_fsa_vec, output_beam=beam_size
-    )
-    den_lats = k2.intersect_dense(
-        den_graphs, dense_fsa_vec, output_beam=beam_size
-    )
+    num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=beam_size)
+    den_lats = k2.intersect_dense(den_graphs, dense_fsa_vec, output_beam=beam_size)
 
-    num_tot_scores = num_lats.get_tot_scores(
-        log_semiring=True, use_double_scores=True
-    )
+    num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
 
-    den_tot_scores = den_lats.get_tot_scores(
-        log_semiring=True, use_double_scores=True
-    )
+    den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
 
     tot_scores = num_tot_scores - den_scale * den_tot_scores
 
@@ -168,13 +157,9 @@ def _compute_mmi_loss_pruned(
         max_active_states=10000,
     )
 
-    num_tot_scores = num_lats.get_tot_scores(
-        log_semiring=True, use_double_scores=True
-    )
+    num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
 
-    den_tot_scores = den_lats.get_tot_scores(
-        log_semiring=True, use_double_scores=True
-    )
+    den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
 
     tot_scores = num_tot_scores - den_scale * den_tot_scores
 
diff --git a/icefall/mmi_graph_compiler.py b/icefall/mmi_graph_compiler.py
index 0d901227d..9f680f83d 100644
--- a/icefall/mmi_graph_compiler.py
+++ b/icefall/mmi_graph_compiler.py
@@ -137,9 +137,7 @@ class MmiTrainingGraphCompiler(object):
             transcript_fsa
         )
 
-        transcript_fsa_with_self_loops = k2.arc_sort(
-            transcript_fsa_with_self_loops
-        )
+        transcript_fsa_with_self_loops = k2.arc_sort(transcript_fsa_with_self_loops)
 
         num = k2.compose(
             self.ctc_topo_P,
@@ -155,9 +153,7 @@ class MmiTrainingGraphCompiler(object):
 
         ctc_topo_P_vec = k2.create_fsa_vec([self.ctc_topo_P])
         if replicate_den:
-            indexes = torch.zeros(
-                len(texts), dtype=torch.int32, device=self.device
-            )
+            indexes = torch.zeros(len(texts), dtype=torch.int32, device=self.device)
             den = k2.index_fsa(ctc_topo_P_vec, indexes)
         else:
             den = ctc_topo_P_vec
diff --git a/icefall/rnn_lm/dataset.py b/icefall/rnn_lm/dataset.py
index 598e329c4..4bf982503 100644
--- a/icefall/rnn_lm/dataset.py
+++ b/icefall/rnn_lm/dataset.py
@@ -155,12 +155,8 @@ class LmDatasetCollate:
         sentence_tokens_with_sos = add_sos(sentence_tokens, self.sos_id)
         sentence_tokens_with_eos = add_eos(sentence_tokens, self.eos_id)
 
-        x = sentence_tokens_with_sos.pad(
-            mode="constant", padding_value=self.blank_id
-        )
-        y = sentence_tokens_with_eos.pad(
-            mode="constant", padding_value=self.blank_id
-        )
+        x = sentence_tokens_with_sos.pad(mode="constant", padding_value=self.blank_id)
+        y = sentence_tokens_with_eos.pad(mode="constant", padding_value=self.blank_id)
         sentence_token_lengths += 1  # plus 1 since we added a SOS
 
         return x.to(torch.int64), y.to(torch.int64), sentence_token_lengths
diff --git a/icefall/rnn_lm/export.py b/icefall/rnn_lm/export.py
index 094035fce..2411cb1f0 100644
--- a/icefall/rnn_lm/export.py
+++ b/icefall/rnn_lm/export.py
@@ -159,9 +159,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/icefall/rnn_lm/model.py b/icefall/rnn_lm/model.py
index a6144727a..9eef88840 100644
--- a/icefall/rnn_lm/model.py
+++ b/icefall/rnn_lm/model.py
@@ -129,9 +129,7 @@ class RnnLmModel(torch.nn.Module):
         tokens_eos = add_eos(tokens, eos_id)
         sos_tokens_row_splits = sos_tokens.shape.row_splits(1)
 
-        sentence_lengths = (
-            sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
-        )
+        sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
 
         x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id)
         y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id)
@@ -161,12 +159,12 @@ class RnnLmModel(torch.nn.Module):
         if state:
             h, c = state
         else:
-            h = torch.zeros(
-                self.rnn.num_layers, batch_size, self.rnn.input_size
-            ).to(device)
-            c = torch.zeros(
-                self.rnn.num_layers, batch_size, self.rnn.input_size
-            ).to(device)
+            h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(
+                device
+            )
+            c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(
+                device
+            )
 
         embedding = self.input_embedding(tokens)
         rnn_out, states = self.rnn(embedding, (h, c))
@@ -181,12 +179,8 @@ class RnnLmModel(torch.nn.Module):
         if state:
             h, c = state
         else:
-            h = torch.zeros(
-                self.rnn.num_layers, batch_size, self.rnn.input_size
-            )
-            c = torch.zeros(
-                self.rnn.num_layers, batch_size, self.rnn.input_size
-            )
+            h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size)
+            c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size)
 
         device = next(self.parameters()).device
 
@@ -194,9 +188,7 @@ class RnnLmModel(torch.nn.Module):
         tokens_eos = add_eos(tokens, eos_id)
         sos_tokens_row_splits = sos_tokens.shape.row_splits(1)
 
-        sentence_lengths = (
-            sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
-        )
+        sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
 
         x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id)
         y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id)
diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py
index bb5f03fb9..3ba5bfbee 100755
--- a/icefall/rnn_lm/train.py
+++ b/icefall/rnn_lm/train.py
@@ -446,17 +446,13 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
                 tb_writer.add_scalar(
                     "train/current_ppl", this_batch_ppl, params.batch_idx_train
                 )
 
-                tb_writer.add_scalar(
-                    "train/tot_ppl", tot_ppl, params.batch_idx_train
-                )
+                tb_writer.add_scalar("train/tot_ppl", tot_ppl, params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
diff --git a/icefall/shared/make_kn_lm.py b/icefall/shared/make_kn_lm.py
index c2edd823e..b1220d55e 100755
--- a/icefall/shared/make_kn_lm.py
+++ b/icefall/shared/make_kn_lm.py
@@ -15,30 +15,43 @@
 # The data structure is based on: kaldi/egs/wsj/s5/utils/lang/make_phone_lm.py
 # The smoothing algorithm is based on: http://www.speech.sri.com/projects/srilm/manpages/ngram-discount.7.html
 
-import sys
-import os
-import re
+import argparse
 import io
 import math
-import argparse
+import os
+import re
+import sys
 from collections import Counter, defaultdict
 
-
-parser = argparse.ArgumentParser(description="""
+parser = argparse.ArgumentParser(
+    description="""
     Generate kneser-ney language model as arpa format. By default,
     it will read the corpus from standard input, and output to standard output.
-    """)
-parser.add_argument("-ngram-order", type=int, default=4, choices=[2, 3, 4, 5, 6, 7], help="Order of n-gram")
+    """
+)
+parser.add_argument(
+    "-ngram-order",
+    type=int,
+    default=4,
+    choices=[2, 3, 4, 5, 6, 7],
+    help="Order of n-gram",
+)
 parser.add_argument("-text", type=str, default=None, help="Path to the corpus file")
-parser.add_argument("-lm", type=str, default=None, help="Path to output arpa file for language models")
-parser.add_argument("-verbose", type=int, default=0, choices=[0, 1, 2, 3, 4, 5], help="Verbose level")
+parser.add_argument(
+    "-lm", type=str, default=None, help="Path to output arpa file for language models"
+)
+parser.add_argument(
+    "-verbose", type=int, default=0, choices=[0, 1, 2, 3, 4, 5], help="Verbose level"
+)
 args = parser.parse_args()
 
-default_encoding = "latin-1"  # For encoding-agnostic scripts, we assume byte stream as input.
-                              # Need to be very careful about the use of strip() and split()
-                              # in this case, because there is a latin-1 whitespace character
-                              # (nbsp) which is part of the unicode encoding range.
-                              # Ref: kaldi/egs/wsj/s5/utils/lang/bpe/prepend_words.py @ 69cd717
+default_encoding = (
+    "latin-1"  # For encoding-agnostic scripts, we assume byte stream as input.
+)
+# Need to be very careful about the use of strip() and split()
+# in this case, because there is a latin-1 whitespace character
+# (nbsp) which is part of the unicode encoding range.
+# Ref: kaldi/egs/wsj/s5/utils/lang/bpe/prepend_words.py @ 69cd717
 strip_chars = " \t\r\n"
 whitespace = re.compile("[ \t]+")
 
@@ -52,7 +65,9 @@ class CountsForHistory:
         # The 'lambda: defaultdict(float)' is an anonymous function taking no
         # arguments that returns a new defaultdict(float).
         self.word_to_count = defaultdict(int)
-        self.word_to_context = defaultdict(set)  # using a set to count the number of unique contexts
+        self.word_to_context = defaultdict(
+            set
+        )  # using a set to count the number of unique contexts
         self.word_to_f = dict()  # discounted probability
         self.word_to_bow = dict()  # back-off weight
         self.total_count = 0
@@ -62,10 +77,15 @@ class CountsForHistory:
 
     def __str__(self):
         # e.g. returns ' total=12: 3->4, 4->6, -1->2'
-        return ' total={0}: {1}'.format(
+        return " total={0}: {1}".format(
             str(self.total_count),
-            ', '.join(['{0} -> {1}'.format(word, count)
-                      for word, count in self.word_to_count.items()]))
+            ", ".join(
+                [
+                    "{0} -> {1}".format(word, count)
+                    for word, count in self.word_to_count.items()
+                ]
+            ),
+        )
 
     def add_count(self, predicted_word, context_word, count):
         assert count >= 0
@@ -85,7 +105,7 @@ class NgramCounts:
     # accumulating the 4-gram count for the '8' in the sequence '5 6 7 8', we'd
     # do as follows: self.counts[3][[5,6,7]][8] += 1.0 where the [3] indexes an
     # array, the [[5,6,7]] indexes a dict, and the [8] indexes a dict.
-    def __init__(self, ngram_order, bos_symbol='', eos_symbol=''):
+    def __init__(self, ngram_order, bos_symbol="", eos_symbol=""):
         assert ngram_order >= 2
 
         self.ngram_order = ngram_order
@@ -103,39 +123,48 @@ class NgramCounts:
     # would be (6,7,8) and 'predicted_word' would be 9; 'count' would be
     # 1.
     def add_count(self, history, predicted_word, context_word, count):
-        self.counts[len(history)][history].add_count(predicted_word, context_word, count)
+        self.counts[len(history)][history].add_count(
+            predicted_word, context_word, count
+        )
 
     # 'line' is a string containing a sequence of integer word-ids.
     # This function adds the un-smoothed counts from this line of text.
     def add_raw_counts_from_line(self, line):
-        if line == '':
+        if line == "":
             words = [self.bos_symbol, self.eos_symbol]
         else:
             words = [self.bos_symbol] + whitespace.split(line) + [self.eos_symbol]
 
         for i in range(len(words)):
-            for n in range(1, self.ngram_order+1):
+            for n in range(1, self.ngram_order + 1):
                 if i + n > len(words):
                     break
-                ngram = words[i: i + n]
+                ngram = words[i : i + n]
                 predicted_word = ngram[-1]
-                history = tuple(ngram[: -1])
+                history = tuple(ngram[:-1])
                 if i == 0 or n == self.ngram_order:
                     context_word = None
                 else:
-                    context_word = words[i-1]
+                    context_word = words[i - 1]
 
                 self.add_count(history, predicted_word, context_word, 1)
 
     def add_raw_counts_from_standard_input(self):
         lines_processed = 0
-        infile = io.TextIOWrapper(sys.stdin.buffer, encoding=default_encoding)  # byte stream as input
+        infile = io.TextIOWrapper(
+            sys.stdin.buffer, encoding=default_encoding
+        )  # byte stream as input
         for line in infile:
             line = line.strip(strip_chars)
             self.add_raw_counts_from_line(line)
             lines_processed += 1
         if lines_processed == 0 or args.verbose > 0:
-            print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr)
+            print(
+                "make_phone_lm.py: processed {0} lines of input".format(
+                    lines_processed
+                ),
+                file=sys.stderr,
+            )
 
     def add_raw_counts_from_file(self, filename):
         lines_processed = 0
@@ -145,7 +174,12 @@ class NgramCounts:
                 self.add_raw_counts_from_line(line)
                 lines_processed += 1
         if lines_processed == 0 or args.verbose > 0:
-            print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr)
+            print(
+                "make_phone_lm.py: processed {0} lines of input".format(
+                    lines_processed
+                ),
+                file=sys.stderr,
+            )
 
     def cal_discounting_constants(self):
         # For each order N of N-grams, we calculate discounting constant D_N = n1_N / (n1_N + 2 * n2_N),
@@ -153,9 +187,11 @@ class NgramCounts:
         # This constant is used similarly to absolute discounting.
         # Return value: d is a list of floats, where d[N+1] = D_N
 
-        self.d = [0]  # for the lowest order, i.e., 1-gram, we do not need to discount, thus the constant is 0
-                      # This is a special case: as we currently assumed having seen all vocabularies in the dictionary,
-                      # but perhaps this is not the case for some other scenarios.
+        self.d = [
+            0
+        ]  # for the lowest order, i.e., 1-gram, we do not need to discount, thus the constant is 0
+        # This is a special case: as we currently assumed having seen all vocabularies in the dictionary,
+        # but perhaps this is not the case for some other scenarios.
         for n in range(1, self.ngram_order):
             this_order_counts = self.counts[n]
             n1 = 0
@@ -165,9 +201,11 @@ class NgramCounts:
                 n1 += stat[1]
                 n2 += stat[2]
             assert n1 + 2 * n2 > 0
-            self.d.append(max(0.1, n1 * 1.0) / (n1 + 2 * n2))   # We are doing this max(0.001, xxx) to avoid zero discounting constant D due to n1=0, 
-                                                                # which could happen if the number of symbols is small.
-                                                                # Otherwise, zero discounting constant can cause division by zero in computing BOW.
+            self.d.append(
+                max(0.1, n1 * 1.0) / (n1 + 2 * n2)
+            )  # We are doing this max(0.001, xxx) to avoid zero discounting constant D due to n1=0,
+            # which could happen if the number of symbols is small.
+            # Otherwise, zero discounting constant can cause division by zero in computing BOW.
 
     def cal_f(self):
         # f(a_z) is a probability distribution of word sequence a_z.
@@ -182,7 +220,9 @@ class NgramCounts:
         this_order_counts = self.counts[n]
         for hist, counts_for_hist in this_order_counts.items():
             for w, c in counts_for_hist.word_to_count.items():
-                counts_for_hist.word_to_f[w] = max((c - self.d[n]), 0) * 1.0 / counts_for_hist.total_count
+                counts_for_hist.word_to_f[w] = (
+                    max((c - self.d[n]), 0) * 1.0 / counts_for_hist.total_count
+                )
 
         # lower order N-grams
         for n in range(0, self.ngram_order - 1):
@@ -196,11 +236,17 @@ class NgramCounts:
                 if n_star_star != 0:
                     for w in counts_for_hist.word_to_count.keys():
                         n_star_z = len(counts_for_hist.word_to_context[w])
-                        counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / n_star_star
+                        counts_for_hist.word_to_f[w] = (
+                            max((n_star_z - self.d[n]), 0) * 1.0 / n_star_star
+                        )
                 else:  # patterns begin with , they do not have "modified count", so use raw count instead
                     for w in counts_for_hist.word_to_count.keys():
                         n_star_z = counts_for_hist.word_to_count[w]
-                        counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / counts_for_hist.total_count
+                        counts_for_hist.word_to_f[w] = (
+                            max((n_star_z - self.d[n]), 0)
+                            * 1.0
+                            / counts_for_hist.total_count
+                        )
 
     def cal_bow(self):
         # Backoff weights are only necessary for ngrams which form a prefix of a longer ngram.
@@ -240,12 +286,18 @@ class NgramCounts:
                         sum_z1_f_z = 0
                         _ = a_[1:]
                         _counts_for_hist = self.counts[len(_)][_]
-                        for u in a_counts_for_hist.word_to_count.keys():  # Should be careful here: what is Z1
+                        for (
+                            u
+                        ) in (
+                            a_counts_for_hist.word_to_count.keys()
+                        ):  # Should be careful here: what is Z1
                             sum_z1_f_z += _counts_for_hist.word_to_f[u]
 
                         if sum_z1_f_z < 1:
                             # assert sum_z1_f_a_z < 1
-                            counts_for_hist.word_to_bow[w] = (1.0 - sum_z1_f_a_z) / (1.0 - sum_z1_f_z)
+                            counts_for_hist.word_to_bow[w] = (1.0 - sum_z1_f_a_z) / (
+                                1.0 - sum_z1_f_z
+                            )
                         else:
                             counts_for_hist.word_to_bow[w] = None
 
@@ -259,7 +311,9 @@ class NgramCounts:
                     ngram = " ".join(hist) + " " + w
                     ngram = ngram.strip(strip_chars)
 
-                    res.append("{0}\t{1}".format(ngram, counts_for_hist.word_to_count[w]))
+                    res.append(
+                        "{0}\t{1}".format(ngram, counts_for_hist.word_to_count[w])
+                    )
         res.sort(reverse=True)
         for r in res:
             print(r)
@@ -322,27 +376,40 @@ class NgramCounts:
                     if bow is None:
                         res.append("{1}\t{0}".format(ngram, math.log(f, 10)))
                     else:
-                        res.append("{1}\t{0}\t{2}".format(ngram, math.log(f, 10), math.log(bow, 10)))
+                        res.append(
+                            "{1}\t{0}\t{2}".format(
+                                ngram, math.log(f, 10), math.log(bow, 10)
+                            )
+                        )
         res.sort(reverse=True)
         for r in res:
             print(r)
 
-    def print_as_arpa(self, fout=io.TextIOWrapper(sys.stdout.buffer, encoding='latin-1')):
+    def print_as_arpa(
+        self, fout=io.TextIOWrapper(sys.stdout.buffer, encoding="latin-1")
+    ):
         # print as ARPA format.
 
-        print('\\data\\', file=fout)
+        print("\\data\\", file=fout)
         for hist_len in range(self.ngram_order):
             # print the number of n-grams.
-            print('ngram {0}={1}'.format(
-                hist_len + 1,
-                sum([len(counts_for_hist.word_to_f) for counts_for_hist in self.counts[hist_len].values()])),
-                file=fout
+            print(
+                "ngram {0}={1}".format(
+                    hist_len + 1,
+                    sum(
+                        [
+                            len(counts_for_hist.word_to_f)
+                            for counts_for_hist in self.counts[hist_len].values()
+                        ]
+                    ),
+                ),
+                file=fout,
             )
 
-        print('', file=fout)
+        print("", file=fout)
 
         for hist_len in range(self.ngram_order):
-            print('\\{0}-grams:'.format(hist_len + 1), file=fout)
+            print("\\{0}-grams:".format(hist_len + 1), file=fout)
 
             this_order_counts = self.counts[hist_len]
             for hist, counts_for_hist in this_order_counts.items():
@@ -354,12 +421,12 @@ class NgramCounts:
                     if prob == 0:  # f() is always 0
                         prob = 1e-99
 
-                    line = '{0}\t{1}'.format('%.7f' % math.log10(prob), ' '.join(ngram))
+                    line = "{0}\t{1}".format("%.7f" % math.log10(prob), " ".join(ngram))
                     if bow is not None:
-                        line += '\t{0}'.format('%.7f' % math.log10(bow))
+                        line += "\t{0}".format("%.7f" % math.log10(bow))
                     print(line, file=fout)
-            print('', file=fout)
-        print('\\end\\', file=fout)
+            print("", file=fout)
+        print("\\end\\", file=fout)
 
 
 if __name__ == "__main__":
@@ -379,5 +446,5 @@ if __name__ == "__main__":
     if args.lm is None:
         ngram_counts.print_as_arpa()
     else:
-        with open(args.lm, 'w', encoding=default_encoding) as f:
+        with open(args.lm, "w", encoding=default_encoding) as f:
             ngram_counts.print_as_arpa(fout=f)
diff --git a/icefall/utils.py b/icefall/utils.py
index 143c79497..b4d8e9a51 100644
--- a/icefall/utils.py
+++ b/icefall/utils.py
@@ -130,9 +130,7 @@ def setup_logger(
         formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s"  # noqa
         log_filename = f"{log_filename}-{date_time}-{rank}"
     else:
-        formatter = (
-            "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-        )
+        formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
         log_filename = f"{log_filename}-{date_time}"
 
     os.makedirs(os.path.dirname(log_filename), exist_ok=True)
@@ -203,7 +201,7 @@ def encode_supervisions(
                 supervisions["num_frames"],
                 subsampling_factor,
                 rounding_mode="floor",
-            )
+            ),
         ),
         1,
     ).to(torch.int32)
@@ -288,13 +286,9 @@ def get_texts_with_timestamp(
     """
     if isinstance(best_paths.aux_labels, k2.RaggedTensor):
         all_aux_shape = (
-            best_paths.arcs.shape()
-            .remove_axis(1)
-            .compose(best_paths.aux_labels.shape)
-        )
-        all_aux_labels = k2.RaggedTensor(
-            all_aux_shape, best_paths.aux_labels.values
+            best_paths.arcs.shape().remove_axis(1).compose(best_paths.aux_labels.shape)
         )
+        all_aux_labels = k2.RaggedTensor(all_aux_shape, best_paths.aux_labels.values)
         # remove 0's and -1's.
         aux_labels = best_paths.aux_labels.remove_values_leq(0)
         # TODO: change arcs.shape() to arcs.shape
@@ -363,9 +357,7 @@ def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]:
     # arc.shape() has axes [fsa][state][arc], we remove "state"-axis here
     token_shape = best_paths.arcs.shape().remove_axis(1)
     # token_shape has axes [fsa][arc]
-    tokens = k2.RaggedTensor(
-        token_shape, getattr(best_paths, kind).contiguous()
-    )
+    tokens = k2.RaggedTensor(token_shape, getattr(best_paths, kind).contiguous())
     tokens = tokens.remove_values_eq(-1)
     return tokens.tolist()
 
@@ -586,9 +578,7 @@ def write_error_stats(
             f"{cut_id}:\t"
             + " ".join(
                 (
-                    ref_word
-                    if ref_word == hyp_word
-                    else f"({ref_word}->{hyp_word})"
+                    ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
                     for ref_word, hyp_word in ali
                 )
             ),
@@ -598,9 +588,7 @@ def write_error_stats(
     print("", file=f)
     print("SUBSTITUTIONS: count ref -> hyp", file=f)
 
-    for count, (ref, hyp) in sorted(
-        [(v, k) for k, v in subs.items()], reverse=True
-    ):
+    for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
         print(f"{count}   {ref} -> {hyp}", file=f)
 
     print("", file=f)
@@ -614,9 +602,7 @@ def write_error_stats(
         print(f"{count}   {hyp}", file=f)
 
     print("", file=f)
-    print(
-        "PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp", file=f
-    )
+    print("PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp", file=f)
     for _, word, counts in sorted(
         [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
     ):
@@ -791,9 +777,7 @@ def write_error_stats_with_timestamps(
             f"{cut_id}:\t"
             + " ".join(
                 (
-                    ref_word
-                    if ref_word == hyp_word
-                    else f"({ref_word}->{hyp_word})"
+                    ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
                     for ref_word, hyp_word in ali
                 )
             ),
@@ -803,9 +787,7 @@ def write_error_stats_with_timestamps(
     print("", file=f)
     print("SUBSTITUTIONS: count ref -> hyp", file=f)
 
-    for count, (ref, hyp) in sorted(
-        [(v, k) for k, v in subs.items()], reverse=True
-    ):
+    for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
         print(f"{count}   {ref} -> {hyp}", file=f)
 
     print("", file=f)
@@ -819,9 +801,7 @@ def write_error_stats_with_timestamps(
         print(f"{count}   {hyp}", file=f)
 
     print("", file=f)
-    print(
-        "PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp", file=f
-    )
+    print("PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp", file=f)
     for _, word, counts in sorted(
         [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
     ):
@@ -891,9 +871,7 @@ class MetricsTracker(collections.defaultdict):
             if k == "frames" or k == "utterances":
                 continue
             norm_value = (
-                float(v) / num_frames
-                if "utt_" not in k
-                else float(v) / num_utterances
+                float(v) / num_frames if "utt_" not in k else float(v) / num_utterances
             )
             ans.append((k, norm_value))
         return ans
@@ -927,9 +905,7 @@ class MetricsTracker(collections.defaultdict):
             tb_writer.add_scalar(prefix + k, v, batch_idx)
 
 
-def concat(
-    ragged: k2.RaggedTensor, value: int, direction: str
-) -> k2.RaggedTensor:
+def concat(ragged: k2.RaggedTensor, value: int, direction: str) -> k2.RaggedTensor:
     """Prepend a value to the beginning of each sublist or append a value.
     to the end of each sublist.
 
@@ -1101,9 +1077,7 @@ def linf_norm(x):
     return torch.max(torch.abs(x))
 
 
-def measure_weight_norms(
-    model: nn.Module, norm: str = "l2"
-) -> Dict[str, float]:
+def measure_weight_norms(model: nn.Module, norm: str = "l2") -> Dict[str, float]:
     """
     Compute the norms of the model's parameters.
 
@@ -1126,9 +1100,7 @@ def measure_weight_norms(
         return norms
 
 
-def measure_gradient_norms(
-    model: nn.Module, norm: str = "l1"
-) -> Dict[str, float]:
+def measure_gradient_norms(model: nn.Module, norm: str = "l1") -> Dict[str, float]:
     """
     Compute the norms of the gradients for each of model's parameters.
 
@@ -1413,9 +1385,7 @@ def parse_hyp_and_timestamp(
         use_word_table = True
 
     for i in range(N):
-        time = convert_timestamp(
-            res.timestamps[i], subsampling_factor, frame_shift_ms
-        )
+        time = convert_timestamp(res.timestamps[i], subsampling_factor, frame_shift_ms)
         if use_word_table:
             words = [word_table[i] for i in res.hyps[i]]
         else:
diff --git a/pyproject.toml b/pyproject.toml
index b4f8c3377..3183055d4 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -3,7 +3,7 @@ profile = "black"
 skip = ["icefall/__init__.py"]
 
 [tool.black]
-line-length = 80
+line-length = 88
 exclude = '''
 /(
     \.git
diff --git a/setup.py b/setup.py
index 6c720e121..ccd2503ff 100644
--- a/setup.py
+++ b/setup.py
@@ -1,8 +1,9 @@
 #!/usr/bin/env python3
 
-from setuptools import find_packages, setup
 from pathlib import Path
 
+from setuptools import find_packages, setup
+
 icefall_dir = Path(__file__).parent
 install_requires = (icefall_dir / "requirements.txt").read_text().splitlines()
 
diff --git a/test/test_checkpoint.py b/test/test_checkpoint.py
index 511a11c23..34e829642 100644
--- a/test/test_checkpoint.py
+++ b/test/test_checkpoint.py
@@ -20,11 +20,7 @@ import pytest
 import torch
 import torch.nn as nn
 
-from icefall.checkpoint import (
-    average_checkpoints,
-    load_checkpoint,
-    save_checkpoint,
-)
+from icefall.checkpoint import average_checkpoints, load_checkpoint, save_checkpoint
 
 
 @pytest.fixture
diff --git a/test/test_decode.py b/test/test_decode.py
index 97964ac67..4c2e192a7 100644
--- a/test/test_decode.py
+++ b/test/test_decode.py
@@ -23,6 +23,7 @@ You can run this file in one of the two ways:
 """
 
 import k2
+
 from icefall.decode import Nbest
 
 
diff --git a/test/test_graph_compiler.py b/test/test_graph_compiler.py
index ccfb57d49..10443cf22 100644
--- a/test/test_graph_compiler.py
+++ b/test/test_graph_compiler.py
@@ -154,9 +154,7 @@ class TestCtcTrainingGraphCompiler(object):
         fsas = k2.Fsa.from_fsas([fsa1, fsa2])
 
         decoding_graph = k2.arc_sort(decoding_graph)
-        lattice = k2.intersect(
-            decoding_graph, fsas, treat_epsilons_specially=False
-        )
+        lattice = k2.intersect(decoding_graph, fsas, treat_epsilons_specially=False)
         lattice = k2.connect(lattice)
 
         aux_labels0 = lattice[0].aux_labels[:-1]
diff --git a/test/test_utils.py b/test/test_utils.py
index 6a9ce7853..31f06bd51 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -50,9 +50,7 @@ def test_encode_supervisions(sup):
     assert torch.all(
         torch.eq(
             supervision_segments,
-            torch.tensor(
-                [[1, 0, 30 // 4], [0, 0, 20 // 4], [2, 9 // 4, 10 // 4]]
-            ),
+            torch.tensor([[1, 0, 30 // 4], [0, 0, 20 // 4], [2, 9 // 4, 10 // 4]]),
         )
     )
     assert texts == ["two", "one", "three"]

From 18e3a7a9d59ed4079a6ec53039ef60f2aeeb89f4 Mon Sep 17 00:00:00 2001
From: Desh Raj 
Date: Thu, 17 Nov 2022 09:43:48 -0500
Subject: [PATCH 039/120] add git blame ignore file

---
 .git-blame-ignore-revs | 2 ++
 1 file changed, 2 insertions(+)
 create mode 100644 .git-blame-ignore-revs

diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs
new file mode 100644
index 000000000..be5901517
--- /dev/null
+++ b/.git-blame-ignore-revs
@@ -0,0 +1,2 @@
+# Migrate to 88 characters per line (see: https://github.com/lhotse-speech/lhotse/issues/890)
+107df3b115a58f1b68a6458c3f94a130004be34c

From d31db010371a4128856480382876acdc0d1739ed Mon Sep 17 00:00:00 2001
From: Desh Raj 
Date: Thu, 17 Nov 2022 14:18:05 -0500
Subject: [PATCH 040/120] manual correction of black formatting

---
 .../pruned_transducer_stateless2/decode.py    |  2 +-
 .../pruned_transducer_stateless2/export.py    |  2 +-
 .../pretrained.py                             |  8 ++---
 .../ASR/pruned_transducer_stateless2/train.py |  4 +--
 egs/aishell/ASR/conformer_ctc/pretrained.py   |  6 ++--
 .../pruned_transducer_stateless2/decode.py    |  4 +--
 .../pruned_transducer_stateless2/export.py    |  4 +--
 .../pretrained.py                             |  8 ++---
 .../ASR/pruned_transducer_stateless2/train.py |  6 ++--
 .../pruned_transducer_stateless3/decode.py    |  2 +-
 .../pruned_transducer_stateless3/export.py    |  2 +-
 .../pretrained.py                             |  8 ++---
 .../ASR/pruned_transducer_stateless3/train.py |  6 ++--
 egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py   |  6 ++--
 .../ASR/transducer_stateless/decode.py        |  2 +-
 .../ASR/transducer_stateless/export.py        |  2 +-
 .../ASR/transducer_stateless/pretrained.py    |  8 ++---
 egs/aishell/ASR/transducer_stateless/train.py |  2 +-
 .../transducer_stateless_modified-2/decode.py |  2 +-
 .../transducer_stateless_modified-2/export.py |  2 +-
 .../pretrained.py                             |  8 ++---
 .../transducer_stateless_modified-2/train.py  |  4 +--
 .../transducer_stateless_modified/decode.py   |  2 +-
 .../transducer_stateless_modified/export.py   |  2 +-
 .../pretrained.py                             |  8 ++---
 .../transducer_stateless_modified/train.py    |  2 +-
 .../pruned_transducer_stateless5/decode.py    |  2 +-
 .../pruned_transducer_stateless5/export.py    |  2 +-
 .../pretrained.py                             |  8 ++---
 .../ASR/pruned_transducer_stateless5/train.py |  6 ++--
 .../pruned_transducer_stateless5/decode.py    |  2 +-
 .../pruned_transducer_stateless5/export.py    |  2 +-
 .../pretrained.py                             |  8 ++---
 .../ASR/pruned_transducer_stateless5/train.py |  6 ++--
 .../pruned_transducer_stateless2/decode.py    |  2 +-
 .../pruned_transducer_stateless2/export.py    |  2 +-
 .../pretrained.py                             |  8 ++---
 .../ASR/pruned_transducer_stateless2/train.py |  4 +--
 egs/csj/ASR/local/compute_fbank_csj.py        |  8 +++--
 egs/csj/ASR/local/prepare_lang_char.py        |  6 ++--
 .../ASR/conformer_ctc/asr_datamodule.py       |  2 +-
 .../asr_datamodule.py                         |  2 +-
 .../pruned_transducer_stateless2/decode.py    |  4 +--
 .../pruned_transducer_stateless2/export.py    |  4 +--
 .../ASR/pruned_transducer_stateless2/train.py |  4 +--
 .../ASR/conformer_ctc/pretrained.py           |  6 ++--
 .../decode.py                                 |  2 +-
 .../emformer.py                               |  2 +-
 .../export.py                                 |  2 +-
 .../streaming_decode.py                       |  2 +-
 .../train.py                                  |  4 +--
 .../decode.py                                 |  2 +-
 .../emformer.py                               |  2 +-
 .../export.py                                 |  2 +-
 .../streaming_decode.py                       |  2 +-
 .../train.py                                  |  4 +--
 egs/librispeech/ASR/local/filter_cuts.py      |  4 +--
 .../ASR/local/prepare_lm_training_data.py     |  2 +-
 .../ASR/lstm_transducer_stateless/decode.py   |  2 +-
 .../ASR/lstm_transducer_stateless/export.py   |  2 +-
 .../jit_pretrained.py                         |  6 ++--
 .../lstm_transducer_stateless/pretrained.py   |  8 ++---
 .../streaming_decode.py                       |  2 +-
 .../ASR/lstm_transducer_stateless/train.py    |  6 ++--
 .../ASR/lstm_transducer_stateless2/decode.py  |  2 +-
 .../ASR/lstm_transducer_stateless2/export.py  |  2 +-
 .../jit_pretrained.py                         |  6 ++--
 .../lstm_transducer_stateless2/ncnn-decode.py |  6 ++--
 .../lstm_transducer_stateless2/pretrained.py  |  8 ++---
 .../streaming-ncnn-decode.py                  |  6 ++--
 .../streaming-onnx-decode.py                  |  6 ++--
 .../ASR/lstm_transducer_stateless2/train.py   |  8 ++---
 .../ASR/lstm_transducer_stateless3/decode.py  |  2 +-
 .../ASR/lstm_transducer_stateless3/export.py  |  2 +-
 .../jit_pretrained.py                         |  6 ++--
 .../lstm_transducer_stateless3/pretrained.py  |  8 ++---
 .../streaming_decode.py                       |  2 +-
 .../ASR/lstm_transducer_stateless3/train.py   |  6 ++--
 .../ASR/pruned2_knowledge/asr_datamodule.py   |  2 +-
 .../ASR/pruned2_knowledge/decode.py           |  2 +-
 .../ASR/pruned2_knowledge/export.py           |  2 +-
 .../ASR/pruned2_knowledge/train.py            |  4 +--
 .../pruned_stateless_emformer_rnnt2/decode.py |  2 +-
 .../pruned_stateless_emformer_rnnt2/export.py |  2 +-
 .../pruned_stateless_emformer_rnnt2/train.py  |  6 ++--
 .../ASR/pruned_transducer_stateless/decode.py |  4 +--
 .../ASR/pruned_transducer_stateless/export.py |  2 +-
 .../pruned_transducer_stateless/pretrained.py |  8 ++---
 .../streaming_decode.py                       |  4 +--
 .../ASR/pruned_transducer_stateless/train.py  |  6 ++--
 .../pruned_transducer_stateless2/decode.py    |  4 +--
 .../pruned_transducer_stateless2/export.py    |  4 +--
 .../pretrained.py                             |  8 ++---
 .../streaming_decode.py                       |  4 +--
 .../ASR/pruned_transducer_stateless2/train.py |  8 ++---
 .../decode-giga.py                            |  4 +--
 .../pruned_transducer_stateless3/decode.py    |  6 ++--
 .../pruned_transducer_stateless3/export.py    |  4 +--
 .../jit_pretrained.py                         |  6 ++--
 .../onnx_pretrained.py                        |  6 ++--
 .../pretrained.py                             |  8 ++---
 .../streaming_decode.py                       |  4 +--
 .../ASR/pruned_transducer_stateless3/train.py | 10 +++---
 .../pruned_transducer_stateless4/decode.py    |  2 +-
 .../pruned_transducer_stateless4/export.py    |  2 +-
 .../streaming_decode.py                       |  2 +-
 .../ASR/pruned_transducer_stateless4/train.py |  6 ++--
 .../pruned_transducer_stateless5/decode.py    |  2 +-
 .../pruned_transducer_stateless5/export.py    |  2 +-
 .../pretrained.py                             |  8 ++---
 .../streaming_decode.py                       |  2 +-
 .../ASR/pruned_transducer_stateless5/train.py |  8 ++---
 .../pruned_transducer_stateless6/decode.py    |  2 +-
 .../pruned_transducer_stateless6/export.py    |  4 +--
 .../ASR/pruned_transducer_stateless6/train.py |  6 ++--
 .../pruned_transducer_stateless7/decode.py    |  2 +-
 .../pruned_transducer_stateless7/export.py    |  2 +-
 .../jit_pretrained.py                         |  6 ++--
 .../pretrained.py                             |  8 ++---
 .../ASR/pruned_transducer_stateless7/train.py |  6 ++--
 .../pruned_transducer_stateless8/decode.py    |  2 +-
 .../pruned_transducer_stateless8/export.py    |  2 +-
 .../jit_pretrained.py                         |  6 ++--
 .../pretrained.py                             |  8 ++---
 .../ASR/pruned_transducer_stateless8/train.py |  8 ++---
 .../streaming_decode.py                       |  2 +-
 .../ASR/tdnn_lstm_ctc/asr_datamodule.py       |  2 +-
 .../ASR/tdnn_lstm_ctc/pretrained.py           |  6 ++--
 egs/librispeech/ASR/transducer/pretrained.py  |  6 ++--
 .../ASR/transducer_stateless/compute_ali.py   |  2 +-
 .../ASR/transducer_stateless/decode.py        |  2 +-
 .../ASR/transducer_stateless/export.py        |  2 +-
 .../ASR/transducer_stateless/pretrained.py    |  8 ++---
 .../ASR/transducer_stateless/train.py         |  2 +-
 .../ASR/transducer_stateless2/decode.py       |  2 +-
 .../ASR/transducer_stateless2/export.py       |  2 +-
 .../ASR/transducer_stateless2/pretrained.py   |  8 ++---
 .../ASR/transducer_stateless2/train.py        |  2 +-
 .../decode.py                                 |  2 +-
 .../export.py                                 |  2 +-
 .../pretrained.py                             |  8 ++---
 .../train.py                                  |  4 +--
 .../pruned_transducer_stateless2/decode.py    |  4 +--
 .../pruned_transducer_stateless2/export.py    |  2 +-
 .../ASR/pruned_transducer_stateless2/train.py |  6 ++--
 .../pruned_transducer_stateless5/decode.py    |  2 +-
 .../pruned_transducer_stateless5/export.py    |  2 +-
 .../pretrained.py                             |  8 ++---
 .../ASR/pruned_transducer_stateless5/train.py |  6 ++--
 .../ASR/pruned_transducer_stateless/decode.py |  2 +-
 .../ASR/pruned_transducer_stateless/export.py |  2 +-
 .../pruned_transducer_stateless/pretrained.py |  8 ++---
 .../ASR/pruned_transducer_stateless/train.py  |  4 +--
 .../ASR/transducer_stateless/decode.py        |  2 +-
 .../ASR/transducer_stateless/export.py        |  2 +-
 .../ASR/transducer_stateless/pretrained.py    |  8 ++---
 .../ASR/transducer_stateless/train.py         |  2 +-
 egs/timit/ASR/tdnn_ligru_ctc/pretrained.py    |  6 ++--
 egs/timit/ASR/tdnn_lstm_ctc/pretrained.py     |  6 ++--
 .../pruned_transducer_stateless2/decode.py    |  2 +-
 .../pruned_transducer_stateless2/export.py    |  2 +-
 .../jit_pretrained.py                         |  6 ++--
 .../onnx_pretrained.py                        |  6 ++--
 .../pretrained.py                             |  8 ++---
 .../ASR/pruned_transducer_stateless2/train.py |  4 +--
 .../pruned_transducer_stateless5/decode.py    |  2 +-
 .../pruned_transducer_stateless5/export.py    |  2 +-
 .../pretrained.py                             |  8 ++---
 .../streaming_decode.py                       |  2 +-
 .../ASR/pruned_transducer_stateless5/train.py |  6 ++--
 egs/yesno/ASR/tdnn/pretrained.py              |  6 ++--
 icefall/shared/make_kn_lm.py                  | 34 ++++++++-----------
 172 files changed, 381 insertions(+), 383 deletions(-)

diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py
index b1c7c2839..d0f118959 100755
--- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py
@@ -188,7 +188,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py
index de37ec7e4..e348f7b2b 100644
--- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py
@@ -103,7 +103,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py
index 548b7263c..75c316eaf 100644
--- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py
@@ -162,7 +162,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -192,9 +192,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py
index 322fa6b00..c9d9c4aa8 100644
--- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py
@@ -185,7 +185,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -208,7 +208,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
diff --git a/egs/aishell/ASR/conformer_ctc/pretrained.py b/egs/aishell/ASR/conformer_ctc/pretrained.py
index e0dcb8ad4..66d583396 100755
--- a/egs/aishell/ASR/conformer_ctc/pretrained.py
+++ b/egs/aishell/ASR/conformer_ctc/pretrained.py
@@ -210,9 +210,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py
index 199acf6c3..20a4f21c7 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py
@@ -184,7 +184,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -487,7 +487,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/export.py b/egs/aishell/ASR/pruned_transducer_stateless2/export.py
index 4d41e425c..2ce5cfe69 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless2/export.py
@@ -116,7 +116,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     add_model_arguments(parser)
@@ -152,7 +152,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py
index 8aa0fbdd7..82c10f129 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py
@@ -165,7 +165,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -195,9 +195,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/train.py b/egs/aishell/ASR/pruned_transducer_stateless2/train.py
index f81ab2568..d08908238 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless2/train.py
@@ -200,7 +200,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need " "to be changed.",
+        help="The initial learning rate.  This value should not need to be changed.",
     )
 
     parser.add_argument(
@@ -223,7 +223,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -246,7 +246,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py
index f6c919e9d..bac829ae1 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py
@@ -202,7 +202,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/export.py b/egs/aishell/ASR/pruned_transducer_stateless3/export.py
index 5e701c121..7f10eb36e 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless3/export.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/export.py
@@ -132,7 +132,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     add_model_arguments(parser)
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py
index 40926173c..ead393e6e 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py
@@ -165,7 +165,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -195,9 +195,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/train.py b/egs/aishell/ASR/pruned_transducer_stateless3/train.py
index 680986ee9..62e67530d 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless3/train.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/train.py
@@ -222,7 +222,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need " "to be changed.",
+        help="The initial learning rate.  This value should not need to be changed.",
     )
 
     parser.add_argument(
@@ -245,7 +245,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -268,7 +268,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py
index fe197a9f9..7e7213501 100644
--- a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py
+++ b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py
@@ -110,9 +110,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/aishell/ASR/transducer_stateless/decode.py b/egs/aishell/ASR/transducer_stateless/decode.py
index fbc54f68b..e019d2329 100755
--- a/egs/aishell/ASR/transducer_stateless/decode.py
+++ b/egs/aishell/ASR/transducer_stateless/decode.py
@@ -99,7 +99,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/aishell/ASR/transducer_stateless/export.py b/egs/aishell/ASR/transducer_stateless/export.py
index eea9b6883..01de5d772 100755
--- a/egs/aishell/ASR/transducer_stateless/export.py
+++ b/egs/aishell/ASR/transducer_stateless/export.py
@@ -110,7 +110,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
diff --git a/egs/aishell/ASR/transducer_stateless/pretrained.py b/egs/aishell/ASR/transducer_stateless/pretrained.py
index b03a2643a..40f430e13 100755
--- a/egs/aishell/ASR/transducer_stateless/pretrained.py
+++ b/egs/aishell/ASR/transducer_stateless/pretrained.py
@@ -117,7 +117,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -210,9 +210,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/aishell/ASR/transducer_stateless/train.py b/egs/aishell/ASR/transducer_stateless/train.py
index 4ea902507..62ffff473 100755
--- a/egs/aishell/ASR/transducer_stateless/train.py
+++ b/egs/aishell/ASR/transducer_stateless/train.py
@@ -126,7 +126,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py
index cb206af6d..41cc1c01c 100755
--- a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py
@@ -170,7 +170,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/export.py b/egs/aishell/ASR/transducer_stateless_modified-2/export.py
index 3c56d4a01..c1081c32b 100755
--- a/egs/aishell/ASR/transducer_stateless_modified-2/export.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/export.py
@@ -109,7 +109,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py
index d8c0c5fcd..5d8ca2e11 100755
--- a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py
@@ -165,7 +165,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -193,9 +193,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/train.py b/egs/aishell/ASR/transducer_stateless_modified-2/train.py
index a9a30d7f7..8fb7d1e49 100755
--- a/egs/aishell/ASR/transducer_stateless_modified-2/train.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/train.py
@@ -149,7 +149,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -167,7 +167,7 @@ def get_parser():
         "--datatang-prob",
         type=float,
         default=0.2,
-        help="The probability to select a batch from the " "aidatatang_200zh dataset",
+        help="The probability to select a batch from the aidatatang_200zh dataset",
     )
 
     return parser
diff --git a/egs/aishell/ASR/transducer_stateless_modified/decode.py b/egs/aishell/ASR/transducer_stateless_modified/decode.py
index ba3cb3218..7c06e6e51 100755
--- a/egs/aishell/ASR/transducer_stateless_modified/decode.py
+++ b/egs/aishell/ASR/transducer_stateless_modified/decode.py
@@ -171,7 +171,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/aishell/ASR/transducer_stateless_modified/export.py b/egs/aishell/ASR/transducer_stateless_modified/export.py
index cbdbdbeb6..3e14ad69c 100755
--- a/egs/aishell/ASR/transducer_stateless_modified/export.py
+++ b/egs/aishell/ASR/transducer_stateless_modified/export.py
@@ -109,7 +109,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
diff --git a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py
index 7dfa92a3c..9e4459247 100755
--- a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py
+++ b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py
@@ -165,7 +165,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -193,9 +193,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/aishell/ASR/transducer_stateless_modified/train.py b/egs/aishell/ASR/transducer_stateless_modified/train.py
index c4bf4dd56..5f116f2bd 100755
--- a/egs/aishell/ASR/transducer_stateless_modified/train.py
+++ b/egs/aishell/ASR/transducer_stateless_modified/train.py
@@ -142,7 +142,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py
index 7900c5883..b5da0959b 100755
--- a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py
@@ -269,7 +269,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py
index ea4a8d4f9..8a5be94d0 100755
--- a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py
@@ -133,7 +133,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     add_model_arguments(parser)
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py b/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py
index 94536fa6f..bc3ae7abf 100755
--- a/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py
@@ -159,7 +159,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -190,9 +190,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py
index 4a228113d..74bf68ccb 100755
--- a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py
@@ -218,7 +218,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need " "to be changed.",
+        help="The initial learning rate.  This value should not need to be changed.",
     )
 
     parser.add_argument(
@@ -241,7 +241,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -264,7 +264,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py
index cb533df35..37d766ec8 100755
--- a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py
@@ -201,7 +201,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py
index cc9b7b444..bf9856c60 100755
--- a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py
@@ -136,7 +136,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     add_model_arguments(parser)
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py b/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py
index a234f9d65..ee898c303 100755
--- a/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py
@@ -172,7 +172,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -203,9 +203,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py
index 73ee34284..d7c69f226 100755
--- a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py
@@ -211,7 +211,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need " "to be changed.",
+        help="The initial learning rate.  This value should not need to be changed.",
     )
 
     parser.add_argument(
@@ -234,7 +234,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -257,7 +257,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py
index f3b63b222..e4a90ef71 100755
--- a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py
@@ -189,7 +189,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py
index 538853f67..8e5cc6075 100644
--- a/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py
@@ -103,7 +103,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py
index 4da8d8e14..f5a0dd8c8 100644
--- a/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py
@@ -162,7 +162,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -192,9 +192,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py
index c9d2f3cb9..e57b5c859 100644
--- a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py
@@ -185,7 +185,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -208,7 +208,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
diff --git a/egs/csj/ASR/local/compute_fbank_csj.py b/egs/csj/ASR/local/compute_fbank_csj.py
index c248aa668..667ad427e 100644
--- a/egs/csj/ASR/local/compute_fbank_csj.py
+++ b/egs/csj/ASR/local/compute_fbank_csj.py
@@ -25,7 +25,9 @@ from random import Random
 from typing import List, Tuple
 
 import torch
-from lhotse import (  # fmt: off; See the following for why LilcomChunkyWriter is preferred; https://github.com/k2-fsa/icefall/pull/404; https://github.com/lhotse-speech/lhotse/pull/527; fmt: on
+
+# fmt: off
+from lhotse import (  # See the following for why LilcomChunkyWriter is preferred; https://github.com/k2-fsa/icefall/pull/404; https://github.com/lhotse-speech/lhotse/pull/527
     CutSet,
     Fbank,
     FbankConfig,
@@ -34,6 +36,8 @@ from lhotse import (  # fmt: off; See the following for why LilcomChunkyWriter i
     SupervisionSet,
 )
 
+# fmt: on
+
 ARGPARSE_DESCRIPTION = """
 This script follows the espnet method of splitting the remaining core+noncore
 utterances into valid and train cutsets at an index which is by default 4000.
@@ -92,7 +96,7 @@ def make_cutset_blueprints(
     cut_set = cut_set.shuffle(Random(RNG_SEED))
 
     logging.info(
-        "Creating valid and train cuts from core and noncore," f"split at {split}."
+        "Creating valid and train cuts from core and noncore, split at {split}."
     )
     valid_set = CutSet.from_cuts(islice(cut_set, 0, split))
 
diff --git a/egs/csj/ASR/local/prepare_lang_char.py b/egs/csj/ASR/local/prepare_lang_char.py
index ef91f6e43..16107f543 100644
--- a/egs/csj/ASR/local/prepare_lang_char.py
+++ b/egs/csj/ASR/local/prepare_lang_char.py
@@ -87,7 +87,7 @@ def main():
     args = get_args()
 
     logging.basicConfig(
-        format=("%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] " "%(message)s"),
+        format=("%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"),
         level=logging.INFO,
     )
 
@@ -109,7 +109,7 @@ def main():
 
     words = set()
     logging.info(
-        f"Creating vocabulary from {args.train_cut.name}" f" at {args.trans_mode} mode."
+        f"Creating vocabulary from {args.train_cut.name} at {args.trans_mode} mode."
     )
     for cut in train_set:
         try:
@@ -120,7 +120,7 @@ def main():
             )
         except KeyError:
             raise KeyError(
-                f"Could not find {args.trans_mode} in " f"{cut.supervisions[0].custom}"
+                f"Could not find {args.trans_mode} in {cut.supervisions[0].custom}"
             )
         for t in text.split():
             if t in args.userdef_string:
diff --git a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py
index 72dcd772a..9437c935c 100644
--- a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py
+++ b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py
@@ -183,7 +183,7 @@ class GigaSpeechAsrDataModule:
             "--small-dev",
             type=str2bool,
             default=False,
-            help="Should we use only 1000 utterances for dev " "(speeds up training)",
+            help="Should we use only 1000 utterances for dev (speeds up training)",
         )
 
     def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
index 7f114fba6..5c01d7190 100644
--- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -195,7 +195,7 @@ class GigaSpeechAsrDataModule:
             "--small-dev",
             type=str2bool,
             default=False,
-            help="Should we use only 1000 utterances for dev " "(speeds up training)",
+            help="Should we use only 1000 utterances for dev (speeds up training)",
         )
 
     def train_dataloaders(
diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py
index c0b17750e..8595c27bd 100755
--- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py
@@ -184,7 +184,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -498,7 +498,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py
index 3d1e7bc18..b6190e8a6 100755
--- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py
@@ -116,7 +116,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
@@ -155,7 +155,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py
index f51584120..9edc42b61 100755
--- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py
@@ -176,7 +176,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -199,7 +199,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py
index 8200af866..30def9c40 100755
--- a/egs/librispeech/ASR/conformer_ctc/pretrained.py
+++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py
@@ -236,9 +236,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py
index 6854c82d8..365e8b8a7 100755
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py
@@ -215,7 +215,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py
index 1aaa3b9cb..91f50cf67 100644
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py
@@ -445,7 +445,7 @@ class EmformerAttention(nn.Module):
 
         if embed_dim % nhead != 0:
             raise ValueError(
-                f"embed_dim ({embed_dim}) is not a multiple of" f"nhead ({nhead})."
+                f"embed_dim ({embed_dim}) is not a multiple of nhead ({nhead})."
             )
 
         self.embed_dim = embed_dim
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py
index 334682ad6..09a3e96b0 100755
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py
@@ -136,7 +136,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py
index 621eeb952..c93125c80 100755
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py
@@ -211,7 +211,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py
index 3d8d4a18a..213115854 100755
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py
@@ -263,7 +263,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -286,7 +286,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py
index d3c001942..78e1f4096 100755
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py
@@ -215,7 +215,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py
index c3739566f..3cedf99b6 100644
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py
@@ -445,7 +445,7 @@ class EmformerAttention(nn.Module):
 
         if embed_dim % nhead != 0:
             raise ValueError(
-                f"embed_dim ({embed_dim}) is not a multiple of" f"nhead ({nhead})."
+                f"embed_dim ({embed_dim}) is not a multiple of nhead ({nhead})."
             )
 
         self.embed_dim = embed_dim
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py
index 998fb6e81..949214aec 100755
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py
@@ -136,7 +136,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py
index 618d8bb63..b2cb2c96b 100755
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py
@@ -211,7 +211,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py
index 542f524a9..6a019fd63 100755
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py
@@ -263,7 +263,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -286,7 +286,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/local/filter_cuts.py b/egs/librispeech/ASR/local/filter_cuts.py
index b3f0956c3..fbcc9e24a 100644
--- a/egs/librispeech/ASR/local/filter_cuts.py
+++ b/egs/librispeech/ASR/local/filter_cuts.py
@@ -79,7 +79,7 @@ def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor):
         total += 1
         if c.duration < 1.0 or c.duration > 20.0:
             logging.warning(
-                f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}"
+                f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
             )
             removed += 1
             return False
@@ -124,7 +124,7 @@ def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor):
     ans = cut_set.filter(remove_short_and_long_utterances).to_eager()
     ratio = removed / total * 100
     logging.info(
-        f"Removed {removed} cuts from {total} cuts. " f"{ratio:.3f}% data is removed."
+        f"Removed {removed} cuts from {total} cuts. {ratio:.3f}% data is removed."
     )
     return ans
 
diff --git a/egs/librispeech/ASR/local/prepare_lm_training_data.py b/egs/librispeech/ASR/local/prepare_lm_training_data.py
index 32ae8c580..70343fef7 100755
--- a/egs/librispeech/ASR/local/prepare_lm_training_data.py
+++ b/egs/librispeech/ASR/local/prepare_lm_training_data.py
@@ -137,7 +137,7 @@ def main():
     for i in range(num_sentences):
         if step and i % step == 0:
             logging.info(
-                f"Processed number of lines: {i} " f"({i/num_sentences*100: .3f}%)"
+                f"Processed number of lines: {i} ({i/num_sentences*100: .3f}%)"
             )
 
         word_ids = sentences[i]
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py
index 79b21fab1..3ad08f56a 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py
@@ -272,7 +272,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/export.py b/egs/librispeech/ASR/lstm_transducer_stateless/export.py
index 45fa6d662..e338342cc 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless/export.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless/export.py
@@ -172,7 +172,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     add_model_arguments(parser)
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py
index 51f4a2e8a..c07956243 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py
@@ -123,9 +123,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py
index 9263b41b2..b3a34a9e3 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py
@@ -166,7 +166,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -197,9 +197,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py
index 4cc2aabb2..961d8ddfb 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py
@@ -199,7 +199,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/train.py b/egs/librispeech/ASR/lstm_transducer_stateless/train.py
index b9a68753e..a54108f6d 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless/train.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless/train.py
@@ -220,7 +220,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -243,7 +243,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
@@ -970,7 +970,7 @@ def run(rank, world_size, args):
         # the threshold
         if c.duration < 1.0 or c.duration > 20.0:
             logging.warning(
-                f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}"
+                f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
             )
             return False
 
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py
index 41602d207..69f695fef 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py
@@ -295,7 +295,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py
index 2a25cb46a..5977cb36d 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py
@@ -225,7 +225,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     add_model_arguments(parser)
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py
index 40f11018f..728b09104 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py
@@ -124,9 +124,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py
index ab2f17480..3b471fa85 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py
@@ -198,9 +198,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py
index 2983328bf..f3f272b9f 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py
@@ -169,7 +169,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -200,9 +200,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py
index a787a00e6..baff15ea6 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py
@@ -186,9 +186,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py
index e896fd510..34d2e5630 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py
@@ -147,9 +147,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py
index 056285c64..8736384b4 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py
@@ -161,7 +161,7 @@ def get_parser():
         "--full-libri",
         type=str2bool,
         default=True,
-        help="When enabled, use 960h LibriSpeech. " "Otherwise, use 100h subset.",
+        help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.",
     )
 
     parser.add_argument(
@@ -235,7 +235,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -258,7 +258,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
@@ -986,7 +986,7 @@ def filter_short_and_long_utterances(
         # the threshold
         if c.duration < 1.0 or c.duration > 20.0:
             logging.warning(
-                f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}"
+                f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
             )
             return False
 
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py
index cba1ac689..b7953e5e3 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py
@@ -290,7 +290,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/export.py b/egs/librispeech/ASR/lstm_transducer_stateless3/export.py
index 457bd472f..a82cad043 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless3/export.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless3/export.py
@@ -172,7 +172,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     add_model_arguments(parser)
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py
index 71b37ac55..237591a36 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py
@@ -123,9 +123,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py
index e72f4ee42..f49e9c518 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py
@@ -166,7 +166,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -197,9 +197,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py
index dad6b905f..109746ed5 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py
@@ -199,7 +199,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py
index 97ca4b94c..f56b4fd83 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py
@@ -230,7 +230,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -253,7 +253,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
@@ -987,7 +987,7 @@ def run(rank, world_size, args):
         # the threshold
         if c.duration < 1.0 or c.duration > 20.0:
             logging.warning(
-                f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}"
+                f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
             )
             return False
 
diff --git a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py
index 3dc9164f8..b839a4a4c 100644
--- a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py
+++ b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py
@@ -83,7 +83,7 @@ class LibriSpeechAsrDataModule:
             "--full-libri",
             type=str2bool,
             default=True,
-            help="When enabled, use 960h LibriSpeech. " "Otherwise, use 100h subset.",
+            help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.",
         )
         group.add_argument(
             "--manifest-dir",
diff --git a/egs/librispeech/ASR/pruned2_knowledge/decode.py b/egs/librispeech/ASR/pruned2_knowledge/decode.py
index c3e7b01ab..40d14bb5a 100755
--- a/egs/librispeech/ASR/pruned2_knowledge/decode.py
+++ b/egs/librispeech/ASR/pruned2_knowledge/decode.py
@@ -182,7 +182,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/librispeech/ASR/pruned2_knowledge/export.py b/egs/librispeech/ASR/pruned2_knowledge/export.py
index ce5f162bf..51020aa30 100755
--- a/egs/librispeech/ASR/pruned2_knowledge/export.py
+++ b/egs/librispeech/ASR/pruned2_knowledge/export.py
@@ -105,7 +105,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
diff --git a/egs/librispeech/ASR/pruned2_knowledge/train.py b/egs/librispeech/ASR/pruned2_knowledge/train.py
index c322abaf8..123d448bb 100755
--- a/egs/librispeech/ASR/pruned2_knowledge/train.py
+++ b/egs/librispeech/ASR/pruned2_knowledge/train.py
@@ -177,7 +177,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -200,7 +200,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py
index 891719f3d..0e3b7ff74 100755
--- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py
+++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py
@@ -204,7 +204,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py
index 047a1d476..3612a2bfd 100755
--- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py
+++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py
@@ -133,7 +133,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     add_model_arguments(parser)
diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py
index 69e74cc57..ed3fa1521 100755
--- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py
+++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py
@@ -209,7 +209,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -232,7 +232,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
@@ -898,7 +898,7 @@ def run(rank, world_size, args):
         # the threshold
         if c.duration < 1.0 or c.duration > 20.0:
             logging.warning(
-                f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}"
+                f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
             )
             return False
 
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py
index 12bd7f9bb..0444afe40 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py
@@ -265,7 +265,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -703,7 +703,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/export.py b/egs/librispeech/ASR/pruned_transducer_stateless/export.py
index be45536d8..a19f9ab9a 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless/export.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless/export.py
@@ -105,7 +105,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py
index 6e91e0501..2ed1725b4 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py
@@ -168,7 +168,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -220,9 +220,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py
index ce8e2f348..fbc39fb65 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py
@@ -158,7 +158,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -519,7 +519,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py
index 7861df874..4dabbccc1 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py
@@ -203,7 +203,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -226,7 +226,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
@@ -889,7 +889,7 @@ def run(rank, world_size, args):
         # the threshold
         if c.duration < 1.0 or c.duration > 20.0:
             logging.warning(
-                f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}"
+                f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
             )
             return False
 
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py
index 92138a5ea..5f135f219 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py
@@ -271,7 +271,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -725,7 +725,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py
index 4f1170bbc..984caf5f2 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py
@@ -116,7 +116,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -168,7 +168,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py
index e5b5aeba5..013964720 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py
@@ -168,7 +168,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -221,9 +221,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py
index 0eea3a782..bb08246d9 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py
@@ -158,7 +158,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -522,7 +522,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py
index f6702ef16..86333fc97 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py
@@ -208,7 +208,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need to " "be changed.",
+        help="The initial learning rate.  This value should not need to be changed.",
     )
 
     parser.add_argument(
@@ -231,7 +231,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -254,7 +254,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
@@ -947,7 +947,7 @@ def run(rank, world_size, args):
         # the threshold
         if c.duration < 1.0 or c.duration > 20.0:
             logging.warning(
-                f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}"
+                f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
             )
             return False
 
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py
index df24d9585..b4804ecde 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py
@@ -188,7 +188,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -552,7 +552,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py
index 55585e08c..03137501f 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py
@@ -261,7 +261,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -681,7 +681,7 @@ def decode_one_batch(
         return {key: hyps}
     else:
         return {
-            (f"beam_size_{params.beam_size}_" f"temperature_{params.temperature}"): hyps
+            (f"beam_size_{params.beam_size}_temperature_{params.temperature}"): hyps
         }
 
 
@@ -963,7 +963,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py
index 2e444353c..239bdc12f 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py
@@ -231,7 +231,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -607,7 +607,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py
index 86cb45c09..0669284b3 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py
@@ -142,9 +142,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py
index 825c6510b..550cf6aad 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py
@@ -140,9 +140,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py
index 77bd6d13d..7c3dfc660 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py
@@ -177,7 +177,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -230,9 +230,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py
index e85d2060a..0e5111f33 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py
@@ -159,7 +159,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -521,7 +521,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py
index e9ceb60de..281ba4650 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py
@@ -161,7 +161,7 @@ def get_parser():
         "--full-libri",
         type=str2bool,
         default=True,
-        help="When enabled, use 960h LibriSpeech. " "Otherwise, use 100h subset.",
+        help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.",
     )
 
     parser.add_argument(
@@ -211,7 +211,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need " "to be changed.",
+        help="The initial learning rate.  This value should not need to be changed.",
     )
 
     parser.add_argument(
@@ -234,7 +234,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -257,7 +257,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
@@ -950,7 +950,7 @@ def filter_short_and_long_utterances(
         # the threshold
         if c.duration < 1.0 or c.duration > 20.0:
             logging.warning(
-                f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}"
+                f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
             )
             return False
 
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py
index 2f9a60f13..f5cbc21f7 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py
@@ -306,7 +306,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py
index 64ef89733..401b3ef3a 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py
@@ -133,7 +133,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py
index d74d1c89d..c4e3cef16 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py
@@ -175,7 +175,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py
index 97f3e56a9..cb56c8294 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py
@@ -237,7 +237,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -260,7 +260,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
@@ -994,7 +994,7 @@ def run(rank, world_size, args):
         # the threshold
         if c.duration < 1.0 or c.duration > 20.0:
             logging.warning(
-                f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}"
+                f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
             )
             return False
 
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py
index 5c76afde6..8b993f638 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py
@@ -303,7 +303,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py
index f0bfd3b4c..a4fad1e59 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py
@@ -133,7 +133,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py
index 77ba0873b..74a2210c3 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py
@@ -166,7 +166,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -197,9 +197,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py
index e750f5554..064811f1c 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py
@@ -175,7 +175,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py
index a1a810d3e..436620744 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py
@@ -246,7 +246,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need " "to be changed.",
+        help="The initial learning rate.  This value should not need to be changed.",
     )
 
     parser.add_argument(
@@ -269,7 +269,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -292,7 +292,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
@@ -1025,7 +1025,7 @@ def run(rank, world_size, args):
         # the threshold
         if c.duration < 1.0 or c.duration > 20.0:
             logging.warning(
-                f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}"
+                f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
             )
             return False
 
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py
index 3734564fe..fd9de052a 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py
@@ -208,7 +208,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/export.py b/egs/librispeech/ASR/pruned_transducer_stateless6/export.py
index 3d1e7bc18..b6190e8a6 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless6/export.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless6/export.py
@@ -116,7 +116,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
@@ -155,7 +155,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py
index a24becb14..8f4d3b879 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py
@@ -201,7 +201,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -224,7 +224,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
@@ -986,7 +986,7 @@ def run(rank, world_size, args):
         # the threshold
         if c.duration < 1.0 or c.duration > 20.0:
             logging.warning(
-                f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}"
+                f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
             )
             return False
 
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py
index 162966df8..bc15948fc 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py
@@ -272,7 +272,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py
index 57af52fb1..9a6f3ed37 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py
@@ -176,7 +176,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     add_model_arguments(parser)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py
index f469442ed..5af6dae25 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py
@@ -93,9 +93,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py
index 758e0c036..d05bafcfb 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py
@@ -177,7 +177,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -208,9 +208,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py
index 7160fc54a..b27c573ab 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py
@@ -267,7 +267,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -290,7 +290,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
@@ -1031,7 +1031,7 @@ def run(rank, world_size, args):
         # the threshold
         if c.duration < 1.0 or c.duration > 20.0:
             logging.warning(
-                f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}"
+                f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
             )
             return False
 
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py
index 3d89ae00a..e61367134 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py
@@ -273,7 +273,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/export.py b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py
index 0a962149d..d4a228b47 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless8/export.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py
@@ -176,7 +176,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     add_model_arguments(parser)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py
index c458ee5a9..129497d5a 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py
@@ -93,9 +93,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py
index f1f0771ef..486d9d74e 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py
@@ -177,7 +177,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -208,9 +208,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py
index ba8ed3ea8..abe249c7b 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py
@@ -212,7 +212,7 @@ def get_parser():
         "--full-libri",
         type=str2bool,
         default=True,
-        help="When enabled, use 960h LibriSpeech. " "Otherwise, use 100h subset.",
+        help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.",
     )
 
     parser.add_argument(
@@ -282,7 +282,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -305,7 +305,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
@@ -1030,7 +1030,7 @@ def filter_short_and_long_utterances(
         # the threshold
         if c.duration < 1.0 or c.duration > 20.0:
             logging.warning(
-                f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}"
+                f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
             )
             return False
 
diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py b/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py
index 3965bd5c3..a26d0b789 100755
--- a/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py
+++ b/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py
@@ -87,7 +87,7 @@ def get_parser():
         "--tailing-num-frames",
         type=int,
         default=20,
-        help="tailing dummy frames padded to the right," "only used during decoding",
+        help="tailing dummy frames padded to the right, only used during decoding",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
index 993a7cab5..95d1b273a 100644
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -86,7 +86,7 @@ class LibriSpeechAsrDataModule:
             "--full-libri",
             type=str2bool,
             default=True,
-            help="When enabled, use 960h LibriSpeech. " "Otherwise, use 100h subset.",
+            help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.",
         )
         group.add_argument(
             "--manifest-dir",
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
index addadbe4e..fde724866 100755
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
@@ -138,9 +138,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/transducer/pretrained.py b/egs/librispeech/ASR/transducer/pretrained.py
index b1ff7b2b1..511610245 100755
--- a/egs/librispeech/ASR/transducer/pretrained.py
+++ b/egs/librispeech/ASR/transducer/pretrained.py
@@ -188,9 +188,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/transducer_stateless/compute_ali.py b/egs/librispeech/ASR/transducer_stateless/compute_ali.py
index c91198bb9..f479389df 100755
--- a/egs/librispeech/ASR/transducer_stateless/compute_ali.py
+++ b/egs/librispeech/ASR/transducer_stateless/compute_ali.py
@@ -124,7 +124,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py
index 688e214c8..643238f1b 100755
--- a/egs/librispeech/ASR/transducer_stateless/decode.py
+++ b/egs/librispeech/ASR/transducer_stateless/decode.py
@@ -171,7 +171,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/librispeech/ASR/transducer_stateless/export.py b/egs/librispeech/ASR/transducer_stateless/export.py
index c617e6c4c..89359f1a4 100755
--- a/egs/librispeech/ASR/transducer_stateless/export.py
+++ b/egs/librispeech/ASR/transducer_stateless/export.py
@@ -109,7 +109,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
diff --git a/egs/librispeech/ASR/transducer_stateless/pretrained.py b/egs/librispeech/ASR/transducer_stateless/pretrained.py
index c393974e6..915a6069d 100755
--- a/egs/librispeech/ASR/transducer_stateless/pretrained.py
+++ b/egs/librispeech/ASR/transducer_stateless/pretrained.py
@@ -167,7 +167,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -196,9 +196,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py
index c86125f44..bcb883fa5 100755
--- a/egs/librispeech/ASR/transducer_stateless/train.py
+++ b/egs/librispeech/ASR/transducer_stateless/train.py
@@ -136,7 +136,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/transducer_stateless2/decode.py b/egs/librispeech/ASR/transducer_stateless2/decode.py
index c642b16bd..9a6363629 100755
--- a/egs/librispeech/ASR/transducer_stateless2/decode.py
+++ b/egs/librispeech/ASR/transducer_stateless2/decode.py
@@ -171,7 +171,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/librispeech/ASR/transducer_stateless2/export.py b/egs/librispeech/ASR/transducer_stateless2/export.py
index 229c514b9..d33d02642 100755
--- a/egs/librispeech/ASR/transducer_stateless2/export.py
+++ b/egs/librispeech/ASR/transducer_stateless2/export.py
@@ -104,7 +104,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
diff --git a/egs/librispeech/ASR/transducer_stateless2/pretrained.py b/egs/librispeech/ASR/transducer_stateless2/pretrained.py
index 9053bc6e0..0738f30c0 100755
--- a/egs/librispeech/ASR/transducer_stateless2/pretrained.py
+++ b/egs/librispeech/ASR/transducer_stateless2/pretrained.py
@@ -167,7 +167,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -196,9 +196,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/transducer_stateless2/train.py b/egs/librispeech/ASR/transducer_stateless2/train.py
index 71c9c5df7..68e247f23 100755
--- a/egs/librispeech/ASR/transducer_stateless2/train.py
+++ b/egs/librispeech/ASR/transducer_stateless2/train.py
@@ -136,7 +136,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py
index 253821028..56ad558c6 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py
@@ -172,7 +172,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py
index 97b0eea4a..3735ef452 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py
@@ -110,7 +110,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py
index c698a35b0..8c7726367 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py
@@ -167,7 +167,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -196,9 +196,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
index e5b7dc390..88987d91c 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
@@ -114,7 +114,7 @@ def get_parser():
         "--full-libri",
         type=str2bool,
         default=True,
-        help="When enabled, use 960h LibriSpeech. " "Otherwise, use 100h subset.",
+        help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.",
     )
 
     parser.add_argument(
@@ -169,7 +169,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py
index 098da3ff0..219c96d60 100755
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py
@@ -183,7 +183,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -505,7 +505,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py
index e79cb300d..68763808a 100755
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py
@@ -115,7 +115,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py
index 213635894..d943180b1 100755
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py
@@ -153,7 +153,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need to be " "changed.",
+        help="The initial learning rate.  This value should not need to be changed.",
     )
 
     parser.add_argument(
@@ -176,7 +176,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -199,7 +199,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py
index 82e1a9437..bf91fef7e 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py
@@ -208,7 +208,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
index d0875c5f5..bc33dd160 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
@@ -139,7 +139,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     add_model_arguments(parser)
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py
index da4e3bc2f..3305f5bd3 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py
@@ -165,7 +165,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -196,9 +196,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py
index 97d434157..43f3231ba 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py
@@ -212,7 +212,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need " "to be changed.",
+        help="The initial learning rate.  This value should not need to be changed.",
     )
 
     parser.add_argument(
@@ -235,7 +235,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -258,7 +258,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py
index 8ca875c24..38f2ae83c 100755
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py
@@ -172,7 +172,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/export.py b/egs/tedlium3/ASR/pruned_transducer_stateless/export.py
index 71a9e2d71..aa22f82ec 100644
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/export.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/export.py
@@ -106,7 +106,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py
index e8a453c80..8a89c3578 100644
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py
@@ -165,7 +165,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -202,9 +202,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/train.py b/egs/tedlium3/ASR/pruned_transducer_stateless/train.py
index 59d80a0d8..170f37767 100755
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/train.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/train.py
@@ -133,7 +133,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -156,7 +156,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
diff --git a/egs/tedlium3/ASR/transducer_stateless/decode.py b/egs/tedlium3/ASR/transducer_stateless/decode.py
index e5ab2c107..01f08ce59 100755
--- a/egs/tedlium3/ASR/transducer_stateless/decode.py
+++ b/egs/tedlium3/ASR/transducer_stateless/decode.py
@@ -130,7 +130,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/tedlium3/ASR/transducer_stateless/export.py b/egs/tedlium3/ASR/transducer_stateless/export.py
index c2ec7a590..48dcdc736 100644
--- a/egs/tedlium3/ASR/transducer_stateless/export.py
+++ b/egs/tedlium3/ASR/transducer_stateless/export.py
@@ -110,7 +110,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
diff --git a/egs/tedlium3/ASR/transducer_stateless/pretrained.py b/egs/tedlium3/ASR/transducer_stateless/pretrained.py
index 070b070a7..81afd6a4e 100644
--- a/egs/tedlium3/ASR/transducer_stateless/pretrained.py
+++ b/egs/tedlium3/ASR/transducer_stateless/pretrained.py
@@ -127,7 +127,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -221,9 +221,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/tedlium3/ASR/transducer_stateless/train.py b/egs/tedlium3/ASR/transducer_stateless/train.py
index 4fc13b1da..6fed32e81 100755
--- a/egs/tedlium3/ASR/transducer_stateless/train.py
+++ b/egs/tedlium3/ASR/transducer_stateless/train.py
@@ -133,7 +133,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py b/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py
index 4ef134412..3fdf3b855 100644
--- a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py
+++ b/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py
@@ -138,9 +138,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py b/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py
index 3f143912e..98c746ce5 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py
@@ -138,9 +138,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py
index cd9ed57b9..04602ea2e 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py
@@ -248,7 +248,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
index df2fc5df5..8c4fbdd47 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
@@ -205,7 +205,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
index 42ffbcfb8..f90dd2b43 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
@@ -145,9 +145,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py
index ca1e408fa..9e34b4427 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py
@@ -149,9 +149,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
index aaf7ac874..bc499f3dd 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
@@ -158,7 +158,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -188,9 +188,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
index 7aba0711d..43fa0d01b 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
@@ -217,7 +217,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -240,7 +240,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
index 166497c31..7bd1177bd 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
@@ -244,7 +244,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
index ff2c4db38..35577c327 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
@@ -131,7 +131,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     add_model_arguments(parser)
 
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py
index 7e4829a60..1cac20435 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py
@@ -157,7 +157,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -188,9 +188,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py
index 6909f40be..c7863415b 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py
@@ -201,7 +201,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py
index 5f614e77c..440b65f32 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py
@@ -258,7 +258,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need " "to be changed.",
+        help="The initial learning rate.  This value should not need to be changed.",
     )
 
     parser.add_argument(
@@ -281,7 +281,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -304,7 +304,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
diff --git a/egs/yesno/ASR/tdnn/pretrained.py b/egs/yesno/ASR/tdnn/pretrained.py
index 88d5eca5d..65be77db1 100755
--- a/egs/yesno/ASR/tdnn/pretrained.py
+++ b/egs/yesno/ASR/tdnn/pretrained.py
@@ -99,9 +99,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/icefall/shared/make_kn_lm.py b/icefall/shared/make_kn_lm.py
index b1220d55e..7150297d6 100755
--- a/icefall/shared/make_kn_lm.py
+++ b/icefall/shared/make_kn_lm.py
@@ -45,13 +45,13 @@ parser.add_argument(
 )
 args = parser.parse_args()
 
-default_encoding = (
-    "latin-1"  # For encoding-agnostic scripts, we assume byte stream as input.
-)
+# For encoding-agnostic scripts, we assume byte stream as input.
 # Need to be very careful about the use of strip() and split()
 # in this case, because there is a latin-1 whitespace character
 # (nbsp) which is part of the unicode encoding range.
 # Ref: kaldi/egs/wsj/s5/utils/lang/bpe/prepend_words.py @ 69cd717
+default_encoding = "latin-1"
+
 strip_chars = " \t\r\n"
 whitespace = re.compile("[ \t]+")
 
@@ -65,9 +65,8 @@ class CountsForHistory:
         # The 'lambda: defaultdict(float)' is an anonymous function taking no
         # arguments that returns a new defaultdict(float).
         self.word_to_count = defaultdict(int)
-        self.word_to_context = defaultdict(
-            set
-        )  # using a set to count the number of unique contexts
+        # using a set to count the number of unique contexts
+        self.word_to_context = defaultdict(set)
         self.word_to_f = dict()  # discounted probability
         self.word_to_bow = dict()  # back-off weight
         self.total_count = 0
@@ -151,9 +150,8 @@ class NgramCounts:
 
     def add_raw_counts_from_standard_input(self):
         lines_processed = 0
-        infile = io.TextIOWrapper(
-            sys.stdin.buffer, encoding=default_encoding
-        )  # byte stream as input
+        # byte stream as input
+        infile = io.TextIOWrapper(sys.stdin.buffer, encoding=default_encoding)
         for line in infile:
             line = line.strip(strip_chars)
             self.add_raw_counts_from_line(line)
@@ -187,11 +185,10 @@ class NgramCounts:
         # This constant is used similarly to absolute discounting.
         # Return value: d is a list of floats, where d[N+1] = D_N
 
-        self.d = [
-            0
-        ]  # for the lowest order, i.e., 1-gram, we do not need to discount, thus the constant is 0
+        # for the lowest order, i.e., 1-gram, we do not need to discount, thus the constant is 0
         # This is a special case: as we currently assumed having seen all vocabularies in the dictionary,
         # but perhaps this is not the case for some other scenarios.
+        self.d = [0]
         for n in range(1, self.ngram_order):
             this_order_counts = self.counts[n]
             n1 = 0
@@ -201,11 +198,11 @@ class NgramCounts:
                 n1 += stat[1]
                 n2 += stat[2]
             assert n1 + 2 * n2 > 0
-            self.d.append(
-                max(0.1, n1 * 1.0) / (n1 + 2 * n2)
-            )  # We are doing this max(0.001, xxx) to avoid zero discounting constant D due to n1=0,
+
+            # We are doing this max(0.001, xxx) to avoid zero discounting constant D due to n1=0,
             # which could happen if the number of symbols is small.
             # Otherwise, zero discounting constant can cause division by zero in computing BOW.
+            self.d.append(max(0.1, n1 * 1.0) / (n1 + 2 * n2))
 
     def cal_f(self):
         # f(a_z) is a probability distribution of word sequence a_z.
@@ -286,11 +283,8 @@ class NgramCounts:
                         sum_z1_f_z = 0
                         _ = a_[1:]
                         _counts_for_hist = self.counts[len(_)][_]
-                        for (
-                            u
-                        ) in (
-                            a_counts_for_hist.word_to_count.keys()
-                        ):  # Should be careful here: what is Z1
+                        # Should be careful here: what is Z1
+                        for u in a_counts_for_hist.word_to_count.keys():
                             sum_z1_f_z += _counts_for_hist.word_to_f[u]
 
                         if sum_z1_f_z < 1:

From 349dae35037ee468340e889fd99704336c16c2a2 Mon Sep 17 00:00:00 2001
From: Desh Raj 
Date: Thu, 17 Nov 2022 14:18:50 -0500
Subject: [PATCH 041/120] add revision commit to git blame ignore

---
 .git-blame-ignore-revs | 1 +
 1 file changed, 1 insertion(+)

diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs
index be5901517..5d65b98e9 100644
--- a/.git-blame-ignore-revs
+++ b/.git-blame-ignore-revs
@@ -1,2 +1,3 @@
 # Migrate to 88 characters per line (see: https://github.com/lhotse-speech/lhotse/issues/890)
 107df3b115a58f1b68a6458c3f94a130004be34c
+d31db010371a4128856480382876acdc0d1739ed

From fbe1e35b74e3ae593e3798e1a56e9d5b708a6767 Mon Sep 17 00:00:00 2001
From: Desh Raj 
Date: Fri, 18 Nov 2022 09:24:07 -0500
Subject: [PATCH 042/120] update code style docs

---
 docs/source/contributing/code-style.rst | 15 +++++++++++----
 1 file changed, 11 insertions(+), 4 deletions(-)

diff --git a/docs/source/contributing/code-style.rst b/docs/source/contributing/code-style.rst
index 7d61a3ba1..3baaaeec2 100644
--- a/docs/source/contributing/code-style.rst
+++ b/docs/source/contributing/code-style.rst
@@ -11,9 +11,9 @@ We use the following tools to make the code style to be as consistent as possibl
 
 The following versions of the above tools are used:
 
-  - ``black == 12.6b0``
-  - ``flake8 == 3.9.2``
-  - ``isort == 5.9.2``
+  - ``black == 22.3.0``
+  - ``flake8 == 5.0.4``
+  - ``isort == 5.10.1``
 
 After running the following commands:
 
@@ -54,10 +54,17 @@ it should succeed this time:
 If you want to check the style of your code before ``git commit``, you
 can do the following:
 
+  .. code-block:: bash
+
+    $ pre-commit install
+    $ pre-commit run
+
+Or without installing the pre-commit hooks:
+
   .. code-block:: bash
 
     $ cd icefall
-    $ pip install black==21.6b0 flake8==3.9.2 isort==5.9.2
+    $ pip install black==22.3.0 flake8==5.0.4 isort==5.10.1
     $ black --check your_changed_file.py
     $ black your_changed_file.py  # modify it in-place
     $

From 53454701cb69ec23be8a37c6ab69f1cf5104585d Mon Sep 17 00:00:00 2001
From: marcoyang 
Date: Tue, 22 Nov 2022 11:39:21 +0800
Subject: [PATCH 043/120] fix segmentation fault

---
 egs/aidatatang_200zh/ASR/prepare.sh | 3 +++
 egs/aishell/ASR/prepare.sh          | 3 +++
 egs/aishell2/ASR/prepare.sh         | 3 +++
 egs/aishell4/ASR/prepare.sh         | 3 +++
 egs/alimeeting/ASR/prepare.sh       | 3 +++
 egs/csj/ASR/prepare.sh              | 3 +++
 egs/gigaspeech/ASR/prepare.sh       | 3 +++
 egs/librispeech/ASR/prepare.sh      | 3 +++
 egs/ptb/LM/prepare.sh               | 3 +++
 egs/spgispeech/ASR/prepare.sh       | 3 +++
 egs/tal_csasr/ASR/prepare.sh        | 3 +++
 egs/tedlium3/ASR/prepare.sh         | 3 +++
 egs/timit/ASR/prepare.sh            | 3 +++
 egs/wenetspeech/ASR/prepare.sh      | 3 +++
 egs/yesno/ASR/prepare.sh            | 3 +++
 15 files changed, 45 insertions(+)

diff --git a/egs/aidatatang_200zh/ASR/prepare.sh b/egs/aidatatang_200zh/ASR/prepare.sh
index 4749e1b7f..46ecd5769 100755
--- a/egs/aidatatang_200zh/ASR/prepare.sh
+++ b/egs/aidatatang_200zh/ASR/prepare.sh
@@ -1,5 +1,8 @@
 #!/usr/bin/env bash
 
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
 set -eou pipefail
 
 stage=-1
diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh
index eaeecfc4a..5917668a1 100755
--- a/egs/aishell/ASR/prepare.sh
+++ b/egs/aishell/ASR/prepare.sh
@@ -1,5 +1,8 @@
 #!/usr/bin/env bash
 
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
 set -eou pipefail
 
 nj=15
diff --git a/egs/aishell2/ASR/prepare.sh b/egs/aishell2/ASR/prepare.sh
index 06810bfdd..3e8e840ab 100755
--- a/egs/aishell2/ASR/prepare.sh
+++ b/egs/aishell2/ASR/prepare.sh
@@ -1,5 +1,8 @@
 #!/usr/bin/env bash
 
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
 set -eou pipefail
 
 nj=30
diff --git a/egs/aishell4/ASR/prepare.sh b/egs/aishell4/ASR/prepare.sh
index c351e3964..cb2b73a3e 100755
--- a/egs/aishell4/ASR/prepare.sh
+++ b/egs/aishell4/ASR/prepare.sh
@@ -1,5 +1,8 @@
 #!/usr/bin/env bash
 
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
 set -eou pipefail
 
 stage=-1
diff --git a/egs/alimeeting/ASR/prepare.sh b/egs/alimeeting/ASR/prepare.sh
index 17224bb68..604cc92c6 100755
--- a/egs/alimeeting/ASR/prepare.sh
+++ b/egs/alimeeting/ASR/prepare.sh
@@ -1,5 +1,8 @@
 #!/usr/bin/env bash
 
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
 set -eou pipefail
 
 stage=-1
diff --git a/egs/csj/ASR/prepare.sh b/egs/csj/ASR/prepare.sh
index 052748ca6..c4ce91984 100755
--- a/egs/csj/ASR/prepare.sh
+++ b/egs/csj/ASR/prepare.sh
@@ -35,6 +35,9 @@
 # can generate other transcript formats by supplying your own config files. A few examples of these
 # config files can be found in local/conf.
 
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
 set -eou pipefail
 
 nj=8
diff --git a/egs/gigaspeech/ASR/prepare.sh b/egs/gigaspeech/ASR/prepare.sh
index fd2532741..bd255dc6a 100755
--- a/egs/gigaspeech/ASR/prepare.sh
+++ b/egs/gigaspeech/ASR/prepare.sh
@@ -1,5 +1,8 @@
 #!/usr/bin/env bash
 
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
 set -eou pipefail
 
 nj=15
diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh
index 94e003036..8668af0e4 100755
--- a/egs/librispeech/ASR/prepare.sh
+++ b/egs/librispeech/ASR/prepare.sh
@@ -1,5 +1,8 @@
 #!/usr/bin/env bash
 
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
 set -eou pipefail
 
 nj=15
diff --git a/egs/ptb/LM/prepare.sh b/egs/ptb/LM/prepare.sh
index 70586785d..91c3c667a 100755
--- a/egs/ptb/LM/prepare.sh
+++ b/egs/ptb/LM/prepare.sh
@@ -1,5 +1,8 @@
 #!/usr/bin/env bash
 
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
 set -eou pipefail
 
 nj=15
diff --git a/egs/spgispeech/ASR/prepare.sh b/egs/spgispeech/ASR/prepare.sh
index 231ebd742..4842f52d0 100755
--- a/egs/spgispeech/ASR/prepare.sh
+++ b/egs/spgispeech/ASR/prepare.sh
@@ -1,5 +1,8 @@
 #!/usr/bin/env bash
 
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
 set -eou pipefail
 
 nj=20
diff --git a/egs/tal_csasr/ASR/prepare.sh b/egs/tal_csasr/ASR/prepare.sh
index 340521ad8..d9938fa63 100755
--- a/egs/tal_csasr/ASR/prepare.sh
+++ b/egs/tal_csasr/ASR/prepare.sh
@@ -1,5 +1,8 @@
 #!/usr/bin/env bash
 
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
 set -eou pipefail
 
 stage=-1
diff --git a/egs/tedlium3/ASR/prepare.sh b/egs/tedlium3/ASR/prepare.sh
index ccb307a52..272cf7aed 100755
--- a/egs/tedlium3/ASR/prepare.sh
+++ b/egs/tedlium3/ASR/prepare.sh
@@ -1,5 +1,8 @@
 #!/usr/bin/env bash
 
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
 set -eou pipefail
 
 nj=15
diff --git a/egs/timit/ASR/prepare.sh b/egs/timit/ASR/prepare.sh
index d11cd3a05..148a9f51b 100644
--- a/egs/timit/ASR/prepare.sh
+++ b/egs/timit/ASR/prepare.sh
@@ -1,5 +1,8 @@
 #!/usr/bin/env bash
 
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
 set -eou pipefail
 
 num_phones=39
diff --git a/egs/wenetspeech/ASR/prepare.sh b/egs/wenetspeech/ASR/prepare.sh
index da7d7e061..50a00253d 100755
--- a/egs/wenetspeech/ASR/prepare.sh
+++ b/egs/wenetspeech/ASR/prepare.sh
@@ -1,5 +1,8 @@
 #!/usr/bin/env bash
 
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
 set -eou pipefail
 
 nj=15
diff --git a/egs/yesno/ASR/prepare.sh b/egs/yesno/ASR/prepare.sh
index 8fcee0290..d4ef8d601 100755
--- a/egs/yesno/ASR/prepare.sh
+++ b/egs/yesno/ASR/prepare.sh
@@ -1,5 +1,8 @@
 #!/usr/bin/env bash
 
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
 set -eou pipefail
 
 stage=-1

From 4c636c2cfffd853a4dc1f618dd8a6fede78a3bea Mon Sep 17 00:00:00 2001
From: Senyan Li <1149593720@qq.com>
Date: Fri, 25 Nov 2022 14:39:56 +0800
Subject: [PATCH 044/120] fix librispeech ASR pruned_transducer_stateless5
 export (#704)

---
 egs/librispeech/ASR/pruned_transducer_stateless5/export.py      | 2 ++
 egs/librispeech/ASR/pruned_transducer_stateless5/lstmp.py       | 1 +
 .../ASR/pruned_transducer_stateless5/scaling_converter.py       | 1 +
 3 files changed, 4 insertions(+)
 create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless5/lstmp.py
 create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless5/scaling_converter.py

diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py
index a4fad1e59..54f656859 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py
@@ -50,6 +50,7 @@ from pathlib import Path
 
 import sentencepiece as spm
 import torch
+from scaling_converter import convert_scaled_to_non_scaled
 from train import add_model_arguments, get_params, get_transducer_model
 
 from icefall.checkpoint import (
@@ -263,6 +264,7 @@ def main():
         # it here.
         # Otherwise, one of its arguments is a ragged tensor and is not
         # torch scriptabe.
+        convert_scaled_to_non_scaled(model, inplace=True)
         model.__class__.forward = torch.jit.ignore(model.__class__.forward)
         logging.info("Using torch.jit.script")
         model = torch.jit.script(model)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/lstmp.py b/egs/librispeech/ASR/pruned_transducer_stateless5/lstmp.py
new file mode 120000
index 000000000..4f377cd01
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless5/lstmp.py
@@ -0,0 +1 @@
+../lstm_transducer_stateless2/lstmp.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless5/scaling_converter.py
new file mode 120000
index 000000000..3b667058d
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless5/scaling_converter.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless3/scaling_converter.py
\ No newline at end of file

From 89c3982a0760f135740556ae67c11d0af434303c Mon Sep 17 00:00:00 2001
From: Guo Liyong 
Date: Sat, 26 Nov 2022 00:50:21 +0800
Subject: [PATCH 045/120] show dominant parameters

---
 .../ASR/pruned_transducer_stateless7/optim.py | 79 ++++++++++++++++---
 .../ASR/pruned_transducer_stateless7/train.py | 13 ++-
 2 files changed, 79 insertions(+), 13 deletions(-)

diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py
index 8b90c9a0d..ab55381d7 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py
@@ -42,7 +42,7 @@ class BatchedOptimizer(Optimizer):
         super(BatchedOptimizer, self).__init__(params, defaults)
 
     @contextlib.contextmanager
-    def batched_params(self, param_group):
+    def batched_params(self, param_group, group_params_names=None):
         """
         This function returns (technically, yields) a list of
           of tuples (p, state), where
@@ -75,20 +75,28 @@ class BatchedOptimizer(Optimizer):
         batches = defaultdict(
             list
         )  # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
+        batches_names = defaultdict(
+            list
+        )  # `batches` maps from tuple (dtype_as_str,*shape) to list of str
 
-        for p in param_group:
+        for p, named_p in zip(param_group, group_params_names):
             key = (str(p.dtype), *p.shape)
             batches[key].append(p)
+            batches_names[key].append(named_p)
+
+        batches_names_keys = list(batches_names.keys())
+        sorted_idx = sorted(range(len(batches_names)), key=lambda i: batches_names_keys[i])
+        batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx]
+        batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
 
         stacked_params_dict = dict()
 
         # turn batches into a list, in deterministic order.
-        batches = [batches[key] for key in sorted(batches.keys())]
         # pairs will contain pairs of (stacked_param, state), one for each batch
         # in `batches`.
         pairs = []
 
-        for batch in batches:
+        for batch, batch_names in zip(batches, batches_names):
             p = batch[0]
             # we arbitrarily store the state in the
             # state corresponding to the 1st parameter in the
@@ -100,11 +108,11 @@ class BatchedOptimizer(Optimizer):
             )
             p_stacked.grad = grad
             stacked_params_dict[key] = p_stacked
-            pairs.append((p_stacked, state))
+            pairs.append((p_stacked, state, batch_names))
 
         yield pairs  # <-- calling code will do the actual optimization here!
 
-        for ((stacked_params, _state), batch) in zip(pairs, batches):
+        for ((stacked_params, _state, _names), batch) in zip(pairs, batches):
             for i, p in enumerate(batch):  # batch is list of Parameter
                 p.copy_(stacked_params[i])
 
@@ -165,6 +173,8 @@ class ScaledAdam(BatchedOptimizer):
         scalar_max=10.0,
         size_update_period=4,
         clipping_update_period=100,
+        parameters_names=None,
+        show_dominant_parameters=False,
     ):
 
         defaults = dict(
@@ -181,6 +191,8 @@ class ScaledAdam(BatchedOptimizer):
         )
 
         super(ScaledAdam, self).__init__(params, defaults)
+        self.parameters_names = parameters_names
+        self.show_dominant_parameters = show_dominant_parameters
 
     def __setstate__(self, state):
         super(ScaledAdam, self).__setstate__(state)
@@ -199,9 +211,11 @@ class ScaledAdam(BatchedOptimizer):
                 loss = closure()
 
         batch = True
-        for group in self.param_groups:
+        assert len(self.param_groups)  == len(self.parameters_names)
 
-            with self.batched_params(group["params"]) as batches:
+        for group, group_params_names in zip(self.param_groups, self.parameters_names):
+
+            with self.batched_params(group["params"], group_params_names) as batches:
 
                 # batches is list of pairs (stacked_param, state).  stacked_param is like
                 # a regular parameter, and will have a .grad, but the 1st dim corresponds to
@@ -214,7 +228,7 @@ class ScaledAdam(BatchedOptimizer):
                 else:
                     clipping_scale = self._get_clipping_scale(group, batches)
 
-                for p, state in batches:
+                for p, state, _ in batches:
                     # Perform optimization step.
                     # grad is not going to be None, we handled that when creating the batches.
                     grad = p.grad
@@ -276,7 +290,7 @@ class ScaledAdam(BatchedOptimizer):
         state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
 
     def _get_clipping_scale(
-        self, group: dict, pairs: List[Tuple[Tensor, dict]]
+        self, group: dict, pairs: List[Tuple[Tensor, dict, List[str]]]
     ) -> float:
         """
         Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
@@ -289,7 +303,7 @@ class ScaledAdam(BatchedOptimizer):
         """
         assert len(pairs) >= 1
         clipping_scale = group["clipping_scale"]
-        (first_p, first_state) = pairs[0]
+        (first_p, first_state, _) = pairs[0]
         step = first_state["step"]
         if clipping_scale is None or step == 0:
             # no clipping.  return early on step == 0 because the other
@@ -298,7 +312,7 @@ class ScaledAdam(BatchedOptimizer):
         clipping_update_period = group["clipping_update_period"]
 
         tot_sumsq = torch.tensor(0.0, device=first_p.device)
-        for (p, state) in pairs:
+        for (p, state, param_names) in pairs:
             grad = p.grad
             if grad.is_sparse:
                 raise RuntimeError(
@@ -361,8 +375,49 @@ class ScaledAdam(BatchedOptimizer):
                 logging.warn(
                     f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
                 )
+                if self.show_dominant_parameters:
+                    assert p.shape[0] == len(param_names)
+                    self._show_gradient_dominating_parameter(pairs, tot_sumsq)
             return ans
 
+    def _show_gradient_dominating_parameter(self, pairs, tot_sumsq):
+        # ori means calculated with state["param_rms"]
+        # cur means calculated with "param_rms" of current param.
+        # bt is short batch
+        # all_sumsq_ori_rms
+        all_sumsq_ori = {}
+        all_sumsq_cur = {}
+        for (p, state, batch_param_names) in pairs:
+            # p is a stacked batch parameters.
+            grad = p.grad
+            if p.numel() == p.shape[0]:  # a batch of scalars
+                batch_sumsq_ori = grad**2  # sum() to change shape [1] to []
+                batch_sumsq_cur = batch_sumsq_ori  # sum() to change shape [1] to []
+                # Dummpy values used by following `zip` statement.
+                batch_rms_ori = torch.zeros(p.shape[0])
+                batch_rms_cur = batch_rms_ori
+            else:
+                batch_rms_ori = state["param_rms"]
+                batch_sumsq_ori = ((grad * batch_rms_ori) ** 2).sum(dim=list(range(1, grad.ndim)))
+
+                batch_rms_cur = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
+                batch_sumsq_cur = ((grad * batch_rms_cur) ** 2).sum(dim=list(range(1, grad.ndim)))
+
+            for name, sumsq_ori, sumsq_cur in zip(
+                    batch_param_names, batch_sumsq_ori, batch_sumsq_cur):
+
+                proportion_ori = sumsq_ori / tot_sumsq
+                proportion_cur = sumsq_cur / tot_sumsq
+
+                all_sumsq_ori[name] = (proportion_ori, sumsq_ori)
+                all_sumsq_cur[name] = (proportion_cur, sumsq_cur)
+
+        for rms_type, all_sumsq in zip(("ori", "cur"), (all_sumsq_ori, all_sumsq_cur)):
+            sorted_by_proportion = {k: v for k, v in sorted(all_sumsq.items(), key=lambda item: item[1][0], reverse=True)}
+            dominant_param_name = next(iter(sorted_by_proportion))
+            dominant_proportion, dominant_sumsq = sorted_by_proportion[dominant_param_name]
+            logging.info(f"Dominant sumsq with {rms_type}_rms: {dominant_param_name} {dominant_proportion}  {dominant_sumsq} {tot_sumsq}")
+
     def _step_one_batch(
         self, group: dict, p: Tensor, state: dict, clipping_scale: float
     ):
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py
index b27c573ab..8375b1a18 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py
@@ -368,6 +368,13 @@ def get_parser():
         help="Whether to use half precision training.",
     )
 
+    parser.add_argument(
+        "--show-dominant-parameters",
+        type=str2bool,
+        default=False,
+        help="Whether to show dominant parameters.",
+    )
+
     add_model_arguments(parser)
 
     return parser
@@ -988,7 +995,11 @@ def run(rank, world_size, args):
         logging.info("Using DDP")
         model = DDP(model, device_ids=[rank], find_unused_parameters=True)
 
-    optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0)
+    parameters_names = []
+    parameters_names.append([name_param_pair[0] for name_param_pair in model.named_parameters()])
+    optimizer = ScaledAdam(model.parameters(), lr=params.base_lr,
+            clipping_scale=2.0, parameters_names=parameters_names,
+            show_dominant_parameters=params.show_dominant_parameters)
 
     scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
 

From db75627e92155c16fd6e74d640ece4f6563f96f2 Mon Sep 17 00:00:00 2001
From: Desh Raj 
Date: Fri, 25 Nov 2022 21:00:45 -0500
Subject: [PATCH 046/120] [recipe] AMI Zipformer transducer (#698)

* remove unnecessary changes

* add AMI prepare scripts

* add zipformer scripts for AMI

* added logs and pretrained model

* minor fix

* remove unwanted changes

* fix missing link

* make suggested changes

* update results
---
 egs/ami/ASR/README.md                         |   48 +
 egs/ami/ASR/RESULTS.md                        |   92 ++
 egs/ami/ASR/local/__init__.py                 |    0
 egs/ami/ASR/local/compute_fbank_ami.py        |  194 +++
 egs/ami/ASR/local/compute_fbank_musan.py      |  114 ++
 egs/ami/ASR/local/prepare_ami_enhanced.py     |  158 +++
 egs/ami/ASR/local/prepare_ami_gss.sh          |   98 ++
 egs/ami/ASR/local/prepare_lang_bpe.py         |    1 +
 egs/ami/ASR/local/train_bpe_model.py          |    1 +
 egs/ami/ASR/prepare.sh                        |  144 ++
 .../pruned_transducer_stateless7/__init__.py  |    0
 .../asr_datamodule.py                         |  430 ++++++
 .../beam_search.py                            |    1 +
 .../pruned_transducer_stateless7/decode.py    |  747 +++++++++++
 .../pruned_transducer_stateless7/decoder.py   |    1 +
 .../encoder_interface.py                      |    1 +
 .../pruned_transducer_stateless7/export.py    |    1 +
 .../pruned_transducer_stateless7/joiner.py    |    1 +
 .../ASR/pruned_transducer_stateless7/model.py |    1 +
 .../ASR/pruned_transducer_stateless7/optim.py |    1 +
 .../pruned_transducer_stateless7/scaling.py   |    1 +
 .../scaling_converter.py                      |    1 +
 .../ASR/pruned_transducer_stateless7/train.py | 1184 +++++++++++++++++
 .../pruned_transducer_stateless7/zipformer.py |    1 +
 egs/ami/ASR/shared                            |    1 +
 25 files changed, 3222 insertions(+)
 create mode 100644 egs/ami/ASR/README.md
 create mode 100644 egs/ami/ASR/RESULTS.md
 create mode 100644 egs/ami/ASR/local/__init__.py
 create mode 100755 egs/ami/ASR/local/compute_fbank_ami.py
 create mode 100755 egs/ami/ASR/local/compute_fbank_musan.py
 create mode 100644 egs/ami/ASR/local/prepare_ami_enhanced.py
 create mode 100755 egs/ami/ASR/local/prepare_ami_gss.sh
 create mode 120000 egs/ami/ASR/local/prepare_lang_bpe.py
 create mode 120000 egs/ami/ASR/local/train_bpe_model.py
 create mode 100755 egs/ami/ASR/prepare.sh
 create mode 100644 egs/ami/ASR/pruned_transducer_stateless7/__init__.py
 create mode 100644 egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py
 create mode 120000 egs/ami/ASR/pruned_transducer_stateless7/beam_search.py
 create mode 100755 egs/ami/ASR/pruned_transducer_stateless7/decode.py
 create mode 120000 egs/ami/ASR/pruned_transducer_stateless7/decoder.py
 create mode 120000 egs/ami/ASR/pruned_transducer_stateless7/encoder_interface.py
 create mode 120000 egs/ami/ASR/pruned_transducer_stateless7/export.py
 create mode 120000 egs/ami/ASR/pruned_transducer_stateless7/joiner.py
 create mode 120000 egs/ami/ASR/pruned_transducer_stateless7/model.py
 create mode 120000 egs/ami/ASR/pruned_transducer_stateless7/optim.py
 create mode 120000 egs/ami/ASR/pruned_transducer_stateless7/scaling.py
 create mode 120000 egs/ami/ASR/pruned_transducer_stateless7/scaling_converter.py
 create mode 100755 egs/ami/ASR/pruned_transducer_stateless7/train.py
 create mode 120000 egs/ami/ASR/pruned_transducer_stateless7/zipformer.py
 create mode 120000 egs/ami/ASR/shared

diff --git a/egs/ami/ASR/README.md b/egs/ami/ASR/README.md
new file mode 100644
index 000000000..1c9714bd4
--- /dev/null
+++ b/egs/ami/ASR/README.md
@@ -0,0 +1,48 @@
+# AMI
+
+This is an ASR recipe for the AMI corpus. AMI provides recordings from the speaker's
+headset and lapel microphones, and also 2 array microphones containing 8 channels each.
+We pool data in the following 4 ways and train a single model on the pooled data:
+
+(i) individual headset microphone (IHM)
+(ii) IHM with simulated reverb
+(iii) Single distant microphone (SDM)
+(iv) GSS-enhanced array microphones
+
+Speed perturbation and MUSAN noise augmentation are additionally performed on the pooled
+data. Here are the statistics of the combined training data:
+
+```python
+>>> cuts_train.describe()
+Cuts count: 1222053
+Total duration (hh:mm:ss): 905:00:28
+Speech duration (hh:mm:ss): 905:00:28 (99.9%)
+Duration statistics (seconds):
+mean    2.7
+std     2.8
+min     0.0
+25%     0.6
+50%     1.6
+75%     3.8
+99%     12.3
+99.5%   13.9
+99.9%   18.4
+max     36.8
+```
+
+**Note:** This recipe additionally uses [GSS](https://github.com/desh2608/gss) for enhancement
+of far-field array microphones, but this is optional (see `prepare.sh` for details).
+
+## Performance Record
+
+### pruned_transducer_stateless7
+
+The following are decoded using `modified_beam_search`:
+
+| Evaluation set           | dev WER    | test WER |
+|--------------------------|------------|---------|
+| IHM                      |  18.92  | 17.40 |
+| SDM                      |  31.25  | 32.21 |
+| MDM (GSS-enhanced)       |  21.67  | 22.43 |
+
+See [RESULTS](/egs/ami/ASR/RESULTS.md) for details.
diff --git a/egs/ami/ASR/RESULTS.md b/egs/ami/ASR/RESULTS.md
new file mode 100644
index 000000000..163986021
--- /dev/null
+++ b/egs/ami/ASR/RESULTS.md
@@ -0,0 +1,92 @@
+## Results
+
+### AMI training results (Pruned Transducer)
+
+#### 2022-11-20
+
+#### Zipformer (pruned_transducer_stateless7)
+
+Zipformer encoder + non-current decoder. The decoder
+contains only an embedding layer, a Conv1d (with kernel size 2) and a linear
+layer (to transform tensor dim).
+
+All the results below are using a single model that is trained by combining the following
+data: IHM, IHM+reverb, SDM, and GSS-enhanced MDM. Speed perturbation and MUSAN noise
+augmentation are applied on top of the pooled data.
+
+**WERs for IHM:**
+
+|                           | dev | test | comment                                  |
+|---------------------------|------------|------------|------------------------------------------|
+| greedy search             |  19.25  |  17.83  | --epoch 14 --avg 8 --max-duration 500 |
+| modified beam search      |  18.92  |  17.40  | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 |
+| fast beam search          |  19.44  |  18.04  | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 |
+
+**WERs for SDM:**
+
+|                           | dev | test | comment                                  |
+|---------------------------|------------|------------|------------------------------------------|
+| greedy search             |  31.32  |  32.38  | --epoch 14 --avg 8 --max-duration 500 |
+| modified beam search      |  31.25  |  32.21  | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 |
+| fast beam search          |  31.11  |  32.10  | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 |
+
+**WERs for GSS-enhanced MDM:**
+
+|                           | dev | test | comment                                  |
+|---------------------------|------------|------------|------------------------------------------|
+| greedy search             |  22.05  |  22.93  | --epoch 14 --avg 8 --max-duration 500 |
+| modified beam search      |  21.67  |  22.43  | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 |
+| fast beam search          |  22.21  |  22.83  | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 |
+
+The training command for reproducing is given below:
+
+```
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless7/train.py \
+  --world-size 4 \
+  --num-epochs 15 \
+  --exp-dir pruned_transducer_stateless7/exp \
+  --max-duration 150 \
+  --max-cuts 150 \
+  --prune-range 5 \
+  --lr-factor 5 \
+  --lm-scale 0.25 \
+  --use-fp16 True
+```
+
+The decoding command is:
+```
+# greedy search
+./pruned_transducer_stateless7/decode.py \
+        --epoch 14 \
+        --avg 8 \
+        --exp-dir ./pruned_transducer_stateless7/exp \
+        --max-duration 500 \
+        --decoding-method greedy_search
+
+# modified beam search
+./pruned_transducer_stateless7/decode.py \
+        --iter 105000 \
+        --avg 10 \
+        --exp-dir ./pruned_transducer_stateless7/exp \
+        --max-duration 500 \
+        --decoding-method modified_beam_search \
+        --beam-size 4
+
+# fast beam search
+./pruned_transducer_stateless7/decode.py \
+        --iter 105000 \
+        --avg 10 \
+        --exp-dir ./pruned_transducer_stateless5/exp \
+        --max-duration 500 \
+        --decoding-method fast_beam_search \
+        --beam 4 \
+        --max-contexts 4 \
+        --max-states 8
+```
+
+Pretrained model is available at 
+
+The tensorboard training log can be found at
+
diff --git a/egs/ami/ASR/local/__init__.py b/egs/ami/ASR/local/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/ami/ASR/local/compute_fbank_ami.py b/egs/ami/ASR/local/compute_fbank_ami.py
new file mode 100755
index 000000000..4892b40e3
--- /dev/null
+++ b/egs/ami/ASR/local/compute_fbank_ami.py
@@ -0,0 +1,194 @@
+#!/usr/bin/env python3
+# Copyright    2022  Johns Hopkins University        (authors: Desh Raj)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of the AMI dataset.
+For the training data, we pool together IHM, reverberated IHM, and GSS-enhanced
+audios. For the test data, we separately prepare IHM, SDM, and GSS-enhanced
+parts (which are the 3 evaluation settings).
+It looks for manifests in the directory data/manifests.
+
+The generated fbank features are saved in data/fbank.
+"""
+import logging
+import math
+from pathlib import Path
+
+import torch
+import torch.multiprocessing
+from lhotse import CutSet, LilcomChunkyWriter
+from lhotse.features.kaldifeat import (
+    KaldifeatFbank,
+    KaldifeatFbankConfig,
+    KaldifeatFrameOptions,
+    KaldifeatMelOptions,
+)
+from lhotse.recipes.utils import read_manifests_if_cached
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+torch.multiprocessing.set_sharing_strategy("file_system")
+
+
+def compute_fbank_ami():
+    src_dir = Path("data/manifests")
+    output_dir = Path("data/fbank")
+
+    sampling_rate = 16000
+    num_mel_bins = 80
+
+    extractor = KaldifeatFbank(
+        KaldifeatFbankConfig(
+            frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
+            mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
+            device="cuda",
+        )
+    )
+
+    logging.info("Reading manifests")
+    manifests_ihm = read_manifests_if_cached(
+        dataset_parts=["train", "dev", "test"],
+        output_dir=src_dir,
+        prefix="ami-ihm",
+        suffix="jsonl.gz",
+    )
+    manifests_sdm = read_manifests_if_cached(
+        dataset_parts=["train", "dev", "test"],
+        output_dir=src_dir,
+        prefix="ami-sdm",
+        suffix="jsonl.gz",
+    )
+    # For GSS we already have cuts so we read them directly.
+    manifests_gss = read_manifests_if_cached(
+        dataset_parts=["train", "dev", "test"],
+        output_dir=src_dir,
+        prefix="ami-gss",
+        suffix="jsonl.gz",
+    )
+
+    def _extract_feats(cuts: CutSet, storage_path: Path, manifest_path: Path) -> None:
+        cuts = cuts + cuts.perturb_speed(0.9) + cuts.perturb_speed(1.1)
+        _ = cuts.compute_and_store_features_batch(
+            extractor=extractor,
+            storage_path=storage_path,
+            manifest_path=manifest_path,
+            batch_duration=5000,
+            num_workers=8,
+            storage_type=LilcomChunkyWriter,
+        )
+
+    logging.info(
+        "Preparing training cuts: IHM + reverberated IHM + SDM + GSS (optional)"
+    )
+
+    logging.info("Processing train split IHM")
+    cuts_ihm = (
+        CutSet.from_manifests(**manifests_ihm["train"])
+        .trim_to_supervisions(keep_overlapping=False, keep_all_channels=False)
+        .modify_ids(lambda x: x + "-ihm")
+    )
+    _extract_feats(
+        cuts_ihm,
+        output_dir / "feats_train_ihm",
+        src_dir / "cuts_train_ihm.jsonl.gz",
+    )
+
+    logging.info("Processing train split IHM + reverberated IHM")
+    cuts_ihm_rvb = cuts_ihm.reverb_rir()
+    _extract_feats(
+        cuts_ihm_rvb,
+        output_dir / "feats_train_ihm_rvb",
+        src_dir / "cuts_train_ihm_rvb.jsonl.gz",
+    )
+
+    logging.info("Processing train split SDM")
+    cuts_sdm = (
+        CutSet.from_manifests(**manifests_sdm["train"])
+        .trim_to_supervisions(keep_overlapping=False)
+        .modify_ids(lambda x: x + "-sdm")
+    )
+    _extract_feats(
+        cuts_sdm,
+        output_dir / "feats_train_sdm",
+        src_dir / "cuts_train_sdm.jsonl.gz",
+    )
+
+    logging.info("Processing train split GSS")
+    cuts_gss = (
+        CutSet.from_manifests(**manifests_gss["train"])
+        .trim_to_supervisions(keep_overlapping=False)
+        .modify_ids(lambda x: x + "-gss")
+    )
+    _extract_feats(
+        cuts_gss,
+        output_dir / "feats_train_gss",
+        src_dir / "cuts_train_gss.jsonl.gz",
+    )
+
+    logging.info("Preparing test cuts: IHM, SDM, GSS (optional)")
+    for split in ["dev", "test"]:
+        logging.info(f"Processing {split} IHM")
+        cuts_ihm = (
+            CutSet.from_manifests(**manifests_ihm[split])
+            .trim_to_supervisions(keep_overlapping=False, keep_all_channels=False)
+            .compute_and_store_features_batch(
+                extractor=extractor,
+                storage_path=output_dir / f"feats_{split}_ihm",
+                manifest_path=src_dir / f"cuts_{split}_ihm.jsonl.gz",
+                batch_duration=5000,
+                num_workers=8,
+                storage_type=LilcomChunkyWriter,
+            )
+        )
+        logging.info(f"Processing {split} SDM")
+        cuts_sdm = (
+            CutSet.from_manifests(**manifests_sdm[split])
+            .trim_to_supervisions(keep_overlapping=False)
+            .compute_and_store_features_batch(
+                extractor=extractor,
+                storage_path=output_dir / f"feats_{split}_sdm",
+                manifest_path=src_dir / f"cuts_{split}_sdm.jsonl.gz",
+                batch_duration=500,
+                num_workers=4,
+                storage_type=LilcomChunkyWriter,
+            )
+        )
+        logging.info(f"Processing {split} GSS")
+        cuts_gss = (
+            CutSet.from_manifests(**manifests_gss[split])
+            .trim_to_supervisions(keep_overlapping=False)
+            .compute_and_store_features_batch(
+                extractor=extractor,
+                storage_path=output_dir / f"feats_{split}_gss",
+                manifest_path=src_dir / f"cuts_{split}_gss.jsonl.gz",
+                batch_duration=500,
+                num_workers=4,
+                storage_type=LilcomChunkyWriter,
+            )
+        )
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    logging.basicConfig(format=formatter, level=logging.INFO)
+
+    compute_fbank_ami()
diff --git a/egs/ami/ASR/local/compute_fbank_musan.py b/egs/ami/ASR/local/compute_fbank_musan.py
new file mode 100755
index 000000000..1fcf951f9
--- /dev/null
+++ b/egs/ami/ASR/local/compute_fbank_musan.py
@@ -0,0 +1,114 @@
+#!/usr/bin/env python3
+# Copyright    2021  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of the musan dataset.
+It looks for manifests in the directory data/manifests.
+
+The generated fbank features are saved in data/fbank.
+"""
+
+import logging
+from pathlib import Path
+
+import torch
+from lhotse import CutSet, LilcomChunkyWriter, combine
+from lhotse.features.kaldifeat import (
+    KaldifeatFbank,
+    KaldifeatFbankConfig,
+    KaldifeatFrameOptions,
+    KaldifeatMelOptions,
+)
+from lhotse.recipes.utils import read_manifests_if_cached
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def compute_fbank_musan():
+    src_dir = Path("data/manifests")
+    output_dir = Path("data/fbank")
+
+    sampling_rate = 16000
+    num_mel_bins = 80
+
+    dataset_parts = (
+        "music",
+        "speech",
+        "noise",
+    )
+    prefix = "musan"
+    suffix = "jsonl.gz"
+    manifests = read_manifests_if_cached(
+        dataset_parts=dataset_parts,
+        output_dir=src_dir,
+        prefix=prefix,
+        suffix=suffix,
+    )
+    assert manifests is not None
+
+    assert len(manifests) == len(dataset_parts), (
+        len(manifests),
+        len(dataset_parts),
+        list(manifests.keys()),
+        dataset_parts,
+    )
+
+    musan_cuts_path = src_dir / "musan_cuts.jsonl.gz"
+
+    if musan_cuts_path.is_file():
+        logging.info(f"{musan_cuts_path} already exists - skipping")
+        return
+
+    logging.info("Extracting features for Musan")
+
+    extractor = KaldifeatFbank(
+        KaldifeatFbankConfig(
+            frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
+            mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
+            device="cuda",
+        )
+    )
+
+    # create chunks of Musan with duration 5 - 10 seconds
+    _ = (
+        CutSet.from_manifests(
+            recordings=combine(part["recordings"] for part in manifests.values())
+        )
+        .cut_into_windows(10.0)
+        .filter(lambda c: c.duration > 5)
+        .compute_and_store_features_batch(
+            extractor=extractor,
+            storage_path=output_dir / "musan_feats",
+            manifest_path=musan_cuts_path,
+            batch_duration=500,
+            num_workers=4,
+            storage_type=LilcomChunkyWriter,
+        )
+    )
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    compute_fbank_musan()
diff --git a/egs/ami/ASR/local/prepare_ami_enhanced.py b/egs/ami/ASR/local/prepare_ami_enhanced.py
new file mode 100644
index 000000000..bed220eb3
--- /dev/null
+++ b/egs/ami/ASR/local/prepare_ami_enhanced.py
@@ -0,0 +1,158 @@
+#!/usr/local/bin/python
+# -*- coding: utf-8 -*-
+# Data preparation for AMI GSS-enhanced dataset.
+
+import logging
+from concurrent.futures import ThreadPoolExecutor
+from pathlib import Path
+
+from lhotse import Recording, RecordingSet, SupervisionSet
+from lhotse.qa import fix_manifests
+from lhotse.recipes.utils import read_manifests_if_cached
+from lhotse.utils import fastcopy
+from tqdm import tqdm
+
+logging.basicConfig(
+    format="%(asctime)s %(levelname)-8s %(message)s",
+    level=logging.INFO,
+    datefmt="%Y-%m-%d %H:%M:%S",
+)
+
+
+def get_args():
+    import argparse
+
+    parser = argparse.ArgumentParser(description="AMI enhanced dataset preparation.")
+    parser.add_argument(
+        "manifests_dir",
+        type=Path,
+        help="Path to directory containing AMI manifests.",
+    )
+    parser.add_argument(
+        "enhanced_dir",
+        type=Path,
+        help="Path to enhanced data directory.",
+    )
+    parser.add_argument(
+        "--num-jobs",
+        "-j",
+        type=int,
+        default=1,
+        help="Number of parallel jobs to run.",
+    )
+    parser.add_argument(
+        "--min-segment-duration",
+        "-d",
+        type=float,
+        default=0.0,
+        help="Minimum duration of a segment in seconds.",
+    )
+    return parser.parse_args()
+
+
+def find_recording_and_create_new_supervision(enhanced_dir, supervision):
+    """
+    Given a supervision (corresponding to original AMI recording), this function finds the
+    enhanced recording correspoding to the supervision, and returns this recording and
+    a new supervision whose start and end times are adjusted to match the enhanced recording.
+    """
+    file_name = Path(
+        f"{supervision.recording_id}-{supervision.speaker}-{int(100*supervision.start):06d}_{int(100*supervision.end):06d}.flac"
+    )
+    save_path = enhanced_dir / f"{supervision.recording_id}" / file_name
+    if save_path.exists():
+        recording = Recording.from_file(save_path)
+        if recording.duration == 0:
+            logging.warning(f"Skipping {save_path} which has duration 0 seconds.")
+            return None
+
+        # Old supervision is wrt to the original recording, we create new supervision
+        # wrt to the enhanced segment
+        new_supervision = fastcopy(
+            supervision,
+            recording_id=recording.id,
+            start=0,
+            duration=recording.duration,
+        )
+        return recording, new_supervision
+    else:
+        logging.warning(f"{save_path} does not exist.")
+        return None
+
+
+def main(args):
+    # Get arguments
+    manifests_dir = args.manifests_dir
+    enhanced_dir = args.enhanced_dir
+
+    # Load manifests from cache if they exist (saves time)
+    manifests = read_manifests_if_cached(
+        dataset_parts=["train", "dev", "test"],
+        output_dir=manifests_dir,
+        prefix="ami-sdm",
+        suffix="jsonl.gz",
+    )
+    if not manifests:
+        raise ValueError("AMI SDM manifests not found in {}".format(manifests_dir))
+
+    with ThreadPoolExecutor(args.num_jobs) as ex:
+        for part in ["train", "dev", "test"]:
+            logging.info(f"Processing {part}...")
+            supervisions_orig = manifests[part]["supervisions"].filter(
+                lambda s: s.duration >= args.min_segment_duration
+            )
+            # Remove TS3009d supervisions since they are not present in the enhanced data
+            supervisions_orig = supervisions_orig.filter(
+                lambda s: s.recording_id != "TS3009d"
+            )
+            futures = []
+
+            for supervision in tqdm(
+                supervisions_orig,
+                desc="Distributing tasks",
+            ):
+                futures.append(
+                    ex.submit(
+                        find_recording_and_create_new_supervision,
+                        enhanced_dir,
+                        supervision,
+                    )
+                )
+
+            recordings = []
+            supervisions = []
+            for future in tqdm(
+                futures,
+                total=len(futures),
+                desc="Processing tasks",
+            ):
+                result = future.result()
+                if result is not None:
+                    recording, new_supervision = result
+                    recordings.append(recording)
+                    supervisions.append(new_supervision)
+
+            # Remove duplicates from the recordings
+            recordings_nodup = {}
+            for recording in recordings:
+                if recording.id not in recordings_nodup:
+                    recordings_nodup[recording.id] = recording
+                else:
+                    logging.warning("Recording {} is duplicated.".format(recording.id))
+            recordings = RecordingSet.from_recordings(recordings_nodup.values())
+            supervisions = SupervisionSet.from_segments(supervisions)
+
+            recordings, supervisions = fix_manifests(
+                recordings=recordings, supervisions=supervisions
+            )
+
+            logging.info(f"Writing {part} enhanced manifests")
+            recordings.to_file(manifests_dir / f"ami-gss_recordings_{part}.jsonl.gz")
+            supervisions.to_file(
+                manifests_dir / f"ami-gss_supervisions_{part}.jsonl.gz"
+            )
+
+
+if __name__ == "__main__":
+    args = get_args()
+    main(args)
diff --git a/egs/ami/ASR/local/prepare_ami_gss.sh b/egs/ami/ASR/local/prepare_ami_gss.sh
new file mode 100755
index 000000000..d5422458b
--- /dev/null
+++ b/egs/ami/ASR/local/prepare_ami_gss.sh
@@ -0,0 +1,98 @@
+#!/bin/bash
+# This script is used to run GSS-based enhancement on AMI data.
+set -euo pipefail
+nj=4
+stage=0
+
+. shared/parse_options.sh || exit 1
+
+if [ $# != 2 ]; then
+   echo "Wrong #arguments ($#, expected 2)"
+   echo "Usage: local/prepare_ami_gss.sh [options]  "
+   echo "e.g. local/prepare_ami_gss.sh data/manifests exp/ami_gss"
+   echo "main options (for others, see top of script file)"
+   echo "  --nj                                 # number of parallel jobs"
+   echo "  --stage                           # stage to start running from"
+   exit 1;
+fi
+
+DATA_DIR=$1
+EXP_DIR=$2
+
+mkdir -p $EXP_DIR
+
+log() {
+  # This function is from espnet
+  local fname=${BASH_SOURCE[1]##*/}
+  echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+if [ $stage -le 1 ]; then
+  log "Stage 1: Prepare cut sets"
+  for part in train dev test; do
+    lhotse cut simple \
+      -r $DATA_DIR/ami-mdm_recordings_${part}.jsonl.gz \
+      -s $DATA_DIR/ami-mdm_supervisions_${part}.jsonl.gz \
+      $EXP_DIR/cuts_${part}.jsonl.gz
+  done
+fi
+
+if [ $stage -le 2 ]; then
+  log "Stage 2: Trim cuts to supervisions (1 cut per supervision segment)"
+  for part in train dev test; do
+    lhotse cut trim-to-supervisions --discard-overlapping \
+        $EXP_DIR/cuts_${part}.jsonl.gz $EXP_DIR/cuts_per_segment_${part}.jsonl.gz
+  done
+fi
+
+if [ $stage -le 3 ]; then
+  log "Stage 3: Split manifests for multi-GPU processing (optional)"
+  for part in train; do
+    gss utils split $nj $EXP_DIR/cuts_per_segment_${part}.jsonl.gz \
+      $EXP_DIR/cuts_per_segment_${part}_split$nj
+  done
+fi
+
+if [ $stage -le 4 ]; then
+  log "Stage 4: Enhance train segments using GSS (requires GPU)"
+  # for train, we use smaller context and larger batches to speed-up processing
+  for JOB in $(seq $nj); do
+    gss enhance cuts $EXP_DIR/cuts_train.jsonl.gz \
+      $EXP_DIR/cuts_per_segment_train_split$nj/cuts_per_segment_train.JOB.jsonl.gz $EXP_DIR/enhanced \
+      --bss-iterations 10 \
+      --context-duration 5.0 \
+      --use-garbage-class \
+      --channels 0,1,2,3,4,5,6,7 \
+      --min-segment-length 0.05 \
+      --max-segment-length 35.0 \
+      --max-batch-duration 60.0 \
+      --num-buckets 3 \
+      --num-workers 2
+  done
+fi
+
+if [ $stage -le 5 ]; then
+  log "Stage 5: Enhance dev/test segments using GSS (using GPU)"
+  # for dev/test, we use larger context and smaller batches to get better quality
+  for part in dev test; do
+    for JOB in $(seq $nj); do
+      gss enhance cuts $EXP_DIR/cuts_${part}.jsonl.gz \
+      $EXP_DIR/cuts_per_segment_${part}_split$nj/cuts_per_segment_${part}.JOB.jsonl.gz \
+      $EXP_DIR/enhanced \
+      --bss-iterations 10 \
+      --context-duration 15.0 \
+      --use-garbage-class \
+      --channels 0,1,2,3,4,5,6,7 \
+      --min-segment-length 0.05 \
+      --max-segment-length 30.0 \
+      --max-batch-duration 45.0 \
+      --num-buckets 3 \
+      --num-workers 2
+    done
+  done
+fi
+
+if [ $stage -le 6 ]; then
+  log "Stage 6: Prepare manifests for GSS-enhanced data"
+  python local/prepare_ami_enhanced.py $DATA_DIR $EXP_DIR/enhanced -j $nj --min-segment-duration 0.05
+fi
diff --git a/egs/ami/ASR/local/prepare_lang_bpe.py b/egs/ami/ASR/local/prepare_lang_bpe.py
new file mode 120000
index 000000000..36b40e7fc
--- /dev/null
+++ b/egs/ami/ASR/local/prepare_lang_bpe.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/prepare_lang_bpe.py
\ No newline at end of file
diff --git a/egs/ami/ASR/local/train_bpe_model.py b/egs/ami/ASR/local/train_bpe_model.py
new file mode 120000
index 000000000..6fad36421
--- /dev/null
+++ b/egs/ami/ASR/local/train_bpe_model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/train_bpe_model.py
\ No newline at end of file
diff --git a/egs/ami/ASR/prepare.sh b/egs/ami/ASR/prepare.sh
new file mode 100755
index 000000000..fb21a8ec6
--- /dev/null
+++ b/egs/ami/ASR/prepare.sh
@@ -0,0 +1,144 @@
+#!/usr/bin/env bash
+
+set -eou pipefail
+
+stage=-1
+stop_stage=100
+use_gss=true  # Use GSS-based enhancement with MDM setting
+
+# We assume dl_dir (download dir) contains the following
+# directories and files. If not, they will be downloaded
+# by this script automatically.
+#
+#  - $dl_dir/amicorpus
+#      You can find audio and transcripts in this path.
+#
+#  - $dl_dir/musan
+#      This directory contains the following directories downloaded from
+#       http://www.openslr.org/17/
+#
+#     - music
+#     - noise
+#     - speech
+#
+#  - $dl_dir/{LDC2004S13,LDC2005S13,LDC2004T19,LDC2005T19}
+#      These contain the Fisher English audio and transcripts. We will
+#      only use the transcripts as extra LM training data (similar to Kaldi).
+#
+dl_dir=$PWD/download
+
+. shared/parse_options.sh || exit 1
+
+# All files generated by this script are saved in "data".
+# You can safely remove "data" and rerun this script to regenerate it.
+mkdir -p data
+vocab_size=500
+
+log() {
+  # This function is from espnet
+  local fname=${BASH_SOURCE[1]##*/}
+  echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "dl_dir: $dl_dir"
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+  log "Stage 0: Download data"
+
+  # If you have pre-downloaded it to /path/to/amicorpus,
+  # you can create a symlink
+  #
+  #   ln -sfv /path/to/amicorpus $dl_dir/amicorpus
+  #
+  if [ ! -d $dl_dir/amicorpus ]; then
+    lhotse download ami --mic ihm $dl_dir/amicorpus
+    lhotse download ami --mic mdm $dl_dir/amicorpus
+  fi
+
+  # If you have pre-downloaded it to /path/to/musan,
+  # you can create a symlink
+  #
+  #   ln -sfv /path/to/musan $dl_dir/
+  #
+  if [ ! -d $dl_dir/musan ]; then
+    lhotse download musan $dl_dir
+  fi
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+  log "Stage 1: Prepare AMI manifests"
+  # We assume that you have downloaded the AMI corpus
+  # to $dl_dir/amicorpus. We perform text normalization for the transcripts.
+  mkdir -p data/manifests
+  for mic in ihm sdm mdm; do
+    lhotse prepare ami --mic $mic --partition full-corpus-asr --normalize-text kaldi \
+      --max-words-per-segment 30 $dl_dir/amicorpus data/manifests/
+  done
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+  log "Stage 2: Prepare musan manifest"
+  # We assume that you have downloaded the musan corpus
+  # to $dl_dir/musan
+  mkdir -p data/manifests
+  lhotse prepare musan $dl_dir/musan data/manifests
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ] && [ $use_gss = true ]; then
+  log "Stage 3: Apply GSS enhancement on MDM data (this stage requires a GPU)"
+  # We assume that you have installed the GSS package: https://github.com/desh2608/gss
+  local/prepare_ami_gss.sh data/manifests exp/ami_gss
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+  log "Stage 4: Compute fbank features for AMI"
+  mkdir -p data/fbank
+  python local/compute_fbank_ami.py
+  log "Combine features from train splits"
+  lhotse combine data/manifests/cuts_train_{ihm,ihm_rvb,sdm,gss}.jsonl.gz - | shuf |\
+    gzip -c > data/manifests/cuts_train_all.jsonl.gz
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+  log "Stage 5: Compute fbank features for musan"
+  mkdir -p data/fbank
+  python local/compute_fbank_musan.py
+fi
+
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+  log "Stage 6: Dump transcripts for BPE model training."
+  mkdir -p data/lm
+  cat <(gunzip -c data/manifests/ami-sdm_supervisions_train.jsonl.gz | jq '.text' | sed 's:"::g')> data/lm/transcript_words.txt
+fi
+
+if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
+  log "Stage 7: Prepare BPE based lang"
+
+  lang_dir=data/lang_bpe_${vocab_size}
+  mkdir -p $lang_dir
+
+  # Add special words to words.txt
+  echo " 0" > $lang_dir/words.txt
+  echo "!SIL 1" >> $lang_dir/words.txt
+  echo " 2" >> $lang_dir/words.txt
+
+  # Add regular words to words.txt
+  cat data/lm/transcript_words.txt | grep -o -E '\w+' | sort -u | awk '{print $0,NR+2}' >> $lang_dir/words.txt
+
+  # Add remaining special word symbols expected by LM scripts.
+  num_words=$(cat $lang_dir/words.txt | wc -l)
+  echo " ${num_words}" >> $lang_dir/words.txt
+  num_words=$(cat $lang_dir/words.txt | wc -l)
+  echo " ${num_words}" >> $lang_dir/words.txt
+  num_words=$(cat $lang_dir/words.txt | wc -l)
+  echo "#0 ${num_words}" >> $lang_dir/words.txt
+
+  ./local/train_bpe_model.py \
+    --lang-dir $lang_dir \
+    --vocab-size $vocab_size \
+    --transcript data/lm/transcript_words.txt
+
+  if [ ! -f $lang_dir/L_disambig.pt ]; then
+    ./local/prepare_lang_bpe.py --lang-dir $lang_dir
+  fi
+fi
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/__init__.py b/egs/ami/ASR/pruned_transducer_stateless7/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py b/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py
new file mode 100644
index 000000000..f7ee9c962
--- /dev/null
+++ b/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py
@@ -0,0 +1,430 @@
+# Copyright      2021  Piotr Żelasko
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import logging
+import re
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+import torch
+from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
+from lhotse.cut import Cut
+from lhotse.dataset import (
+    CutConcatenate,
+    CutMix,
+    DynamicBucketingSampler,
+    K2SpeechRecognitionDataset,
+    PrecomputedFeatures,
+    SpecAugment,
+)
+from lhotse.dataset.input_strategies import OnTheFlyFeatures
+from lhotse.utils import fix_random_seed
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+
+from icefall.utils import str2bool
+
+
+class _SeedWorkers:
+    def __init__(self, seed: int):
+        self.seed = seed
+
+    def __call__(self, worker_id: int):
+        fix_random_seed(self.seed + worker_id)
+
+
+class AmiAsrDataModule:
+    """
+    DataModule for k2 ASR experiments.
+    It assumes there is always one train and valid dataloader,
+    but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
+    and test-other).
+    It contains all the common data pipeline modules used in ASR
+    experiments, e.g.:
+    - dynamic batch size,
+    - bucketing samplers,
+    - cut concatenation,
+    - augmentation,
+    - on-the-fly feature extraction
+    This class should be derived for specific corpora used in ASR tasks.
+    """
+
+    def __init__(self, args: argparse.Namespace):
+        self.args = args
+
+    @classmethod
+    def add_arguments(cls, parser: argparse.ArgumentParser):
+        group = parser.add_argument_group(
+            title="ASR data related options",
+            description=(
+                "These options are used for the preparation of "
+                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+                "effective batch sizes, sampling strategies, applied data "
+                "augmentations, etc."
+            ),
+        )
+        group.add_argument(
+            "--manifest-dir",
+            type=Path,
+            default=Path("data/manifests"),
+            help="Path to directory with train/valid/test cuts.",
+        )
+        group.add_argument(
+            "--enable-musan",
+            type=str2bool,
+            default=True,
+            help=(
+                "When enabled, select noise from MUSAN and mix it "
+                "with training dataset. "
+            ),
+        )
+        group.add_argument(
+            "--concatenate-cuts",
+            type=str2bool,
+            default=False,
+            help=(
+                "When enabled, utterances (cuts) will be concatenated "
+                "to minimize the amount of padding."
+            ),
+        )
+        group.add_argument(
+            "--duration-factor",
+            type=float,
+            default=1.0,
+            help=(
+                "Determines the maximum duration of a concatenated cut "
+                "relative to the duration of the longest cut in a batch."
+            ),
+        )
+        group.add_argument(
+            "--gap",
+            type=float,
+            default=1.0,
+            help=(
+                "The amount of padding (in seconds) inserted between "
+                "concatenated cuts. This padding is filled with noise when "
+                "noise augmentation is used."
+            ),
+        )
+        group.add_argument(
+            "--max-duration",
+            type=int,
+            default=100.0,
+            help=(
+                "Maximum pooled recordings duration (seconds) in a "
+                "single batch. You can reduce it if it causes CUDA OOM."
+            ),
+        )
+        group.add_argument(
+            "--max-cuts", type=int, default=None, help="Maximum cuts in a single batch."
+        )
+        group.add_argument(
+            "--num-buckets",
+            type=int,
+            default=50,
+            help=(
+                "The number of buckets for the BucketingSampler"
+                "(you might want to increase it for larger datasets)."
+            ),
+        )
+        group.add_argument(
+            "--on-the-fly-feats",
+            type=str2bool,
+            default=False,
+            help=(
+                "When enabled, use on-the-fly cut mixing and feature "
+                "extraction. Will drop existing precomputed feature manifests "
+                "if available."
+            ),
+        )
+        group.add_argument(
+            "--shuffle",
+            type=str2bool,
+            default=True,
+            help=(
+                "When enabled (=default), the examples will be "
+                "shuffled for each epoch."
+            ),
+        )
+
+        group.add_argument(
+            "--num-workers",
+            type=int,
+            default=8,
+            help=(
+                "The number of training dataloader workers that " "collect the batches."
+            ),
+        )
+        group.add_argument(
+            "--enable-spec-aug",
+            type=str2bool,
+            default=True,
+            help="When enabled, use SpecAugment for training dataset.",
+        )
+        group.add_argument(
+            "--spec-aug-time-warp-factor",
+            type=int,
+            default=80,
+            help=(
+                "Used only when --enable-spec-aug is True. "
+                "It specifies the factor for time warping in SpecAugment. "
+                "Larger values mean more warping. "
+                "A value less than 1 means to disable time warp."
+            ),
+        )
+        group.add_argument(
+            "--ihm-only",
+            type=str2bool,
+            default=False,
+            help="When enabled, only use IHM data for training.",
+        )
+
+    def train_dataloaders(
+        self,
+        cuts_train: CutSet,
+        sampler_state_dict: Optional[Dict[str, Any]] = None,
+    ) -> DataLoader:
+        """
+        Args:
+          cuts_train:
+            CutSet for training.
+          sampler_state_dict:
+            The state dict for the training sampler.
+        """
+        logging.info("About to get Musan cuts")
+
+        transforms = []
+        if self.args.enable_musan:
+            logging.info("Enable MUSAN")
+            cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+            transforms.append(
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+            )
+        else:
+            logging.info("Disable MUSAN")
+
+        if self.args.concatenate_cuts:
+            logging.info(
+                "Using cut concatenation with duration factor "
+                f"{self.args.duration_factor} and gap {self.args.gap}."
+            )
+            # Cut concatenation should be the first transform in the list,
+            # so that if we e.g. mix noise in, it will fill the gaps between
+            # different utterances.
+            transforms = [
+                CutConcatenate(
+                    duration_factor=self.args.duration_factor, gap=self.args.gap
+                )
+            ] + transforms
+
+        input_transforms = []
+        if self.args.enable_spec_aug:
+            logging.info("Enable SpecAugment")
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+            input_transforms.append(
+                SpecAugment(
+                    time_warp_factor=self.args.spec_aug_time_warp_factor,
+                    num_frame_masks=2,
+                    features_mask_size=27,
+                    num_feature_masks=2,
+                    frames_mask_size=100,
+                )
+            )
+        else:
+            logging.info("Disable SpecAugment")
+
+        logging.info("About to create train dataset")
+        if self.args.on_the_fly_feats:
+            train = K2SpeechRecognitionDataset(
+                cut_transforms=transforms,
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_transforms=input_transforms,
+            )
+        else:
+            train = K2SpeechRecognitionDataset(
+                cut_transforms=transforms,
+                input_transforms=input_transforms,
+            )
+
+        logging.info("Using DynamicBucketingSampler.")
+        train_sampler = DynamicBucketingSampler(
+            cuts_train,
+            max_duration=self.args.max_duration,
+            max_cuts=self.args.max_cuts,
+            shuffle=False,
+            num_buckets=self.args.num_buckets,
+            drop_last=True,
+        )
+        logging.info("About to create train dataloader")
+
+        if sampler_state_dict is not None:
+            logging.info("Loading sampler state dict")
+            train_sampler.load_state_dict(sampler_state_dict)
+
+        # 'seed' is derived from the current random state, which will have
+        # previously been set in the main process.
+        seed = torch.randint(0, 100000, ()).item()
+        worker_init_fn = _SeedWorkers(seed)
+
+        train_dl = DataLoader(
+            train,
+            sampler=train_sampler,
+            batch_size=None,
+            num_workers=self.args.num_workers,
+            persistent_workers=False,
+            worker_init_fn=worker_init_fn,
+        )
+
+        return train_dl
+
+    def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+
+        transforms = []
+        if self.args.concatenate_cuts:
+            transforms = [
+                CutConcatenate(
+                    duration_factor=self.args.duration_factor, gap=self.args.gap
+                )
+            ] + transforms
+
+        logging.info("About to create dev dataset")
+        if self.args.on_the_fly_feats:
+            validate = K2SpeechRecognitionDataset(
+                cut_transforms=transforms,
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+            )
+        else:
+            validate = K2SpeechRecognitionDataset(
+                cut_transforms=transforms,
+            )
+        valid_sampler = DynamicBucketingSampler(
+            cuts_valid,
+            max_duration=self.args.max_duration,
+            shuffle=False,
+        )
+        logging.info("About to create dev dataloader")
+        valid_dl = DataLoader(
+            validate,
+            sampler=valid_sampler,
+            batch_size=None,
+            num_workers=2,
+            persistent_workers=False,
+        )
+
+        return valid_dl
+
+    def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+        logging.debug("About to create test dataset")
+        test = K2SpeechRecognitionDataset(
+            input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
+            if self.args.on_the_fly_feats
+            else PrecomputedFeatures(),
+            return_cuts=True,
+        )
+        sampler = DynamicBucketingSampler(
+            cuts, max_duration=self.args.max_duration, shuffle=False
+        )
+        logging.debug("About to create test dataloader")
+        test_dl = DataLoader(
+            test,
+            batch_size=None,
+            sampler=sampler,
+            num_workers=self.args.num_workers,
+        )
+        return test_dl
+
+    def remove_short_cuts(self, cut: Cut) -> bool:
+        """
+        See: https://github.com/k2-fsa/icefall/issues/500
+        Basically, the zipformer model subsamples the input using the following formula:
+        num_out_frames = (num_in_frames - 7)//2
+        For num_out_frames to be at least 1, num_in_frames must be at least 9.
+        """
+        return cut.duration >= 0.09
+
+    @lru_cache()
+    def train_cuts(self, sp: Optional[Any] = None) -> CutSet:
+        logging.info("About to get AMI train cuts")
+
+        def _remove_short_and_long_utt(c: Cut):
+            if c.duration < 0.2 or c.duration > 25.0:
+                return False
+
+            # In pruned RNN-T, we require that T >= S
+            # where T is the number of feature frames after subsampling
+            # and S is the number of tokens in the utterance
+
+            # In ./zipformer.py, the conv module uses the following expression
+            # for subsampling
+            T = ((c.num_frames - 7) // 2 + 1) // 2
+            tokens = sp.encode(c.supervisions[0].text, out_type=str)
+            return T >= len(tokens)
+
+        if self.args.ihm_only:
+            cuts_train = load_manifest_lazy(
+                self.args.manifest_dir / "cuts_train_ihm.jsonl.gz"
+            )
+        else:
+            cuts_train = load_manifest_lazy(
+                self.args.manifest_dir / "cuts_train_all.jsonl.gz"
+            )
+
+        return cuts_train.filter(_remove_short_and_long_utt)
+
+    @lru_cache()
+    def dev_ihm_cuts(self) -> CutSet:
+        logging.info("About to get AMI IHM dev cuts")
+        cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_ihm.jsonl.gz")
+        return cs.filter(self.remove_short_cuts)
+
+    @lru_cache()
+    def dev_sdm_cuts(self) -> CutSet:
+        logging.info("About to get AMI SDM dev cuts")
+        cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_sdm.jsonl.gz")
+        return cs.filter(self.remove_short_cuts)
+
+    @lru_cache()
+    def dev_gss_cuts(self) -> CutSet:
+        if not (self.args.manifest_dir / "cuts_dev_gss.jsonl.gz").exists():
+            logging.info("No GSS dev cuts found")
+            return None
+        logging.info("About to get AMI GSS-enhanced dev cuts")
+        cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_gss.jsonl.gz")
+        return cs.filter(self.remove_short_cuts)
+
+    @lru_cache()
+    def test_ihm_cuts(self) -> CutSet:
+        logging.info("About to get AMI IHM test cuts")
+        cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_ihm.jsonl.gz")
+        return cs.filter(self.remove_short_cuts)
+
+    @lru_cache()
+    def test_sdm_cuts(self) -> CutSet:
+        logging.info("About to get AMI SDM test cuts")
+        cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_sdm.jsonl.gz")
+        return cs.filter(self.remove_short_cuts)
+
+    @lru_cache()
+    def test_gss_cuts(self) -> CutSet:
+        if not (self.args.manifest_dir / "cuts_test_gss.jsonl.gz").exists():
+            logging.info("No GSS test cuts found")
+            return None
+        logging.info("About to get AMI GSS-enhanced test cuts")
+        cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_gss.jsonl.gz")
+        return cs.filter(self.remove_short_cuts)
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/beam_search.py b/egs/ami/ASR/pruned_transducer_stateless7/beam_search.py
new file mode 120000
index 000000000..37516affc
--- /dev/null
+++ b/egs/ami/ASR/pruned_transducer_stateless7/beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/beam_search.py
\ No newline at end of file
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/decode.py b/egs/ami/ASR/pruned_transducer_stateless7/decode.py
new file mode 100755
index 000000000..f47228fbe
--- /dev/null
+++ b/egs/ami/ASR/pruned_transducer_stateless7/decode.py
@@ -0,0 +1,747 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./pruned_transducer_stateless7/decode.py \
+        --iter 105000 \
+        --avg 10 \
+        --exp-dir ./pruned_transducer_stateless7/exp \
+        --max-duration 100 \
+        --decoding-method greedy_search
+
+(2) beam search
+./pruned_transducer_stateless7/decode.py \
+        --iter 105000 \
+        --avg 10 \
+        --exp-dir ./pruned_transducer_stateless7/exp \
+        --max-duration 500 \
+        --decoding-method beam_search \
+        --beam-size 4
+
+(3) modified beam search
+./pruned_transducer_stateless7/decode.py \
+        --iter 105000 \
+        --avg 10 \
+        --exp-dir ./pruned_transducer_stateless7/exp \
+        --max-duration 500 \
+        --decoding-method modified_beam_search \
+        --beam-size 4
+
+(4) fast beam search
+./pruned_transducer_stateless7/decode.py \
+        --iter 105000 \
+        --avg 10 \
+        --exp-dir ./pruned_transducer_stateless5/exp \
+        --max-duration 500 \
+        --decoding-method fast_beam_search \
+        --beam 4 \
+        --max-contexts 4 \
+        --max-states 8
+"""
+
+
+import argparse
+import logging
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import AmiAsrDataModule
+from beam_search import (
+    beam_search,
+    fast_beam_search_nbest_LG,
+    fast_beam_search_one_best,
+    greedy_search,
+    greedy_search_batch,
+    modified_beam_search,
+)
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall import NgramLm
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+    AttributeDict,
+    setup_logger,
+    store_transcripts,
+    str2bool,
+    write_error_stats,
+)
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=30,
+        help="""It specifies the checkpoint to use for decoding.
+        Note: Epoch counts from 0.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=10,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=True,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless2/exp",
+        help="The experiment dir",
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=Path,
+        default="data/lang_bpe_500",
+        help="The lang dir containing word table and LG graph",
+    )
+
+    parser.add_argument(
+        "--decoding-method",
+        type=str,
+        default="greedy_search",
+        help="""Possible values are:
+          - greedy_search
+          - beam_search
+          - modified_beam_search
+          - fast_beam_search
+          - fast_beam_search_nbest
+          - fast_beam_search_nbest_oracle
+          - fast_beam_search_nbest_LG
+        If you use fast_beam_search_nbest_LG, you have to specify
+        `--lang-dir`, which should contain `LG.pt`.
+        """,
+    )
+
+    parser.add_argument(
+        "--beam-size",
+        type=int,
+        default=4,
+        help="""An interger indicating how many candidates we will keep for each
+        frame. Used only when --decoding-method is beam_search or
+        modified_beam_search.""",
+    )
+
+    parser.add_argument(
+        "--beam",
+        type=float,
+        default=4,
+        help="""A floating point value to calculate the cutoff score during beam
+        search (i.e., `cutoff = max-score - beam`), which is the same as the
+        `beam` in Kaldi.
+        Used only when --decoding-method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--ngram-lm-scale",
+        type=float,
+        default=0.01,
+        help="""
+        Used only when --decoding_method is fast_beam_search_nbest_LG.
+        It specifies the scale for n-gram LM scores.
+        """,
+    )
+
+    parser.add_argument(
+        "--max-contexts",
+        type=int,
+        default=8,
+        help="""Used only when --decoding-method is
+        fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+        and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--max-states",
+        type=int,
+        default=64,
+        help="""Used only when --decoding-method is
+        fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+        and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+    )
+    parser.add_argument(
+        "--max-sym-per-frame",
+        type=int,
+        default=1,
+        help="""Maximum number of symbols per frame.
+        Used only when --decoding_method is greedy_search""",
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=200,
+        help="""Number of paths for nbest decoding.
+        Used only when the decoding method is fast_beam_search_nbest,
+        fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=0.5,
+        help="""Scale applied to lattice scores when computing nbest paths.
+        Used only when the decoding method is fast_beam_search_nbest,
+        fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def decode_one_batch(
+    params: AttributeDict,
+    model: nn.Module,
+    sp: spm.SentencePieceProcessor,
+    batch: dict,
+    decoding_graph: Optional[k2.Fsa] = None,
+    word_table: Optional[k2.SymbolTable] = None,
+) -> Dict[str, List[List[str]]]:
+    """Decode one batch and return the result in a dict. The dict has the
+    following format:
+
+        - key: It indicates the setting used for decoding. For example,
+               if greedy_search is used, it would be "greedy_search"
+               If beam search with a beam size of 7 is used, it would be
+               "beam_7"
+        - value: It contains the decoding result. `len(value)` equals to
+                 batch size. `value[i]` is the decoding result for the i-th
+                 utterance in the given batch.
+    Args:
+      params:
+        It's the return value of :func:`get_params`.
+      model:
+        The neural model.
+      sp:
+        The BPE model.
+      batch:
+        It is the return value from iterating
+        `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+        for the format of the `batch`.
+      decoding_graph:
+        The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+        only when --decoding_method is fast_beam_search.
+      word_table:
+        The word symbol table.
+    Returns:
+      Return the decoding result. See above description for the format of
+      the returned dict.
+    """
+    device = model.device
+    feature = batch["inputs"]
+    assert feature.ndim == 3
+
+    feature = feature.to(device)
+    # at entry, feature is (N, T, C)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    hyps = []
+
+    if params.decoding_method == "fast_beam_search":
+        hyp_tokens = fast_beam_search_one_best(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "fast_beam_search_nbest_LG":
+        hyp_tokens = fast_beam_search_nbest_LG(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+            num_paths=params.num_paths,
+            nbest_scale=params.nbest_scale,
+        )
+        for hyp in hyp_tokens:
+            hyps.append([word_table[i] for i in hyp])
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+        hyp_tokens = greedy_search_batch(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "modified_beam_search":
+        hyp_tokens = modified_beam_search(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam_size,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    else:
+        batch_size = encoder_out.size(0)
+
+        for i in range(batch_size):
+            # fmt: off
+            encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+            # fmt: on
+            if params.decoding_method == "greedy_search":
+                hyp = greedy_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    max_sym_per_frame=params.max_sym_per_frame,
+                )
+            elif params.decoding_method == "beam_search":
+                hyp = beam_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    beam=params.beam_size,
+                )
+            else:
+                raise ValueError(
+                    f"Unsupported decoding method: {params.decoding_method}"
+                )
+            hyps.append(sp.decode(hyp).split())
+
+    if params.decoding_method == "greedy_search":
+        return {"greedy_search": hyps}
+    elif params.decoding_method == "fast_beam_search":
+        return {
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
+        }
+    elif "fast_beam_search" in params.decoding_method:
+        key = f"beam_{params.beam}_"
+        key += f"max_contexts_{params.max_contexts}_"
+        key += f"max_states_{params.max_states}"
+        if "nbest" in params.decoding_method:
+            key += f"_num_paths_{params.num_paths}_"
+            key += f"nbest_scale_{params.nbest_scale}"
+            if "LG" in params.decoding_method:
+                key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+
+        return {key: hyps}
+    else:
+        return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+    dl: torch.utils.data.DataLoader,
+    params: AttributeDict,
+    model: nn.Module,
+    sp: spm.SentencePieceProcessor,
+    decoding_graph: Optional[k2.Fsa] = None,
+    word_table: Optional[k2.SymbolTable] = None,
+) -> Dict[str, List[Tuple[List[str], List[str]]]]:
+    """Decode dataset.
+
+    Args:
+      dl:
+        PyTorch's dataloader containing the dataset to decode.
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The neural model.
+      sp:
+        The BPE model.
+      decoding_graph:
+        The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+        only when --decoding_method is fast_beam_search.
+    Returns:
+      Return a dict, whose key may be "greedy_search" if greedy search
+      is used, or it may be "beam_7" if beam size of 7 is used.
+      Its value is a list of tuples. Each tuple contains two elements:
+      The first is the reference transcript, and the second is the
+      predicted result.
+    """
+    num_cuts = 0
+
+    try:
+        num_batches = len(dl)
+    except TypeError:
+        num_batches = "?"
+
+    if params.decoding_method == "greedy_search":
+        log_interval = 100
+    else:
+        log_interval = 2
+
+    results = defaultdict(list)
+    for batch_idx, batch in enumerate(dl):
+        texts = batch["supervisions"]["text"]
+        cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+        hyps_dict = decode_one_batch(
+            params=params,
+            model=model,
+            sp=sp,
+            decoding_graph=decoding_graph,
+            word_table=word_table,
+            batch=batch,
+        )
+
+        for name, hyps in hyps_dict.items():
+            this_batch = []
+            assert len(hyps) == len(texts)
+            for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+                ref_words = ref_text.split()
+                this_batch.append((cut_id, ref_words, hyp_words))
+
+            results[name].extend(this_batch)
+
+        num_cuts += len(texts)
+
+        if batch_idx % log_interval == 0:
+            batch_str = f"{batch_idx}/{num_batches}"
+
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+    return results
+
+
+def save_results(
+    params: AttributeDict,
+    test_set_name: str,
+    results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
+):
+    test_set_wers = dict()
+    test_set_cers = dict()
+    for key, results in results_dict.items():
+        recog_path = (
+            params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+        )
+        store_transcripts(filename=recog_path, texts=results)
+        logging.info(f"The transcripts are stored in {recog_path}")
+
+        # The following prints out WERs, per-word error statistics and aligned
+        # ref/hyp pairs.
+        wers_filename = (
+            params.res_dir / f"wers-{test_set_name}-{key}-{params.suffix}.txt"
+        )
+        with open(wers_filename, "w") as f:
+            wer = write_error_stats(
+                f, f"{test_set_name}-{key}", results, enable_log=True
+            )
+            test_set_wers[key] = wer
+
+        # we also compute CER for AMI dataset.
+        results_char = []
+        for res in results:
+            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
+        cers_filename = (
+            params.res_dir / f"cers-{test_set_name}-{key}-{params.suffix}.txt"
+        )
+        with open(cers_filename, "w") as f:
+            cer = write_error_stats(
+                f, f"{test_set_name}-{key}", results_char, enable_log=True
+            )
+            test_set_cers[key] = cer
+
+        logging.info("Wrote detailed error stats to {}".format(wers_filename))
+
+    test_set_wers = {k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])}
+    test_set_cers = {k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])}
+    errs_info = (
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+    )
+    with open(errs_info, "w") as f:
+        print("settings\tWER\tCER", file=f)
+        for key in test_set_wers:
+            print(
+                "{}\t{}\t{}".format(key, test_set_wers[key], test_set_cers[key]),
+                file=f,
+            )
+
+    s = "\nFor {}, WER/CER of different settings are:\n".format(test_set_name)
+    note = "\tbest for {}".format(test_set_name)
+    for key in test_set_wers:
+        s += "{}\t{}\t{}{}\n".format(key, test_set_wers[key], test_set_cers[key], note)
+        note = ""
+    logging.info(s)
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    AmiAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    params = get_params()
+    params.update(vars(args))
+
+    assert params.decoding_method in (
+        "greedy_search",
+        "beam_search",
+        "fast_beam_search",
+        "fast_beam_search_nbest_LG",
+        "modified_beam_search",
+    )
+    params.res_dir = params.exp_dir / params.decoding_method
+
+    if params.iter > 0:
+        params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+    else:
+        params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+    if "fast_beam_search" in params.decoding_method:
+        params.suffix += f"-beam-{params.beam}"
+        params.suffix += f"-max-contexts-{params.max_contexts}"
+        params.suffix += f"-max-states-{params.max_states}"
+        if "nbest" in params.decoding_method:
+            params.suffix += f"-nbest-scale-{params.nbest_scale}"
+            params.suffix += f"-num-paths-{params.num_paths}"
+            if "LG" in params.decoding_method:
+                params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+    elif "beam_search" in params.decoding_method:
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+    else:
+        params.suffix += f"-context-{params.context_size}"
+        params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+    setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+    logging.info("Decoding started")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"Device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(f"{params.lang_dir}/bpe.model")
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.unk_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to(device)
+    model.eval()
+    model.device = device
+
+    if "fast_beam_search" in params.decoding_method:
+        if params.decoding_method == "fast_beam_search_nbest_LG":
+            lexicon = Lexicon(params.lang_dir)
+            word_table = lexicon.word_table
+            lg_filename = params.lang_dir / "LG.pt"
+            logging.info(f"Loading {lg_filename}")
+            decoding_graph = k2.Fsa.from_dict(
+                torch.load(lg_filename, map_location=device)
+            )
+            decoding_graph.scores *= params.ngram_lm_scale
+        else:
+            word_table = None
+            decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+    else:
+        decoding_graph = None
+        word_table = None
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    ami = AmiAsrDataModule(args)
+
+    dev_ihm_cuts = ami.dev_ihm_cuts()
+    test_ihm_cuts = ami.test_ihm_cuts()
+    dev_sdm_cuts = ami.dev_sdm_cuts()
+    test_sdm_cuts = ami.test_sdm_cuts()
+    dev_gss_cuts = ami.dev_gss_cuts()
+    test_gss_cuts = ami.test_gss_cuts()
+
+    dev_ihm_dl = ami.test_dataloaders(dev_ihm_cuts)
+    test_ihm_dl = ami.test_dataloaders(test_ihm_cuts)
+    dev_sdm_dl = ami.test_dataloaders(dev_sdm_cuts)
+    test_sdm_dl = ami.test_dataloaders(test_sdm_cuts)
+    if dev_gss_cuts is not None:
+        dev_gss_dl = ami.test_dataloaders(dev_gss_cuts)
+    if test_gss_cuts is not None:
+        test_gss_dl = ami.test_dataloaders(test_gss_cuts)
+
+    test_sets = {
+        "dev_ihm": (dev_ihm_dl, dev_ihm_cuts),
+        "test_ihm": (test_ihm_dl, test_ihm_cuts),
+        "dev_sdm": (dev_sdm_dl, dev_sdm_cuts),
+        "test_sdm": (test_sdm_dl, test_sdm_cuts),
+    }
+    if dev_gss_cuts is not None:
+        test_sets["dev_gss"] = (dev_gss_dl, dev_gss_cuts)
+    if test_gss_cuts is not None:
+        test_sets["test_gss"] = (test_gss_dl, test_gss_cuts)
+
+    for test_set in test_sets:
+        logging.info(f"Decoding {test_set}")
+        dl, cuts = test_sets[test_set]
+        results_dict = decode_dataset(
+            dl=dl,
+            params=params,
+            model=model,
+            sp=sp,
+            word_table=word_table,
+            decoding_graph=decoding_graph,
+        )
+
+        save_results(
+            params=params,
+            test_set_name=test_set,
+            results_dict=results_dict,
+        )
+
+    logging.info("Done!")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/decoder.py b/egs/ami/ASR/pruned_transducer_stateless7/decoder.py
new file mode 120000
index 000000000..8283d8c5a
--- /dev/null
+++ b/egs/ami/ASR/pruned_transducer_stateless7/decoder.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py
\ No newline at end of file
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/encoder_interface.py b/egs/ami/ASR/pruned_transducer_stateless7/encoder_interface.py
new file mode 120000
index 000000000..0c2673d46
--- /dev/null
+++ b/egs/ami/ASR/pruned_transducer_stateless7/encoder_interface.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/encoder_interface.py
\ No newline at end of file
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/export.py b/egs/ami/ASR/pruned_transducer_stateless7/export.py
new file mode 120000
index 000000000..2713792e6
--- /dev/null
+++ b/egs/ami/ASR/pruned_transducer_stateless7/export.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/export.py
\ No newline at end of file
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/joiner.py b/egs/ami/ASR/pruned_transducer_stateless7/joiner.py
new file mode 120000
index 000000000..0f0c3c90a
--- /dev/null
+++ b/egs/ami/ASR/pruned_transducer_stateless7/joiner.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py
\ No newline at end of file
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/model.py b/egs/ami/ASR/pruned_transducer_stateless7/model.py
new file mode 120000
index 000000000..0d8bc665b
--- /dev/null
+++ b/egs/ami/ASR/pruned_transducer_stateless7/model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/model.py
\ No newline at end of file
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/optim.py b/egs/ami/ASR/pruned_transducer_stateless7/optim.py
new file mode 120000
index 000000000..8a05abb5f
--- /dev/null
+++ b/egs/ami/ASR/pruned_transducer_stateless7/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/optim.py
\ No newline at end of file
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/scaling.py b/egs/ami/ASR/pruned_transducer_stateless7/scaling.py
new file mode 120000
index 000000000..5f9be9fe0
--- /dev/null
+++ b/egs/ami/ASR/pruned_transducer_stateless7/scaling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py
\ No newline at end of file
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/scaling_converter.py b/egs/ami/ASR/pruned_transducer_stateless7/scaling_converter.py
new file mode 120000
index 000000000..f9960e5c6
--- /dev/null
+++ b/egs/ami/ASR/pruned_transducer_stateless7/scaling_converter.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py
\ No newline at end of file
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/train.py b/egs/ami/ASR/pruned_transducer_stateless7/train.py
new file mode 100755
index 000000000..b5efb3405
--- /dev/null
+++ b/egs/ami/ASR/pruned_transducer_stateless7/train.py
@@ -0,0 +1,1184 @@
+#!/usr/bin/env python3
+# Copyright    2021-2022  Xiaomi Corp.        (authors: Fangjun Kuang,
+#                                                       Wei Kang,
+#                                                       Mingshuang Luo,)
+#                                                       Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless7/train.py \
+  --world-size 4 \
+  --num-epochs 15 \
+  --start-epoch 1 \
+  --exp-dir pruned_transducer_stateless7/exp \
+  --max-duration 150 \
+    --use-fp16 True
+
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import AmiAsrDataModule
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import Transducer
+from optim import Eden, ScaledAdam
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+    save_checkpoint_with_global_batch_idx,
+    update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+    if isinstance(model, DDP):
+        # get underlying nn.Module
+        model = model.module
+    for module in model.modules():
+        if hasattr(module, "batch_count"):
+            module.batch_count = batch_count
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+    parser.add_argument(
+        "--num-encoder-layers",
+        type=str,
+        default="2,4,3,2,4",
+        help="Number of zipformer encoder layers, comma separated.",
+    )
+
+    parser.add_argument(
+        "--feedforward-dims",
+        type=str,
+        default="1024,1024,2048,2048,1024",
+        help="Feedforward dimension of the zipformer encoder layers, comma separated.",
+    )
+
+    parser.add_argument(
+        "--nhead",
+        type=str,
+        default="8,8,8,8,8",
+        help="Number of attention heads in the zipformer encoder layers.",
+    )
+
+    parser.add_argument(
+        "--encoder-dims",
+        type=str,
+        default="384,384,384,384,384",
+        help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated",
+    )
+
+    parser.add_argument(
+        "--attention-dims",
+        type=str,
+        default="192,192,192,192,192",
+        help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
+        not the same as embedding dimension.""",
+    )
+
+    parser.add_argument(
+        "--encoder-unmasked-dims",
+        type=str,
+        default="256,256,256,256,256",
+        help="Unmasked dimensions in the encoders, relates to augmentation during training.  "
+        "Must be <= each of encoder_dims.  Empirically, less than 256 seems to make performance "
+        " worse.",
+    )
+
+    parser.add_argument(
+        "--zipformer-downsampling-factors",
+        type=str,
+        default="1,2,4,8,2",
+        help="Downsampling factor for each stack of encoder layers.",
+    )
+
+    parser.add_argument(
+        "--cnn-module-kernels",
+        type=str,
+        default="31,31,31,31,31",
+        help="Sizes of kernels in convolution modules",
+    )
+
+    parser.add_argument(
+        "--decoder-dim",
+        type=int,
+        default=512,
+        help="Embedding dimension in the decoder model.",
+    )
+
+    parser.add_argument(
+        "--joiner-dim",
+        type=int,
+        default=512,
+        help="""Dimension used in the joiner model.
+        Outputs from the encoder and decoder model are projected
+        to this dimension before adding.
+        """,
+    )
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--world-size",
+        type=int,
+        default=1,
+        help="Number of GPUs for DDP training.",
+    )
+
+    parser.add_argument(
+        "--master-port",
+        type=int,
+        default=12354,
+        help="Master port to use for DDP training.",
+    )
+
+    parser.add_argument(
+        "--tensorboard",
+        type=str2bool,
+        default=True,
+        help="Should various information be logged in tensorboard.",
+    )
+
+    parser.add_argument(
+        "--num-epochs",
+        type=int,
+        default=11,
+        help="Number of epochs to train.",
+    )
+
+    parser.add_argument(
+        "--start-epoch",
+        type=int,
+        default=1,
+        help="""Resume training from this epoch. It should be positive.
+        If larger than 1, it will load checkpoint from
+        exp-dir/epoch-{start_epoch-1}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--start-batch",
+        type=int,
+        default=0,
+        help="""If positive, --start-epoch is ignored and
+        it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless7/exp",
+        help="""The experiment dir.
+        It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_500/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--base-lr", type=float, default=0.05, help="The base learning rate."
+    )
+
+    parser.add_argument(
+        "--lr-batches",
+        type=float,
+        default=5000,
+        help="""Number of steps that affects how rapidly the learning rate
+        decreases. We suggest not to change this.""",
+    )
+
+    parser.add_argument(
+        "--lr-epochs",
+        type=float,
+        default=3.5,
+        help="""Number of epochs that affects how rapidly the learning rate decreases.
+        """,
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+    )
+
+    parser.add_argument(
+        "--prune-range",
+        type=int,
+        default=5,
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
+    )
+
+    parser.add_argument(
+        "--lm-scale",
+        type=float,
+        default=0.25,
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
+    )
+
+    parser.add_argument(
+        "--am-scale",
+        type=float,
+        default=0.0,
+        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+    )
+
+    parser.add_argument(
+        "--simple-loss-scale",
+        type=float,
+        default=0.5,
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
+    )
+
+    parser.add_argument(
+        "--seed",
+        type=int,
+        default=42,
+        help="The seed for random generators intended for reproducibility",
+    )
+
+    parser.add_argument(
+        "--print-diagnostics",
+        type=str2bool,
+        default=False,
+        help="Accumulate stats on activations, print them and exit.",
+    )
+
+    parser.add_argument(
+        "--inf-check",
+        type=str2bool,
+        default=False,
+        help="Add hooks to check for infinite module outputs and gradients.",
+    )
+
+    parser.add_argument(
+        "--save-every-n",
+        type=int,
+        default=5000,
+        help="""Save checkpoint after processing this number of batches"
+        periodically. We save checkpoint to exp-dir/ whenever
+        params.batch_idx_train % save_every_n == 0. The checkpoint filename
+        has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+        Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+        end of each epoch where `xxx` is the epoch number counting from 0.
+        """,
+    )
+
+    parser.add_argument(
+        "--keep-last-k",
+        type=int,
+        default=10,
+        help="""Only keep this number of checkpoints on disk.
+        For instance, if it is 3, there are only 3 checkpoints
+        in the exp-dir with filenames `checkpoint-xxx.pt`.
+        It does not affect checkpoints with name `epoch-xxx.pt`.
+        """,
+    )
+
+    parser.add_argument(
+        "--average-period",
+        type=int,
+        default=200,
+        help="""Update the averaged model, namely `model_avg`, after processing
+        this number of batches. `model_avg` is a separate version of model,
+        in which each floating-point parameter is the average of all the
+        parameters from the start of training. Each time we take the average,
+        we do: `model_avg = model * (average_period / batch_idx_train) +
+            model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+        """,
+    )
+
+    parser.add_argument(
+        "--use-fp16",
+        type=str2bool,
+        default=False,
+        help="Whether to use half precision training.",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    """Return a dict containing training parameters.
+
+    All training related parameters that are not passed from the commandline
+    are saved in the variable `params`.
+
+    Commandline options are merged into `params` after they are parsed, so
+    you can also access them via `params`.
+
+    Explanation of options saved in `params`:
+
+        - best_train_loss: Best training loss so far. It is used to select
+                           the model that has the lowest training loss. It is
+                           updated during the training.
+
+        - best_valid_loss: Best validation loss so far. It is used to select
+                           the model that has the lowest validation loss. It is
+                           updated during the training.
+
+        - best_train_epoch: It is the epoch that has the best training loss.
+
+        - best_valid_epoch: It is the epoch that has the best validation loss.
+
+        - batch_idx_train: Used to writing statistics to tensorboard. It
+                           contains number of batches trained so far across
+                           epochs.
+
+        - log_interval:  Print training loss if batch_idx % log_interval` is 0
+
+        - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+        - valid_interval:  Run validation if batch_idx % valid_interval is 0
+
+        - feature_dim: The model input dim. It has to match the one used
+                       in computing features.
+
+        - subsampling_factor:  The subsampling factor for the model.
+
+        - encoder_dim: Hidden dim for multi-head attention model.
+
+        - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+        - warm_step: The warmup period that dictates the decay of the
+              scale on "simple" (un-pruned) loss.
+    """
+    params = AttributeDict(
+        {
+            "best_train_loss": float("inf"),
+            "best_valid_loss": float("inf"),
+            "best_train_epoch": -1,
+            "best_valid_epoch": -1,
+            "batch_idx_train": 0,
+            "log_interval": 100,
+            "reset_interval": 200,
+            "valid_interval": 3000,  # For the 100h subset, use 800
+            # parameters for zipformer
+            "feature_dim": 80,
+            "subsampling_factor": 4,  # not passed in, this is fixed.
+            "warm_step": 2000,
+            "env_info": get_env_info(),
+        }
+    )
+
+    return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+    # TODO: We can add an option to switch between Zipformer and Transformer
+    def to_int_tuple(s: str):
+        return tuple(map(int, s.split(",")))
+
+    encoder = Zipformer(
+        num_features=params.feature_dim,
+        output_downsampling_factor=2,
+        zipformer_downsampling_factors=to_int_tuple(
+            params.zipformer_downsampling_factors
+        ),
+        encoder_dims=to_int_tuple(params.encoder_dims),
+        attention_dim=to_int_tuple(params.attention_dims),
+        encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
+        nhead=to_int_tuple(params.nhead),
+        feedforward_dim=to_int_tuple(params.feedforward_dims),
+        cnn_module_kernels=to_int_tuple(params.cnn_module_kernels),
+        num_encoder_layers=to_int_tuple(params.num_encoder_layers),
+    )
+    return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+    decoder = Decoder(
+        vocab_size=params.vocab_size,
+        decoder_dim=params.decoder_dim,
+        blank_id=params.blank_id,
+        context_size=params.context_size,
+    )
+    return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+    joiner = Joiner(
+        encoder_dim=int(params.encoder_dims.split(",")[-1]),
+        decoder_dim=params.decoder_dim,
+        joiner_dim=params.joiner_dim,
+        vocab_size=params.vocab_size,
+    )
+    return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+    encoder = get_encoder_model(params)
+    decoder = get_decoder_model(params)
+    joiner = get_joiner_model(params)
+
+    model = Transducer(
+        encoder=encoder,
+        decoder=decoder,
+        joiner=joiner,
+        encoder_dim=int(params.encoder_dims.split(",")[-1]),
+        decoder_dim=params.decoder_dim,
+        joiner_dim=params.joiner_dim,
+        vocab_size=params.vocab_size,
+    )
+    return model
+
+
+def load_checkpoint_if_available(
+    params: AttributeDict,
+    model: nn.Module,
+    model_avg: nn.Module = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+    """Load checkpoint from file.
+
+    If params.start_batch is positive, it will load the checkpoint from
+    `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+    params.start_epoch is larger than 1, it will load the checkpoint from
+    `params.start_epoch - 1`.
+
+    Apart from loading state dict for `model` and `optimizer` it also updates
+    `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+    and `best_valid_loss` in `params`.
+
+    Args:
+      params:
+        The return value of :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer that we are using.
+      scheduler:
+        The scheduler that we are using.
+    Returns:
+      Return a dict containing previously saved training info.
+    """
+    if params.start_batch > 0:
+        filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+    elif params.start_epoch > 1:
+        filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+    else:
+        return None
+
+    assert filename.is_file(), f"{filename} does not exist!"
+
+    saved_params = load_checkpoint(
+        filename,
+        model=model,
+        model_avg=model_avg,
+        optimizer=optimizer,
+        scheduler=scheduler,
+    )
+
+    keys = [
+        "best_train_epoch",
+        "best_valid_epoch",
+        "batch_idx_train",
+        "best_train_loss",
+        "best_valid_loss",
+    ]
+    for k in keys:
+        params[k] = saved_params[k]
+
+    if params.start_batch > 0:
+        if "cur_epoch" in saved_params:
+            params["start_epoch"] = saved_params["cur_epoch"]
+
+        if "cur_batch_idx" in saved_params:
+            params["cur_batch_idx"] = saved_params["cur_batch_idx"]
+
+    return saved_params
+
+
+def save_checkpoint(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    model_avg: Optional[nn.Module] = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+    sampler: Optional[CutSampler] = None,
+    scaler: Optional[GradScaler] = None,
+    rank: int = 0,
+) -> None:
+    """Save model, optimizer, scheduler and training stats to file.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer used in the training.
+      sampler:
+       The sampler for the training dataset.
+      scaler:
+        The scaler used for mix precision training.
+    """
+    if rank != 0:
+        return
+    filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+    save_checkpoint_impl(
+        filename=filename,
+        model=model,
+        model_avg=model_avg,
+        params=params,
+        optimizer=optimizer,
+        scheduler=scheduler,
+        sampler=sampler,
+        scaler=scaler,
+        rank=rank,
+    )
+
+    if params.best_train_epoch == params.cur_epoch:
+        best_train_filename = params.exp_dir / "best-train-loss.pt"
+        copyfile(src=filename, dst=best_train_filename)
+
+    if params.best_valid_epoch == params.cur_epoch:
+        best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+        copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    sp: spm.SentencePieceProcessor,
+    batch: dict,
+    is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+    """
+    Compute transducer loss given the model and its inputs.
+
+    Args:
+      params:
+        Parameters for training. See :func:`get_params`.
+      model:
+        The model for training. It is an instance of Zipformer in our case.
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      is_training:
+        True for training. False for validation. When it is True, this
+        function enables autograd during computation; when it is False, it
+        disables autograd.
+     warmup: a floating point value which increases throughout training;
+        values >= 1.0 are fully warmed up and have all modules present.
+    """
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+    feature = batch["inputs"]
+    # at entry, feature is (N, T, C)
+    assert feature.ndim == 3
+    feature = feature.to(device)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"]
+
+    batch_idx_train = params.batch_idx_train
+    warm_step = params.warm_step
+
+    texts = supervisions["text"]
+    y = sp.encode(texts, out_type=int)
+    y = k2.RaggedTensor(y).to(device)
+
+    with torch.set_grad_enabled(is_training):
+        simple_loss, pruned_loss = model(
+            x=feature,
+            x_lens=feature_lens,
+            y=y,
+            prune_range=params.prune_range,
+            am_scale=params.am_scale,
+            lm_scale=params.lm_scale,
+        )
+
+        s = params.simple_loss_scale
+        # take down the scale on the simple loss from 1.0 at the start
+        # to params.simple_loss scale by warm_step.
+        simple_loss_scale = (
+            s
+            if batch_idx_train >= warm_step
+            else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+        )
+        pruned_loss_scale = (
+            1.0
+            if batch_idx_train >= warm_step
+            else 0.1 + 0.9 * (batch_idx_train / warm_step)
+        )
+
+        loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+    assert loss.requires_grad == is_training
+
+    info = MetricsTracker()
+    with warnings.catch_warnings():
+        warnings.simplefilter("ignore")
+        info["frames"] = ((feature_lens - 7) // 2).sum().item()
+
+    # Note: We use reduction=sum while computing the loss.
+    info["loss"] = loss.detach().cpu().item()
+    info["simple_loss"] = simple_loss.detach().cpu().item()
+    info["pruned_loss"] = pruned_loss.detach().cpu().item()
+
+    return loss, info
+
+
+def compute_validation_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    sp: spm.SentencePieceProcessor,
+    valid_dl: torch.utils.data.DataLoader,
+    world_size: int = 1,
+) -> MetricsTracker:
+    """Run the validation process."""
+    model.eval()
+
+    tot_loss = MetricsTracker()
+
+    for batch_idx, batch in enumerate(valid_dl):
+        loss, loss_info = compute_loss(
+            params=params,
+            model=model,
+            sp=sp,
+            batch=batch,
+            is_training=False,
+        )
+        assert loss.requires_grad is False
+        tot_loss = tot_loss + loss_info
+
+    if world_size > 1:
+        tot_loss.reduce(loss.device)
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    if loss_value < params.best_valid_loss:
+        params.best_valid_epoch = params.cur_epoch
+        params.best_valid_loss = loss_value
+
+    return tot_loss
+
+
+def train_one_epoch(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    optimizer: torch.optim.Optimizer,
+    scheduler: LRSchedulerType,
+    sp: spm.SentencePieceProcessor,
+    train_dl: torch.utils.data.DataLoader,
+    valid_dl: torch.utils.data.DataLoader,
+    scaler: GradScaler,
+    model_avg: Optional[nn.Module] = None,
+    tb_writer: Optional[SummaryWriter] = None,
+    world_size: int = 1,
+    rank: int = 0,
+) -> None:
+    """Train the model for one epoch.
+
+    The training loss from the mean of all frames is saved in
+    `params.train_loss`. It runs the validation process every
+    `params.valid_interval` batches.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The model for training.
+      optimizer:
+        The optimizer we are using.
+      scheduler:
+        The learning rate scheduler, we call step() every step.
+      train_dl:
+        Dataloader for the training dataset.
+      valid_dl:
+        Dataloader for the validation dataset.
+      scaler:
+        The scaler used for mix precision training.
+      model_avg:
+        The stored model averaged from the start of training.
+      tb_writer:
+        Writer to write log messages to tensorboard.
+      world_size:
+        Number of nodes in DDP training. If it is 1, DDP is disabled.
+      rank:
+        The rank of the node in DDP training. If no DDP is used, it should
+        be set to 0.
+    """
+    model.train()
+
+    tot_loss = MetricsTracker()
+
+    cur_batch_idx = params.get("cur_batch_idx", 0)
+
+    for batch_idx, batch in enumerate(train_dl):
+        if batch_idx < cur_batch_idx:
+            continue
+        cur_batch_idx = batch_idx
+
+        params.batch_idx_train += 1
+        batch_size = len(batch["supervisions"]["text"])
+
+        try:
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, loss_info = compute_loss(
+                    params=params,
+                    model=model,
+                    sp=sp,
+                    batch=batch,
+                    is_training=True,
+                )
+            # summary stats
+            tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+            # NOTE: We use reduction==sum and loss is computed over utterances
+            # in the batch and there is no normalization to it so far.
+            scaler.scale(loss).backward()
+            set_batch_count(model, params.batch_idx_train)
+            scheduler.step_batch(params.batch_idx_train)
+
+            scaler.step(optimizer)
+            scaler.update()
+            optimizer.zero_grad()
+        except:  # noqa
+            display_and_save_batch(batch, params=params, sp=sp)
+            raise
+
+        if params.print_diagnostics and batch_idx == 5:
+            return
+
+        if (
+            rank == 0
+            and params.batch_idx_train > 0
+            and params.batch_idx_train % params.average_period == 0
+        ):
+            update_averaged_model(
+                params=params,
+                model_cur=model,
+                model_avg=model_avg,
+            )
+
+        if (
+            params.batch_idx_train > 0
+            and params.batch_idx_train % params.save_every_n == 0
+        ):
+            params.cur_batch_idx = batch_idx
+            save_checkpoint_with_global_batch_idx(
+                out_dir=params.exp_dir,
+                global_batch_idx=params.batch_idx_train,
+                model=model,
+                model_avg=model_avg,
+                params=params,
+                optimizer=optimizer,
+                scheduler=scheduler,
+                sampler=train_dl.sampler,
+                scaler=scaler,
+                rank=rank,
+            )
+            del params.cur_batch_idx
+            remove_checkpoints(
+                out_dir=params.exp_dir,
+                topk=params.keep_last_k,
+                rank=rank,
+            )
+
+        if batch_idx % 100 == 0 and params.use_fp16:
+            # If the grad scale was less than 1, try increasing it.    The _growth_interval
+            # of the grad scaler is configurable, but we can't configure it to have different
+            # behavior depending on the current grad scale.
+            cur_grad_scale = scaler._scale.item()
+            if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
+                scaler.update(cur_grad_scale * 2.0)
+            if cur_grad_scale < 0.01:
+                logging.warning(f"Grad scale is small: {cur_grad_scale}")
+            if cur_grad_scale < 1.0e-05:
+                raise RuntimeError(
+                    f"grad_scale is too small, exiting: {cur_grad_scale}"
+                )
+
+        if batch_idx % params.log_interval == 0:
+            cur_lr = scheduler.get_last_lr()[0]
+            cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+            logging.info(
+                f"Epoch {params.cur_epoch}, "
+                f"batch {batch_idx}, loss[{loss_info}], "
+                f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+                f"lr: {cur_lr:.2e}, "
+                + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+            )
+
+            if tb_writer is not None:
+                tb_writer.add_scalar(
+                    "train/learning_rate", cur_lr, params.batch_idx_train
+                )
+
+                loss_info.write_summary(
+                    tb_writer, "train/current_", params.batch_idx_train
+                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                if params.use_fp16:
+                    tb_writer.add_scalar(
+                        "train/grad_scale", cur_grad_scale, params.batch_idx_train
+                    )
+
+        if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+            logging.info("Computing validation loss")
+            valid_info = compute_validation_loss(
+                params=params,
+                model=model,
+                sp=sp,
+                valid_dl=valid_dl,
+                world_size=world_size,
+            )
+            model.train()
+            logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+            logging.info(
+                f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+            )
+            if tb_writer is not None:
+                valid_info.write_summary(
+                    tb_writer, "train/valid_", params.batch_idx_train
+                )
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    params.train_loss = loss_value
+    if params.train_loss < params.best_train_loss:
+        params.best_train_epoch = params.cur_epoch
+        params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+    """
+    Args:
+      rank:
+        It is a value between 0 and `world_size-1`, which is
+        passed automatically by `mp.spawn()` in :func:`main`.
+        The node with rank 0 is responsible for saving checkpoint.
+      world_size:
+        Number of GPUs for DDP training.
+      args:
+        The return value of get_parser().parse_args()
+    """
+    params = get_params()
+    params.update(vars(args))
+
+    fix_random_seed(params.seed)
+    if world_size > 1:
+        setup_dist(rank, world_size, params.master_port)
+
+    setup_logger(f"{params.exp_dir}/log/log-train")
+    logging.info("Training started")
+
+    if args.tensorboard and rank == 0:
+        tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+    else:
+        tb_writer = None
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", rank)
+    logging.info(f"Device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    assert params.save_every_n >= params.average_period
+    model_avg: Optional[nn.Module] = None
+    if rank == 0:
+        # model_avg is only used with rank 0
+        model_avg = copy.deepcopy(model).to(torch.float64)
+
+    assert params.start_epoch > 0, params.start_epoch
+    checkpoints = load_checkpoint_if_available(
+        params=params, model=model, model_avg=model_avg
+    )
+
+    model.to(device)
+    if world_size > 1:
+        logging.info("Using DDP")
+        model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+    optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0)
+
+    scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+    if checkpoints and "optimizer" in checkpoints:
+        logging.info("Loading optimizer state dict")
+        optimizer.load_state_dict(checkpoints["optimizer"])
+
+    if (
+        checkpoints
+        and "scheduler" in checkpoints
+        and checkpoints["scheduler"] is not None
+    ):
+        logging.info("Loading scheduler state dict")
+        scheduler.load_state_dict(checkpoints["scheduler"])
+
+    if params.print_diagnostics:
+        opts = diagnostics.TensorDiagnosticOptions(
+            2**22
+        )  # allow 4 megabytes per sub-module
+        diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+    if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+        # We only load the sampler's state dict when it loads a checkpoint
+        # saved in the middle of an epoch
+        sampler_state_dict = checkpoints["sampler"]
+    else:
+        sampler_state_dict = None
+
+    if params.inf_check:
+        register_inf_check_hooks(model)
+
+    ami = AmiAsrDataModule(args)
+
+    # Here is the duration statistics of the training set.
+    # Cuts count: 1230033
+    # Total duration (hh:mm:ss): 904:25:34
+    # Speech duration (hh:mm:ss): 904:25:34 (100.0%)
+    # Duration statistics (seconds):
+    # mean	2.6
+    # std	2.8
+    # min	0.0
+    # 25%	0.6
+    # 50%	1.6
+    # 75%	3.8
+    # 99%	12.3
+    # 99.5%	13.9
+    # 99.9%	18.3
+    # max	36.8
+
+    train_cuts = ami.train_cuts(sp=sp)
+    train_dl = ami.train_dataloaders(train_cuts, sampler_state_dict=sampler_state_dict)
+
+    valid_cuts = ami.dev_ihm_cuts()
+    valid_dl = ami.valid_dataloaders(valid_cuts)
+
+    if not params.print_diagnostics:
+        scan_pessimistic_batches_for_oom(
+            model=model,
+            train_dl=train_dl,
+            optimizer=optimizer,
+            sp=sp,
+            params=params,
+        )
+
+    scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+    if checkpoints and "grad_scaler" in checkpoints:
+        logging.info("Loading grad scaler state dict")
+        scaler.load_state_dict(checkpoints["grad_scaler"])
+
+    for epoch in range(params.start_epoch, params.num_epochs + 1):
+        scheduler.step_epoch(epoch - 1)
+        fix_random_seed(params.seed + epoch - 1)
+        train_dl.sampler.set_epoch(epoch - 1)
+
+        if tb_writer is not None:
+            tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+        params.cur_epoch = epoch
+
+        train_one_epoch(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sp=sp,
+            train_dl=train_dl,
+            valid_dl=valid_dl,
+            scaler=scaler,
+            tb_writer=tb_writer,
+            world_size=world_size,
+            rank=rank,
+        )
+
+        if params.print_diagnostics:
+            diagnostic.print_diagnostics()
+            break
+
+        save_checkpoint(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sampler=train_dl.sampler,
+            scaler=scaler,
+            rank=rank,
+        )
+
+    logging.info("Done!")
+
+    if world_size > 1:
+        torch.distributed.barrier()
+        cleanup_dist()
+
+
+def display_and_save_batch(
+    batch: dict,
+    params: AttributeDict,
+    sp: spm.SentencePieceProcessor,
+) -> None:
+    """Display the batch statistics and save the batch into disk.
+
+    Args:
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      params:
+        Parameters for training. See :func:`get_params`.
+      sp:
+        The BPE model.
+    """
+    from lhotse.utils import uuid4
+
+    filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+    logging.info(f"Saving batch to {filename}")
+    torch.save(batch, filename)
+
+    supervisions = batch["supervisions"]
+    features = batch["inputs"]
+
+    logging.info(f"features shape: {features.shape}")
+
+    y = sp.encode(supervisions["text"], out_type=int)
+    num_tokens = sum(len(i) for i in y)
+    logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+    model: Union[nn.Module, DDP],
+    train_dl: torch.utils.data.DataLoader,
+    optimizer: torch.optim.Optimizer,
+    sp: spm.SentencePieceProcessor,
+    params: AttributeDict,
+):
+    from lhotse.dataset import find_pessimistic_batches
+
+    logging.info(
+        "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+    )
+    batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+    for criterion, cuts in batches.items():
+        batch = train_dl.dataset[cuts]
+        try:
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, _ = compute_loss(
+                    params=params,
+                    model=model,
+                    sp=sp,
+                    batch=batch,
+                    is_training=True,
+                )
+            loss.backward()
+            optimizer.zero_grad()
+        except Exception as e:
+            if "CUDA out of memory" in str(e):
+                logging.error(
+                    "Your GPU ran out of memory with the current "
+                    "max_duration setting. We recommend decreasing "
+                    "max_duration and trying again.\n"
+                    f"Failing criterion: {criterion} "
+                    f"(={crit_values[criterion]}) ..."
+                )
+            display_and_save_batch(batch, params=params, sp=sp)
+            raise
+        logging.info(
+            f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+        )
+
+
+def main():
+    parser = get_parser()
+    AmiAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    world_size = args.world_size
+    assert world_size >= 1
+    if world_size > 1:
+        mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+    else:
+        run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/zipformer.py b/egs/ami/ASR/pruned_transducer_stateless7/zipformer.py
new file mode 120000
index 000000000..f2f66041e
--- /dev/null
+++ b/egs/ami/ASR/pruned_transducer_stateless7/zipformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/zipformer.py
\ No newline at end of file
diff --git a/egs/ami/ASR/shared b/egs/ami/ASR/shared
new file mode 120000
index 000000000..4cbd91a7e
--- /dev/null
+++ b/egs/ami/ASR/shared
@@ -0,0 +1 @@
+../../../icefall/shared
\ No newline at end of file

From 61032e70e097aea63d191183466d0f1b16f9e16e Mon Sep 17 00:00:00 2001
From: abb128 <65567823+abb128@users.noreply.github.com>
Date: Sat, 26 Nov 2022 04:10:37 +0200
Subject: [PATCH 047/120] Fix exception in find_checkpoints (#668)

---
 icefall/checkpoint.py | 10 +++++++++-
 1 file changed, 9 insertions(+), 1 deletion(-)

diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py
index 8aa0a8eeb..f0663a1df 100644
--- a/icefall/checkpoint.py
+++ b/icefall/checkpoint.py
@@ -292,7 +292,15 @@ def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
     """
     checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt"))
     pattern = re.compile(r"checkpoint-([0-9]+).pt")
-    iter_checkpoints = [(int(pattern.search(c).group(1)), c) for c in checkpoints]
+    iter_checkpoints = []
+    for c in checkpoints:
+        result = pattern.search(c)
+        if not result:
+            logging.warn(f"Invalid checkpoint filename {c}")
+            continue
+        
+        iter_checkpoints.append((int(result.group(1)), c))
+
     # iter_checkpoints is a list of tuples. Each tuple contains
     # two elements: (iteration_number, checkpoint-iteration_number.pt)
 

From 9cf79cac3f23757e25b499947f045efb0f4d71a6 Mon Sep 17 00:00:00 2001
From: Guo Liyong 
Date: Sat, 26 Nov 2022 21:48:17 +0800
Subject: [PATCH 048/120] message formatting

---
 .../ASR/pruned_transducer_stateless7/optim.py | 76 +++++++++++--------
 .../ASR/pruned_transducer_stateless7/train.py | 10 +--
 2 files changed, 45 insertions(+), 41 deletions(-)

diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py
index ab55381d7..790752fe1 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py
@@ -42,7 +42,7 @@ class BatchedOptimizer(Optimizer):
         super(BatchedOptimizer, self).__init__(params, defaults)
 
     @contextlib.contextmanager
-    def batched_params(self, param_group, group_params_names=None):
+    def batched_params(self, param_group, group_params_names):
         """
         This function returns (technically, yields) a list of
           of tuples (p, state), where
@@ -85,7 +85,9 @@ class BatchedOptimizer(Optimizer):
             batches_names[key].append(named_p)
 
         batches_names_keys = list(batches_names.keys())
-        sorted_idx = sorted(range(len(batches_names)), key=lambda i: batches_names_keys[i])
+        sorted_idx = sorted(
+            range(len(batches_names)), key=lambda i: batches_names_keys[i]
+        )
         batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx]
         batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
 
@@ -174,7 +176,7 @@ class ScaledAdam(BatchedOptimizer):
         size_update_period=4,
         clipping_update_period=100,
         parameters_names=None,
-        show_dominant_parameters=False,
+        show_dominant_parameters=True,
     ):
 
         defaults = dict(
@@ -211,7 +213,7 @@ class ScaledAdam(BatchedOptimizer):
                 loss = closure()
 
         batch = True
-        assert len(self.param_groups)  == len(self.parameters_names)
+        assert len(self.param_groups) == len(self.parameters_names)
 
         for group, group_params_names in zip(self.param_groups, self.parameters_names):
 
@@ -381,42 +383,52 @@ class ScaledAdam(BatchedOptimizer):
             return ans
 
     def _show_gradient_dominating_parameter(self, pairs, tot_sumsq):
-        # ori means calculated with state["param_rms"]
-        # cur means calculated with "param_rms" of current param.
-        # bt is short batch
-        # all_sumsq_ori_rms
-        all_sumsq_ori = {}
-        all_sumsq_cur = {}
+        all_sumsq_orig = {}
         for (p, state, batch_param_names) in pairs:
             # p is a stacked batch parameters.
-            grad = p.grad
+            batch_grad = p.grad
             if p.numel() == p.shape[0]:  # a batch of scalars
-                batch_sumsq_ori = grad**2  # sum() to change shape [1] to []
-                batch_sumsq_cur = batch_sumsq_ori  # sum() to change shape [1] to []
+                batch_sumsq_orig = batch_grad**2
                 # Dummpy values used by following `zip` statement.
-                batch_rms_ori = torch.zeros(p.shape[0])
-                batch_rms_cur = batch_rms_ori
+                batch_rms_orig = torch.ones(p.shape[0])
             else:
-                batch_rms_ori = state["param_rms"]
-                batch_sumsq_ori = ((grad * batch_rms_ori) ** 2).sum(dim=list(range(1, grad.ndim)))
+                batch_rms_orig = state["param_rms"]
+                batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum(
+                    dim=list(range(1, batch_grad.ndim))
+                )
 
-                batch_rms_cur = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
-                batch_sumsq_cur = ((grad * batch_rms_cur) ** 2).sum(dim=list(range(1, grad.ndim)))
+            for name, sumsq_orig, rms, grad in zip(
+                batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
+            ):
 
-            for name, sumsq_ori, sumsq_cur in zip(
-                    batch_param_names, batch_sumsq_ori, batch_sumsq_cur):
+                proportion_orig = sumsq_orig / tot_sumsq
+                all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
 
-                proportion_ori = sumsq_ori / tot_sumsq
-                proportion_cur = sumsq_cur / tot_sumsq
-
-                all_sumsq_ori[name] = (proportion_ori, sumsq_ori)
-                all_sumsq_cur[name] = (proportion_cur, sumsq_cur)
-
-        for rms_type, all_sumsq in zip(("ori", "cur"), (all_sumsq_ori, all_sumsq_cur)):
-            sorted_by_proportion = {k: v for k, v in sorted(all_sumsq.items(), key=lambda item: item[1][0], reverse=True)}
-            dominant_param_name = next(iter(sorted_by_proportion))
-            dominant_proportion, dominant_sumsq = sorted_by_proportion[dominant_param_name]
-            logging.info(f"Dominant sumsq with {rms_type}_rms: {dominant_param_name} {dominant_proportion}  {dominant_sumsq} {tot_sumsq}")
+        assert torch.isclose(
+            sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
+            torch.tensor(1.0),
+        )
+        sorted_by_proportion = {
+            k: v
+            for k, v in sorted(
+                all_sumsq_orig.items(), key=lambda item: item[1][0], reverse=True
+            )
+        }
+        dominant_param_name = next(iter(sorted_by_proportion))
+        (
+            dominant_proportion,
+            dominant_sumsq,
+            dominant_rms,
+            dominant_grad,
+        ) = sorted_by_proportion[dominant_param_name]
+        logging.info(
+            f"Parameter Dominanting tot_sumsq {dominant_param_name}"
+            f" with proportion {dominant_proportion:.2f},"
+            f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
+            f"={dominant_sumsq:.3e},"
+            f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
+            f" orig_rms_sq={(dominant_rms**2).item():.3e}"
+        )
 
     def _step_one_batch(
         self, group: dict, p: Tensor, state: dict, clipping_scale: float
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py
index 8375b1a18..e5a3e68df 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py
@@ -368,13 +368,6 @@ def get_parser():
         help="Whether to use half precision training.",
     )
 
-    parser.add_argument(
-        "--show-dominant-parameters",
-        type=str2bool,
-        default=False,
-        help="Whether to show dominant parameters.",
-    )
-
     add_model_arguments(parser)
 
     return parser
@@ -998,8 +991,7 @@ def run(rank, world_size, args):
     parameters_names = []
     parameters_names.append([name_param_pair[0] for name_param_pair in model.named_parameters()])
     optimizer = ScaledAdam(model.parameters(), lr=params.base_lr,
-            clipping_scale=2.0, parameters_names=parameters_names,
-            show_dominant_parameters=params.show_dominant_parameters)
+            clipping_scale=2.0, parameters_names=parameters_names)
 
     scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
 

From 6693d907d3ddd5c5eade144b55a57c8831d6d9b2 Mon Sep 17 00:00:00 2001
From: huangruizhe 
Date: Sat, 26 Nov 2022 22:26:09 -0500
Subject: [PATCH 049/120] shuffle full Librispeech data (#574)

* shuffled full/partial librispeech data

* fixed the code style issue

* Shuffled full librispeech data off-line

* Fixed style, addressed comments, and removed redandunt codes

* Used the suggested version of black

* Propagated the changes to other folders for librispeech (except
conformer_mmi and streaming_conformer_ctc)
---
 egs/librispeech/ASR/conformer_ctc/train.py             |  6 +++---
 egs/librispeech/ASR/conformer_ctc2/train.py            |  6 +++---
 .../ASR/conv_emformer_transducer_stateless/train.py    |  6 +++---
 .../ASR/conv_emformer_transducer_stateless2/train.py   |  6 +++---
 egs/librispeech/ASR/lstm_transducer_stateless/train.py |  6 +++---
 .../ASR/lstm_transducer_stateless2/train.py            |  6 +++---
 egs/librispeech/ASR/prepare.sh                         |  5 +++++
 .../ASR/pruned_stateless_emformer_rnnt2/train.py       |  6 +++---
 .../ASR/pruned_transducer_stateless/train.py           |  6 +++---
 .../ASR/pruned_transducer_stateless2/train.py          |  6 +++---
 .../ASR/pruned_transducer_stateless3/train.py          |  6 +++---
 .../ASR/pruned_transducer_stateless4/train.py          |  6 +++---
 .../ASR/pruned_transducer_stateless5/train.py          |  6 +++---
 .../ASR/pruned_transducer_stateless6/train.py          |  6 +++---
 egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py    | 10 ++++++++++
 egs/librispeech/ASR/tdnn_lstm_ctc/train.py             |  8 ++++----
 egs/librispeech/ASR/transducer/train.py                |  6 +++---
 egs/librispeech/ASR/transducer_lstm/train.py           |  6 +++---
 egs/librispeech/ASR/transducer_stateless/train.py      |  6 +++---
 egs/librispeech/ASR/transducer_stateless2/train.py     |  6 +++---
 .../ASR/transducer_stateless_multi_datasets/train.py   |  6 +++---
 21 files changed, 73 insertions(+), 58 deletions(-)

diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py
index 1449bc310..99fe64793 100755
--- a/egs/librispeech/ASR/conformer_ctc/train.py
+++ b/egs/librispeech/ASR/conformer_ctc/train.py
@@ -687,10 +687,10 @@ def run(rank, world_size, args):
 
     librispeech = LibriSpeechAsrDataModule(args)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     def remove_short_and_long_utt(c: Cut):
         # Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/librispeech/ASR/conformer_ctc2/train.py b/egs/librispeech/ASR/conformer_ctc2/train.py
index ceea0c22c..121fdb256 100755
--- a/egs/librispeech/ASR/conformer_ctc2/train.py
+++ b/egs/librispeech/ASR/conformer_ctc2/train.py
@@ -928,10 +928,10 @@ def run(rank, world_size, args):
 
     librispeech = LibriSpeechAsrDataModule(args)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     def remove_short_and_long_utt(c: Cut):
         # Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py
index 213115854..6bb5505aa 100755
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py
@@ -970,10 +970,10 @@ def run(rank, world_size, args):
 
     librispeech = LibriSpeechAsrDataModule(args)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     def remove_short_and_long_utt(c: Cut):
         # Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py
index 6a019fd63..8462ae92a 100755
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py
@@ -970,10 +970,10 @@ def run(rank, world_size, args):
 
     librispeech = LibriSpeechAsrDataModule(args)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     def remove_short_and_long_utt(c: Cut):
         # Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/train.py b/egs/librispeech/ASR/lstm_transducer_stateless/train.py
index a54108f6d..feb81d500 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless/train.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless/train.py
@@ -954,10 +954,10 @@ def run(rank, world_size, args):
 
     librispeech = LibriSpeechAsrDataModule(args)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     def remove_short_and_long_utt(c: Cut):
         # Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py
index 8736384b4..4fc4fa7f8 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py
@@ -1108,10 +1108,10 @@ def run(rank, world_size, args):
 
     librispeech = LibriSpeech(manifest_dir=args.manifest_dir)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     train_cuts = filter_short_and_long_utterances(train_cuts, sp)
 
diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh
index 8668af0e4..542bbcdd8 100755
--- a/egs/librispeech/ASR/prepare.sh
+++ b/egs/librispeech/ASR/prepare.sh
@@ -123,6 +123,11 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
     touch data/fbank/.librispeech.done
   fi
 
+  cat <(gunzip -c data/fbank/librispeech_cuts_train-clean-100.jsonl.gz) \
+    <(gunzip -c data/fbank/librispeech_cuts_train-clean-360.jsonl.gz) \
+    <(gunzip -c data/fbank/librispeech_cuts_train-other-500.jsonl.gz) | \
+    shuf | gzip -c > data/fbank/librispeech_cuts_train-all-shuf.jsonl.gz
+
   if [ ! -e data/fbank/.librispeech-validated.done ]; then
     log "Validating data/fbank for LibriSpeech"
     parts=(
diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py
index ed3fa1521..3601e1e11 100755
--- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py
+++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py
@@ -882,10 +882,10 @@ def run(rank, world_size, args):
 
     librispeech = LibriSpeechAsrDataModule(args)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     def remove_short_and_long_utt(c: Cut):
         # Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py
index 4dabbccc1..cf4032027 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py
@@ -873,10 +873,10 @@ def run(rank, world_size, args):
 
     librispeech = LibriSpeechAsrDataModule(args)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     def remove_short_and_long_utt(c: Cut):
         # Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py
index 86333fc97..6c19f2cb0 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py
@@ -931,10 +931,10 @@ def run(rank, world_size, args):
 
     librispeech = LibriSpeechAsrDataModule(args)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     def remove_short_and_long_utt(c: Cut):
         # Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py
index 281ba4650..fdafa5a87 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py
@@ -1065,10 +1065,10 @@ def run(rank, world_size, args):
 
     librispeech = LibriSpeech(manifest_dir=args.manifest_dir)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     train_cuts = filter_short_and_long_utterances(train_cuts, sp)
 
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py
index cb56c8294..9bd7df401 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py
@@ -978,10 +978,10 @@ def run(rank, world_size, args):
 
     librispeech = LibriSpeechAsrDataModule(args)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     def remove_short_and_long_utt(c: Cut):
         # Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py
index 436620744..847c80ab0 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py
@@ -1009,10 +1009,10 @@ def run(rank, world_size, args):
 
     librispeech = LibriSpeechAsrDataModule(args)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     def remove_short_and_long_utt(c: Cut):
         # Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py
index 8f4d3b879..57753599a 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py
@@ -970,10 +970,10 @@ def run(rank, world_size, args):
 
     librispeech = LibriSpeechAsrDataModule(args)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     def remove_short_and_long_utt(c: Cut):
         # Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
index 95d1b273a..c5787835d 100644
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -414,6 +414,16 @@ class LibriSpeechAsrDataModule:
             self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
         )
 
+    @lru_cache()
+    def train_all_shuf_cuts(self) -> CutSet:
+        logging.info(
+            "About to get the shuffled train-clean-100, \
+            train-clean-360 and train-other-500 cuts"
+        )
+        return load_manifest_lazy(
+            self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
+        )
+
     @lru_cache()
     def dev_clean_cuts(self) -> CutSet:
         logging.info("About to get dev-clean cuts")
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py
index 071ac792b..0aa1587ba 100755
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py
@@ -173,7 +173,7 @@ def get_params() -> AttributeDict:
         {
             "exp_dir": Path("tdnn_lstm_ctc/exp"),
             "lang_dir": Path("data/lang_phone"),
-            "lr": 1e-3,
+            "lr": 1e-4,
             "feature_dim": 80,
             "weight_decay": 5e-4,
             "subsampling_factor": 3,
@@ -557,10 +557,10 @@ def run(rank, world_size, args):
 
     librispeech = LibriSpeechAsrDataModule(args)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     def remove_short_and_long_utt(c: Cut):
         # Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/librispeech/ASR/transducer/train.py b/egs/librispeech/ASR/transducer/train.py
index 674ea10a6..29625754e 100755
--- a/egs/librispeech/ASR/transducer/train.py
+++ b/egs/librispeech/ASR/transducer/train.py
@@ -614,10 +614,10 @@ def run(rank, world_size, args):
 
     librispeech = LibriSpeechAsrDataModule(args)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     def remove_short_and_long_utt(c: Cut):
         # Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/librispeech/ASR/transducer_lstm/train.py b/egs/librispeech/ASR/transducer_lstm/train.py
index 57bda63fd..792708bc0 100755
--- a/egs/librispeech/ASR/transducer_lstm/train.py
+++ b/egs/librispeech/ASR/transducer_lstm/train.py
@@ -620,10 +620,10 @@ def run(rank, world_size, args):
 
     librispeech = LibriSpeechAsrDataModule(args)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     def remove_short_and_long_utt(c: Cut):
         # Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py
index bcb883fa5..8db9b59e7 100755
--- a/egs/librispeech/ASR/transducer_stateless/train.py
+++ b/egs/librispeech/ASR/transducer_stateless/train.py
@@ -641,10 +641,10 @@ def run(rank, world_size, args):
     if params.print_diagnostics:
         diagnostic = diagnostics.attach_diagnostics(model)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     def remove_short_and_long_utt(c: Cut):
         # Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/librispeech/ASR/transducer_stateless2/train.py b/egs/librispeech/ASR/transducer_stateless2/train.py
index 68e247f23..1c3a33870 100755
--- a/egs/librispeech/ASR/transducer_stateless2/train.py
+++ b/egs/librispeech/ASR/transducer_stateless2/train.py
@@ -629,10 +629,10 @@ def run(rank, world_size, args):
     if params.print_diagnostics:
         diagnostic = diagnostics.attach_diagnostics(model)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     def remove_short_and_long_utt(c: Cut):
         # Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
index 88987d91c..dafccd088 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
@@ -752,10 +752,10 @@ def run(rank, world_size, args):
 
     librispeech = LibriSpeech(manifest_dir=args.manifest_dir)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     train_cuts = filter_short_and_long_utterances(train_cuts)
 

From 4fee3e7f1ea6c2aefe7594e325ede1e530e54d3d Mon Sep 17 00:00:00 2001
From: Guo Liyong 
Date: Mon, 28 Nov 2022 16:55:18 +0800
Subject: [PATCH 050/120] impove comment

---
 .../ASR/pruned_transducer_stateless7/optim.py | 63 +++++++++++++------
 .../ASR/pruned_transducer_stateless7/train.py | 12 +++-
 2 files changed, 54 insertions(+), 21 deletions(-)

diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py
index 790752fe1..ff8fbb32c 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py
@@ -64,13 +64,15 @@ class BatchedOptimizer(Optimizer):
         you can do:
         
           with self.batched_params(group["params"]) as batches:
-             for p, state in batches:
+             for p, state, p_names in batches:
                  ...
         
 
         Args:
           group: a parameter group, which is a list of parameters; should be
-                one of self.groups.
+                one of self.param_groups.
+          group_params_names: name for each parameter in group,
+                which is List[str].
         """
         batches = defaultdict(
             list
@@ -79,6 +81,7 @@ class BatchedOptimizer(Optimizer):
             list
         )  # `batches` maps from tuple (dtype_as_str,*shape) to list of str
 
+        assert len(param_group) == len(group_params_names)
         for p, named_p in zip(param_group, group_params_names):
             key = (str(p.dtype), *p.shape)
             batches[key].append(p)
@@ -94,9 +97,9 @@ class BatchedOptimizer(Optimizer):
         stacked_params_dict = dict()
 
         # turn batches into a list, in deterministic order.
-        # pairs will contain pairs of (stacked_param, state), one for each batch
-        # in `batches`.
-        pairs = []
+        # tuples will contain tuples of (stacked_param, state, stacked_params_names),
+        # one for each batch in `batches`.
+        tuples = []
 
         for batch, batch_names in zip(batches, batches_names):
             p = batch[0]
@@ -110,11 +113,11 @@ class BatchedOptimizer(Optimizer):
             )
             p_stacked.grad = grad
             stacked_params_dict[key] = p_stacked
-            pairs.append((p_stacked, state, batch_names))
+            tuples.append((p_stacked, state, batch_names))
 
-        yield pairs  # <-- calling code will do the actual optimization here!
+        yield tuples  # <-- calling code will do the actual optimization here!
 
-        for ((stacked_params, _state, _names), batch) in zip(pairs, batches):
+        for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
             for i, p in enumerate(batch):  # batch is list of Parameter
                 p.copy_(stacked_params[i])
 
@@ -179,6 +182,11 @@ class ScaledAdam(BatchedOptimizer):
         show_dominant_parameters=True,
     ):
 
+        assert parameters_names is not None, (
+            "Please prepare parameters_names,"
+            "which is a List[List[str]]. Each List[str] is for a group"
+            "and each str is for a parameter"
+        )
         defaults = dict(
             lr=lr,
             clipping_scale=clipping_scale,
@@ -193,6 +201,7 @@ class ScaledAdam(BatchedOptimizer):
         )
 
         super(ScaledAdam, self).__init__(params, defaults)
+        assert len(self.param_groups) == len(parameters_names)
         self.parameters_names = parameters_names
         self.show_dominant_parameters = show_dominant_parameters
 
@@ -213,7 +222,6 @@ class ScaledAdam(BatchedOptimizer):
                 loss = closure()
 
         batch = True
-        assert len(self.param_groups) == len(self.parameters_names)
 
         for group, group_params_names in zip(self.param_groups, self.parameters_names):
 
@@ -292,7 +300,7 @@ class ScaledAdam(BatchedOptimizer):
         state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
 
     def _get_clipping_scale(
-        self, group: dict, pairs: List[Tuple[Tensor, dict, List[str]]]
+        self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]
     ) -> float:
         """
         Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
@@ -300,12 +308,16 @@ class ScaledAdam(BatchedOptimizer):
 
         Args:
            group: the parameter group, an item in self.param_groups
-           pairs: a list of pairs of (param, state) where param is a batched set of parameters, with a .grad
-                (1st dim is batch dim) and state is the state-dict where optimization parameters are kept.
+           tuples: a list of tuples of (param, state, param_names)
+                where param is a batched set of parameters,
+                with a .grad (1st dim is batch dim)
+                and state is the state-dict where optimization parameters are kept.
+                param_names is a List[str] while each str is name for a parameter
+                in batched set of parameters "param".
         """
-        assert len(pairs) >= 1
+        assert len(tuples) >= 1
         clipping_scale = group["clipping_scale"]
-        (first_p, first_state, _) = pairs[0]
+        (first_p, first_state, _) = tuples[0]
         step = first_state["step"]
         if clipping_scale is None or step == 0:
             # no clipping.  return early on step == 0 because the other
@@ -314,7 +326,7 @@ class ScaledAdam(BatchedOptimizer):
         clipping_update_period = group["clipping_update_period"]
 
         tot_sumsq = torch.tensor(0.0, device=first_p.device)
-        for (p, state, param_names) in pairs:
+        for (p, state, param_names) in tuples:
             grad = p.grad
             if grad.is_sparse:
                 raise RuntimeError(
@@ -379,12 +391,27 @@ class ScaledAdam(BatchedOptimizer):
                 )
                 if self.show_dominant_parameters:
                     assert p.shape[0] == len(param_names)
-                    self._show_gradient_dominating_parameter(pairs, tot_sumsq)
+                    self._show_gradient_dominating_parameter(tuples, tot_sumsq)
             return ans
 
-    def _show_gradient_dominating_parameter(self, pairs, tot_sumsq):
+    def _show_gradient_dominating_parameter(
+        self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor
+    ):
+        """
+        Show information of parameter wihch dominanting tot_sumsq.
+
+        Args:
+           tuples: a list of tuples of (param, state, param_names)
+                where param is a batched set of parameters,
+                with a .grad (1st dim is batch dim)
+                and state is the state-dict where optimization parameters are kept.
+                param_names is a List[str] while each str is name for a parameter
+                in batched set of parameters "param".
+            tot_sumsq: sumsq of all parameters. Though it's could be calculated
+                from tuples, we still pass it to save some time.
+        """
         all_sumsq_orig = {}
-        for (p, state, batch_param_names) in pairs:
+        for (p, state, batch_param_names) in tuples:
             # p is a stacked batch parameters.
             batch_grad = p.grad
             if p.numel() == p.shape[0]:  # a batch of scalars
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py
index e5a3e68df..31a3a0505 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py
@@ -989,9 +989,15 @@ def run(rank, world_size, args):
         model = DDP(model, device_ids=[rank], find_unused_parameters=True)
 
     parameters_names = []
-    parameters_names.append([name_param_pair[0] for name_param_pair in model.named_parameters()])
-    optimizer = ScaledAdam(model.parameters(), lr=params.base_lr,
-            clipping_scale=2.0, parameters_names=parameters_names)
+    parameters_names.append(
+        [name_param_pair[0] for name_param_pair in model.named_parameters()]
+    )
+    optimizer = ScaledAdam(
+        model.parameters(),
+        lr=params.base_lr,
+        clipping_scale=2.0,
+        parameters_names=parameters_names,
+    )
 
     scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
 

From ece728d895c11545eb3232caa4f6a1c907192064 Mon Sep 17 00:00:00 2001
From: Zengwei Yao 
Date: Mon, 28 Nov 2022 22:34:02 +0800
Subject: [PATCH 051/120] Apply delay penalty on k2 ctc loss (#669)

* add init files

* fix bug, apply delay penalty

* fix decoding code and getting timestamps

* add option applying delay penalty on ctc log-prob

* fix bug of streaming decoding

* minor change for bpe-based case

* add test_model.py

* add README.md

* add CI
---
 .flake8                                       |    2 +-
 ...n-librispeech-conformer-ctc3-2022-11-28.sh |  119 ++
 ...-librispeech-conformer-ctc3-2022-11-28.yml |  151 +++
 egs/librispeech/ASR/RESULTS.md                |  102 +-
 .../ASR/conformer_ctc3/__init__.py            |    1 +
 .../ASR/conformer_ctc3/asr_datamodule.py      |    1 +
 .../ASR/conformer_ctc3/conformer.py           |    1 +
 egs/librispeech/ASR/conformer_ctc3/decode.py  | 1004 +++++++++++++++
 .../ASR/conformer_ctc3/encoder_interface.py   |    1 +
 egs/librispeech/ASR/conformer_ctc3/export.py  |  292 +++++
 .../ASR/conformer_ctc3/jit_pretrained.py      |  406 ++++++
 egs/librispeech/ASR/conformer_ctc3/lstmp.py   |    1 +
 egs/librispeech/ASR/conformer_ctc3/model.py   |  122 ++
 egs/librispeech/ASR/conformer_ctc3/optim.py   |    1 +
 .../ASR/conformer_ctc3/pretrained.py          |  458 +++++++
 egs/librispeech/ASR/conformer_ctc3/scaling.py |    1 +
 .../ASR/conformer_ctc3/scaling_converter.py   |    1 +
 .../ASR/conformer_ctc3/test_model.py          |   82 ++
 egs/librispeech/ASR/conformer_ctc3/train.py   | 1108 +++++++++++++++++
 icefall/bpe_graph_compiler.py                 |    5 +-
 icefall/char_graph_compiler.py                |    3 +-
 icefall/checkpoint.py                         |    2 +-
 icefall/graph_compiler.py                     |    4 +
 icefall/utils.py                              |   51 +-
 24 files changed, 3876 insertions(+), 43 deletions(-)
 create mode 100755 .github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh
 create mode 100644 .github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml
 create mode 120000 egs/librispeech/ASR/conformer_ctc3/__init__.py
 create mode 120000 egs/librispeech/ASR/conformer_ctc3/asr_datamodule.py
 create mode 120000 egs/librispeech/ASR/conformer_ctc3/conformer.py
 create mode 100755 egs/librispeech/ASR/conformer_ctc3/decode.py
 create mode 120000 egs/librispeech/ASR/conformer_ctc3/encoder_interface.py
 create mode 100755 egs/librispeech/ASR/conformer_ctc3/export.py
 create mode 100755 egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py
 create mode 120000 egs/librispeech/ASR/conformer_ctc3/lstmp.py
 create mode 100644 egs/librispeech/ASR/conformer_ctc3/model.py
 create mode 120000 egs/librispeech/ASR/conformer_ctc3/optim.py
 create mode 100755 egs/librispeech/ASR/conformer_ctc3/pretrained.py
 create mode 120000 egs/librispeech/ASR/conformer_ctc3/scaling.py
 create mode 120000 egs/librispeech/ASR/conformer_ctc3/scaling_converter.py
 create mode 100755 egs/librispeech/ASR/conformer_ctc3/test_model.py
 create mode 100755 egs/librispeech/ASR/conformer_ctc3/train.py

diff --git a/.flake8 b/.flake8
index 609fa2c03..a0f44263c 100644
--- a/.flake8
+++ b/.flake8
@@ -11,7 +11,7 @@ per-file-ignores =
     egs/*/ASR/*/scaling.py: E501,
     egs/librispeech/ASR/lstm_transducer_stateless*/*.py: E501, E203
     egs/librispeech/ASR/conv_emformer_transducer_stateless*/*.py: E501, E203
-    egs/librispeech/ASR/conformer_ctc2/*py: E501,
+    egs/librispeech/ASR/conformer_ctc*/*py: E501,
     egs/librispeech/ASR/RESULTS.md: E999,
 
     # invalid escape sequence (cause by tex formular), W605
diff --git a/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh b/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh
new file mode 100755
index 000000000..27944807f
--- /dev/null
+++ b/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh
@@ -0,0 +1,119 @@
+#!/usr/bin/env bash
+
+set -e
+
+log() {
+  # This function is from espnet
+  local fname=${BASH_SOURCE[1]##*/}
+  echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+cd egs/librispeech/ASR
+
+repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conformer-ctc3-2022-11-27
+
+log "Downloading pre-trained model from $repo_url"
+git lfs install
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+repo=$(basename $repo_url)
+
+log "Display test files"
+tree $repo/
+soxi $repo/test_wavs/*.wav
+ls -lh $repo/test_wavs/*.wav
+
+pushd $repo/exp
+git lfs pull --include "data/*"
+git lfs pull --include "exp/jit_trace.pt"
+git lfs pull --include "exp/pretrained.pt"
+ln -s pretrained.pt epoch-99.pt
+ls -lh *.pt
+popd
+
+log "Decode with models exported by torch.jit.trace()"
+
+for m in ctc-decoding 1best; do
+  ./conformer_ctc3/jit_pretrained.py \
+    --model-filename $repo/exp/jit_trace.pt \
+    --words-file $repo/data/lang_bpe_500/words.txt  \
+    --HLG $repo/data/lang_bpe_500/HLG.pt \
+    --bpe-model $repo/data/lang_bpe_500/bpe.model \
+    --G $repo/data/lm/G_4_gram.pt \
+    --method $m \
+    --sample-rate 16000 \
+    $repo/test_wavs/1089-134686-0001.wav \
+    $repo/test_wavs/1221-135766-0001.wav \
+    $repo/test_wavs/1221-135766-0002.wav
+done
+
+log "Export to torchscript model"
+
+./conformer_ctc3/export.py \
+  --exp-dir $repo/exp \
+  --lang-dir $repo/data/lang_bpe_500 \
+  --jit-trace 1 \
+  --epoch 99 \
+  --avg 1 \
+  --use-averaged-model 0
+
+ls -lh $repo/exp/*.pt
+
+log "Decode with models exported by torch.jit.trace()"
+
+for m in ctc-decoding 1best; do
+  ./conformer_ctc3/jit_pretrained.py \
+    --model-filename $repo/exp/jit_trace.pt \
+    --words-file $repo/data/lang_bpe_500/words.txt  \
+    --HLG $repo/data/lang_bpe_500/HLG.pt \
+    --bpe-model $repo/data/lang_bpe_500/bpe.model \
+    --G $repo/data/lm/G_4_gram.pt \
+    --method $m \
+    --sample-rate 16000 \
+    $repo/test_wavs/1089-134686-0001.wav \
+    $repo/test_wavs/1221-135766-0001.wav \
+    $repo/test_wavs/1221-135766-0002.wav
+done
+
+for m in ctc-decoding 1best; do
+  ./conformer_ctc3/pretrained.py \
+    --checkpoint $repo/exp/pretrained.pt \
+    --words-file $repo/data/lang_bpe_500/words.txt  \
+    --HLG $repo/data/lang_bpe_500/HLG.pt \
+    --bpe-model $repo/data/lang_bpe_500/bpe.model \
+    --G $repo/data/lm/G_4_gram.pt \
+    --method $m \
+    --sample-rate 16000 \
+    $repo/test_wavs/1089-134686-0001.wav \
+    $repo/test_wavs/1221-135766-0001.wav \
+    $repo/test_wavs/1221-135766-0002.wav
+done
+
+echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
+echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
+if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode"  ]]; then
+  mkdir -p conformer_ctc3/exp
+  ln -s $PWD/$repo/exp/pretrained.pt conformer_ctc3/exp/epoch-999.pt
+  ln -s $PWD/$repo/data/lang_bpe_500 data/
+
+  ls -lh data
+  ls -lh conformer_ctc3/exp
+
+  log "Decoding test-clean and test-other"
+
+  # use a small value for decoding with CPU
+  max_duration=100
+
+  for method in ctc-decoding 1best; do
+    log "Decoding with $method"
+    ./conformer_ctc3/decode.py \
+      --epoch 999 \
+      --avg 1 \
+      --use-averaged-model 0 \
+      --exp-dir conformer_ctc3/exp/ \
+      --max-duration $max_duration \
+      --decoding-method $method \
+      --lm-dir data/lm
+  done
+
+  rm conformer_ctc3/exp/*.pt
+fi
diff --git a/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml b/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml
new file mode 100644
index 000000000..21f396c32
--- /dev/null
+++ b/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml
@@ -0,0 +1,151 @@
+# Copyright      2022  Fangjun Kuang (csukuangfj@gmail.com)
+
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: run-librispeech-conformer-ctc3-2022-11-28
+# zipformer
+
+on:
+  push:
+    branches:
+      - master
+  pull_request:
+    types: [labeled]
+
+  schedule:
+    # minute (0-59)
+    # hour (0-23)
+    # day of the month (1-31)
+    # month (1-12)
+    # day of the week (0-6)
+    # nightly build at 15:50 UTC time every day
+    - cron: "50 15 * * *"
+
+jobs:
+  run_librispeech_2022_11_28_conformer_ctc3:
+    if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
+    runs-on: ${{ matrix.os }}
+    strategy:
+      matrix:
+        os: [ubuntu-latest]
+        python-version: [3.8]
+
+      fail-fast: false
+
+    steps:
+      - uses: actions/checkout@v2
+        with:
+          fetch-depth: 0
+
+      - name: Setup Python ${{ matrix.python-version }}
+        uses: actions/setup-python@v2
+        with:
+          python-version: ${{ matrix.python-version }}
+          cache: 'pip'
+          cache-dependency-path: '**/requirements-ci.txt'
+
+      - name: Install Python dependencies
+        run: |
+          grep -v '^#' ./requirements-ci.txt  | xargs -n 1 -L 1 pip install
+          pip uninstall -y protobuf
+          pip install --no-binary protobuf protobuf
+
+      - name: Cache kaldifeat
+        id: my-cache
+        uses: actions/cache@v2
+        with:
+          path: |
+            ~/tmp/kaldifeat
+          key: cache-tmp-${{ matrix.python-version }}-2022-09-25
+
+      - name: Install kaldifeat
+        if: steps.my-cache.outputs.cache-hit != 'true'
+        shell: bash
+        run: |
+          .github/scripts/install-kaldifeat.sh
+
+      - name: Cache LibriSpeech test-clean and test-other datasets
+        id: libri-test-clean-and-test-other-data
+        uses: actions/cache@v2
+        with:
+          path: |
+            ~/tmp/download
+          key: cache-libri-test-clean-and-test-other
+
+      - name: Download LibriSpeech test-clean and test-other
+        if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
+        shell: bash
+        run: |
+          .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
+
+      - name: Prepare manifests for LibriSpeech test-clean and test-other
+        shell: bash
+        run: |
+          .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
+
+      - name: Cache LibriSpeech test-clean and test-other fbank features
+        id: libri-test-clean-and-test-other-fbank
+        uses: actions/cache@v2
+        with:
+          path: |
+            ~/tmp/fbank-libri
+          key: cache-libri-fbank-test-clean-and-test-other-v2
+
+      - name: Compute fbank for LibriSpeech test-clean and test-other
+        if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
+        shell: bash
+        run: |
+          .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
+
+      - name: Inference with pre-trained model
+        shell: bash
+        env:
+          GITHUB_EVENT_NAME: ${{ github.event_name }}
+          GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
+        run: |
+          mkdir -p egs/librispeech/ASR/data
+          ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
+          ls -lh egs/librispeech/ASR/data/*
+
+          sudo apt-get -qq install git-lfs tree sox
+          export PYTHONPATH=$PWD:$PYTHONPATH
+          export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
+          export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
+
+          .github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh
+
+      - name: Display decoding results for librispeech conformer_ctc3
+        if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+        shell: bash
+        run: |
+          cd egs/librispeech/ASR/
+          tree ./conformer_ctc3/exp
+
+          cd conformer_ctc3
+          echo "results for conformer_ctc3"
+          echo "===ctc-decoding==="
+          find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+          find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+          echo "===1best==="
+          find exp/1best -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+          find exp/1best -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+      - name: Upload decoding results for librispeech conformer_ctc3
+        uses: actions/upload-artifact@v2
+        if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+        with:
+          name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-conformer_ctc3-2022-11-28
+          path: egs/librispeech/ASR/conformer_ctc3/exp/
diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md
index 030e47b86..efd60ba81 100644
--- a/egs/librispeech/ASR/RESULTS.md
+++ b/egs/librispeech/ASR/RESULTS.md
@@ -1,5 +1,106 @@
 ## Results
 
+### LibriSpeech BPE training results (Conformer CTC, supporting delay penalty)
+
+#### [conformer_ctc3](./conformer_ctc3)
+
+It implements Conformer model training with CTC loss.
+For streaming mode, it supports symbol delay penalty.
+
+See  for more details.
+
+##### training on full librispeech
+
+This model contains 12 encoder layers. The number of model parameters is 77352694.
+
+The WERs are:
+
+|                                     | test-clean | test-other | comment              |
+|-------------------------------------|------------|------------|----------------------|
+| ctc-decoding                        | 3.09       | 7.62       | --epoch 25 --avg 7   |
+| 1best                               | 2.87       | 6.44       | --epoch 25 --avg 7   |
+| nbest                               | 2.88       | 6.5        | --epoch 25 --avg 7   |
+| nbest-rescoring                     | 2.71       | 6.1        | --epoch 25 --avg 7   |
+| whole-lattice-rescoring             | 2.71       | 6.04       | --epoch 25 --avg 7   |
+
+The training command is:
+
+```bash
+./conformer_ctc3/train.py \
+  --world-size 4 \
+  --num-epochs 25 \
+  --start-epoch 1 \
+  --exp-dir conformer_ctc3/full \
+  --full-libri 1 \
+  --max-duration 300 \
+  --master-port 12345
+```
+
+The tensorboard log can be found at
+
+
+The decoding command using different methods is:
+```bash
+for method in ctc-decoding 1best nbest nbest-rescoring whole-lattice-rescoring; do
+  ./conformer_ctc3/decode.py \
+    --epoch 25 \
+    --avg 7 \
+    --exp-dir conformer_ctc3/exp \
+    --max-duration 300 \
+    --decoding-method $method \
+    --manifest-dir data/fbank \
+    --lm-dir data/lm \
+done
+```
+
+Pretrained models, training logs, decoding logs, and decoding results
+are available at
+
+
+The command to train a streaming model with symbol delay penalty is:
+```bash
+./conformer_ctc3/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --exp-dir conformer_ctc3/exp \
+  --full-libri 1 \
+  --dynamic-chunk-training 1 \
+  --causal-convolution 1 \
+  --short-chunk-size 25 \
+  --num-left-chunks 4 \
+  --max-duration 300 \
+  --delay-penalty 0.1
+```
+To evaluate symbol delay, you should:
+(1) Generate cuts with word-time alignments:
+```bash
+./local/add_alignment_librispeech.py \
+  --alignments-dir data/alignment \
+  --cuts-in-dir data/fbank \
+  --cuts-out-dir data/fbank_ali
+```
+(2) Set the argument "--manifest-dir data/fbank_ali" while decoding.
+For example:
+```bash
+./conformer_ctc3/decode.py \
+  --epoch 25 \
+  --avg 7 \
+  --exp-dir ./conformer_ctc3/exp \
+  --max-duration 300 \
+  --decoding-method ctc-decoding \
+  --simulate-streaming 1 \
+  --causal-convolution 1 \
+  --decode-chunk-size 16 \
+  --left-context 64 \
+  --manifest-dir data/fbank_ali
+```
+Note: It supports to calculate symbol delay with following decoding methods:
+  - ctc-greedy-search
+  - ctc-decoding
+  - 1best
+
+
 ### pruned_transducer_stateless8 (zipformer + multidataset)
 
 See  for more details.
@@ -115,7 +216,6 @@ done
 ```
 
 
-
 ### LibriSpeech BPE training results (Pruned Stateless LSTM RNN-T + gradient filter)
 
 #### [lstm_transducer_stateless3](./lstm_transducer_stateless3)
diff --git a/egs/librispeech/ASR/conformer_ctc3/__init__.py b/egs/librispeech/ASR/conformer_ctc3/__init__.py
new file mode 120000
index 000000000..b24e5e357
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc3/__init__.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/__init__.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/conformer_ctc3/asr_datamodule.py b/egs/librispeech/ASR/conformer_ctc3/asr_datamodule.py
new file mode 120000
index 000000000..a074d6085
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc3/asr_datamodule.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/asr_datamodule.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/conformer_ctc3/conformer.py b/egs/librispeech/ASR/conformer_ctc3/conformer.py
new file mode 120000
index 000000000..3b84b9573
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc3/conformer.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/conformer.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/conformer_ctc3/decode.py b/egs/librispeech/ASR/conformer_ctc3/decode.py
new file mode 100755
index 000000000..8eca2ae02
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc3/decode.py
@@ -0,0 +1,1004 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
+#                                                 Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) decode in non-streaming mode (take ctc-decoding as an example)
+./conformer_ctc3/decode.py \
+    --epoch 30 \
+    --avg 15 \
+    --exp-dir ./conformer_ctc3/exp \
+    --max-duration 600 \
+    --decoding-method ctc-decoding
+
+(2) decode in streaming mode (take ctc-decoding as an example)
+./conformer_ctc3/decode.py \
+    --epoch 30 \
+    --avg 15 \
+    --simulate-streaming 1 \
+    --causal-convolution 1 \
+    --decode-chunk-size 16 \
+    --left-context 64 \
+    --exp-dir ./conformer_ctc3/exp \
+    --max-duration 600 \
+    --decoding-method ctc-decoding
+
+To evaluate symbol delay, you should:
+(1) Generate cuts with word-time alignments:
+./local/add_alignment_librispeech.py \
+    --alignments-dir data/alignment \
+    --cuts-in-dir data/fbank \
+    --cuts-out-dir data/fbank_ali
+(2) Set the argument "--manifest-dir data/fbank_ali" while decoding.
+For example:
+./conformer_ctc3/decode.py \
+    --epoch 30 \
+    --avg 15 \
+    --exp-dir ./conformer_ctc3/exp \
+    --max-duration 600 \
+    --decoding-method ctc-decoding \
+    --simulate-streaming 1 \
+    --causal-convolution 1 \
+    --decode-chunk-size 16 \
+    --left-context 64 \
+    --manifest-dir data/fbank_ali
+Note: It supports calculating symbol delay with following decoding methods:
+    - ctc-greedy-search
+    - ctc-decoding
+    - 1best
+"""
+
+
+import argparse
+import logging
+import math
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from train import add_model_arguments, get_ctc_model, get_params
+
+from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.decode import (
+    get_lattice,
+    nbest_decoding,
+    nbest_oracle,
+    one_best_decoding,
+    rescore_with_n_best_list,
+    rescore_with_whole_lattice,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+    AttributeDict,
+    DecodingResults,
+    get_texts,
+    get_texts_with_timestamp,
+    make_pad_mask,
+    parse_hyp_and_timestamp,
+    setup_logger,
+    store_transcripts_and_timestamps,
+    str2bool,
+    write_error_stats_with_timestamps,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=30,
+        help="""It specifies the checkpoint to use for decoding.
+        Note: Epoch counts from 1.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=15,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=True,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless4/exp",
+        help="The experiment dir",
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=Path,
+        default="data/lang_bpe_500",
+        help="The lang dir containing word table and LG graph",
+    )
+
+    parser.add_argument(
+        "--decoding-method",
+        type=str,
+        default="ctc-decoding",
+        help="""Decoding method.
+        Supported values are:
+        - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece
+          model, i.e., lang_dir/bpe.model, to convert word pieces to words.
+          It needs neither a lexicon nor an n-gram LM.
+        - (1) ctc-greedy-search. It only use CTC output and a sentence piece
+          model for decoding. It produces the same results with ctc-decoding.
+        - (2) 1best. Extract the best path from the decoding lattice as the
+          decoding result.
+        - (3) nbest. Extract n paths from the decoding lattice; the path
+          with the highest score is the decoding result.
+        - (4) nbest-rescoring. Extract n paths from the decoding lattice,
+          rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
+          the highest score is the decoding result.
+        - (5) whole-lattice-rescoring. Rescore the decoding lattice with an
+          n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
+          is the decoding result.
+          you have trained an RNN LM using ./rnn_lm/train.py
+        - (6) nbest-oracle. Its WER is the lower bound of any n-best
+          rescoring method can achieve. Useful for debugging n-best
+          rescoring method.
+        """,
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=100,
+        help="""Number of paths for n-best based decoding method.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, and nbest-oracle
+        """,
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=0.5,
+        help="""The scale to be applied to `lattice.scores`.
+        It's needed if you use any kinds of n-best based rescoring.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, and nbest-oracle
+        A smaller value results in more unique paths.
+        """,
+    )
+
+    parser.add_argument(
+        "--lm-dir",
+        type=str,
+        default="data/lm",
+        help="""The n-gram LM dir.
+        It should contain either G_4_gram.pt or G_4_gram.fst.txt
+        """,
+    )
+
+    parser.add_argument(
+        "--simulate-streaming",
+        type=str2bool,
+        default=False,
+        help="""Whether to simulate streaming in decoding, this is a good way to
+        test a streaming model.
+        """,
+    )
+
+    parser.add_argument(
+        "--decode-chunk-size",
+        type=int,
+        default=16,
+        help="The chunk size for decoding (in frames after subsampling)",
+    )
+
+    parser.add_argument(
+        "--left-context",
+        type=int,
+        default=64,
+        help="left context can be seen during decoding (in frames after subsampling)",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def get_decoding_params() -> AttributeDict:
+    """Parameters for decoding."""
+    params = AttributeDict(
+        {
+            "frame_shift_ms": 10,
+            "search_beam": 20,
+            "output_beam": 8,
+            "min_active_states": 30,
+            "max_active_states": 10000,
+            "use_double_scores": True,
+        }
+    )
+    return params
+
+
+def ctc_greedy_search(
+    ctc_probs: torch.Tensor,
+    nnet_output_lens: torch.Tensor,
+) -> List[List[int]]:
+    """Apply CTC greedy search
+    Args:
+      ctc_probs (torch.Tensor): (batch, max_len, feat_dim)
+      nnet_output_lens (torch.Tensor): (batch, )
+    Returns:
+      List[List[int]]: best path result
+    """
+    topk_prob, topk_index = ctc_probs.topk(1, dim=2)  # (B, maxlen, 1)
+    topk_index = topk_index.squeeze(2)  # (B, maxlen)
+    mask = make_pad_mask(nnet_output_lens)
+    topk_index = topk_index.masked_fill_(mask, 0)  # (B, maxlen)
+    hyps = [hyp.tolist() for hyp in topk_index]
+    scores = topk_prob.max(1)
+    ret_hyps = []
+    timestamps = []
+    for i in range(len(hyps)):
+        hyp, time = remove_duplicates_and_blank(hyps[i])
+        ret_hyps.append(hyp)
+        timestamps.append(time)
+    return ret_hyps, timestamps, scores
+
+
+def remove_duplicates_and_blank(hyp: List[int]) -> Tuple[List[int], List[int]]:
+    # modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py
+    new_hyp: List[int] = []
+    time: List[int] = []
+    cur = 0
+    while cur < len(hyp):
+        if hyp[cur] != 0:
+            new_hyp.append(hyp[cur])
+            time.append(cur)
+        prev = cur
+        while cur < len(hyp) and hyp[cur] == hyp[prev]:
+            cur += 1
+    return new_hyp, time
+
+
+def decode_one_batch(
+    params: AttributeDict,
+    model: nn.Module,
+    HLG: Optional[k2.Fsa],
+    H: Optional[k2.Fsa],
+    bpe_model: Optional[spm.SentencePieceProcessor],
+    batch: dict,
+    word_table: k2.SymbolTable,
+    sos_id: int,
+    eos_id: int,
+    G: Optional[k2.Fsa] = None,
+) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]:
+    """Decode one batch and return the result in a dict. The dict has the
+    following format:
+    - key: It indicates the setting used for decoding. For example,
+           if no rescoring is used, the key is the string `no_rescore`.
+           If LM rescoring is used, the key is the string `lm_scale_xxx`,
+           where `xxx` is the value of `lm_scale`. An example key is
+           `lm_scale_0.7`
+    - value: It contains the decoding result. `len(value)` equals to
+             batch size. `value[i]` is the decoding result for the i-th
+             utterance in the given batch.
+
+    Args:
+      params:
+        It's the return value of :func:`get_params`.
+
+        - params.decoding_method is "1best", it uses 1best decoding without LM rescoring.
+        - params.decoding_method is "nbest", it uses nbest decoding without LM rescoring.
+        - params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring.
+        - params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM
+          rescoring.
+
+      model:
+        The neural model.
+      HLG:
+        The decoding graph. Used only when params.decoding_method is NOT ctc-decoding.
+      H:
+        The ctc topo. Used only when params.decoding_method is ctc-decoding.
+      bpe_model:
+        The BPE model. Used only when params.decoding_method is ctc-decoding.
+      batch:
+        It is the return value from iterating
+        `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+        for the format of the `batch`.
+      word_table:
+        The word symbol table.
+      sos_id:
+        The token ID of the SOS.
+      eos_id:
+        The token ID of the EOS.
+      G:
+        An LM. It is not None when params.decoding_method is "nbest-rescoring"
+        or "whole-lattice-rescoring". In general, the G in HLG
+        is a 3-gram LM, while this G is a 4-gram LM.
+    Returns:
+      Return the decoding result. See above description for the format of
+      the returned dict. Note: If it decodes to nothing, then return None.
+    """
+    if HLG is not None:
+        device = HLG.device
+    else:
+        device = H.device
+    feature = batch["inputs"]
+    assert feature.ndim == 3
+    feature = feature.to(device)
+    # at entry, feature is (N, T, C)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    if params.simulate_streaming:
+        feature_lens += params.left_context
+        feature = torch.nn.functional.pad(
+            feature,
+            pad=(0, 0, 0, params.left_context),
+            value=LOG_EPS,
+        )
+        encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
+            x=feature,
+            x_lens=feature_lens,
+            chunk_size=params.decode_chunk_size,
+            left_context=params.left_context,
+            simulate_streaming=True,
+        )
+    else:
+        encoder_out, encoder_out_lens = model.encoder(feature, feature_lens)
+
+    nnet_output = model.get_ctc_output(encoder_out)
+    # nnet_output is (N, T, C)
+
+    if params.decoding_method == "ctc-greedy-search":
+        hyps, timestamps, _ = ctc_greedy_search(
+            nnet_output,
+            encoder_out_lens,
+        )
+        res = DecodingResults(hyps=hyps, timestamps=timestamps)
+        hyps, timestamps = parse_hyp_and_timestamp(
+            res=res,
+            sp=bpe_model,
+            subsampling_factor=params.subsampling_factor,
+            frame_shift_ms=params.frame_shift_ms,
+        )
+        key = "ctc-greedy-search"
+        return {key: (hyps, timestamps)}
+
+    supervision_segments = torch.stack(
+        (
+            supervisions["sequence_idx"],
+            supervisions["start_frame"] // params.subsampling_factor,
+            supervisions["num_frames"] // params.subsampling_factor,
+        ),
+        1,
+    ).to(torch.int32)
+
+    if H is None:
+        assert HLG is not None
+        decoding_graph = HLG
+    else:
+        assert HLG is None
+        assert bpe_model is not None
+        decoding_graph = H
+
+    if params.decoding_method in ["1best", "nbest", "nbest-oracle"]:
+        hlg_scale_list = [0.2, 0.4, 0.6, 0.8, 1.0]
+
+        ori_scores = decoding_graph.scores.clone()
+
+        ans = {}
+        for hlg_scale in hlg_scale_list:
+            decoding_graph.scores = ori_scores * hlg_scale
+            lattice = get_lattice(
+                nnet_output=nnet_output,
+                decoding_graph=decoding_graph,
+                supervision_segments=supervision_segments,
+                search_beam=params.search_beam,
+                output_beam=params.output_beam,
+                min_active_states=params.min_active_states,
+                max_active_states=params.max_active_states,
+                subsampling_factor=params.subsampling_factor,
+            )
+            key_suffix = f"-HLG-scale-{hlg_scale}"
+
+            if params.decoding_method == "nbest-oracle":
+                # Note: You can also pass rescored lattices to it.
+                # We choose the HLG decoded lattice for speed reasons
+                # as HLG decoding is faster and the oracle WER
+                # is only slightly worse than that of rescored lattices.
+                best_path = nbest_oracle(
+                    lattice=lattice,
+                    num_paths=params.num_paths,
+                    ref_texts=supervisions["text"],
+                    word_table=word_table,
+                    nbest_scale=params.nbest_scale,
+                    oov="",
+                )
+                hyps = get_texts(best_path)
+                hyps = [[word_table[i] for i in ids] for ids in hyps]
+                key = f"oracle-{params.num_paths}-nbest-scale-{params.nbest_scale}"  # noqa
+                timestamps = [[] for _ in range(len(hyps))]
+                ans[key + key_suffix] = (hyps, timestamps)
+
+            elif params.decoding_method in ["1best", "nbest"]:
+                if params.decoding_method == "1best":
+                    best_path = one_best_decoding(
+                        lattice=lattice,
+                        use_double_scores=params.use_double_scores,
+                    )
+                    key = "no-rescore"
+                    res = get_texts_with_timestamp(best_path)
+                    hyps, timestamps = parse_hyp_and_timestamp(
+                        res=res,
+                        subsampling_factor=params.subsampling_factor,
+                        frame_shift_ms=params.frame_shift_ms,
+                        word_table=word_table,
+                    )
+                else:
+                    best_path = nbest_decoding(
+                        lattice=lattice,
+                        num_paths=params.num_paths,
+                        use_double_scores=params.use_double_scores,
+                        nbest_scale=params.nbest_scale,
+                    )
+                    key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}"  # noqa
+                    hyps = get_texts(best_path)
+                    hyps = [[word_table[i] for i in ids] for ids in hyps]
+                    timestamps = [[] for _ in range(len(hyps))]
+
+                ans[key + key_suffix] = (hyps, timestamps)
+
+        return ans
+
+    lattice = get_lattice(
+        nnet_output=nnet_output,
+        decoding_graph=decoding_graph,
+        supervision_segments=supervision_segments,
+        search_beam=params.search_beam,
+        output_beam=params.output_beam,
+        min_active_states=params.min_active_states,
+        max_active_states=params.max_active_states,
+        subsampling_factor=params.subsampling_factor,
+    )
+
+    if params.decoding_method == "ctc-decoding":
+        best_path = one_best_decoding(
+            lattice=lattice, use_double_scores=params.use_double_scores
+        )
+        # Note: `best_path.aux_labels` contains token IDs, not word IDs
+        # since we are using H, not HLG here.
+        #
+        # token_ids is a lit-of-list of IDs
+        res = get_texts_with_timestamp(best_path)
+        hyps, timestamps = parse_hyp_and_timestamp(
+            res=res,
+            sp=bpe_model,
+            subsampling_factor=params.subsampling_factor,
+            frame_shift_ms=params.frame_shift_ms,
+        )
+        key = "ctc-decoding"
+        return {key: (hyps, timestamps)}
+
+    assert params.decoding_method in [
+        "nbest-rescoring",
+        "whole-lattice-rescoring",
+    ]
+
+    lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
+    lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
+    lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
+
+    if params.decoding_method == "nbest-rescoring":
+        best_path_dict = rescore_with_n_best_list(
+            lattice=lattice,
+            G=G,
+            num_paths=params.num_paths,
+            lm_scale_list=lm_scale_list,
+            nbest_scale=params.nbest_scale,
+        )
+    elif params.decoding_method == "whole-lattice-rescoring":
+        best_path_dict = rescore_with_whole_lattice(
+            lattice=lattice,
+            G_with_epsilon_loops=G,
+            lm_scale_list=lm_scale_list,
+        )
+    else:
+        assert False, f"Unsupported decoding method: {params.decoding_method}"
+
+    ans = dict()
+    if best_path_dict is not None:
+        for lm_scale_str, best_path in best_path_dict.items():
+            hyps = get_texts(best_path)
+            hyps = [[word_table[i] for i in ids] for ids in hyps]
+            timestamps = [[] for _ in range(len(hyps))]
+            ans[lm_scale_str] = (hyps, timestamps)
+    else:
+        ans = None
+    return ans
+
+
+def decode_dataset(
+    dl: torch.utils.data.DataLoader,
+    params: AttributeDict,
+    model: nn.Module,
+    HLG: Optional[k2.Fsa],
+    H: Optional[k2.Fsa],
+    bpe_model: Optional[spm.SentencePieceProcessor],
+    word_table: k2.SymbolTable,
+    sos_id: int,
+    eos_id: int,
+    G: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]:
+    """Decode dataset.
+
+    Args:
+      dl:
+        PyTorch's dataloader containing the dataset to decode.
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The neural model.
+      HLG:
+        The decoding graph. Used only when params.decoding_method is NOT ctc-decoding.
+      H:
+        The ctc topo. Used only when params.decoding_method is ctc-decoding.
+      bpe_model:
+        The BPE model. Used only when params.decoding_method is ctc-decoding.
+      word_table:
+        It is the word symbol table.
+      sos_id:
+        The token ID for SOS.
+      eos_id:
+        The token ID for EOS.
+      G:
+        An LM. It is not None when params.decoding_method is "nbest-rescoring"
+        or "whole-lattice-rescoring". In general, the G in HLG
+        is a 3-gram LM, while this G is a 4-gram LM.
+    Returns:
+      Return a dict, whose key may be "no-rescore" if no LM rescoring
+      is used, or it may be "lm_scale_0.7" if LM rescoring is used.
+      Its value is a list of tuples. Each tuple contains two elements:
+      The first is the reference transcript, and the second is the
+      predicted result.
+    """
+    num_cuts = 0
+
+    try:
+        num_batches = len(dl)
+    except TypeError:
+        num_batches = "?"
+
+    results = defaultdict(list)
+    for batch_idx, batch in enumerate(dl):
+        texts = batch["supervisions"]["text"]
+        cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+        timestamps_ref = []
+        for cut in batch["supervisions"]["cut"]:
+            for s in cut.supervisions:
+                time = []
+                if s.alignment is not None and "word" in s.alignment:
+                    time = [
+                        aliword.start
+                        for aliword in s.alignment["word"]
+                        if aliword.symbol != ""
+                    ]
+                timestamps_ref.append(time)
+
+        hyps_dict = decode_one_batch(
+            params=params,
+            model=model,
+            HLG=HLG,
+            H=H,
+            bpe_model=bpe_model,
+            batch=batch,
+            word_table=word_table,
+            G=G,
+            sos_id=sos_id,
+            eos_id=eos_id,
+        )
+
+        for name, (hyps, timestamps_hyp) in hyps_dict.items():
+            this_batch = []
+            assert len(hyps) == len(texts) and len(timestamps_hyp) == len(
+                timestamps_ref
+            )
+            for cut_id, hyp_words, ref_text, time_hyp, time_ref in zip(
+                cut_ids, hyps, texts, timestamps_hyp, timestamps_ref
+            ):
+                ref_words = ref_text.split()
+                this_batch.append((cut_id, ref_words, hyp_words, time_ref, time_hyp))
+
+            results[name].extend(this_batch)
+
+        num_cuts += len(texts)
+
+        if batch_idx % 100 == 0:
+            batch_str = f"{batch_idx}/{num_batches}"
+
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+    return results
+
+
+def save_results(
+    params: AttributeDict,
+    test_set_name: str,
+    results_dict: Dict[
+        str,
+        List[Tuple[List[str], List[str], List[str], List[float], List[float]]],
+    ],
+):
+    test_set_wers = dict()
+    test_set_delays = dict()
+    for key, results in results_dict.items():
+        recog_path = (
+            params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+        )
+        results = sorted(results)
+        store_transcripts_and_timestamps(filename=recog_path, texts=results)
+        logging.info(f"The transcripts are stored in {recog_path}")
+
+        # The following prints out WERs, per-word error statistics and aligned
+        # ref/hyp pairs.
+        errs_filename = (
+            params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+        )
+        with open(errs_filename, "w") as f:
+            wer, mean_delay, var_delay = write_error_stats_with_timestamps(
+                f, f"{test_set_name}-{key}", results, enable_log=True
+            )
+            test_set_wers[key] = wer
+            test_set_delays[key] = (mean_delay, var_delay)
+
+        logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+    test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+    errs_info = (
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+    )
+    with open(errs_info, "w") as f:
+        print("settings\tWER", file=f)
+        for key, val in test_set_wers:
+            print("{}\t{}".format(key, val), file=f)
+
+    test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0])
+    delays_info = (
+        params.res_dir
+        / f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt"
+    )
+    with open(delays_info, "w") as f:
+        print("settings\tsymbol-delay", file=f)
+        for key, val in test_set_delays:
+            print(
+                "{}\tmean: {}s, variance: {}".format(key, val[0], val[1]),
+                file=f,
+            )
+
+    s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+    note = "\tbest for {}".format(test_set_name)
+    for key, val in test_set_wers:
+        s += "{}\t{}{}\n".format(key, val, note)
+        note = ""
+    logging.info(s)
+
+    s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name)
+    note = "\tbest for {}".format(test_set_name)
+    for key, val in test_set_delays:
+        s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note)
+        note = ""
+    logging.info(s)
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    LibriSpeechAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+    args.lang_dir = Path(args.lang_dir)
+    args.lm_dir = Path(args.lm_dir)
+
+    params = get_params()
+    # add decoding params
+    params.update(get_decoding_params())
+    params.update(vars(args))
+
+    assert params.decoding_method in (
+        "ctc-greedy-search",
+        "ctc-decoding",
+        "1best",
+        "nbest",
+        "nbest-rescoring",
+        "whole-lattice-rescoring",
+        "nbest-oracle",
+    )
+    params.res_dir = params.exp_dir / params.decoding_method
+
+    if params.iter > 0:
+        params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+    else:
+        params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+    if params.simulate_streaming:
+        params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
+        params.suffix += f"-left-context-{params.left_context}"
+
+    if params.simulate_streaming:
+        assert (
+            params.causal_convolution
+        ), "Decoding in streaming requires causal convolution"
+
+    if params.use_averaged_model:
+        params.suffix += "-use-averaged-model"
+
+    setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+    logging.info("Decoding started")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"Device: {device}")
+    logging.info(params)
+
+    lexicon = Lexicon(params.lang_dir)
+    max_token_id = max(lexicon.tokens)
+    num_classes = max_token_id + 1  # +1 for the blank
+
+    graph_compiler = BpeCtcTrainingGraphCompiler(
+        params.lang_dir,
+        device=device,
+        sos_token="",
+        eos_token="",
+    )
+    sos_id = graph_compiler.sos_id
+    eos_id = graph_compiler.eos_id
+
+    params.vocab_size = num_classes
+    params.sos_id = sos_id
+    params.eos_id = eos_id
+
+    if params.decoding_method in ["ctc-decoding", "ctc-greedy-search"]:
+        HLG = None
+        H = k2.ctc_topo(
+            max_token=max_token_id,
+            modified=False,
+            device=device,
+        )
+        bpe_model = spm.SentencePieceProcessor()
+        bpe_model.load(str(params.lang_dir / "bpe.model"))
+    else:
+        H = None
+        bpe_model = None
+        HLG = k2.Fsa.from_dict(
+            torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
+        )
+        assert HLG.requires_grad is False
+
+        if not hasattr(HLG, "lm_scores"):
+            HLG.lm_scores = HLG.scores.clone()
+
+    if params.decoding_method in (
+        "nbest-rescoring",
+        "whole-lattice-rescoring",
+    ):
+        if not (params.lm_dir / "G_4_gram.pt").is_file():
+            logging.info("Loading G_4_gram.fst.txt")
+            logging.warning("It may take 8 minutes.")
+            with open(params.lm_dir / "G_4_gram.fst.txt") as f:
+                first_word_disambig_id = lexicon.word_table["#0"]
+
+                G = k2.Fsa.from_openfst(f.read(), acceptor=False)
+                # G.aux_labels is not needed in later computations, so
+                # remove it here.
+                del G.aux_labels
+                # CAUTION: The following line is crucial.
+                # Arcs entering the back-off state have label equal to #0.
+                # We have to change it to 0 here.
+                G.labels[G.labels >= first_word_disambig_id] = 0
+                # See https://github.com/k2-fsa/k2/issues/874
+                # for why we need to set G.properties to None
+                G.__dict__["_properties"] = None
+                G = k2.Fsa.from_fsas([G]).to(device)
+                G = k2.arc_sort(G)
+                # Save a dummy value so that it can be loaded in C++.
+                # See https://github.com/pytorch/pytorch/issues/67902
+                # for why we need to do this.
+                G.dummy = 1
+
+                torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
+        else:
+            logging.info("Loading pre-compiled G_4_gram.pt")
+            d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
+            G = k2.Fsa.from_dict(d)
+
+        if params.decoding_method == "whole-lattice-rescoring":
+            # Add epsilon self-loops to G as we will compose
+            # it with the whole lattice later
+            G = k2.add_epsilon_self_loops(G)
+            G = k2.arc_sort(G)
+            G = G.to(device)
+
+        # G.lm_scores is used to replace HLG.lm_scores during
+        # LM rescoring.
+        G.lm_scores = G.scores.clone()
+    else:
+        G = None
+
+    logging.info("About to create model")
+    model = get_ctc_model(params)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to(device)
+    model.eval()
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    # we need cut ids to display recognition results.
+    args.return_cuts = True
+    librispeech = LibriSpeechAsrDataModule(args)
+
+    test_clean_cuts = librispeech.test_clean_cuts()
+    test_other_cuts = librispeech.test_other_cuts()
+
+    test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
+    test_other_dl = librispeech.test_dataloaders(test_other_cuts)
+
+    test_sets = ["test-clean", "test-other"]
+    test_dl = [test_clean_dl, test_other_dl]
+
+    for test_set, test_dl in zip(test_sets, test_dl):
+        results_dict = decode_dataset(
+            dl=test_dl,
+            params=params,
+            model=model,
+            HLG=HLG,
+            H=H,
+            bpe_model=bpe_model,
+            word_table=lexicon.word_table,
+            G=G,
+            sos_id=sos_id,
+            eos_id=eos_id,
+        )
+
+        save_results(
+            params=params,
+            test_set_name=test_set,
+            results_dict=results_dict,
+        )
+
+    logging.info("Done!")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/librispeech/ASR/conformer_ctc3/encoder_interface.py b/egs/librispeech/ASR/conformer_ctc3/encoder_interface.py
new file mode 120000
index 000000000..b9aa0ae08
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc3/encoder_interface.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/encoder_interface.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/conformer_ctc3/export.py b/egs/librispeech/ASR/conformer_ctc3/export.py
new file mode 100755
index 000000000..c5b95d981
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc3/export.py
@@ -0,0 +1,292 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script converts several saved checkpoints
+# to a single one using model averaging.
+"""
+Usage:
+
+(1) Export to torchscript model using torch.jit.trace()
+
+./conformer_ctc3/export.py \
+  --exp-dir ./conformer_ctc3/exp \
+  --lang-dir data/lang_bpe_500 \
+  --epoch 20 \
+  --avg 10 \
+  --jit-trace 1
+
+It will generates the file: `jit_trace.pt`.
+
+(2) Export `model.state_dict()`
+
+./conformer_ctc3/export.py \
+  --exp-dir ./conformer_ctc3/exp \
+  --lang-dir data/lang_bpe_500 \
+  --epoch 20 \
+  --avg 10
+
+It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
+load it by `icefall.checkpoint.load_checkpoint()`.
+
+To use the generated file with `conformer_ctc3/decode.py`,
+you can do:
+
+    cd /path/to/exp_dir
+    ln -s pretrained.pt epoch-9999.pt
+
+    cd /path/to/egs/librispeech/ASR
+    ./conformer_ctc3/decode.py \
+        --exp-dir ./conformer_ctc3/exp \
+        --epoch 9999 \
+        --avg 1 \
+        --max-duration 100 \
+        --lang-dir data/lang_bpe_500
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import torch
+from scaling_converter import convert_scaled_to_non_scaled
+from train import add_model_arguments, get_ctc_model, get_params
+
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import str2bool
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=28,
+        help="""It specifies the checkpoint to use for averaging.
+        Note: Epoch counts from 0.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=15,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=True,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless4/exp",
+        help="""It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=Path,
+        default="data/lang_bpe_500",
+        help="The lang dir containing word table and LG graph",
+    )
+
+    parser.add_argument(
+        "--jit-trace",
+        type=str2bool,
+        default=False,
+        help="""True to save a model after applying torch.jit.script.
+        """,
+    )
+
+    parser.add_argument(
+        "--streaming-model",
+        type=str2bool,
+        default=False,
+        help="""Whether to export a streaming model, if the models in exp-dir
+        are streaming model, this should be True.
+        """,
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def main():
+    args = get_parser().parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    params = get_params()
+    params.update(vars(args))
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    lexicon = Lexicon(params.lang_dir)
+    max_token_id = max(lexicon.tokens)
+    num_classes = max_token_id + 1  # +1 for the blank
+    params.vocab_size = num_classes
+
+    if params.streaming_model:
+        assert params.causal_convolution
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_ctc_model(params)
+
+    model.to(device)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to("cpu")
+    model.eval()
+
+    if params.jit_trace:
+        # TODO: will support streaming mode
+        assert not params.streaming_model
+        convert_scaled_to_non_scaled(model, inplace=True)
+
+        logging.info("Using torch.jit.trace()")
+
+        x = torch.zeros(1, 100, 80, dtype=torch.float32)
+        x_lens = torch.tensor([100], dtype=torch.int64)
+        traced_model = torch.jit.trace(model, (x, x_lens))
+
+        filename = params.exp_dir / "jit_trace.pt"
+        traced_model.save(str(filename))
+        logging.info(f"Saved to {filename}")
+    else:
+        logging.info("Not using torch.jit.trace()")
+        # Save it using a format so that it can be loaded
+        # by :func:`load_checkpoint`
+        filename = params.exp_dir / "pretrained.pt"
+        torch.save({"model": model.state_dict()}, str(filename))
+        logging.info(f"Saved to {filename}")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py b/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py
new file mode 100755
index 000000000..c96defd23
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py
@@ -0,0 +1,406 @@
+#!/usr/bin/env python3
+# Copyright      2021  Xiaomi Corp.        (authors: Fangjun Kuang,
+#                                                    Mingshuang Luo,)
+#                                                    Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+Usage (for non-streaming mode):
+
+(1) ctc-decoding
+./conformer_ctc3/pretrained.py \
+  --checkpoint conformer_ctc3/exp/pretrained.pt \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --method ctc-decoding \
+  --sample-rate 16000 \
+  test_wavs/1089-134686-0001.wav
+
+(2) 1best
+./conformer_ctc3/pretrained.py \
+  --checkpoint conformer_ctc3/exp/pretrained.pt \
+  --HLG data/lang_bpe_500/HLG.pt \
+  --words-file data/lang_bpe_500/words.txt  \
+  --method 1best \
+  --sample-rate 16000 \
+  test_wavs/1089-134686-0001.wav
+
+(3) nbest-rescoring
+./conformer_ctc3/pretrained.py \
+  --checkpoint conformer_ctc3/exp/pretrained.pt \
+  --HLG data/lang_bpe_500/HLG.pt \
+  --words-file data/lang_bpe_500/words.txt  \
+  --G data/lm/G_4_gram.pt \
+  --method nbest-rescoring \
+  --sample-rate 16000 \
+  test_wavs/1089-134686-0001.wav
+
+(4) whole-lattice-rescoring
+./conformer_ctc3/pretrained.py \
+  --checkpoint conformer_ctc3/exp/pretrained.pt \
+  --HLG data/lang_bpe_500/HLG.pt \
+  --words-file data/lang_bpe_500/words.txt  \
+  --G data/lm/G_4_gram.pt \
+  --method whole-lattice-rescoring \
+  --sample-rate 16000 \
+  test_wavs/1089-134686-0001.wav
+"""
+
+
+import argparse
+import logging
+import math
+from typing import List
+
+import k2
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from decode import get_decoding_params
+from torch.nn.utils.rnn import pad_sequence
+from train import add_model_arguments, get_params
+
+from icefall.decode import (
+    get_lattice,
+    one_best_decoding,
+    rescore_with_n_best_list,
+    rescore_with_whole_lattice,
+)
+from icefall.utils import get_texts
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--model-filename",
+        type=str,
+        required=True,
+        help="Path to the torchscript model.",
+    )
+
+    parser.add_argument(
+        "--words-file",
+        type=str,
+        help="""Path to words.txt.
+        Used only when method is not ctc-decoding.
+        """,
+    )
+
+    parser.add_argument(
+        "--HLG",
+        type=str,
+        help="""Path to HLG.pt.
+        Used only when method is not ctc-decoding.
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        help="""Path to bpe.model.
+        Used only when method is ctc-decoding.
+        """,
+    )
+
+    parser.add_argument(
+        "--method",
+        type=str,
+        default="1best",
+        help="""Decoding method.
+        Possible values are:
+        (0) ctc-decoding - Use CTC decoding. It uses a sentence
+            piece model, i.e., lang_dir/bpe.model, to convert
+            word pieces to words. It needs neither a lexicon
+            nor an n-gram LM.
+        (1) 1best - Use the best path as decoding output. Only
+            the transformer encoder output is used for decoding.
+            We call it HLG decoding.
+        (2) nbest-rescoring. Extract n paths from the decoding lattice,
+            rescore them with an LM, the path with
+            the highest score is the decoding result.
+            We call it HLG decoding + n-gram LM rescoring.
+        (3) whole-lattice-rescoring - Use an LM to rescore the
+            decoding lattice and then use 1best to decode the
+            rescored lattice.
+            We call it HLG decoding + n-gram LM rescoring.
+        """,
+    )
+
+    parser.add_argument(
+        "--G",
+        type=str,
+        help="""An LM for rescoring.
+        Used only when method is
+        whole-lattice-rescoring or nbest-rescoring.
+        It's usually a 4-gram LM.
+        """,
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=100,
+        help="""
+        Used only when method is attention-decoder.
+        It specifies the size of n-best list.""",
+    )
+
+    parser.add_argument(
+        "--ngram-lm-scale",
+        type=float,
+        default=1.3,
+        help="""
+        Used only when method is whole-lattice-rescoring and nbest-rescoring.
+        It specifies the scale for n-gram LM scores.
+        (Note: You need to tune it on a dataset.)
+        """,
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=0.5,
+        help="""
+        Used only when method is nbest-rescoring.
+        It specifies the scale for lattice.scores when
+        extracting n-best lists. A smaller value results in
+        more unique number of paths with the risk of missing
+        the best path.
+        """,
+    )
+
+    parser.add_argument(
+        "--num-classes",
+        type=int,
+        default=500,
+        help="""
+        Vocab size in the BPE model.
+        """,
+    )
+
+    parser.add_argument(
+        "--sample-rate",
+        type=int,
+        default=16000,
+        help="The sample rate of the input sound file",
+    )
+
+    parser.add_argument(
+        "sound_files",
+        type=str,
+        nargs="+",
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def read_sound_files(
+    filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+    """Read a list of sound files into a list 1-D float32 torch tensors.
+    Args:
+      filenames:
+        A list of sound filenames.
+      expected_sample_rate:
+        The expected sample rate of the sound files.
+    Returns:
+      Return a list of 1-D float32 torch tensors.
+    """
+    ans = []
+    for f in filenames:
+        wave, sample_rate = torchaudio.load(f)
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
+        )
+        # We use only the first channel
+        ans.append(wave[0])
+    return ans
+
+
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+
+    params = get_params()
+    # add decoding params
+    params.update(get_decoding_params())
+    params.update(vars(args))
+    params.vocab_size = params.num_classes
+
+    logging.info(f"{params}")
+
+    device = torch.device("cpu")
+
+    logging.info(f"device: {device}")
+
+    model = torch.jit.load(args.model_filename)
+    model.to(device)
+    model.eval()
+
+    logging.info("Constructing Fbank computer")
+    opts = kaldifeat.FbankOptions()
+    opts.device = device
+    opts.frame_opts.dither = 0
+    opts.frame_opts.snip_edges = False
+    opts.frame_opts.samp_freq = params.sample_rate
+    opts.mel_opts.num_bins = params.feature_dim
+
+    fbank = kaldifeat.Fbank(opts)
+
+    logging.info(f"Reading sound files: {params.sound_files}")
+    waves = read_sound_files(
+        filenames=params.sound_files, expected_sample_rate=params.sample_rate
+    )
+    waves = [w.to(device) for w in waves]
+
+    logging.info("Decoding started")
+    features = fbank(waves)
+    feature_lengths = [f.size(0) for f in features]
+
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    feature_lengths = torch.tensor(feature_lengths, device=device)
+
+    nnet_output, _ = model(features, feature_lengths)
+
+    batch_size = nnet_output.shape[0]
+    supervision_segments = torch.tensor(
+        [[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
+        dtype=torch.int32,
+    )
+
+    if params.method == "ctc-decoding":
+        logging.info("Use CTC decoding")
+        bpe_model = spm.SentencePieceProcessor()
+        bpe_model.load(params.bpe_model)
+        max_token_id = params.num_classes - 1
+
+        H = k2.ctc_topo(
+            max_token=max_token_id,
+            modified=False,
+            device=device,
+        )
+
+        lattice = get_lattice(
+            nnet_output=nnet_output,
+            decoding_graph=H,
+            supervision_segments=supervision_segments,
+            search_beam=params.search_beam,
+            output_beam=params.output_beam,
+            min_active_states=params.min_active_states,
+            max_active_states=params.max_active_states,
+            subsampling_factor=params.subsampling_factor,
+        )
+
+        best_path = one_best_decoding(
+            lattice=lattice, use_double_scores=params.use_double_scores
+        )
+        token_ids = get_texts(best_path)
+        hyps = bpe_model.decode(token_ids)
+        hyps = [s.split() for s in hyps]
+    elif params.method in [
+        "1best",
+        "nbest-rescoring",
+        "whole-lattice-rescoring",
+    ]:
+        logging.info(f"Loading HLG from {params.HLG}")
+        HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
+        HLG = HLG.to(device)
+        if not hasattr(HLG, "lm_scores"):
+            # For whole-lattice-rescoring and attention-decoder
+            HLG.lm_scores = HLG.scores.clone()
+
+        if params.method in [
+            "nbest-rescoring",
+            "whole-lattice-rescoring",
+        ]:
+            logging.info(f"Loading G from {params.G}")
+            G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
+            G = G.to(device)
+            if params.method == "whole-lattice-rescoring":
+                # Add epsilon self-loops to G as we will compose
+                # it with the whole lattice later
+                G = k2.add_epsilon_self_loops(G)
+                G = k2.arc_sort(G)
+
+            # G.lm_scores is used to replace HLG.lm_scores during
+            # LM rescoring.
+            G.lm_scores = G.scores.clone()
+
+        lattice = get_lattice(
+            nnet_output=nnet_output,
+            decoding_graph=HLG,
+            supervision_segments=supervision_segments,
+            search_beam=params.search_beam,
+            output_beam=params.output_beam,
+            min_active_states=params.min_active_states,
+            max_active_states=params.max_active_states,
+            subsampling_factor=params.subsampling_factor,
+        )
+
+        if params.method == "1best":
+            logging.info("Use HLG decoding")
+            best_path = one_best_decoding(
+                lattice=lattice, use_double_scores=params.use_double_scores
+            )
+        if params.method == "nbest-rescoring":
+            logging.info("Use HLG decoding + LM rescoring")
+            best_path_dict = rescore_with_n_best_list(
+                lattice=lattice,
+                G=G,
+                num_paths=params.num_paths,
+                lm_scale_list=[params.ngram_lm_scale],
+                nbest_scale=params.nbest_scale,
+            )
+            best_path = next(iter(best_path_dict.values()))
+        elif params.method == "whole-lattice-rescoring":
+            logging.info("Use HLG decoding + LM rescoring")
+            best_path_dict = rescore_with_whole_lattice(
+                lattice=lattice,
+                G_with_epsilon_loops=G,
+                lm_scale_list=[params.ngram_lm_scale],
+            )
+            best_path = next(iter(best_path_dict.values()))
+
+        hyps = get_texts(best_path)
+        word_sym_table = k2.SymbolTable.from_file(params.words_file)
+        hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
+    else:
+        raise ValueError(f"Unsupported decoding method: {params.method}")
+
+    s = "\n"
+    for filename, hyp in zip(params.sound_files, hyps):
+        words = " ".join(hyp)
+        s += f"{filename}:\n{words}\n\n"
+    logging.info(s)
+
+    logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/librispeech/ASR/conformer_ctc3/lstmp.py b/egs/librispeech/ASR/conformer_ctc3/lstmp.py
new file mode 120000
index 000000000..4f377cd01
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc3/lstmp.py
@@ -0,0 +1 @@
+../lstm_transducer_stateless2/lstmp.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/conformer_ctc3/model.py b/egs/librispeech/ASR/conformer_ctc3/model.py
new file mode 100644
index 000000000..f56df2006
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc3/model.py
@@ -0,0 +1,122 @@
+# Copyright  2021-2022  Xiaomi Corp.     (authors: Fangjun Kuang,
+#                                                  Wei Kang,
+#                                                  Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import math
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+from encoder_interface import EncoderInterface
+from scaling import ScaledLinear
+
+
+class CTCModel(nn.Module):
+    """It implements https://www.cs.toronto.edu/~graves/icml_2006.pdf
+    "Connectionist Temporal Classification: Labelling Unsegmented
+    Sequence Data with Recurrent Neural Networks"
+    """
+
+    def __init__(
+        self,
+        encoder: EncoderInterface,
+        encoder_dim: int,
+        vocab_size: int,
+    ):
+        """
+        Args:
+          encoder:
+            It is the transcription network in the paper. Its accepts
+            two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
+            It returns two tensors: `logits` of shape (N, T, encoder_dm) and
+            `logit_lens` of shape (N,).
+          encoder_dim:
+            The feature embedding dimension.
+          vocab_size:
+            The vocabulary size.
+        """
+        super().__init__()
+        assert isinstance(encoder, EncoderInterface), type(encoder)
+
+        self.encoder = encoder
+        self.ctc_output_module = nn.Sequential(
+            nn.Dropout(p=0.1),
+            ScaledLinear(encoder_dim, vocab_size),
+        )
+
+    def get_ctc_output(
+        self,
+        encoder_out: torch.Tensor,
+        delay_penalty: float = 0.0,
+        blank_threshold: float = 0.99,
+    ):
+        """Compute ctc log-prob and optionally (delay_penalty > 0) apply delay penalty.
+        We first split utterance into sub-utterances according to the
+        blank probs, and then add sawtooth-like "blank-bonus" values to
+        the blank probs.
+        See https://github.com/k2-fsa/icefall/pull/669 for details.
+
+        Args:
+          encoder_out:
+            A tensor with shape of (N, T, C).
+          delay_penalty:
+            A constant used to scale the delay penalty score.
+          blank_threshold:
+            The threshold used to split utterance into sub-utterances.
+        """
+        output = self.ctc_output_module(encoder_out)
+        log_prob = nn.functional.log_softmax(output, dim=-1)
+
+        if self.training and delay_penalty > 0:
+            T_arange = torch.arange(encoder_out.shape[1]).to(device=encoder_out.device)
+            # split into sub-utterances using the blank-id
+            mask = log_prob[:, :, 0] >= math.log(blank_threshold)  # (B, T)
+            mask[:, 0] = True
+            cummax_out = (T_arange * mask).cummax(dim=-1)[0]  # (B, T)
+            # the sawtooth "blank-bonus" value
+            penalty = T_arange - cummax_out  # (B, T)
+            penalty_all = torch.zeros_like(log_prob)
+            penalty_all[:, :, 0] = delay_penalty * penalty
+            # apply latency penalty on probs
+            log_prob = log_prob + penalty_all
+
+        return log_prob
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        x_lens: torch.Tensor,
+        warmup: float = 1.0,
+        delay_penalty: float = 0.0,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Args:
+          x:
+            A 3-D tensor of shape (N, T, C).
+          x_lens:
+            A 1-D tensor of shape (N,). It contains the number of frames in `x`
+            before padding.
+          warmup: a floating point value which increases throughout training;
+            values >= 1.0 are fully warmed up and have all modules present.
+          delay_penalty:
+            A constant used to scale the delay penalty score.
+        """
+        encoder_out, encoder_out_lens = self.encoder(x, x_lens, warmup=warmup)
+        assert torch.all(encoder_out_lens > 0)
+        nnet_output = self.get_ctc_output(encoder_out, delay_penalty=delay_penalty)
+        return nnet_output, encoder_out_lens
diff --git a/egs/librispeech/ASR/conformer_ctc3/optim.py b/egs/librispeech/ASR/conformer_ctc3/optim.py
new file mode 120000
index 000000000..e2deb4492
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc3/optim.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/optim.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/conformer_ctc3/pretrained.py b/egs/librispeech/ASR/conformer_ctc3/pretrained.py
new file mode 100755
index 000000000..3628d6a5f
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc3/pretrained.py
@@ -0,0 +1,458 @@
+#!/usr/bin/env python3
+# Copyright      2021  Xiaomi Corp.        (authors: Fangjun Kuang,
+#                                                    Mingshuang Luo,)
+#                                                    Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+Usage (for non-streaming mode):
+
+(1) ctc-decoding
+./conformer_ctc3/pretrained.py \
+  --checkpoint conformer_ctc3/exp/pretrained.pt \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --method ctc-decoding \
+  --sample-rate 16000 \
+  test_wavs/1089-134686-0001.wav
+
+(2) 1best
+./conformer_ctc3/pretrained.py \
+  --checkpoint conformer_ctc3/exp/pretrained.pt \
+  --HLG data/lang_bpe_500/HLG.pt \
+  --words-file data/lang_bpe_500/words.txt  \
+  --method 1best \
+  --sample-rate 16000 \
+  test_wavs/1089-134686-0001.wav
+
+(3) nbest-rescoring
+./conformer_ctc3/pretrained.py \
+  --checkpoint conformer_ctc3/exp/pretrained.pt \
+  --HLG data/lang_bpe_500/HLG.pt \
+  --words-file data/lang_bpe_500/words.txt  \
+  --G data/lm/G_4_gram.pt \
+  --method nbest-rescoring \
+  --sample-rate 16000 \
+  test_wavs/1089-134686-0001.wav
+
+(4) whole-lattice-rescoring
+./conformer_ctc3/pretrained.py \
+  --checkpoint conformer_ctc3/exp/pretrained.pt \
+  --HLG data/lang_bpe_500/HLG.pt \
+  --words-file data/lang_bpe_500/words.txt  \
+  --G data/lm/G_4_gram.pt \
+  --method whole-lattice-rescoring \
+  --sample-rate 16000 \
+  test_wavs/1089-134686-0001.wav
+"""
+
+
+import argparse
+import logging
+import math
+from typing import List
+
+import k2
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from decode import get_decoding_params
+from torch.nn.utils.rnn import pad_sequence
+from train import add_model_arguments, get_ctc_model, get_params
+
+from icefall.decode import (
+    get_lattice,
+    one_best_decoding,
+    rescore_with_n_best_list,
+    rescore_with_whole_lattice,
+)
+from icefall.utils import get_texts, str2bool
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--checkpoint",
+        type=str,
+        required=True,
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
+    )
+
+    parser.add_argument(
+        "--words-file",
+        type=str,
+        help="""Path to words.txt.
+        Used only when method is not ctc-decoding.
+        """,
+    )
+
+    parser.add_argument(
+        "--HLG",
+        type=str,
+        help="""Path to HLG.pt.
+        Used only when method is not ctc-decoding.
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        help="""Path to bpe.model.
+        Used only when method is ctc-decoding.
+        """,
+    )
+
+    parser.add_argument(
+        "--method",
+        type=str,
+        default="1best",
+        help="""Decoding method.
+        Possible values are:
+        (0) ctc-decoding - Use CTC decoding. It uses a sentence
+            piece model, i.e., lang_dir/bpe.model, to convert
+            word pieces to words. It needs neither a lexicon
+            nor an n-gram LM.
+        (1) 1best - Use the best path as decoding output. Only
+            the transformer encoder output is used for decoding.
+            We call it HLG decoding.
+        (2) nbest-rescoring. Extract n paths from the decoding lattice,
+            rescore them with an LM, the path with
+            the highest score is the decoding result.
+            We call it HLG decoding + n-gram LM rescoring.
+        (3) whole-lattice-rescoring - Use an LM to rescore the
+            decoding lattice and then use 1best to decode the
+            rescored lattice.
+            We call it HLG decoding + n-gram LM rescoring.
+        """,
+    )
+
+    parser.add_argument(
+        "--G",
+        type=str,
+        help="""An LM for rescoring.
+        Used only when method is
+        whole-lattice-rescoring or nbest-rescoring.
+        It's usually a 4-gram LM.
+        """,
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=100,
+        help="""
+        Used only when method is attention-decoder.
+        It specifies the size of n-best list.""",
+    )
+
+    parser.add_argument(
+        "--ngram-lm-scale",
+        type=float,
+        default=1.3,
+        help="""
+        Used only when method is whole-lattice-rescoring and nbest-rescoring.
+        It specifies the scale for n-gram LM scores.
+        (Note: You need to tune it on a dataset.)
+        """,
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=0.5,
+        help="""
+        Used only when method is nbest-rescoring.
+        It specifies the scale for lattice.scores when
+        extracting n-best lists. A smaller value results in
+        more unique number of paths with the risk of missing
+        the best path.
+        """,
+    )
+
+    parser.add_argument(
+        "--num-classes",
+        type=int,
+        default=500,
+        help="""
+        Vocab size in the BPE model.
+        """,
+    )
+
+    parser.add_argument(
+        "--simulate-streaming",
+        type=str2bool,
+        default=False,
+        help="""Whether to simulate streaming in decoding, this is a good way to
+        test a streaming model.
+        """,
+    )
+
+    parser.add_argument(
+        "--decode-chunk-size",
+        type=int,
+        default=16,
+        help="The chunk size for decoding (in frames after subsampling)",
+    )
+
+    parser.add_argument(
+        "--left-context",
+        type=int,
+        default=64,
+        help="left context can be seen during decoding (in frames after subsampling)",
+    )
+
+    parser.add_argument(
+        "--sample-rate",
+        type=int,
+        default=16000,
+        help="The sample rate of the input sound file",
+    )
+
+    parser.add_argument(
+        "sound_files",
+        type=str,
+        nargs="+",
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def read_sound_files(
+    filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+    """Read a list of sound files into a list 1-D float32 torch tensors.
+    Args:
+      filenames:
+        A list of sound filenames.
+      expected_sample_rate:
+        The expected sample rate of the sound files.
+    Returns:
+      Return a list of 1-D float32 torch tensors.
+    """
+    ans = []
+    for f in filenames:
+        wave, sample_rate = torchaudio.load(f)
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
+        )
+        # We use only the first channel
+        ans.append(wave[0])
+    return ans
+
+
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+
+    params = get_params()
+    # add decoding params
+    params.update(get_decoding_params())
+    params.update(vars(args))
+    params.vocab_size = params.num_classes
+
+    if params.simulate_streaming:
+        assert (
+            params.causal_convolution
+        ), "Decoding in streaming requires causal convolution"
+
+    logging.info(f"{params}")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    logging.info("About to create model")
+    model = get_ctc_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    checkpoint = torch.load(args.checkpoint, map_location="cpu")
+    model.load_state_dict(checkpoint["model"], strict=False)
+    model.to(device)
+    model.eval()
+
+    logging.info("Constructing Fbank computer")
+    opts = kaldifeat.FbankOptions()
+    opts.device = device
+    opts.frame_opts.dither = 0
+    opts.frame_opts.snip_edges = False
+    opts.frame_opts.samp_freq = params.sample_rate
+    opts.mel_opts.num_bins = params.feature_dim
+
+    fbank = kaldifeat.Fbank(opts)
+
+    logging.info(f"Reading sound files: {params.sound_files}")
+    waves = read_sound_files(
+        filenames=params.sound_files, expected_sample_rate=params.sample_rate
+    )
+    waves = [w.to(device) for w in waves]
+
+    logging.info("Decoding started")
+    features = fbank(waves)
+    feature_lengths = [f.size(0) for f in features]
+
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    feature_lengths = torch.tensor(feature_lengths, device=device)
+
+    # model forward
+    if params.simulate_streaming:
+        encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
+            x=features,
+            x_lens=feature_lengths,
+            chunk_size=params.decode_chunk_size,
+            left_context=params.left_context,
+            simulate_streaming=True,
+        )
+    else:
+        encoder_out, encoder_out_lens = model.encoder(
+            x=features, x_lens=feature_lengths
+        )
+    nnet_output = model.get_ctc_output(encoder_out)
+
+    batch_size = nnet_output.shape[0]
+    supervision_segments = torch.tensor(
+        [[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
+        dtype=torch.int32,
+    )
+
+    if params.method == "ctc-decoding":
+        logging.info("Use CTC decoding")
+        bpe_model = spm.SentencePieceProcessor()
+        bpe_model.load(params.bpe_model)
+        max_token_id = params.num_classes - 1
+
+        H = k2.ctc_topo(
+            max_token=max_token_id,
+            modified=False,
+            device=device,
+        )
+
+        lattice = get_lattice(
+            nnet_output=nnet_output,
+            decoding_graph=H,
+            supervision_segments=supervision_segments,
+            search_beam=params.search_beam,
+            output_beam=params.output_beam,
+            min_active_states=params.min_active_states,
+            max_active_states=params.max_active_states,
+            subsampling_factor=params.subsampling_factor,
+        )
+
+        best_path = one_best_decoding(
+            lattice=lattice, use_double_scores=params.use_double_scores
+        )
+        token_ids = get_texts(best_path)
+        hyps = bpe_model.decode(token_ids)
+        hyps = [s.split() for s in hyps]
+    elif params.method in [
+        "1best",
+        "nbest-rescoring",
+        "whole-lattice-rescoring",
+    ]:
+        logging.info(f"Loading HLG from {params.HLG}")
+        HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
+        HLG = HLG.to(device)
+        if not hasattr(HLG, "lm_scores"):
+            # For whole-lattice-rescoring and attention-decoder
+            HLG.lm_scores = HLG.scores.clone()
+
+        if params.method in [
+            "nbest-rescoring",
+            "whole-lattice-rescoring",
+        ]:
+            logging.info(f"Loading G from {params.G}")
+            G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
+            G = G.to(device)
+            if params.method == "whole-lattice-rescoring":
+                # Add epsilon self-loops to G as we will compose
+                # it with the whole lattice later
+                G = k2.add_epsilon_self_loops(G)
+                G = k2.arc_sort(G)
+
+            # G.lm_scores is used to replace HLG.lm_scores during
+            # LM rescoring.
+            G.lm_scores = G.scores.clone()
+
+        lattice = get_lattice(
+            nnet_output=nnet_output,
+            decoding_graph=HLG,
+            supervision_segments=supervision_segments,
+            search_beam=params.search_beam,
+            output_beam=params.output_beam,
+            min_active_states=params.min_active_states,
+            max_active_states=params.max_active_states,
+            subsampling_factor=params.subsampling_factor,
+        )
+
+        if params.method == "1best":
+            logging.info("Use HLG decoding")
+            best_path = one_best_decoding(
+                lattice=lattice, use_double_scores=params.use_double_scores
+            )
+        if params.method == "nbest-rescoring":
+            logging.info("Use HLG decoding + LM rescoring")
+            best_path_dict = rescore_with_n_best_list(
+                lattice=lattice,
+                G=G,
+                num_paths=params.num_paths,
+                lm_scale_list=[params.ngram_lm_scale],
+                nbest_scale=params.nbest_scale,
+            )
+            best_path = next(iter(best_path_dict.values()))
+        elif params.method == "whole-lattice-rescoring":
+            logging.info("Use HLG decoding + LM rescoring")
+            best_path_dict = rescore_with_whole_lattice(
+                lattice=lattice,
+                G_with_epsilon_loops=G,
+                lm_scale_list=[params.ngram_lm_scale],
+            )
+            best_path = next(iter(best_path_dict.values()))
+
+        hyps = get_texts(best_path)
+        word_sym_table = k2.SymbolTable.from_file(params.words_file)
+        hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
+    else:
+        raise ValueError(f"Unsupported decoding method: {params.method}")
+
+    s = "\n"
+    for filename, hyp in zip(params.sound_files, hyps):
+        words = " ".join(hyp)
+        s += f"{filename}:\n{words}\n\n"
+    logging.info(s)
+
+    logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/librispeech/ASR/conformer_ctc3/scaling.py b/egs/librispeech/ASR/conformer_ctc3/scaling.py
new file mode 120000
index 000000000..09d802cc4
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc3/scaling.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/scaling.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/conformer_ctc3/scaling_converter.py b/egs/librispeech/ASR/conformer_ctc3/scaling_converter.py
new file mode 120000
index 000000000..3b667058d
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc3/scaling_converter.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless3/scaling_converter.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/conformer_ctc3/test_model.py b/egs/librispeech/ASR/conformer_ctc3/test_model.py
new file mode 100755
index 000000000..b97b7eed8
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc3/test_model.py
@@ -0,0 +1,82 @@
+#!/usr/bin/env python3
+# Copyright    2022  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+To run this file, do:
+
+    cd icefall/egs/librispeech/ASR
+    python ./conformer_ctc3/test_model.py
+"""
+
+import torch
+
+from train import get_params, get_ctc_model
+
+
+def test_model():
+    params = get_params()
+    params.vocab_size = 500
+    params.blank_id = 0
+    params.context_size = 2
+    params.unk_id = 2
+
+    params.dynamic_chunk_training = False
+    params.short_chunk_size = 25
+    params.num_left_chunks = 4
+    params.causal_convolution = False
+
+    model = get_ctc_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    print(f"Number of model parameters: {num_param}")
+
+    features = torch.randn(2, 100, 80)
+    feature_lengths = torch.full((2,), 100)
+    model(x=features, x_lens=feature_lengths)
+
+
+def test_model_streaming():
+    params = get_params()
+    params.vocab_size = 500
+    params.blank_id = 0
+    params.context_size = 2
+    params.unk_id = 2
+
+    params.dynamic_chunk_training = True
+    params.short_chunk_size = 25
+    params.num_left_chunks = 4
+    params.causal_convolution = True
+
+    model = get_ctc_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    print(f"Number of model parameters: {num_param}")
+
+    features = torch.randn(2, 100, 80)
+    feature_lengths = torch.full((2,), 100)
+    encoder_out, _ = model.encoder(x=features, x_lens=feature_lengths)
+    model.get_ctc_output(encoder_out)
+
+
+def main():
+    test_model()
+    test_model_streaming()
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/librispeech/ASR/conformer_ctc3/train.py b/egs/librispeech/ASR/conformer_ctc3/train.py
new file mode 100755
index 000000000..fb3b740c1
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc3/train.py
@@ -0,0 +1,1108 @@
+#!/usr/bin/env python3
+# Copyright    2021  Xiaomi Corp.        (authors: Fangjun Kuang,
+#                                                  Wei Kang,
+#                                                  Mingshuang Luo,)
+#                                                  Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./conformer_ctc3/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --exp-dir conformer_ctc3/exp \
+  --full-libri 1 \
+  --max-duration 300
+
+# For mix precision training:
+
+./conformer_ctc3/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --use-fp16 1 \
+  --exp-dir conformer_ctc3/exp \
+  --full-libri 1 \
+  --max-duration 550
+
+# train a streaming model
+./conformer_ctc3/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --exp-dir conformer_ctc3/exp \
+  --full-libri 1 \
+  --dynamic-chunk-training 1 \
+  --causal-convolution 1 \
+  --short-chunk-size 25 \
+  --num-left-chunks 4 \
+  --max-duration 300 \
+  --delay-penalty 0.0
+"""
+
+import argparse
+import copy
+import logging
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from conformer import Conformer
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import CTCModel
+from optim import Eden, Eve
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+
+from icefall import diagnostics
+from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+    save_checkpoint_with_global_batch_idx,
+    update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.graph_compiler import CtcTrainingGraphCompiler
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+    AttributeDict,
+    MetricsTracker,
+    encode_supervisions,
+    setup_logger,
+    str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+    parser.add_argument(
+        "--dynamic-chunk-training",
+        type=str2bool,
+        default=False,
+        help="""Whether to use dynamic_chunk_training, if you want a streaming
+        model, this requires to be True.
+        """,
+    )
+
+    parser.add_argument(
+        "--causal-convolution",
+        type=str2bool,
+        default=False,
+        help="""Whether to use causal convolution, this requires to be True when
+        using dynamic_chunk_training.
+        """,
+    )
+
+    parser.add_argument(
+        "--short-chunk-size",
+        type=int,
+        default=25,
+        help="""Chunk length of dynamic training, the chunk size would be either
+        max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
+        """,
+    )
+
+    parser.add_argument(
+        "--num-left-chunks",
+        type=int,
+        default=4,
+        help="How many left context can be seen in chunks when calculating attention.",
+    )
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--world-size",
+        type=int,
+        default=1,
+        help="Number of GPUs for DDP training.",
+    )
+
+    parser.add_argument(
+        "--master-port",
+        type=int,
+        default=12354,
+        help="Master port to use for DDP training.",
+    )
+
+    parser.add_argument(
+        "--tensorboard",
+        type=str2bool,
+        default=True,
+        help="Should various information be logged in tensorboard.",
+    )
+
+    parser.add_argument(
+        "--num-epochs",
+        type=int,
+        default=30,
+        help="Number of epochs to train.",
+    )
+
+    parser.add_argument(
+        "--start-epoch",
+        type=int,
+        default=1,
+        help="""Resume training from this epoch. It should be positive.
+        If larger than 1, it will load checkpoint from
+        exp-dir/epoch-{start_epoch-1}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--start-batch",
+        type=int,
+        default=0,
+        help="""If positive, --start-epoch is ignored and
+        it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="conformer_ctc3/exp",
+        help="""The experiment dir.
+        It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        default="data/lang_bpe_500",
+        help="""The lang dir
+        It contains language related input files such as
+        "lexicon.txt"
+        """,
+    )
+
+    parser.add_argument(
+        "--initial-lr",
+        type=float,
+        default=0.003,
+        help="""The initial learning rate. This value should not need to be
+        changed.""",
+    )
+
+    parser.add_argument(
+        "--lr-batches",
+        type=float,
+        default=5000,
+        help="""Number of steps that affects how rapidly the learning rate decreases.
+        We suggest not to change this.""",
+    )
+
+    parser.add_argument(
+        "--lr-epochs",
+        type=float,
+        default=6,
+        help="""Number of epochs that affects how rapidly the learning rate decreases.
+        """,
+    )
+
+    parser.add_argument(
+        "--seed",
+        type=int,
+        default=42,
+        help="The seed for random generators intended for reproducibility",
+    )
+
+    parser.add_argument(
+        "--print-diagnostics",
+        type=str2bool,
+        default=False,
+        help="Accumulate stats on activations, print them and exit.",
+    )
+
+    parser.add_argument(
+        "--save-every-n",
+        type=int,
+        default=8000,
+        help="""Save checkpoint after processing this number of batches"
+        periodically. We save checkpoint to exp-dir/ whenever
+        params.batch_idx_train % save_every_n == 0. The checkpoint filename
+        has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+        Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+        end of each epoch where `xxx` is the epoch number counting from 0.
+        """,
+    )
+
+    parser.add_argument(
+        "--keep-last-k",
+        type=int,
+        default=20,
+        help="""Only keep this number of checkpoints on disk.
+        For instance, if it is 3, there are only 3 checkpoints
+        in the exp-dir with filenames `checkpoint-xxx.pt`.
+        It does not affect checkpoints with name `epoch-xxx.pt`.
+        """,
+    )
+
+    parser.add_argument(
+        "--average-period",
+        type=int,
+        default=100,
+        help="""Update the averaged model, namely `model_avg`, after processing
+        this number of batches. `model_avg` is a separate version of model,
+        in which each floating-point parameter is the average of all the
+        parameters from the start of training. Each time we take the average,
+        we do: `model_avg = model * (average_period / batch_idx_train) +
+            model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+        """,
+    )
+
+    parser.add_argument(
+        "--use-fp16",
+        type=str2bool,
+        default=False,
+        help="Whether to use half precision training.",
+    )
+
+    parser.add_argument(
+        "--delay-penalty",
+        type=float,
+        default=0.0,
+        help="""A constant used to scale the symbol delay penalty,
+        to encourage symbol emit earlier for streaming models.
+        It is almost the same as the `delay_penalty` in our `rnnt_loss`, See
+        https://github.com/k2-fsa/k2/issues/955 and
+        https://arxiv.org/pdf/2211.00490.pdf for more details.""",
+    )
+
+    parser.add_argument(
+        "--nnet-delay-penalty",
+        type=float,
+        default=0.0,
+        help="""A constant to penalize symbol delay, which is applied on
+        the nnet_output after log-softmax.
+        We recommend using --delay-penalty instead.
+        See https://github.com/k2-fsa/icefall/pull/669 for details.""",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    """Return a dict containing training parameters.
+
+    All training related parameters that are not passed from the commandline
+    are saved in the variable `params`.
+
+    Commandline options are merged into `params` after they are parsed, so
+    you can also access them via `params`.
+
+    Explanation of options saved in `params`:
+
+        - best_train_loss: Best training loss so far. It is used to select
+                           the model that has the lowest training loss. It is
+                           updated during the training.
+
+        - best_valid_loss: Best validation loss so far. It is used to select
+                           the model that has the lowest validation loss. It is
+                           updated during the training.
+
+        - best_train_epoch: It is the epoch that has the best training loss.
+
+        - best_valid_epoch: It is the epoch that has the best validation loss.
+
+        - batch_idx_train: Used to writing statistics to tensorboard. It
+                           contains number of batches trained so far across
+                           epochs.
+
+        - log_interval:  Print training loss if batch_idx % log_interval` is 0
+
+        - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+        - valid_interval:  Run validation if batch_idx % valid_interval is 0
+
+        - feature_dim: The model input dim. It has to match the one used
+                       in computing features.
+
+        - subsampling_factor:  The subsampling factor for the model.
+
+        - encoder_dim: Hidden dim for multi-head attention model.
+
+        - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+        - warm_step: The warm_step for Noam optimizer.
+    """
+    params = AttributeDict(
+        {
+            "best_train_loss": float("inf"),
+            "best_valid_loss": float("inf"),
+            "best_train_epoch": -1,
+            "best_valid_epoch": -1,
+            "batch_idx_train": 0,
+            "log_interval": 50,
+            "reset_interval": 200,
+            "valid_interval": 3000,  # For the 100h subset, use 800
+            # parameters for conformer
+            "feature_dim": 80,
+            "subsampling_factor": 4,
+            "encoder_dim": 512,
+            "nhead": 8,
+            "dim_feedforward": 2048,
+            "num_encoder_layers": 12,
+            # parameters for loss
+            "beam_size": 10,
+            "reduction": "sum",
+            "use_double_scores": True,
+            # parameters for Noam
+            "model_warm_step": 3000,  # arg given to model, not for lrate
+            "env_info": get_env_info(),
+        }
+    )
+
+    return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+    # TODO: We can add an option to switch between Conformer and Transformer
+    encoder = Conformer(
+        num_features=params.feature_dim,
+        subsampling_factor=params.subsampling_factor,
+        d_model=params.encoder_dim,
+        nhead=params.nhead,
+        dim_feedforward=params.dim_feedforward,
+        num_encoder_layers=params.num_encoder_layers,
+        dynamic_chunk_training=params.dynamic_chunk_training,
+        short_chunk_size=params.short_chunk_size,
+        num_left_chunks=params.num_left_chunks,
+        causal=params.causal_convolution,
+    )
+    return encoder
+
+
+def get_ctc_model(params: AttributeDict) -> nn.Module:
+    encoder = get_encoder_model(params)
+    model = CTCModel(
+        encoder=encoder,
+        encoder_dim=params.encoder_dim,
+        vocab_size=params.vocab_size,
+    )
+    return model
+
+
+def load_checkpoint_if_available(
+    params: AttributeDict,
+    model: nn.Module,
+    model_avg: nn.Module = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+    """Load checkpoint from file.
+
+    If params.start_batch is positive, it will load the checkpoint from
+    `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+    params.start_epoch is larger than 1, it will load the checkpoint from
+    `params.start_epoch - 1`.
+
+    Apart from loading state dict for `model` and `optimizer` it also updates
+    `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+    and `best_valid_loss` in `params`.
+
+    Args:
+      params:
+        The return value of :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer that we are using.
+      scheduler:
+        The scheduler that we are using.
+    Returns:
+      Return a dict containing previously saved training info.
+    """
+    if params.start_batch > 0:
+        filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+    elif params.start_epoch > 1:
+        filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+    else:
+        return None
+
+    assert filename.is_file(), f"{filename} does not exist!"
+
+    saved_params = load_checkpoint(
+        filename,
+        model=model,
+        model_avg=model_avg,
+        optimizer=optimizer,
+        scheduler=scheduler,
+    )
+
+    keys = [
+        "best_train_epoch",
+        "best_valid_epoch",
+        "batch_idx_train",
+        "best_train_loss",
+        "best_valid_loss",
+    ]
+    for k in keys:
+        params[k] = saved_params[k]
+
+    if params.start_batch > 0:
+        if "cur_epoch" in saved_params:
+            params["start_epoch"] = saved_params["cur_epoch"]
+
+    return saved_params
+
+
+def save_checkpoint(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    model_avg: Optional[nn.Module] = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+    sampler: Optional[CutSampler] = None,
+    scaler: Optional[GradScaler] = None,
+    rank: int = 0,
+) -> None:
+    """Save model, optimizer, scheduler and training stats to file.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer used in the training.
+      sampler:
+       The sampler for the training dataset.
+      scaler:
+        The scaler used for mix precision training.
+    """
+    if rank != 0:
+        return
+    filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+    save_checkpoint_impl(
+        filename=filename,
+        model=model,
+        model_avg=model_avg,
+        params=params,
+        optimizer=optimizer,
+        scheduler=scheduler,
+        sampler=sampler,
+        scaler=scaler,
+        rank=rank,
+    )
+
+    if params.best_train_epoch == params.cur_epoch:
+        best_train_filename = params.exp_dir / "best-train-loss.pt"
+        copyfile(src=filename, dst=best_train_filename)
+
+    if params.best_valid_epoch == params.cur_epoch:
+        best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+        copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler],
+    batch: dict,
+    is_training: bool,
+    warmup: float = 1.0,
+) -> Tuple[Tensor, MetricsTracker]:
+    """
+    Compute RNN-T loss given the model and its inputs.
+
+    Args:
+      params:
+        Parameters for training. See :func:`get_params`.
+      model:
+        The model for training. It is an instance of Conformer in our case.
+      graph_compiler:
+        It is used to build a decoding graph from a ctc topo and training
+        transcript. The training transcript is contained in the given `batch`,
+        while the ctc topo is built when this compiler is instantiated.
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      is_training:
+        True for training. False for validation. When it is True, this
+        function enables autograd during computation; when it is False, it
+        disables autograd.
+     warmup: a floating point value which increases throughout training;
+        values >= 1.0 are fully warmed up and have all modules present.
+    """
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+    feature = batch["inputs"]
+    # at entry, feature is (N, T, C)
+    assert feature.ndim == 3
+    feature = feature.to(device)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    with torch.set_grad_enabled(is_training):
+        nnet_output, encoder_out_lens = model(
+            feature,
+            feature_lens,
+            warmup=warmup,
+            delay_penalty=params.nnet_delay_penalty if warmup >= 1.0 else 0,
+        )
+        assert torch.all(encoder_out_lens > 0)
+
+    # NOTE: We need `encode_supervisions` to sort sequences with
+    # different duration in decreasing order, required by
+    # `k2.intersect_dense` called in `k2.ctc_loss`
+    supervision_segments, texts = encode_supervisions(
+        supervisions, subsampling_factor=params.subsampling_factor
+    )
+
+    if isinstance(graph_compiler, BpeCtcTrainingGraphCompiler):
+        # Works with a BPE model
+        token_ids = graph_compiler.texts_to_ids(texts)
+        decoding_graph = graph_compiler.compile(token_ids)
+    elif isinstance(graph_compiler, CtcTrainingGraphCompiler):
+        # Works with a phone lexicon
+        decoding_graph = graph_compiler.compile(texts)
+    else:
+        raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}")
+
+    dense_fsa_vec = k2.DenseFsaVec(
+        nnet_output,
+        supervision_segments,
+        allow_truncate=params.subsampling_factor - 1,
+    )
+
+    ctc_loss = k2.ctc_loss(
+        decoding_graph=decoding_graph,
+        dense_fsa_vec=dense_fsa_vec,
+        output_beam=params.beam_size,
+        delay_penalty=params.delay_penalty if warmup >= 1.0 else 0.0,
+        reduction=params.reduction,
+        use_double_scores=params.use_double_scores,
+    )
+    ctc_loss_is_finite = torch.isfinite(ctc_loss)
+    if not torch.all(ctc_loss_is_finite):
+        logging.info("Not all losses are finite!\n" f"ctc_loss: {ctc_loss}")
+        ctc_loss = ctc_loss[ctc_loss_is_finite]
+
+        # If either all simple_loss or pruned_loss is inf or nan,
+        # we stop the training process by raising an exception
+        if torch.all(~ctc_loss_is_finite):
+            raise ValueError(
+                "There are too many utterances in this batch "
+                "leading to inf or nan losses."
+            )
+    loss = ctc_loss.sum()
+
+    assert loss.requires_grad == is_training
+
+    info = MetricsTracker()
+    # info["frames"] is an approximate number for two reasons:
+    # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2
+    # (2) If some utterances in the batch lead to inf/nan loss, they
+    #     are filtered out.
+    info["frames"] = supervision_segments[:, 2].sum().item()
+    # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances`  # noqa
+    info["utterances"] = feature.size(0)
+    # averaged input duration in frames over utterances
+    info["utt_duration"] = feature_lens.sum().item()
+    # averaged padding proportion over utterances
+    info["utt_pad_proportion"] = (
+        ((feature.size(1) - feature_lens) / feature.size(1)).sum().item()
+    )
+
+    # Note: We use reduction=sum while computing the loss.
+    info["loss"] = loss.detach().cpu().item()
+
+    return loss, info
+
+
+def compute_validation_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler],
+    valid_dl: torch.utils.data.DataLoader,
+    world_size: int = 1,
+) -> MetricsTracker:
+    """Run the validation process."""
+    model.eval()
+
+    tot_loss = MetricsTracker()
+
+    for batch_idx, batch in enumerate(valid_dl):
+        loss, loss_info = compute_loss(
+            params=params,
+            model=model,
+            graph_compiler=graph_compiler,
+            batch=batch,
+            is_training=False,
+        )
+        assert loss.requires_grad is False
+        tot_loss = tot_loss + loss_info
+
+    if world_size > 1:
+        tot_loss.reduce(loss.device)
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    if loss_value < params.best_valid_loss:
+        params.best_valid_epoch = params.cur_epoch
+        params.best_valid_loss = loss_value
+
+    return tot_loss
+
+
+def train_one_epoch(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    optimizer: torch.optim.Optimizer,
+    scheduler: LRSchedulerType,
+    graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler],
+    train_dl: torch.utils.data.DataLoader,
+    valid_dl: torch.utils.data.DataLoader,
+    scaler: GradScaler,
+    model_avg: Optional[nn.Module] = None,
+    tb_writer: Optional[SummaryWriter] = None,
+    world_size: int = 1,
+    rank: int = 0,
+) -> None:
+    """Train the model for one epoch.
+
+    The training loss from the mean of all frames is saved in
+    `params.train_loss`. It runs the validation process every
+    `params.valid_interval` batches.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The model for training.
+      optimizer:
+        The optimizer we are using.
+      scheduler:
+        The learning rate scheduler, we call step() every step.
+      graph_compiler:
+        It is used to build a decoding graph from a ctc topo and training
+        transcript. The training transcript is contained in the given `batch`,
+        while the ctc topo is built when this compiler is instantiated.
+      train_dl:
+        Dataloader for the training dataset.
+      valid_dl:
+        Dataloader for the validation dataset.
+      scaler:
+        The scaler used for mix precision training.
+      model_avg:
+        The stored model averaged from the start of training.
+      tb_writer:
+        Writer to write log messages to tensorboard.
+      world_size:
+        Number of nodes in DDP training. If it is 1, DDP is disabled.
+      rank:
+        The rank of the node in DDP training. If no DDP is used, it should
+        be set to 0.
+    """
+    model.train()
+
+    tot_loss = MetricsTracker()
+
+    for batch_idx, batch in enumerate(train_dl):
+        params.batch_idx_train += 1
+        batch_size = len(batch["supervisions"]["text"])
+
+        with torch.cuda.amp.autocast(enabled=params.use_fp16):
+            loss, loss_info = compute_loss(
+                params=params,
+                model=model,
+                graph_compiler=graph_compiler,
+                batch=batch,
+                is_training=True,
+                warmup=(params.batch_idx_train / params.model_warm_step),
+            )
+        # summary stats
+        tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+        # NOTE: We use reduction==sum and loss is computed over utterances
+        # in the batch and there is no normalization to it so far.
+        scaler.scale(loss).backward()
+        scheduler.step_batch(params.batch_idx_train)
+        scaler.step(optimizer)
+        scaler.update()
+        optimizer.zero_grad()
+
+        if params.print_diagnostics and batch_idx == 30:
+            return
+
+        if (
+            rank == 0
+            and params.batch_idx_train > 0
+            and params.batch_idx_train % params.average_period == 0
+        ):
+            update_averaged_model(
+                params=params,
+                model_cur=model,
+                model_avg=model_avg,
+            )
+
+        if (
+            params.batch_idx_train > 0
+            and params.batch_idx_train % params.save_every_n == 0
+        ):
+            save_checkpoint_with_global_batch_idx(
+                out_dir=params.exp_dir,
+                global_batch_idx=params.batch_idx_train,
+                model=model,
+                model_avg=model_avg,
+                params=params,
+                optimizer=optimizer,
+                scheduler=scheduler,
+                sampler=train_dl.sampler,
+                scaler=scaler,
+                rank=rank,
+            )
+            remove_checkpoints(
+                out_dir=params.exp_dir,
+                topk=params.keep_last_k,
+                rank=rank,
+            )
+
+        if batch_idx % params.log_interval == 0:
+            cur_lr = scheduler.get_last_lr()[0]
+            logging.info(
+                f"Epoch {params.cur_epoch}, "
+                f"batch {batch_idx}, loss[{loss_info}], "
+                f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+                f"lr: {cur_lr:.2e}"
+            )
+
+            if tb_writer is not None:
+                tb_writer.add_scalar(
+                    "train/learning_rate", cur_lr, params.batch_idx_train
+                )
+
+                loss_info.write_summary(
+                    tb_writer, "train/current_", params.batch_idx_train
+                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+
+        if batch_idx > 0 and batch_idx % params.valid_interval == 0:
+            logging.info("Computing validation loss")
+            valid_info = compute_validation_loss(
+                params=params,
+                model=model,
+                graph_compiler=graph_compiler,
+                valid_dl=valid_dl,
+                world_size=world_size,
+            )
+            model.train()
+            logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+            if tb_writer is not None:
+                valid_info.write_summary(
+                    tb_writer, "train/valid_", params.batch_idx_train
+                )
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    params.train_loss = loss_value
+    if params.train_loss < params.best_train_loss:
+        params.best_train_epoch = params.cur_epoch
+        params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+    """
+    Args:
+      rank:
+        It is a value between 0 and `world_size-1`, which is
+        passed automatically by `mp.spawn()` in :func:`main`.
+        The node with rank 0 is responsible for saving checkpoint.
+      world_size:
+        Number of GPUs for DDP training.
+      args:
+        The return value of get_parser().parse_args()
+    """
+    params = get_params()
+    params.update(vars(args))
+    if params.full_libri is False:
+        params.valid_interval = 1600
+
+    fix_random_seed(params.seed)
+    if world_size > 1:
+        setup_dist(rank, world_size, params.master_port)
+
+    setup_logger(f"{params.exp_dir}/log/log-train")
+    logging.info("Training started")
+
+    if args.tensorboard and rank == 0:
+        tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+    else:
+        tb_writer = None
+
+    lexicon = Lexicon(params.lang_dir)
+    max_token_id = max(lexicon.tokens)
+    params.vocab_size = max_token_id + 1  # +1 for the blank
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", rank)
+    logging.info(f"Device: {device}")
+
+    if "lang_bpe" in str(params.lang_dir):
+        graph_compiler = BpeCtcTrainingGraphCompiler(
+            params.lang_dir,
+            device=device,
+            sos_token="",
+            eos_token="",
+        )
+    elif "lang_phone" in str(params.lang_dir):
+        graph_compiler = CtcTrainingGraphCompiler(
+            lexicon,
+            device=device,
+        )
+        # Manually add the sos/eos ID with their default values
+        # from the BPE recipe which we're adapting here.
+        graph_compiler.sos_id = 1
+        graph_compiler.eos_id = 1
+    else:
+        raise ValueError(
+            f"Unsupported type of lang dir (we expected it to have "
+            f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}"
+        )
+
+    if params.dynamic_chunk_training:
+        assert (
+            params.causal_convolution
+        ), "dynamic_chunk_training requires causal convolution"
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_ctc_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    assert params.save_every_n >= params.average_period
+    model_avg: Optional[nn.Module] = None
+    if rank == 0:
+        # model_avg is only used with rank 0
+        model_avg = copy.deepcopy(model)
+
+    assert params.start_epoch > 0, params.start_epoch
+    checkpoints = load_checkpoint_if_available(
+        params=params, model=model, model_avg=model_avg
+    )
+
+    model.to(device)
+    if world_size > 1:
+        logging.info("Using DDP")
+        model = DDP(model, device_ids=[rank])
+
+    optimizer = Eve(model.parameters(), lr=params.initial_lr)
+
+    scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+    if checkpoints and "optimizer" in checkpoints:
+        logging.info("Loading optimizer state dict")
+        optimizer.load_state_dict(checkpoints["optimizer"])
+
+    if (
+        checkpoints
+        and "scheduler" in checkpoints
+        and checkpoints["scheduler"] is not None
+    ):
+        logging.info("Loading scheduler state dict")
+        scheduler.load_state_dict(checkpoints["scheduler"])
+
+    if params.print_diagnostics:
+        diagnostic = diagnostics.attach_diagnostics(model)
+
+    librispeech = LibriSpeechAsrDataModule(args)
+
+    train_cuts = librispeech.train_clean_100_cuts()
+    if params.full_libri:
+        train_cuts += librispeech.train_clean_360_cuts()
+        train_cuts += librispeech.train_other_500_cuts()
+
+    def remove_short_and_long_utt(c: Cut):
+        # Keep only utterances with duration between 1 second and 20 seconds
+        #
+        # Caution: There is a reason to select 20.0 here. Please see
+        # ../local/display_manifest_statistics.py
+        #
+        # You should use ../local/display_manifest_statistics.py to get
+        # an utterance duration distribution for your dataset to select
+        # the threshold
+        return 1.0 <= c.duration <= 20.0
+
+    train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+    if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+        # We only load the sampler's state dict when it loads a checkpoint
+        # saved in the middle of an epoch
+        sampler_state_dict = checkpoints["sampler"]
+    else:
+        sampler_state_dict = None
+
+    train_dl = librispeech.train_dataloaders(
+        train_cuts, sampler_state_dict=sampler_state_dict
+    )
+
+    valid_cuts = librispeech.dev_clean_cuts()
+    valid_cuts += librispeech.dev_other_cuts()
+    valid_dl = librispeech.valid_dataloaders(valid_cuts)
+
+    if params.start_batch <= 0 and not params.print_diagnostics:
+        scan_pessimistic_batches_for_oom(
+            model=model,
+            train_dl=train_dl,
+            optimizer=optimizer,
+            graph_compiler=graph_compiler,
+            params=params,
+            warmup=0.0 if params.start_epoch == 1 else 1.0,
+        )
+
+    scaler = GradScaler(enabled=params.use_fp16)
+    if checkpoints and "grad_scaler" in checkpoints:
+        logging.info("Loading grad scaler state dict")
+        scaler.load_state_dict(checkpoints["grad_scaler"])
+
+    for epoch in range(params.start_epoch, params.num_epochs + 1):
+        scheduler.step_epoch(epoch - 1)
+        fix_random_seed(params.seed + epoch - 1)
+        train_dl.sampler.set_epoch(epoch - 1)
+
+        if tb_writer is not None:
+            tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+        params.cur_epoch = epoch
+
+        train_one_epoch(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            graph_compiler=graph_compiler,
+            train_dl=train_dl,
+            valid_dl=valid_dl,
+            scaler=scaler,
+            tb_writer=tb_writer,
+            world_size=world_size,
+            rank=rank,
+        )
+
+        if params.print_diagnostics:
+            diagnostic.print_diagnostics()
+            break
+
+        save_checkpoint(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sampler=train_dl.sampler,
+            scaler=scaler,
+            rank=rank,
+        )
+
+    logging.info("Done!")
+
+    if world_size > 1:
+        torch.distributed.barrier()
+        cleanup_dist()
+
+
+def scan_pessimistic_batches_for_oom(
+    model: Union[nn.Module, DDP],
+    train_dl: torch.utils.data.DataLoader,
+    optimizer: torch.optim.Optimizer,
+    graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler],
+    params: AttributeDict,
+    warmup: float,
+):
+    from lhotse.dataset import find_pessimistic_batches
+
+    logging.info(
+        "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+    )
+    batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+    for criterion, cuts in batches.items():
+        batch = train_dl.dataset[cuts]
+        try:
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, _ = compute_loss(
+                    params=params,
+                    model=model,
+                    graph_compiler=graph_compiler,
+                    batch=batch,
+                    is_training=True,
+                    warmup=warmup,
+                )
+            loss.backward()
+            optimizer.step()
+            optimizer.zero_grad()
+        except RuntimeError as e:
+            if "CUDA out of memory" in str(e):
+                logging.error(
+                    "Your GPU ran out of memory with the current "
+                    "max_duration setting. We recommend decreasing "
+                    "max_duration and trying again.\n"
+                    f"Failing criterion: {criterion} "
+                    f"(={crit_values[criterion]}) ..."
+                )
+            raise
+
+
+def main():
+    parser = get_parser()
+    LibriSpeechAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    world_size = args.world_size
+    assert world_size >= 1
+    if world_size > 1:
+        mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+    else:
+        run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+    main()
diff --git a/icefall/bpe_graph_compiler.py b/icefall/bpe_graph_compiler.py
index e76b7ea32..d9659c2dd 100644
--- a/icefall/bpe_graph_compiler.py
+++ b/icefall/bpe_graph_compiler.py
@@ -83,11 +83,12 @@ class BpeCtcTrainingGraphCompiler(object):
         Args:
           piece_ids:
             It is a list-of-list integer IDs.
-         modified:
+          modified:
            See :func:`k2.ctc_graph` for its meaning.
         Return:
           Return an FsaVec, which is the result of composing a
           CTC topology with linear FSAs constructed from the given
           piece IDs.
         """
-        return k2.ctc_graph(piece_ids, modified=modified, device=self.device)
+        graph = k2.ctc_graph(piece_ids, modified=modified, device=self.device)
+        return graph
diff --git a/icefall/char_graph_compiler.py b/icefall/char_graph_compiler.py
index c31db6e4c..5f9571d42 100644
--- a/icefall/char_graph_compiler.py
+++ b/icefall/char_graph_compiler.py
@@ -117,4 +117,5 @@ class CharCtcTrainingGraphCompiler(object):
           CTC topology with linear FSAs constructed from the given
           piece IDs.
         """
-        return k2.ctc_graph(token_ids, modified=modified, device=self.device)
+        graph = k2.ctc_graph(token_ids, modified=modified, device=self.device)
+        return graph
diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py
index f0663a1df..c83c56a53 100644
--- a/icefall/checkpoint.py
+++ b/icefall/checkpoint.py
@@ -298,7 +298,7 @@ def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
         if not result:
             logging.warn(f"Invalid checkpoint filename {c}")
             continue
-        
+
         iter_checkpoints.append((int(result.group(1)), c))
 
     # iter_checkpoints is a list of tuples. Each tuple contains
diff --git a/icefall/graph_compiler.py b/icefall/graph_compiler.py
index e2ff03f61..84be81254 100644
--- a/icefall/graph_compiler.py
+++ b/icefall/graph_compiler.py
@@ -79,6 +79,10 @@ class CtcTrainingGraphCompiler(object):
 
         fsa_with_self_loops = k2.arc_sort(fsa_with_self_loops)
 
+        self.ctc_topo._is_repeat_token_ = (
+            self.ctc_topo.labels != self.ctc_topo.aux_labels
+        )
+
         decoding_graph = k2.compose(
             self.ctc_topo, fsa_with_self_loops, treat_epsilons_specially=False
         )
diff --git a/icefall/utils.py b/icefall/utils.py
index b4d8e9a51..d852491c8 100644
--- a/icefall/utils.py
+++ b/icefall/utils.py
@@ -670,8 +670,8 @@ def write_error_stats_with_timestamps(
     all_delay = []
     for cut_id, ref, hyp, time_ref, time_hyp in results:
         ali = kaldialign.align(ref, hyp, ERR)
-        has_time_ref = len(time_ref) > 0
-        if has_time_ref:
+        has_time = len(time_ref) > 0 and len(time_hyp) > 0
+        if has_time:
             # pointer to timestamp_hyp
             p_hyp = 0
             # pointer to timestamp_ref
@@ -680,28 +680,28 @@ def write_error_stats_with_timestamps(
             if ref_word == ERR:
                 ins[hyp_word] += 1
                 words[hyp_word][3] += 1
-                if has_time_ref:
+                if has_time:
                     p_hyp += 1
             elif hyp_word == ERR:
                 dels[ref_word] += 1
                 words[ref_word][4] += 1
-                if has_time_ref:
+                if has_time:
                     p_ref += 1
             elif hyp_word != ref_word:
                 subs[(ref_word, hyp_word)] += 1
                 words[ref_word][1] += 1
                 words[hyp_word][2] += 1
-                if has_time_ref:
+                if has_time:
                     p_hyp += 1
                     p_ref += 1
             else:
                 words[ref_word][0] += 1
                 num_corr += 1
-                if has_time_ref:
+                if has_time:
                     all_delay.append(time_hyp[p_hyp] - time_ref[p_ref])
                     p_hyp += 1
                     p_ref += 1
-        if has_time_ref:
+        if has_time:
             assert p_hyp == len(hyp), (p_hyp, len(hyp))
             assert p_ref == len(ref), (p_ref, len(ref))
 
@@ -1327,10 +1327,9 @@ def parse_timestamp(tokens: List[str], timestamp: List[float]) -> List[float]:
 
 def parse_hyp_and_timestamp(
     res: DecodingResults,
-    decoding_method: str,
-    sp: spm.SentencePieceProcessor,
     subsampling_factor: int,
     frame_shift_ms: float = 10,
+    sp: Optional[spm.SentencePieceProcessor] = None,
     word_table: Optional[k2.SymbolTable] = None,
 ) -> Tuple[List[List[str]], List[List[float]]]:
     """Parse hypothesis and timestamp.
@@ -1338,51 +1337,29 @@ def parse_hyp_and_timestamp(
     Args:
       res:
         A DecodingResults object.
-      decoding_method:
-        Possible values are:
-          - greedy_search
-          - beam_search
-          - modified_beam_search
-          - fast_beam_search
-          - fast_beam_search_LG
-          - fast_beam_search_nbest
-          - fast_beam_search_nbest_oracle
-          - fast_beam_search_nbest_LG
-      sp:
-        The BPE model.
       subsampling_factor:
         The integer subsampling factor.
       frame_shift_ms:
         The float frame shift used for feature extraction.
+      sp:
+        The BPE model.
       word_table:
         The word symbol table.
 
     Returns:
        Return a list of hypothesis and timestamp.
     """
-    assert decoding_method in (
-        "greedy_search",
-        "beam_search",
-        "fast_beam_search",
-        "fast_beam_search_LG",
-        "fast_beam_search_nbest",
-        "fast_beam_search_nbest_LG",
-        "fast_beam_search_nbest_oracle",
-        "modified_beam_search",
-    )
-
     hyps = []
     timestamps = []
 
     N = len(res.hyps)
     assert len(res.timestamps) == N, (len(res.timestamps), N)
     use_word_table = False
-    if (
-        decoding_method == "fast_beam_search_nbest_LG"
-        and decoding_method == "fast_beam_search_LG"
-    ):
-        assert word_table is not None
+    if word_table is not None:
+        assert sp is None
         use_word_table = True
+    else:
+        assert sp is not None and word_table is None
 
     for i in range(N):
         time = convert_timestamp(res.timestamps[i], subsampling_factor, frame_shift_ms)

From 4b5bc480e8a5ac253dcd22b08dfa59083dadd6fd Mon Sep 17 00:00:00 2001
From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com>
Date: Wed, 30 Nov 2022 17:26:05 +0800
Subject: [PATCH 052/120] Add low-order density ratio in RNNLM shallow fusion
 (#678)

* Support LODR in RNNLM shallow fusion

* fix style

* fix code style

* update workflow and CI

* update results

* propagate changes to stateless3

* add decoding results for stateless3+giga

* fix CI
---
 ...-lstm-transducer-stateless2-2022-09-03.yml |  67 ++++-
 ...-lstm-transducer-stateless2-2022-09-03.yml |  15 +-
 egs/librispeech/ASR/RESULTS.md                |  87 ++++++
 .../ASR/lstm_transducer_stateless2/decode.py  |  51 +++-
 .../beam_search.py                            | 264 ++++++++++++++++++
 .../pruned_transducer_stateless3/decode.py    | 181 +++++++++++-
 6 files changed, 646 insertions(+), 19 deletions(-)

diff --git a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
index 6ce92d022..ac5b15979 100755
--- a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
+++ b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
@@ -16,6 +16,7 @@ log "Downloading pre-trained model from $repo_url"
 git lfs install
 git clone $repo_url
 repo=$(basename $repo_url)
+abs_repo=$(realpath $repo)
 
 log "Display test files"
 tree $repo/
@@ -178,21 +179,27 @@ echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
 if [[ x"${GITHUB_EVENT_LABEL_NAME}" == x"shallow-fusion" ]]; then
   lm_repo_url=https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
   log "Download pre-trained RNN-LM model from ${lm_repo_url}"
-  git clone $lm_repo_url
+  GIT_LFS_SKIP_SMUDGE=1 git clone $lm_repo_url
   lm_repo=$(basename $lm_repo_url)
   pushd $lm_repo
   git lfs pull --include "exp/pretrained.pt"
-  cd exp
-  ln -s pretrained.pt epoch-88.pt
+  mv exp/pretrained.pt exp/epoch-88.pt
   popd
 
+  mkdir -p lstm_transducer_stateless2/exp
+  ln -sf $PWD/$repo/exp/pretrained.pt lstm_transducer_stateless2/exp/epoch-999.pt
+  ln -s $PWD/$repo/data/lang_bpe_500 data/
+
+  ls -lh data
+  ls -lh lstm_transducer_stateless2/exp
+
+  log "Decoding test-clean and test-other"
+
   ./lstm_transducer_stateless2/decode.py \
     --use-averaged-model 0 \
-    --epoch 99 \
+    --epoch 999 \
     --avg 1 \
-    --exp-dir $repo/exp \
-    --lang-dir $repo/data/lang_bpe_500 \
-    --bpe-model $repo/data/lang_bpe_500/bpe.model \
+    --exp-dir lstm_transducer_stateless2/exp \
     --max-duration 600 \
     --decoding-method modified_beam_search_rnnlm_shallow_fusion \
     --beam 4 \
@@ -204,6 +211,52 @@ if [[ x"${GITHUB_EVENT_LABEL_NAME}" == x"shallow-fusion" ]]; then
     --rnn-lm-tie-weights 1
 fi
 
+if [[ x"${GITHUB_EVENT_LABEL_NAME}" == x"LODR" ]]; then
+  bigram_repo_url=https://huggingface.co/marcoyang/librispeech_bigram
+  log "Download bi-gram LM from ${bigram_repo_url}"
+  GIT_LFS_SKIP_SMUDGE=1 git clone $bigram_repo_url
+  bigramlm_repo=$(basename $bigram_repo_url)
+  pushd $bigramlm_repo
+  git lfs pull --include "2gram.fst.txt"
+  cp 2gram.fst.txt $abs_repo/data/lang_bpe_500/.
+  popd
+
+  lm_repo_url=https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
+  log "Download pre-trained RNN-LM model from ${lm_repo_url}"
+  GIT_LFS_SKIP_SMUDGE=1 git clone $lm_repo_url
+  lm_repo=$(basename $lm_repo_url)
+  pushd $lm_repo
+  git lfs pull --include "exp/pretrained.pt"
+  mv exp/pretrained.pt exp/epoch-88.pt
+  popd
+
+  mkdir -p lstm_transducer_stateless2/exp
+  ln -sf $PWD/$repo/exp/pretrained.pt lstm_transducer_stateless2/exp/epoch-999.pt
+  ln -s $PWD/$repo/data/lang_bpe_500 data/
+
+  ls -lh data
+  ls -lh lstm_transducer_stateless2/exp
+
+  log "Decoding test-clean and test-other"
+
+  ./lstm_transducer_stateless2/decode.py \
+    --use-averaged-model 0 \
+    --epoch 999 \
+    --avg 1 \
+    --exp-dir lstm_transducer_stateless2/exp \
+    --max-duration 600 \
+    --decoding-method modified_beam_search_rnnlm_LODR \
+    --beam 4 \
+    --rnn-lm-scale 0.3 \
+    --rnn-lm-exp-dir $lm_repo/exp \
+    --rnn-lm-epoch 88 \
+    --rnn-lm-avg 1 \
+    --rnn-lm-num-layers 3 \
+    --rnn-lm-tie-weights 1 \
+    --tokens-ngram 2 \
+    --ngram-lm-scale -0.16
+fi
+
 if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then
   mkdir -p lstm_transducer_stateless2/exp
   ln -s $PWD/$repo/exp/pretrained.pt lstm_transducer_stateless2/exp/epoch-999.pt
diff --git a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
index a90841fb6..5f0acf9b8 100644
--- a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
+++ b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
@@ -18,7 +18,7 @@ on:
 
 jobs:
   run_librispeech_lstm_transducer_stateless2_2022_09_03:
-    if: github.event.label.name == 'ready' || github.event.label.name == 'shallow-fusion' || github.event.label.name == 'ncnn' || github.event.label.name == 'onnx' || github.event_name == 'push' || github.event_name == 'schedule'
+    if: github.event.label.name == 'ready' || github.event.label.name == 'LODR' || github.event.label.name == 'shallow-fusion' || github.event.label.name == 'ncnn' || github.event.label.name == 'onnx' || github.event_name == 'push' || github.event_name == 'schedule'
     runs-on: ${{ matrix.os }}
     strategy:
       matrix:
@@ -139,9 +139,20 @@ jobs:
           find modified_beam_search_rnnlm_shallow_fusion  -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
           find modified_beam_search_rnnlm_shallow_fusion  -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
 
+      - name: Display decoding results for lstm_transducer_stateless2
+        if: github.event.label.name == 'LODR'
+        shell: bash
+        run: |
+          cd egs/librispeech/ASR
+          tree lstm_transducer_stateless2/exp
+          cd lstm_transducer_stateless2/exp
+          echo "===modified_beam_search_rnnlm_LODR==="
+          find modified_beam_search_rnnlm_LODR  -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+          find modified_beam_search_rnnlm_LODR  -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
       - name: Upload decoding results for lstm_transducer_stateless2
         uses: actions/upload-artifact@v2
-        if: github.event_name == 'schedule' || github.event.label.name == 'shallow-fusion'
+        if: github.event_name == 'schedule' || github.event.label.name == 'shallow-fusion' || github.event.label.name == 'LODR'
         with:
           name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-lstm_transducer_stateless2-2022-09-03
           path: egs/librispeech/ASR/lstm_transducer_stateless2/exp/
diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md
index efd60ba81..c2ea3d050 100644
--- a/egs/librispeech/ASR/RESULTS.md
+++ b/egs/librispeech/ASR/RESULTS.md
@@ -318,6 +318,7 @@ The WERs are:
 | greedy search (max sym per frame 1) | 2.78       | 7.36       | --iter 468000 --avg 16  |
 | modified_beam_search                | 2.73       | 7.15       | --iter 468000 --avg 16  |
 | modified_beam_search + RNNLM shallow fusion   | 2.42     |  6.46      | --iter 468000 --avg 16  |
+| modified_beam_search + RNNLM shallow fusion   | 2.28     |  5.94      | --iter 468000 --avg 16  |
 | fast_beam_search                    | 2.76       | 7.31       | --iter 468000 --avg 16  |
 | greedy search (max sym per frame 1) | 2.77       | 7.35       | --iter 472000 --avg 18  |
 | modified_beam_search                | 2.75       | 7.08       | --iter 472000 --avg 18  |
@@ -393,6 +394,32 @@ for iter in 472000; do
     done
 done
 
+You may also decode using LODR + RNNLM shallow fusion. This decoding method is proposed in .
+It subtracts the internal language model score during shallow fusion, which is approximated by a bi-gram model. The bi-gram can be
+generated by `generate-lm.sh`, or you may download it from .
+
+The decoding command is as follows:
+
+for iter in 472000; do
+    for avg in 8 10 12 14 16 18; do
+        ./lstm_transducer_stateless2/decode.py \
+                --iter $iter \
+                --avg $avg \
+                --exp-dir ./lstm_transducer_stateless2/exp \
+                --max-duration 600 \
+                --decoding-method modified_beam_search_rnnlm_LODR \
+                --beam 4 \
+                --rnn-lm-scale 0.4 \
+                --rnn-lm-exp-dir /path/to/RNNLM \
+                --rnn-lm-epoch 99 \
+                --rnn-lm-avg 1 \
+                --rnn-lm-num-layers 3 \
+                --rnn-lm-tie-weights 1 \
+                --token-ngram 2 \
+                --ngram-lm-scale -0.16
+    done
+done
+
 Pretrained models, training logs, decoding logs, and decoding results
 are available at
 
@@ -1912,6 +1939,8 @@ subset so that the gigaspeech dataloader never exhausts.
 |-------------------------------------|------------|------------|---------------------------------------------|
 | greedy search (max sym per frame 1) | 2.03       | 4.70       | --iter 1224000 --avg 14  --max-duration 600 |
 | modified beam search                | 2.00       | 4.63       | --iter 1224000 --avg 14  --max-duration 600 |
+| modified beam search + rnnlm shallow fusion  | 1.94     |  4.2    | --iter 1224000 --avg 14  --max-duration 600 |
+| modified beam search + LODR         | 1.83       | 4.03       | --iter 1224000 --avg 14  --max-duration 600 |
 | fast beam search                    | 2.10       | 4.68       | --iter 1224000 --avg 14 --max-duration 600 |
 
 The training commands are:
@@ -1957,6 +1986,64 @@ for iter in 1224000; do
   done
 done
 ```
+You may also decode using shallow fusion with external RNNLM. To do so you need to
+download a well-trained RNNLM from this link 
+
+```bash
+rnn_lm_scale=0.3
+
+for iter in 1224000; do
+  for avg in 14; do
+    for method in modified_beam_search_rnnlm_shallow_fusion ; do
+      ./pruned_transducer_stateless3/decode.py \
+        --iter $iter \
+        --avg $avg \
+        --exp-dir ./pruned_transducer_stateless3/exp-0.9/ \
+        --max-duration 600 \
+        --decoding-method $method \
+        --max-sym-per-frame 1 \
+        --beam 4 \
+        --max-contexts 32 \
+        --rnn-lm-scale $rnn_lm_scale \
+        --rnn-lm-exp-dir /path/to/RNNLM \
+        --rnn-lm-epoch 99 \
+        --rnn-lm-avg 1 \
+        --rnn-lm-num-layers 3 \
+        --rnn-lm-tie-weights 1
+    done
+  done
+done
+```
+
+If you want to try out with LODR decoding, use the following command. This assums you have a bi-gram LM trained on LibriSpeech text. You can also download the bi-gram LM from here  and put it under the directory `data/lang_bpe_500`.
+
+```bash
+rnn_lm_scale=0.4
+
+for iter in 1224000; do
+  for avg in 14; do
+    for method in modified_beam_search_rnnlm_LODR ; do
+      ./pruned_transducer_stateless3/decode.py \
+        --iter $iter \
+        --avg $avg \
+        --exp-dir ./pruned_transducer_stateless3/exp-0.9/ \
+        --max-duration 600 \
+        --decoding-method $method \
+        --max-sym-per-frame 1 \
+        --beam 4 \
+        --max-contexts 32 \
+        --rnn-lm-scale $rnn_lm_scale \
+        --rnn-lm-exp-dir /path/to/RNNLM \
+        --rnn-lm-epoch 99 \
+        --rnn-lm-avg 1 \
+        --rnn-lm-num-layers 3 \
+        --rnn-lm-tie-weights 1 \
+        --tokens-ngram 2 \
+        --ngram-lm-scale -0.14
+    done
+  done
+done
+```
 
 The pretrained models, training logs, decoding logs, and decoding results
 can be found at
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py
index 69f695fef..fa5bf1825 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py
@@ -107,8 +107,25 @@ Usage:
     --rnn-lm-avg 1 \
     --rnn-lm-num-layers 3 \
     --rnn-lm-tie-weights 1
-"""
 
+(9) modified beam search with RNNLM shallow fusion + LODR
+./lstm_transducer_stateless2/decode.py \
+    --epoch 35 \
+    --avg 15 \
+    --max-duration 600 \
+    --exp-dir ./lstm_transducer_stateless2/exp \
+    --decoding-method modified_beam_search_rnnlm_LODR \
+    --beam 4 \
+    --max-contexts 4 \
+    --rnn-lm-scale 0.4 \
+    --rnn-lm-exp-dir /path/to/RNNLM/exp \
+    --rnn-lm-epoch 99 \
+    --rnn-lm-avg 1 \
+    --rnn-lm-num-layers 3 \
+    --rnn-lm-tie-weights 1 \
+    --tokens-ngram 2 \
+    --ngram-lm-scale -0.16 \
+"""
 
 import argparse
 import logging
@@ -132,6 +149,7 @@ from beam_search import (
     greedy_search_batch,
     modified_beam_search,
     modified_beam_search_ngram_rescoring,
+    modified_beam_search_rnnlm_LODR,
     modified_beam_search_rnnlm_shallow_fusion,
 )
 from librispeech import LibriSpeech
@@ -235,7 +253,8 @@ def get_parser():
           - fast_beam_search_nbest_oracle
           - fast_beam_search_nbest_LG
           - modified_beam_search_ngram_rescoring
-          - modified_beam_search_rnnlm_shallow_fusion # for rnn lm shallow fusion
+          - modified_beam_search_rnnlm_shallow_fusion
+          - modified_beam_search_rnnlm_LODR
         If you use fast_beam_search_nbest_LG, you have to specify
         `--lang-dir`, which should contain `LG.pt`.
         """,
@@ -394,7 +413,8 @@ def get_parser():
         type=int,
         default=3,
         help="""Token Ngram used for rescoring.
-            Used only when the decoding method is modified_beam_search_ngram_rescoring""",
+            Used only when the decoding method is
+            modified_beam_search_ngram_rescoring""",
     )
 
     parser.add_argument(
@@ -402,7 +422,8 @@ def get_parser():
         type=int,
         default=500,
         help="""ID of the backoff symbol.
-                Used only when the decoding method is modified_beam_search_ngram_rescoring""",
+                Used only when the decoding method is
+                modified_beam_search_ngram_rescoring""",
     )
 
     add_model_arguments(parser)
@@ -572,6 +593,20 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
+    elif params.decoding_method == "modified_beam_search_rnnlm_LODR":
+        hyp_tokens = modified_beam_search_rnnlm_LODR(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam_size,
+            sp=sp,
+            LODR_lm=ngram_lm,
+            LODR_lm_scale=ngram_lm_scale,
+            rnnlm=rnnlm,
+            rnnlm_scale=rnnlm_scale,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
     else:
         batch_size = encoder_out.size(0)
 
@@ -760,6 +795,7 @@ def main():
         "fast_beam_search_nbest_LG",
         "fast_beam_search_nbest_oracle",
         "modified_beam_search",
+        "modified_beam_search_rnnlm_LODR",
         "modified_beam_search_ngram_rescoring",
         "modified_beam_search_rnnlm_shallow_fusion",
     )
@@ -788,6 +824,9 @@ def main():
     if "rnnlm" in params.decoding_method:
         params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}"
 
+    if "LODR" in params.decoding_method:
+        params.suffix += "-LODR"
+
     if params.use_averaged_model:
         params.suffix += "-use-averaged-model"
 
@@ -901,7 +940,7 @@ def main():
     model.eval()
 
     # only load N-gram LM when needed
-    if "ngram" in params.decoding_method:
+    if "ngram" in params.decoding_method or "LODR" in params.decoding_method:
         lm_filename = f"{params.tokens_ngram}gram.fst.txt"
         logging.info(f"lm filename: {lm_filename}")
         ngram_lm = NgramLm(
@@ -910,6 +949,7 @@ def main():
             is_binary=False,
         )
         logging.info(f"num states: {ngram_lm.lm.num_states}")
+        ngram_lm_scale = params.ngram_lm_scale
     else:
         ngram_lm = None
         ngram_lm_scale = None
@@ -933,7 +973,6 @@ def main():
         )
         rnn_lm_model.to(device)
         rnn_lm_model.eval()
-
     else:
         rnn_lm_model = None
         rnn_lm_scale = 0.0
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py
index 5e9428b60..59c8ed5b5 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py
@@ -2083,3 +2083,267 @@ def modified_beam_search_rnnlm_shallow_fusion(
             tokens=ans,
             timestamps=ans_timestamps,
         )
+
+
+def modified_beam_search_rnnlm_LODR(
+    model: Transducer,
+    encoder_out: torch.Tensor,
+    encoder_out_lens: torch.Tensor,
+    sp: spm.SentencePieceProcessor,
+    LODR_lm: NgramLm,
+    LODR_lm_scale: float,
+    rnnlm: RnnLmModel,
+    rnnlm_scale: float,
+    beam: int = 4,
+) -> List[List[int]]:
+    """This function implements LODR (https://arxiv.org/abs/2203.16776) with
+    `modified_beam_search`. It uses a bi-gram language model as the estimate
+    of the internal language model and subtracts its score during shallow fusion
+    with an external language model. This implementation uses a RNNLM as the
+    external language model.
+
+    Args:
+        model (Transducer):
+            The transducer model
+        encoder_out (torch.Tensor):
+            Encoder output in (N,T,C)
+        encoder_out_lens (torch.Tensor):
+            A 1-D tensor of shape (N,), containing the number of
+            valid frames in encoder_out before padding.
+        sp:
+            Sentence piece generator.
+        LODR_lm:
+            A low order n-gram LM
+        LODR_lm_scale:
+            The scale of the LODR_lm
+        rnnlm (RnnLmModel):
+            RNNLM, the external language model
+        rnnlm_scale (float):
+            scale of RNNLM in shallow fusion
+        beam (int, optional):
+            Beam size. Defaults to 4.
+
+    Returns:
+      Return a list-of-list of token IDs. ans[i] is the decoding results
+      for the i-th utterance.
+
+    """
+    assert encoder_out.ndim == 3, encoder_out.shape
+    assert encoder_out.size(0) >= 1, encoder_out.size(0)
+    assert rnnlm is not None
+    lm_scale = rnnlm_scale
+    vocab_size = rnnlm.vocab_size
+
+    packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
+        input=encoder_out,
+        lengths=encoder_out_lens.cpu(),
+        batch_first=True,
+        enforce_sorted=False,
+    )
+
+    blank_id = model.decoder.blank_id
+    sos_id = sp.piece_to_id("")
+    unk_id = getattr(model, "unk_id", blank_id)
+    context_size = model.decoder.context_size
+    device = next(model.parameters()).device
+
+    batch_size_list = packed_encoder_out.batch_sizes.tolist()
+    N = encoder_out.size(0)
+    assert torch.all(encoder_out_lens > 0), encoder_out_lens
+    assert N == batch_size_list[0], (N, batch_size_list)
+
+    # get initial lm score and lm state by scoring the "sos" token
+    sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device)
+    init_score, init_states = rnnlm.score_token(sos_token)
+
+    B = [HypothesisList() for _ in range(N)]
+    for i in range(N):
+        B[i].add(
+            Hypothesis(
+                ys=[blank_id] * context_size,
+                log_prob=torch.zeros(1, dtype=torch.float32, device=device),
+                state=init_states,  # state of the RNNLM
+                lm_score=init_score.reshape(-1),
+                state_cost=NgramLmStateCost(
+                    LODR_lm
+                ),  # state of the source domain ngram
+            )
+        )
+
+    rnnlm.clean_cache()
+    encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
+
+    offset = 0
+    finalized_B = []
+    for batch_size in batch_size_list:
+        start = offset
+        end = offset + batch_size
+        current_encoder_out = encoder_out.data[start:end]  # get batch
+        current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
+        # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
+        offset = end
+
+        finalized_B = B[batch_size:] + finalized_B
+        B = B[:batch_size]
+
+        hyps_shape = get_hyps_shape(B).to(device)
+
+        A = [list(b) for b in B]
+        B = [HypothesisList() for _ in range(batch_size)]
+
+        ys_log_probs = torch.cat(
+            [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
+        )
+
+        decoder_input = torch.tensor(
+            [hyp.ys[-context_size:] for hyps in A for hyp in hyps],
+            device=device,
+            dtype=torch.int64,
+        )  # (num_hyps, context_size)
+
+        decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
+        decoder_out = model.joiner.decoder_proj(decoder_out)
+
+        current_encoder_out = torch.index_select(
+            current_encoder_out,
+            dim=0,
+            index=hyps_shape.row_ids(1).to(torch.int64),
+        )  # (num_hyps, 1, 1, encoder_out_dim)
+
+        logits = model.joiner(
+            current_encoder_out,
+            decoder_out,
+            project_input=False,
+        )  # (num_hyps, 1, 1, vocab_size)
+
+        logits = logits.squeeze(1).squeeze(1)  # (num_hyps, vocab_size)
+
+        log_probs = logits.log_softmax(dim=-1)  # (num_hyps, vocab_size)
+
+        log_probs.add_(ys_log_probs)
+
+        vocab_size = log_probs.size(-1)
+
+        log_probs = log_probs.reshape(-1)
+
+        row_splits = hyps_shape.row_splits(1) * vocab_size
+        log_probs_shape = k2.ragged.create_ragged_shape2(
+            row_splits=row_splits, cached_tot_size=log_probs.numel()
+        )
+        ragged_log_probs = k2.RaggedTensor(
+            shape=log_probs_shape, value=log_probs
+        )
+        """
+        for all hyps with a non-blank new token, score this token.
+        It is a little confusing here because this for-loop
+        looks very similar to the one below. Here, we go through all
+        top-k tokens and only add the non-blanks ones to the token_list.
+        The RNNLM will score those tokens given the LM states. Note that
+        the variable `scores` is the LM score after seeing the new
+        non-blank token.
+        """
+        token_list = []
+        hs = []
+        cs = []
+        for i in range(batch_size):
+            topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
+
+            with warnings.catch_warnings():
+                warnings.simplefilter("ignore")
+                topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
+                topk_token_indexes = (topk_indexes % vocab_size).tolist()
+            for k in range(len(topk_hyp_indexes)):
+                hyp_idx = topk_hyp_indexes[k]
+                hyp = A[i][hyp_idx]
+
+                new_token = topk_token_indexes[k]
+                if new_token not in (blank_id, unk_id):
+                    assert new_token != 0, new_token
+                    token_list.append([new_token])
+                    # store the LSTM states
+                    hs.append(hyp.state[0])
+                    cs.append(hyp.state[1])
+
+        # forward RNNLM to get new states and scores
+        if len(token_list) != 0:
+            tokens_to_score = (
+                torch.tensor(token_list)
+                .to(torch.int64)
+                .to(device)
+                .reshape(-1, 1)
+            )
+
+            hs = torch.cat(hs, dim=1).to(device)
+            cs = torch.cat(cs, dim=1).to(device)
+            scores, lm_states = rnnlm.score_token(tokens_to_score, (hs, cs))
+
+        count = 0  # index, used to locate score and lm states
+        for i in range(batch_size):
+            topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
+
+            with warnings.catch_warnings():
+                warnings.simplefilter("ignore")
+                topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
+                topk_token_indexes = (topk_indexes % vocab_size).tolist()
+
+            for k in range(len(topk_hyp_indexes)):
+                hyp_idx = topk_hyp_indexes[k]
+                hyp = A[i][hyp_idx]
+
+                ys = hyp.ys[:]
+
+                # current score of hyp
+                lm_score = hyp.lm_score
+                state = hyp.state
+
+                hyp_log_prob = topk_log_probs[k]  # get score of current hyp
+                new_token = topk_token_indexes[k]
+                if new_token not in (blank_id, unk_id):
+
+                    ys.append(new_token)
+                    state_cost = hyp.state_cost.forward_one_step(new_token)
+
+                    # calculate the score of the latest token
+                    current_ngram_score = (
+                        state_cost.lm_score - hyp.state_cost.lm_score
+                    )
+
+                    assert current_ngram_score <= 0.0, (
+                        state_cost.lm_score,
+                        hyp.state_cost.lm_score,
+                    )
+                    # score = score + RNNLM_score - LODR_score
+                    # LODR_LM_scale is a negative number here
+                    hyp_log_prob += (
+                        lm_score[new_token] * lm_scale
+                        + LODR_lm_scale * current_ngram_score
+                    )  # add the lm score
+
+                    lm_score = scores[count]
+                    state = (
+                        lm_states[0][:, count, :].unsqueeze(1),
+                        lm_states[1][:, count, :].unsqueeze(1),
+                    )
+                    count += 1
+                else:
+                    state_cost = hyp.state_cost
+
+                new_hyp = Hypothesis(
+                    ys=ys,
+                    log_prob=hyp_log_prob,
+                    state=state,
+                    lm_score=lm_score,
+                    state_cost=state_cost,
+                )
+                B[i].add(new_hyp)
+
+    B = B + finalized_B
+    best_hyps = [b.get_most_probable(length_norm=True) for b in B]
+
+    sorted_ans = [h.ys[context_size:] for h in best_hyps]
+    ans = []
+    unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
+    for i in range(N):
+        ans.append(sorted_ans[unsorted_indices[i]])
+
+    return ans
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py
index 03137501f..e00aab34a 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py
@@ -1,6 +1,7 @@
 #!/usr/bin/env python3
 #
-# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang
+#                                            Xiaoyu Yang)
 #
 # See ../../../../LICENSE for clarification regarding multiple authors
 #
@@ -90,8 +91,40 @@ Usage:
     --beam 20.0 \
     --max-contexts 8 \
     --max-states 64
-"""
 
+(8) modified beam search (with RNNLM shallow fusion)
+./pruned_transducer_stateless3/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless3/exp \
+    --max-duration 600 \
+    --decoding-method modified_beam_search_rnnlm_shallow_fusion \
+    --beam 4 \
+    --rnn-lm-scale 0.3 \
+    --rnn-lm-exp-dir /path/to/RNNLM \
+    --rnn-lm-epoch 99 \
+    --rnn-lm-avg 1 \
+    --rnn-lm-num-layers 3 \
+    --rnn-lm-tie-weights 1
+
+(9) modified beam search with RNNLM shallow fusion + LODR
+./pruned_transducer_stateless3/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --max-duration 600 \
+    --exp-dir ./pruned_transducer_stateless3/exp \
+    --decoding-method modified_beam_search_rnnlm_LODR \
+    --beam 4 \
+    --max-contexts 4 \
+    --rnn-lm-scale 0.4 \
+    --rnn-lm-exp-dir /path/to/RNNLM/exp \
+    --rnn-lm-epoch 99 \
+    --rnn-lm-avg 1 \
+    --rnn-lm-num-layers 3 \
+    --rnn-lm-tie-weights 1 \
+    --tokens-ngram 2 \
+    --ngram-lm-scale -0.16 \
+"""
 
 import argparse
 import logging
@@ -116,10 +149,14 @@ from beam_search import (
     greedy_search,
     greedy_search_batch,
     modified_beam_search,
+    modified_beam_search_ngram_rescoring,
+    modified_beam_search_rnnlm_LODR,
+    modified_beam_search_rnnlm_shallow_fusion,
 )
 from librispeech import LibriSpeech
 from train import add_model_arguments, get_params, get_transducer_model
 
+from icefall import NgramLm
 from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
 from icefall.lexicon import Lexicon
 from icefall.rnn_lm.model import RnnLmModel
@@ -202,6 +239,9 @@ def get_parser():
           - fast_beam_search_nbest
           - fast_beam_search_nbest_oracle
           - fast_beam_search_nbest_LG
+          - modified_beam_search_ngram_rescoring
+          - modified_beam_search_rnnlm_shallow_fusion
+          - modified_beam_search_rnnlm_LODR
         If you use fast_beam_search_nbest_LG, you have to specify
         `--lang-dir`, which should contain `LG.pt`.
         """,
@@ -263,6 +303,7 @@ def get_parser():
         default=2,
         help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
+
     parser.add_argument(
         "--max-sym-per-frame",
         type=int,
@@ -341,6 +382,15 @@ def get_parser():
          """,
     )
 
+    parser.add_argument(
+        "--rnn-lm-scale",
+        type=float,
+        default=0.0,
+        help="""Used only when --method is modified-beam-search_rnnlm_shallow_fusion.
+        It specifies the path to RNN LM exp dir.
+        """,
+    )
+
     parser.add_argument(
         "--rnn-lm-exp-dir",
         type=str,
@@ -397,6 +447,24 @@ def get_parser():
         """,
     )
 
+    parser.add_argument(
+        "--tokens-ngram",
+        type=int,
+        default=3,
+        help="""Token Ngram used for rescoring.
+            Used only when the decoding method is
+            modified_beam_search_ngram_rescoring""",
+    )
+
+    parser.add_argument(
+        "--backoff-id",
+        type=int,
+        default=500,
+        help="""ID of the backoff symbol.
+                Used only when the decoding method is
+                modified_beam_search_ngram_rescoring""",
+    )
+
     add_model_arguments(parser)
 
     return parser
@@ -410,7 +478,10 @@ def decode_one_batch(
     word_table: Optional[k2.SymbolTable] = None,
     decoding_graph: Optional[k2.Fsa] = None,
     G: Optional[k2.Fsa] = None,
-    rnn_lm_model: torch.nn.Module = None,
+    ngram_lm: Optional[NgramLm] = None,
+    ngram_lm_scale: float = 1.0,
+    rnn_lm_model: Optional[RnnLmModel] = None,
+    rnnlm_scale: float = 1.0,
 ) -> Dict[str, List[List[str]]]:
     """Decode one batch and return the result in a dict. The dict has the
     following format:
@@ -444,6 +515,14 @@ def decode_one_batch(
         fast_beam_search_nbest, fast_beam_search_nbest_oracle,
         or fast_beam_search_with_nbest_rescoring.
         It an FsaVec containing an acceptor.
+      rnn_lm_model:
+        A rnnlm which can be used for rescoring or shallow fusion
+      rnnlm_scale:
+        The scale of the rnnlm.
+      ngram_lm:
+        A ngram lm. Used in LODR decoding.
+      ngram_lm_scale:
+        The scale of the ngram language model.
     Returns:
       Return the decoding result. See above description for the format of
       the returned dict.
@@ -607,6 +686,43 @@ def decode_one_batch(
             nbest_scale=params.nbest_scale,
             temperature=params.temperature,
         )
+    elif params.decoding_method == "modified_beam_search_ngram_rescoring":
+        hyp_tokens = modified_beam_search_ngram_rescoring(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            ngram_lm=ngram_lm,
+            ngram_lm_scale=ngram_lm_scale,
+            beam=params.beam_size,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion":
+        hyp_tokens = modified_beam_search_rnnlm_shallow_fusion(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam_size,
+            sp=sp,
+            rnnlm=rnn_lm_model,
+            rnnlm_scale=rnnlm_scale,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "modified_beam_search_rnnlm_LODR":
+        hyp_tokens = modified_beam_search_rnnlm_LODR(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam_size,
+            sp=sp,
+            LODR_lm=ngram_lm,
+            LODR_lm_scale=ngram_lm_scale,
+            rnnlm=rnn_lm_model,
+            rnnlm_scale=rnnlm_scale,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
     else:
         batch_size = encoder_out.size(0)
 
@@ -693,7 +809,10 @@ def decode_dataset(
     word_table: Optional[k2.SymbolTable] = None,
     decoding_graph: Optional[k2.Fsa] = None,
     G: Optional[k2.Fsa] = None,
-    rnn_lm_model: torch.nn.Module = None,
+    ngram_lm: Optional[NgramLm] = None,
+    ngram_lm_scale: float = 1.0,
+    rnn_lm_model: Optional[RnnLmModel] = None,
+    rnnlm_scale: float = 1.0,
 ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
     """Decode dataset.
 
@@ -749,7 +868,10 @@ def decode_dataset(
             decoding_graph=decoding_graph,
             batch=batch,
             G=G,
+            ngram_lm=ngram_lm,
+            ngram_lm_scale=ngram_lm_scale,
             rnn_lm_model=rnn_lm_model,
+            rnnlm_scale=rnnlm_scale,
         )
 
         for name, hyps in hyps_dict.items():
@@ -900,6 +1022,9 @@ def main():
         "modified_beam_search",
         "fast_beam_search_with_nbest_rescoring",
         "fast_beam_search_with_nbest_rnn_rescoring",
+        "modified_beam_search_rnnlm_LODR",
+        "modified_beam_search_ngram_rescoring",
+        "modified_beam_search_rnnlm_shallow_fusion",
     )
     params.res_dir = params.exp_dir / params.decoding_method
 
@@ -930,6 +1055,13 @@ def main():
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
         params.suffix += f"-temperature-{params.temperature}"
 
+    if "rnnlm" in params.decoding_method:
+        params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}"
+    if "LODR" in params.decoding_method:
+        params.suffix += "-LODR"
+    if "ngram" in params.decoding_method:
+        params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+
     setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
     logging.info("Decoding started")
 
@@ -1048,6 +1180,44 @@ def main():
         word_table = None
         rnn_lm_model = None
 
+    # only load N-gram LM when needed
+    if "ngram" in params.decoding_method or "LODR" in params.decoding_method:
+        lm_filename = f"{params.tokens_ngram}gram.fst.txt"
+        logging.info(f"lm filename: {lm_filename}")
+        ngram_lm = NgramLm(
+            str(params.lang_dir / lm_filename),
+            backoff_id=params.backoff_id,
+            is_binary=False,
+        )
+        logging.info(f"num states: {ngram_lm.lm.num_states}")
+        ngram_lm_scale = params.ngram_lm_scale
+    else:
+        ngram_lm = None
+        ngram_lm_scale = None
+
+    # only load rnnlm if used
+    if "rnnlm" in params.decoding_method:
+        rnn_lm_scale = params.rnn_lm_scale
+
+        rnn_lm_model = RnnLmModel(
+            vocab_size=params.vocab_size,
+            embedding_dim=params.rnn_lm_embedding_dim,
+            hidden_dim=params.rnn_lm_hidden_dim,
+            num_layers=params.rnn_lm_num_layers,
+            tie_weights=params.rnn_lm_tie_weights,
+        )
+        assert params.rnn_lm_avg == 1
+
+        load_checkpoint(
+            f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
+            rnn_lm_model,
+        )
+        rnn_lm_model.to(device)
+        rnn_lm_model.eval()
+    else:
+        rnn_lm_model = None
+        rnn_lm_scale = 0.0
+
     num_param = sum([p.numel() for p in model.parameters()])
     logging.info(f"Number of model parameters: {num_param}")
 
@@ -1074,7 +1244,10 @@ def main():
             word_table=word_table,
             decoding_graph=decoding_graph,
             G=G,
+            ngram_lm=ngram_lm,
+            ngram_lm_scale=ngram_lm_scale,
             rnn_lm_model=rnn_lm_model,
+            rnnlm_scale=rnn_lm_scale,
         )
 
         save_results(

From 556c63fbb741bcbc1669ec6848e06b08480d001f Mon Sep 17 00:00:00 2001
From: Fangjun Kuang 
Date: Thu, 1 Dec 2022 08:58:18 +0800
Subject: [PATCH 053/120] Describe how to fix segfault in doc (#719)

---
 docs/source/installation/index.rst | 11 +++++++++++
 1 file changed, 11 insertions(+)

diff --git a/docs/source/installation/index.rst b/docs/source/installation/index.rst
index c4474c3d9..5b9fb2664 100644
--- a/docs/source/installation/index.rst
+++ b/docs/source/installation/index.rst
@@ -393,6 +393,17 @@ Now let us run the training part:
   We use ``export CUDA_VISIBLE_DEVICES=""`` so that ``icefall`` uses CPU
   even if there are GPUs available.
 
+.. hint::
+
+   In case you get a ``Segmentation fault (core dump)`` error, please use:
+
+      .. code-block:: bash
+
+        export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
+   See more at `` if you are
+   interested.
+
 The training log is given below:
 
 .. code-block::

From 2bca7032afb0d5b9eb60f7bcf3bc15ad1e8d8a83 Mon Sep 17 00:00:00 2001
From: Fangjun Kuang 
Date: Thu, 1 Dec 2022 15:57:43 +0800
Subject: [PATCH 054/120] Update RNNLM training scripts (#720)

* Update RNNLM training scripts

* Fix a typo

* Fix CI
---
 .github/workflows/run-ptb-rnn-lm.yml         | 67 ++++++++++++++++++++
 egs/librispeech/ASR/local/train_bpe_model.py |  4 ++
 egs/ptb/LM/prepare.sh                        | 38 ++++++-----
 egs/ptb/LM/rnn_lm                            |  1 +
 egs/ptb/LM/train-rnn-lm.sh                   | 67 ++++++++++++++++++++
 icefall/rnn_lm/compute_perplexity.py         |  2 +-
 icefall/rnn_lm/dataset.py                    |  4 +-
 icefall/rnn_lm/train.py                      | 10 +--
 8 files changed, 170 insertions(+), 23 deletions(-)
 create mode 100644 .github/workflows/run-ptb-rnn-lm.yml
 create mode 120000 egs/ptb/LM/rnn_lm
 create mode 100755 egs/ptb/LM/train-rnn-lm.sh

diff --git a/.github/workflows/run-ptb-rnn-lm.yml b/.github/workflows/run-ptb-rnn-lm.yml
new file mode 100644
index 000000000..8ebc2e79b
--- /dev/null
+++ b/.github/workflows/run-ptb-rnn-lm.yml
@@ -0,0 +1,67 @@
+name: run-ptb-rnn-lm-training
+
+on:
+  push:
+    branches:
+      - master
+  pull_request:
+    types: [labeled]
+
+  schedule:
+    # minute (0-59)
+    # hour (0-23)
+    # day of the month (1-31)
+    # month (1-12)
+    # day of the week (0-6)
+    # nightly build at 15:50 UTC time every day
+    - cron: "50 15 * * *"
+
+jobs:
+  run_ptb_rnn_lm_training:
+    if: github.event.label.name == 'ready' || github.event.label.name == 'rnnlm' || github.event_name == 'push' || github.event_name == 'schedule'
+    runs-on: ${{ matrix.os }}
+    strategy:
+      matrix:
+        os: [ubuntu-latest]
+        python-version: ["3.8"]
+
+      fail-fast: false
+
+    steps:
+      - uses: actions/checkout@v2
+        with:
+          fetch-depth: 0
+
+      - name: Setup Python ${{ matrix.python-version }}
+        uses: actions/setup-python@v2
+        with:
+          python-version: ${{ matrix.python-version }}
+          cache: 'pip'
+          cache-dependency-path: '**/requirements-ci.txt'
+
+      - name: Install Python dependencies
+        run: |
+          grep -v '^#' ./requirements-ci.txt  | grep -v kaldifst | xargs -n 1 -L 1 pip install
+          pip uninstall -y protobuf
+          pip install --no-binary protobuf protobuf
+
+      - name: Prepare data
+        shell: bash
+        run: |
+          export PYTHONPATH=$PWD:$PYTHONPATH
+          cd egs/ptb/LM
+          ./prepare.sh
+
+      - name: Run training
+        shell: bash
+        run: |
+          export PYTHONPATH=$PWD:$PYTHONPATH
+          cd egs/ptb/LM
+          ./train-rnn-lm.sh --world-size 1 --num-epochs 5 --use-epoch 4 --use-avg 2
+
+      - name: Upload pretrained models
+        uses: actions/upload-artifact@v2
+        if: github.event.label.name == 'ready' || github.event.label.name == 'rnnlm' || github.event_name == 'push' || github.event_name == 'schedule'
+        with:
+          name: python-${{ matrix.python-version }}-ubuntu-rnn-lm-ptb
+          path: egs/ptb/LM/my-rnnlm-exp/
diff --git a/egs/librispeech/ASR/local/train_bpe_model.py b/egs/librispeech/ASR/local/train_bpe_model.py
index 42aba9572..7f6f47e16 100755
--- a/egs/librispeech/ASR/local/train_bpe_model.py
+++ b/egs/librispeech/ASR/local/train_bpe_model.py
@@ -89,6 +89,10 @@ def main():
             bos_id=-1,
             eos_id=-1,
         )
+    else:
+        print(f"{model_file} exists - skipping")
+        return
+
 
     shutil.copyfile(model_file, f"{lang_dir}/bpe.model")
 
diff --git a/egs/ptb/LM/prepare.sh b/egs/ptb/LM/prepare.sh
index 91c3c667a..69fab999a 100755
--- a/egs/ptb/LM/prepare.sh
+++ b/egs/ptb/LM/prepare.sh
@@ -22,9 +22,9 @@ dl_dir=$PWD/download
 # if the array contains xxx, yyy
 vocab_sizes=(
   500
-  1000
-  2000
-  5000
+  # 1000
+  # 2000
+  # 5000
 )
 
 # All files generated by this script are saved in "data".
@@ -42,11 +42,14 @@ log "dl_dir: $dl_dir"
 
 if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
   log "Stage -1: Download data"
+
+  # Caution: The downloaded data has already been normalized for LM training.
+
   if [ ! -f $dl_dir/.complete ]; then
-    url=https://raw.githubusercontent.com/townie/PTB-dataset-from-Tomas-Mikolov-s-webpage/master/data/
-    wget --no-verbose --directory-prefix $dl_dir $url/ptb.train.txt
-    wget --no-verbose --directory-prefix $dl_dir $url/ptb.valid.txt
-    wget --no-verbose --directory-prefix $dl_dir $url/ptb.test.txt
+    url=http://raw.githubusercontent.com/townie/PTB-dataset-from-Tomas-Mikolov-s-webpage/master/data
+    wget --directory-prefix $dl_dir $url/ptb.train.txt
+    wget --directory-prefix $dl_dir $url/ptb.valid.txt
+    wget --directory-prefix $dl_dir $url/ptb.test.txt
     touch $dl_dir/.complete
   fi
 fi
@@ -54,11 +57,15 @@ fi
 if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
   log "Stage 0: Train BPE model"
 
+  # Caution: You have to use the same bpe model for training your acoustic model
+  # Caution: You have to use the same bpe model for training your acoustic model
+  # Caution: You have to use the same bpe model for training your acoustic model
+
   for vocab_size in ${vocab_sizes[@]}; do
-    out_dir=data/bpe_${vocab_size}
-    mkdir -p $out_dir
+    lang_dir=data/lang_bpe_${vocab_size}
+    mkdir -p $lang_dir
     ./local/train_bpe_model.py \
-      --out-dir $out_dir \
+      --lang-dir $lang_dir \
       --vocab-size $vocab_size \
       --transcript $dl_dir/ptb.train.txt
   done
@@ -69,20 +76,21 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
   # Note: ptb.train.txt has already been normalized
 
   for vocab_size in ${vocab_sizes[@]}; do
-    out_dir=data/bpe_${vocab_size}
+    lang_dir=data/lang_bpe_${vocab_size}
+    out_dir=data/lm_training_bpe_${vocab_size}
     mkdir -p $out_dir
     ./local/prepare_lm_training_data.py \
-      --bpe-model $out_dir/bpe.model \
+      --bpe-model $lang_dir/bpe.model \
       --lm-data $dl_dir/ptb.train.txt \
       --lm-archive $out_dir/lm_data.pt
 
     ./local/prepare_lm_training_data.py \
-      --bpe-model $out_dir/bpe.model \
+      --bpe-model $lang_dir/bpe.model \
       --lm-data $dl_dir/ptb.valid.txt \
       --lm-archive $out_dir/lm_data-valid.pt
 
     ./local/prepare_lm_training_data.py \
-      --bpe-model $out_dir/bpe.model \
+      --bpe-model $lang_dir/bpe.model \
       --lm-data $dl_dir/ptb.test.txt \
       --lm-archive $out_dir/lm_data-test.pt
   done
@@ -98,7 +106,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
   # in a sentence.
 
   for vocab_size in ${vocab_sizes[@]}; do
-    out_dir=data/bpe_${vocab_size}
+    out_dir=data/lm_training_bpe_${vocab_size}
     mkdir -p $out_dir
     ./local/sort_lm_training_data.py \
       --in-lm-data $out_dir/lm_data.pt \
diff --git a/egs/ptb/LM/rnn_lm b/egs/ptb/LM/rnn_lm
new file mode 120000
index 000000000..87f29771e
--- /dev/null
+++ b/egs/ptb/LM/rnn_lm
@@ -0,0 +1 @@
+../../../icefall/rnn_lm
\ No newline at end of file
diff --git a/egs/ptb/LM/train-rnn-lm.sh b/egs/ptb/LM/train-rnn-lm.sh
new file mode 100755
index 000000000..29c609ee1
--- /dev/null
+++ b/egs/ptb/LM/train-rnn-lm.sh
@@ -0,0 +1,67 @@
+#!/usr/bin/env bash
+
+# Please run ./prepare.sh first
+
+stage=-1
+stop_stage=100
+
+# Number of GPUs to use for training
+world_size=1
+
+# Number of epochs to train
+num_epochs=20
+
+# Use this epoch for computing ppl
+use_epoch=19
+
+# number of models to average for computing ppl
+use_avg=2
+
+exp_dir=./my-rnnlm-exp
+
+. shared/parse_options.sh || exit 1
+
+log() {
+  # This function is from espnet
+  local fname=${BASH_SOURCE[1]##*/}
+  echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+  log "Training RNN LM"
+
+  ./rnn_lm/train.py \
+    --exp-dir $exp_dir \
+    --start-epoch 0 \
+    --num-epochs $num_epochs \
+    --world-size $world_size \
+    --use-fp16 0 \
+    --vocab-size 500 \
+    \
+    --lm-data ./data/lm_training_bpe_500/sorted_lm_data.pt \
+    --lm-data-valid ./data/lm_training_bpe_500/sorted_lm_data-valid.pt \
+    \
+    --embedding-dim 800 \
+    --hidden-dim 200 \
+    --num-layers 2 \
+    --tie-weights false \
+    --batch-size 50
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+  log "Computing perplexity"
+
+  ./rnn_lm/compute_perplexity.py \
+    --exp-dir $exp_dir \
+    --epoch $use_epoch \
+    --avg $use_avg \
+    --vocab-size 500 \
+    \
+    --lm-data ./data/lm_training_bpe_500/sorted_lm_data-test.pt \
+    \
+    --embedding-dim 800 \
+    --hidden-dim 200 \
+    --num-layers 2 \
+    --tie-weights false \
+    --batch-size 50
+fi
diff --git a/icefall/rnn_lm/compute_perplexity.py b/icefall/rnn_lm/compute_perplexity.py
index 550801a8f..f75a89590 100755
--- a/icefall/rnn_lm/compute_perplexity.py
+++ b/icefall/rnn_lm/compute_perplexity.py
@@ -20,7 +20,7 @@ Usage:
   ./rnn_lm/compute_perplexity.py \
     --epoch 4 \
     --avg 2 \
-    --lm-data ./data/bpe_500/sorted_lm_data-test.pt
+    --lm-data ./data/lm_training_bpe_500/sorted_lm_data-test.pt
 
 """
 
diff --git a/icefall/rnn_lm/dataset.py b/icefall/rnn_lm/dataset.py
index 4bf982503..53be53f64 100644
--- a/icefall/rnn_lm/dataset.py
+++ b/icefall/rnn_lm/dataset.py
@@ -1,4 +1,4 @@
-# Copyright (c)  2021  Xiaomi Corporation (authors: Fangjun Kuang)
+# Copyright (c)  2021  Xiaomi Corporation (authors: Daniel Povey, Fangjun Kuang)
 #
 # See ../../../../LICENSE for clarification regarding multiple authors
 #
@@ -194,7 +194,7 @@ def get_dataloader(
         batch_size=params.batch_size,
     )
     if is_distributed:
-        sampler = DistributedSampler(dataset, shuffle=True, drop_last=False)
+        sampler = DistributedSampler(dataset, shuffle=True, drop_last=True)
     else:
         sampler = None
 
diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py
index 3ba5bfbee..803da99d6 100755
--- a/icefall/rnn_lm/train.py
+++ b/icefall/rnn_lm/train.py
@@ -24,7 +24,7 @@ Usage:
         --use-fp16 0 \
         --embedding-dim 800 \
         --hidden-dim 200 \
-        --num-layers 2\
+        --num-layers 2 \
         --batch-size 400
 
 """
@@ -83,7 +83,7 @@ def get_parser():
     parser.add_argument(
         "--num-epochs",
         type=int,
-        default=10,
+        default=30,
         help="Number of epochs to train.",
     )
 
@@ -110,14 +110,14 @@ def get_parser():
     parser.add_argument(
         "--use-fp16",
         type=str2bool,
-        default=False,
+        default=True,
         help="Whether to use half precision training.",
     )
 
     parser.add_argument(
         "--batch-size",
         type=int,
-        default=50,
+        default=400,
     )
 
     parser.add_argument(
@@ -165,7 +165,7 @@ def get_parser():
     parser.add_argument(
         "--tie-weights",
         type=str2bool,
-        default=False,
+        default=True,
         help="""True to share the weights between the input embedding layer and the
         last output linear layer
         """,

From 04c9fc9c9f9e481cbfae18bb34252b878ff51f6a Mon Sep 17 00:00:00 2001
From: Fangjun Kuang 
Date: Fri, 2 Dec 2022 09:18:28 +0800
Subject: [PATCH 055/120] Fix for older versions of k2 (#725)

---
 icefall/graph_compiler.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/icefall/graph_compiler.py b/icefall/graph_compiler.py
index 84be81254..0dcd777ad 100644
--- a/icefall/graph_compiler.py
+++ b/icefall/graph_compiler.py
@@ -81,7 +81,7 @@ class CtcTrainingGraphCompiler(object):
 
         self.ctc_topo._is_repeat_token_ = (
             self.ctc_topo.labels != self.ctc_topo.aux_labels
-        )
+        ).int()
 
         decoding_graph = k2.compose(
             self.ctc_topo, fsa_with_self_loops, treat_epsilons_specially=False

From 6533f359c998cee6fcb618f7b221cbfee05512e8 Mon Sep 17 00:00:00 2001
From: Fangjun Kuang 
Date: Fri, 2 Dec 2022 10:53:06 +0800
Subject: [PATCH 056/120] Fix CI (#726)

* Fix CI

* Disable shuffle for yesno.

See https://github.com/k2-fsa/icefall/issues/197
---
 .github/workflows/build-doc.yml               |  4 ++
 .github/workflows/run-aishell-2022-06-20.yml  |  4 ++
 .../workflows/run-gigaspeech-2022-05-13.yml   |  4 ++
 .../workflows/run-librispeech-2022-03-12.yml  |  4 ++
 .../workflows/run-librispeech-2022-04-29.yml  |  4 ++
 .../workflows/run-librispeech-2022-05-13.yml  |  4 ++
 .../run-librispeech-2022-11-11-stateless7.yml |  4 ++
 .../run-librispeech-2022-11-14-stateless8.yml |  4 ++
 ...-librispeech-conformer-ctc3-2022-11-28.yml |  4 ++
 ...-lstm-transducer-stateless2-2022-09-03.yml |  4 ++
 ...runed-transducer-stateless3-2022-05-13.yml |  4 ++
 ...aming-transducer-stateless2-2022-06-26.yml |  4 ++
 ...peech-transducer-stateless2-2022-04-19.yml |  4 ++
 .../run-pretrained-conformer-ctc.yml          |  4 ++
 ...-transducer-stateless-librispeech-100h.yml |  4 ++
 ...r-stateless-librispeech-multi-datasets.yml |  4 ++
 ...ransducer-stateless-modified-2-aishell.yml |  4 ++
 ...-transducer-stateless-modified-aishell.yml |  4 ++
 .../run-pretrained-transducer-stateless.yml   |  4 ++
 .../workflows/run-pretrained-transducer.yml   |  4 ++
 .github/workflows/run-ptb-rnn-lm.yml          |  4 ++
 ...netspeech-pruned-transducer-stateless2.yml |  6 +-
 .github/workflows/run-yesno-recipe.yml        | 10 +++-
 .github/workflows/style_check.yml             |  4 ++
 .github/workflows/test.yml                    | 60 ++++++++-----------
 egs/librispeech/ASR/local/train_bpe_model.py  |  1 -
 .../beam_search.py                            | 13 +---
 .../test_scaling.py                           |  8 ---
 egs/yesno/ASR/tdnn/asr_datamodule.py          |  2 +-
 29 files changed, 128 insertions(+), 60 deletions(-)

diff --git a/.github/workflows/build-doc.yml b/.github/workflows/build-doc.yml
index dd0969f51..d7fe2c964 100644
--- a/.github/workflows/build-doc.yml
+++ b/.github/workflows/build-doc.yml
@@ -26,6 +26,10 @@ on:
   pull_request:
     types: [labeled]
 
+concurrency:
+  group: build_doc-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   build-doc:
     if: github.event.label.name == 'doc' || github.event_name == 'push'
diff --git a/.github/workflows/run-aishell-2022-06-20.yml b/.github/workflows/run-aishell-2022-06-20.yml
index e46b01a08..1865a0da8 100644
--- a/.github/workflows/run-aishell-2022-06-20.yml
+++ b/.github/workflows/run-aishell-2022-06-20.yml
@@ -34,6 +34,10 @@ on:
     # nightly build at 15:50 UTC time every day
     - cron: "50 15 * * *"
 
+concurrency:
+  group: run_aishell_2022_06_20-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_aishell_2022_06_20:
     if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
diff --git a/.github/workflows/run-gigaspeech-2022-05-13.yml b/.github/workflows/run-gigaspeech-2022-05-13.yml
index c631927fa..e438c5dba 100644
--- a/.github/workflows/run-gigaspeech-2022-05-13.yml
+++ b/.github/workflows/run-gigaspeech-2022-05-13.yml
@@ -33,6 +33,10 @@ on:
     # nightly build at 15:50 UTC time every day
     - cron: "50 15 * * *"
 
+concurrency:
+  group: run_gigaspeech_2022_05_13-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_gigaspeech_2022_05_13:
     if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
diff --git a/.github/workflows/run-librispeech-2022-03-12.yml b/.github/workflows/run-librispeech-2022-03-12.yml
index 5df710006..3ba6850cd 100644
--- a/.github/workflows/run-librispeech-2022-03-12.yml
+++ b/.github/workflows/run-librispeech-2022-03-12.yml
@@ -33,6 +33,10 @@ on:
     # nightly build at 15:50 UTC time every day
     - cron: "50 15 * * *"
 
+concurrency:
+  group: run_librispeech_2022_03_12-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_librispeech_2022_03_12:
     if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
diff --git a/.github/workflows/run-librispeech-2022-04-29.yml b/.github/workflows/run-librispeech-2022-04-29.yml
index 24c062442..595b410b8 100644
--- a/.github/workflows/run-librispeech-2022-04-29.yml
+++ b/.github/workflows/run-librispeech-2022-04-29.yml
@@ -33,6 +33,10 @@ on:
     # nightly build at 15:50 UTC time every day
     - cron: "50 15 * * *"
 
+concurrency:
+  group: run_librispeech_2022_04_29-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_librispeech_2022_04_29:
     if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
diff --git a/.github/workflows/run-librispeech-2022-05-13.yml b/.github/workflows/run-librispeech-2022-05-13.yml
index 29215ec25..eb0b06a2d 100644
--- a/.github/workflows/run-librispeech-2022-05-13.yml
+++ b/.github/workflows/run-librispeech-2022-05-13.yml
@@ -33,6 +33,10 @@ on:
     # nightly build at 15:50 UTC time every day
     - cron: "50 15 * * *"
 
+concurrency:
+  group: run_librispeech_2022_05_13-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_librispeech_2022_05_13:
     if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
diff --git a/.github/workflows/run-librispeech-2022-11-11-stateless7.yml b/.github/workflows/run-librispeech-2022-11-11-stateless7.yml
index 3b98b500e..365e2761a 100644
--- a/.github/workflows/run-librispeech-2022-11-11-stateless7.yml
+++ b/.github/workflows/run-librispeech-2022-11-11-stateless7.yml
@@ -33,6 +33,10 @@ on:
     # nightly build at 15:50 UTC time every day
     - cron: "50 15 * * *"
 
+concurrency:
+  group: run_librispeech_2022_11_11_zipformer-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_librispeech_2022_11_11_zipformer:
     if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
diff --git a/.github/workflows/run-librispeech-2022-11-14-stateless8.yml b/.github/workflows/run-librispeech-2022-11-14-stateless8.yml
index eaab35189..acb11a8f4 100644
--- a/.github/workflows/run-librispeech-2022-11-14-stateless8.yml
+++ b/.github/workflows/run-librispeech-2022-11-14-stateless8.yml
@@ -33,6 +33,10 @@ on:
     # nightly build at 15:50 UTC time every day
     - cron: "50 15 * * *"
 
+concurrency:
+  group: run_librispeech_2022_11_14_zipformer_stateless8-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_librispeech_2022_11_14_zipformer_stateless8:
     if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
diff --git a/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml b/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml
index 21f396c32..d763fb1c5 100644
--- a/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml
+++ b/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml
@@ -33,6 +33,10 @@ on:
     # nightly build at 15:50 UTC time every day
     - cron: "50 15 * * *"
 
+concurrency:
+  group: run_librispeech_2022_11_28_conformer_ctc3-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_librispeech_2022_11_28_conformer_ctc3:
     if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
diff --git a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
index 5f0acf9b8..59f116fde 100644
--- a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
+++ b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
@@ -16,6 +16,10 @@ on:
     # nightly build at 15:50 UTC time every day
     - cron: "50 15 * * *"
 
+concurrency:
+  group: run_librispeech_lstm_transducer_stateless2_2022_09_03-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_librispeech_lstm_transducer_stateless2_2022_09_03:
     if: github.event.label.name == 'ready' || github.event.label.name == 'LODR' || github.event.label.name == 'shallow-fusion' || github.event.label.name == 'ncnn' || github.event.label.name == 'onnx' || github.event_name == 'push' || github.event_name == 'schedule'
diff --git a/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml b/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml
index 66a2c240b..2c2bcab0c 100644
--- a/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml
+++ b/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml
@@ -33,6 +33,10 @@ on:
     # nightly build at 15:50 UTC time every day
     - cron: "50 15 * * *"
 
+concurrency:
+  group: run_librispeech_pruned_transducer_stateless3_2022_05_13-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_librispeech_pruned_transducer_stateless3_2022_05_13:
     if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
diff --git a/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml b/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml
index 55428861c..ac7e58b20 100644
--- a/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml
+++ b/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml
@@ -33,6 +33,10 @@ on:
     # nightly build at 15:50 UTC time every day
     - cron: "50 15 * * *"
 
+concurrency:
+  group: run_librispeech_streaming_2022_06_26-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_librispeech_streaming_2022_06_26:
     if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
diff --git a/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml b/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml
index f520405e1..575727e22 100644
--- a/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml
+++ b/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml
@@ -33,6 +33,10 @@ on:
     # nightly build at 15:50 UTC time every day
     - cron: "50 15 * * *"
 
+concurrency:
+  group: run_librispeech_2022_04_19-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_librispeech_2022_04_19:
     if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
diff --git a/.github/workflows/run-pretrained-conformer-ctc.yml b/.github/workflows/run-pretrained-conformer-ctc.yml
index 9bc6a481f..7dbfd2bd9 100644
--- a/.github/workflows/run-pretrained-conformer-ctc.yml
+++ b/.github/workflows/run-pretrained-conformer-ctc.yml
@@ -23,6 +23,10 @@ on:
   pull_request:
     types: [labeled]
 
+concurrency:
+  group: run_pre_trained_conformer_ctc-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_pre_trained_conformer_ctc:
     if: github.event.label.name == 'ready' || github.event_name == 'push'
diff --git a/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml b/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml
index 7a0f30b0f..d6b3de8d4 100644
--- a/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml
+++ b/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml
@@ -32,6 +32,10 @@ on:
     # nightly build at 15:50 UTC time every day
     - cron: "50 15 * * *"
 
+concurrency:
+  group: run_pre_trained_transducer_stateless_multi_datasets_librispeech_100h-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_pre_trained_transducer_stateless_multi_datasets_librispeech_100h:
     if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
diff --git a/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml b/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml
index 797f3fe50..749fb3fca 100644
--- a/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml
+++ b/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml
@@ -32,6 +32,10 @@ on:
     # nightly build at 15:50 UTC time every day
     - cron: "50 15 * * *"
 
+concurrency:
+  group: run_pre_trained_transducer_stateless_multi_datasets_librispeech_960h-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_pre_trained_transducer_stateless_multi_datasets_librispeech_960h:
     if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
diff --git a/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml b/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml
index 29e665881..92bf6feb8 100644
--- a/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml
+++ b/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml
@@ -23,6 +23,10 @@ on:
   pull_request:
     types: [labeled]
 
+concurrency:
+  group: run_pre_trained_transducer_stateless_modified_2_aishell-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_pre_trained_transducer_stateless_modified_2_aishell:
     if: github.event.label.name == 'ready' || github.event_name == 'push'
diff --git a/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml b/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml
index 6193f28e7..e51da8bd8 100644
--- a/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml
+++ b/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml
@@ -23,6 +23,10 @@ on:
   pull_request:
     types: [labeled]
 
+concurrency:
+  group: run_pre_trained_transducer_stateless_modified_aishell-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_pre_trained_transducer_stateless_modified_aishell:
     if: github.event.label.name == 'ready' || github.event_name == 'push'
diff --git a/.github/workflows/run-pretrained-transducer-stateless.yml b/.github/workflows/run-pretrained-transducer-stateless.yml
index 32208076c..2103d0510 100644
--- a/.github/workflows/run-pretrained-transducer-stateless.yml
+++ b/.github/workflows/run-pretrained-transducer-stateless.yml
@@ -32,6 +32,10 @@ on:
     # nightly build at 15:50 UTC time every day
     - cron: "50 15 * * *"
 
+concurrency:
+  group: run_pre_trained_transducer_stateless-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_pre_trained_transducer_stateless:
     if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
diff --git a/.github/workflows/run-pretrained-transducer.yml b/.github/workflows/run-pretrained-transducer.yml
index 965d0f655..902319b55 100644
--- a/.github/workflows/run-pretrained-transducer.yml
+++ b/.github/workflows/run-pretrained-transducer.yml
@@ -23,6 +23,10 @@ on:
   pull_request:
     types: [labeled]
 
+concurrency:
+  group: run_pre_trained_transducer-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_pre_trained_transducer:
     if: github.event.label.name == 'ready' || github.event_name == 'push'
diff --git a/.github/workflows/run-ptb-rnn-lm.yml b/.github/workflows/run-ptb-rnn-lm.yml
index 8ebc2e79b..47ed958f2 100644
--- a/.github/workflows/run-ptb-rnn-lm.yml
+++ b/.github/workflows/run-ptb-rnn-lm.yml
@@ -16,6 +16,10 @@ on:
     # nightly build at 15:50 UTC time every day
     - cron: "50 15 * * *"
 
+concurrency:
+  group: run_ptb_rnn_lm_training-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_ptb_rnn_lm_training:
     if: github.event.label.name == 'ready' || github.event.label.name == 'rnnlm' || github.event_name == 'push' || github.event_name == 'schedule'
diff --git a/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml b/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml
index d96a3bfe6..8a7be0b80 100644
--- a/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml
+++ b/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml
@@ -23,8 +23,12 @@ on:
   pull_request:
     types: [labeled]
 
+concurrency:
+  group: run_wenetspeech_pruned_transducer_stateless2-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
-  run_librispeech_pruned_transducer_stateless3_2022_05_13:
+  run_wenetspeech_pruned_transducer_stateless2:
     if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'wenetspeech'
     runs-on: ${{ matrix.os }}
     strategy:
diff --git a/.github/workflows/run-yesno-recipe.yml b/.github/workflows/run-yesno-recipe.yml
index ce77c47df..ed343aee5 100644
--- a/.github/workflows/run-yesno-recipe.yml
+++ b/.github/workflows/run-yesno-recipe.yml
@@ -21,11 +21,15 @@ on:
     branches:
       - master
   pull_request:
-    types: [labeled]
+    branches:
+      - master
+
+concurrency:
+  group: run-yesno-recipe-${{ github.ref }}
+  cancel-in-progress: true
 
 jobs:
   run-yesno-recipe:
-    if: github.event.label.name == 'ready' || github.event_name == 'push'
     runs-on: ${{ matrix.os }}
     strategy:
       matrix:
@@ -61,7 +65,7 @@ jobs:
 
       - name: Install Python dependencies
         run: |
-          grep -v '^#' ./requirements-ci.txt  | xargs -n 1 -L 1 pip install
+          grep -v '^#' ./requirements-ci.txt  | grep -v kaldifst | xargs -n 1 -L 1 pip install
           pip uninstall -y protobuf
           pip install --no-binary protobuf protobuf
 
diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml
index 45d261ccc..fc1dcbfd4 100644
--- a/.github/workflows/style_check.yml
+++ b/.github/workflows/style_check.yml
@@ -24,6 +24,10 @@ on:
     branches:
       - master
 
+concurrency:
+  group: style_check-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   style_check:
     runs-on: ${{ matrix.os }}
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 04fc0265f..4dbe99827 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -21,26 +21,23 @@ on:
     branches:
       - master
   pull_request:
-    types: [labeled]
+    branches:
+      - master
+
+concurrency:
+  group: test-${{ github.ref }}
+  cancel-in-progress: true
 
 jobs:
   test:
-    if: github.event.label.name == 'ready' || github.event_name == 'push'
     runs-on: ${{ matrix.os }}
     strategy:
       matrix:
-        # os: [ubuntu-18.04, macos-10.15]
-        # disable macOS test for now.
-        os: [ubuntu-18.04]
-        python-version: [3.7, 3.8]
-        torch: ["1.8.0", "1.11.0"]
-        torchaudio: ["0.8.0", "0.11.0"]
-        k2-version: ["1.15.1.dev20220427"]
-        exclude:
-          - torch: "1.8.0"
-            torchaudio: "0.11.0"
-          - torch: "1.11.0"
-            torchaudio: "0.8.0"
+        os: [ubuntu-latest]
+        python-version: ["3.8"]
+        torch: ["1.10.0"]
+        torchaudio: ["0.10.0"]
+        k2-version: ["1.23.2.dev20221201"]
 
       fail-fast: false
 
@@ -67,11 +64,7 @@ jobs:
           # numpy 1.20.x does not support python 3.6
           pip install numpy==1.19
           pip install torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
-          if [[ ${{ matrix.torchaudio }} == "0.11.0" ]]; then
-            pip install torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
-          else
-            pip install torchaudio==${{ matrix.torchaudio }}
-          fi
+          pip install torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
 
           pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
           pip install git+https://github.com/lhotse-speech/lhotse
@@ -81,7 +74,6 @@ jobs:
 
           pip install kaldifst
           pip install onnxruntime
-
           pip install -r requirements.txt
 
       - name: Install graphviz
@@ -124,16 +116,14 @@ jobs:
           cd ../transducer_stateless
           pytest -v -s
 
-          if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then
-            cd ../transducer
-            pytest -v -s
+          cd ../transducer
+          pytest -v -s
 
-            cd ../transducer_stateless2
-            pytest -v -s
+          cd ../transducer_stateless2
+          pytest -v -s
 
-            cd ../transducer_lstm
-            pytest -v -s
-          fi
+          cd ../transducer_lstm
+          pytest -v -s
 
       - name: Run tests
         if: startsWith(matrix.os, 'macos')
@@ -164,13 +154,11 @@ jobs:
           cd ../transducer_stateless
           pytest -v -s
 
-          if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then
-            cd ../transducer
-            pytest -v -s
+          cd ../transducer
+          pytest -v -s
 
-            cd ../transducer_stateless2
-            pytest -v -s
+          cd ../transducer_stateless2
+          pytest -v -s
 
-            cd ../transducer_lstm
-            pytest -v -s
-          fi
+          cd ../transducer_lstm
+          pytest -v -s
diff --git a/egs/librispeech/ASR/local/train_bpe_model.py b/egs/librispeech/ASR/local/train_bpe_model.py
index 7f6f47e16..43142aee4 100755
--- a/egs/librispeech/ASR/local/train_bpe_model.py
+++ b/egs/librispeech/ASR/local/train_bpe_model.py
@@ -93,7 +93,6 @@ def main():
         print(f"{model_file} exists - skipping")
         return
 
-
     shutil.copyfile(model_file, f"{lang_dir}/bpe.model")
 
 
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py
index 59c8ed5b5..b324cc9b7 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py
@@ -2230,9 +2230,7 @@ def modified_beam_search_rnnlm_LODR(
         log_probs_shape = k2.ragged.create_ragged_shape2(
             row_splits=row_splits, cached_tot_size=log_probs.numel()
         )
-        ragged_log_probs = k2.RaggedTensor(
-            shape=log_probs_shape, value=log_probs
-        )
+        ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
         """
         for all hyps with a non-blank new token, score this token.
         It is a little confusing here because this for-loop
@@ -2267,10 +2265,7 @@ def modified_beam_search_rnnlm_LODR(
         # forward RNNLM to get new states and scores
         if len(token_list) != 0:
             tokens_to_score = (
-                torch.tensor(token_list)
-                .to(torch.int64)
-                .to(device)
-                .reshape(-1, 1)
+                torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1)
             )
 
             hs = torch.cat(hs, dim=1).to(device)
@@ -2304,9 +2299,7 @@ def modified_beam_search_rnnlm_LODR(
                     state_cost = hyp.state_cost.forward_one_step(new_token)
 
                     # calculate the score of the latest token
-                    current_ngram_score = (
-                        state_cost.lm_score - hyp.state_cost.lm_score
-                    )
+                    current_ngram_score = state_cost.lm_score - hyp.state_cost.lm_score
 
                     assert current_ngram_score <= 0.0, (
                         state_cost.lm_score,
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling.py
index e9dfe6d5e..42de2410a 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling.py
@@ -52,17 +52,9 @@ def test_scaled_conv2d():
         torch.jit.script(conv2d)
 
 
-def test_activation_balancer():
-    act = ActivationBalancer(
-        channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
-    )
-    torch.jit.script(act)
-
-
 def main():
     test_scaled_conv1d()
     test_scaled_conv2d()
-    test_activation_balancer()
 
 
 if __name__ == "__main__":
diff --git a/egs/yesno/ASR/tdnn/asr_datamodule.py b/egs/yesno/ASR/tdnn/asr_datamodule.py
index 85e5f1358..3c1682fa1 100644
--- a/egs/yesno/ASR/tdnn/asr_datamodule.py
+++ b/egs/yesno/ASR/tdnn/asr_datamodule.py
@@ -121,7 +121,7 @@ class YesNoAsrDataModule(DataModule):
         group.add_argument(
             "--shuffle",
             type=str2bool,
-            default=True,
+            default=False,
             help="When enabled (=default), the examples will be "
             "shuffled for each epoch.",
         )

From 6f719816673761ceda0bfe6bece5a44b151ead46 Mon Sep 17 00:00:00 2001
From: Amir Hussein <36240131+AmirHussein96@users.noreply.github.com>
Date: Thu, 1 Dec 2022 21:58:34 -0500
Subject: [PATCH 057/120] MGB2 (#396)

* mgb2

* mgb2

* adding pruned transducer stateless to mgb2

* update display_manifest_statistics.py

* .

* stateless transducer MGB-2

* Update README.md

* Update RESULTS.md

* Update prepare_lang_bpe.py

* Update asr_datamodule.py

* .nfs removed

* Adding symlink

* .

* resolving conflicts

* Update .gitignore

* black formatting

* Update compile_hlg.py

* Update compute_fbank_musan.py

* Update convert_transcript_words_to_tokens.py

* Update download_lm.py

* Update generate_unique_lexicon.py

* adding simlinks

* fixing symbolic links
---
 .gitignore                                    |   20 +
 egs/mgb2/ASR/README.md                        |   43 +
 egs/mgb2/ASR/RESULTS.md                       |  236 ++++
 egs/mgb2/ASR/conformer_ctc/__init__.py        |    0
 egs/mgb2/ASR/conformer_ctc/ali.py             |  395 ++++++
 egs/mgb2/ASR/conformer_ctc/asr_datamodule.py  |  372 ++++++
 egs/mgb2/ASR/conformer_ctc/compile_hlg.py     |    1 +
 .../ASR/conformer_ctc/compute_fbank_musan.py  |    1 +
 egs/mgb2/ASR/conformer_ctc/conformer.py       |    1 +
 .../convert_transcript_words_to_tokens.py     |    1 +
 egs/mgb2/ASR/conformer_ctc/decode.py          |  695 ++++++++++
 egs/mgb2/ASR/conformer_ctc/download_lm.py     |    1 +
 egs/mgb2/ASR/conformer_ctc/export.py          |    1 +
 .../conformer_ctc/generate_unique_lexicon.py  |    1 +
 egs/mgb2/ASR/conformer_ctc/label_smoothing.py |    1 +
 egs/mgb2/ASR/conformer_ctc/pretrained.py      |  430 ++++++
 egs/mgb2/ASR/conformer_ctc/subsampling.py     |    1 +
 .../ASR/conformer_ctc/test_label_smoothing.py |    1 +
 .../ASR/conformer_ctc/test_subsampling.py     |    1 +
 .../ASR/conformer_ctc/test_transformer.py     |    1 +
 egs/mgb2/ASR/conformer_ctc/train.py           |  766 +++++++++++
 egs/mgb2/ASR/conformer_ctc/transformer.py     |    1 +
 egs/mgb2/ASR/local/__init__.py                |    0
 egs/mgb2/ASR/local/compile_hlg.py             |    1 +
 egs/mgb2/ASR/local/compute_fbank_mgb2.py      |  101 ++
 egs/mgb2/ASR/local/compute_fbank_musan.py     |  108 ++
 .../convert_transcript_words_to_tokens.py     |  103 ++
 .../ASR/local/display_manifest_statistics.py  |   97 ++
 egs/mgb2/ASR/local/generate_unique_lexicon.py |    1 +
 egs/mgb2/ASR/local/prep_mgb2_lexicon.sh       |   30 +
 egs/mgb2/ASR/local/prepare_lang.py            |    1 +
 egs/mgb2/ASR/local/prepare_lang_bpe.py        |    1 +
 egs/mgb2/ASR/local/prepare_mgb2_lexicon.py    |   37 +
 egs/mgb2/ASR/local/test_prepare_lang.py       |    1 +
 egs/mgb2/ASR/prepare.sh                       |  234 ++++
 .../pruned_transducer_stateless5/__init__.py  |    0
 .../asr_datamodule.py                         |    1 +
 .../beam_search.py                            |    1 +
 .../pruned_transducer_stateless5/conformer.py |    1 +
 .../pruned_transducer_stateless5/decode.py    |  625 +++++++++
 .../pruned_transducer_stateless5/decoder.py   |    1 +
 .../encoder_interface.py                      |    1 +
 .../pruned_transducer_stateless5/export.py    |  272 ++++
 .../pruned_transducer_stateless5/joiner.py    |    1 +
 .../ASR/pruned_transducer_stateless5/model.py |    1 +
 .../ASR/pruned_transducer_stateless5/optim.py |    1 +
 .../pretrained.py                             |  344 +++++
 .../pruned_transducer_stateless5/scaling.py   |    1 +
 .../test_model.py                             |    1 +
 .../ASR/pruned_transducer_stateless5/train.py | 1176 +++++++++++++++++
 egs/mgb2/ASR/shared                           |    1 +
 icefall/diagnostics.py                        |    2 +-
 52 files changed, 6114 insertions(+), 1 deletion(-)
 create mode 100644 egs/mgb2/ASR/README.md
 create mode 100644 egs/mgb2/ASR/RESULTS.md
 create mode 100644 egs/mgb2/ASR/conformer_ctc/__init__.py
 create mode 100755 egs/mgb2/ASR/conformer_ctc/ali.py
 create mode 100644 egs/mgb2/ASR/conformer_ctc/asr_datamodule.py
 create mode 120000 egs/mgb2/ASR/conformer_ctc/compile_hlg.py
 create mode 120000 egs/mgb2/ASR/conformer_ctc/compute_fbank_musan.py
 create mode 120000 egs/mgb2/ASR/conformer_ctc/conformer.py
 create mode 120000 egs/mgb2/ASR/conformer_ctc/convert_transcript_words_to_tokens.py
 create mode 100755 egs/mgb2/ASR/conformer_ctc/decode.py
 create mode 120000 egs/mgb2/ASR/conformer_ctc/download_lm.py
 create mode 120000 egs/mgb2/ASR/conformer_ctc/export.py
 create mode 120000 egs/mgb2/ASR/conformer_ctc/generate_unique_lexicon.py
 create mode 120000 egs/mgb2/ASR/conformer_ctc/label_smoothing.py
 create mode 100755 egs/mgb2/ASR/conformer_ctc/pretrained.py
 create mode 120000 egs/mgb2/ASR/conformer_ctc/subsampling.py
 create mode 120000 egs/mgb2/ASR/conformer_ctc/test_label_smoothing.py
 create mode 120000 egs/mgb2/ASR/conformer_ctc/test_subsampling.py
 create mode 120000 egs/mgb2/ASR/conformer_ctc/test_transformer.py
 create mode 100755 egs/mgb2/ASR/conformer_ctc/train.py
 create mode 120000 egs/mgb2/ASR/conformer_ctc/transformer.py
 create mode 100644 egs/mgb2/ASR/local/__init__.py
 create mode 120000 egs/mgb2/ASR/local/compile_hlg.py
 create mode 100755 egs/mgb2/ASR/local/compute_fbank_mgb2.py
 create mode 100755 egs/mgb2/ASR/local/compute_fbank_musan.py
 create mode 100755 egs/mgb2/ASR/local/convert_transcript_words_to_tokens.py
 create mode 100755 egs/mgb2/ASR/local/display_manifest_statistics.py
 create mode 120000 egs/mgb2/ASR/local/generate_unique_lexicon.py
 create mode 100755 egs/mgb2/ASR/local/prep_mgb2_lexicon.sh
 create mode 120000 egs/mgb2/ASR/local/prepare_lang.py
 create mode 120000 egs/mgb2/ASR/local/prepare_lang_bpe.py
 create mode 100755 egs/mgb2/ASR/local/prepare_mgb2_lexicon.py
 create mode 120000 egs/mgb2/ASR/local/test_prepare_lang.py
 create mode 100755 egs/mgb2/ASR/prepare.sh
 create mode 100644 egs/mgb2/ASR/pruned_transducer_stateless5/__init__.py
 create mode 120000 egs/mgb2/ASR/pruned_transducer_stateless5/asr_datamodule.py
 create mode 120000 egs/mgb2/ASR/pruned_transducer_stateless5/beam_search.py
 create mode 120000 egs/mgb2/ASR/pruned_transducer_stateless5/conformer.py
 create mode 100755 egs/mgb2/ASR/pruned_transducer_stateless5/decode.py
 create mode 120000 egs/mgb2/ASR/pruned_transducer_stateless5/decoder.py
 create mode 120000 egs/mgb2/ASR/pruned_transducer_stateless5/encoder_interface.py
 create mode 100755 egs/mgb2/ASR/pruned_transducer_stateless5/export.py
 create mode 120000 egs/mgb2/ASR/pruned_transducer_stateless5/joiner.py
 create mode 120000 egs/mgb2/ASR/pruned_transducer_stateless5/model.py
 create mode 120000 egs/mgb2/ASR/pruned_transducer_stateless5/optim.py
 create mode 100755 egs/mgb2/ASR/pruned_transducer_stateless5/pretrained.py
 create mode 120000 egs/mgb2/ASR/pruned_transducer_stateless5/scaling.py
 create mode 120000 egs/mgb2/ASR/pruned_transducer_stateless5/test_model.py
 create mode 100755 egs/mgb2/ASR/pruned_transducer_stateless5/train.py
 create mode 120000 egs/mgb2/ASR/shared

diff --git a/.gitignore b/.gitignore
index 406deff6a..583410f45 100644
--- a/.gitignore
+++ b/.gitignore
@@ -11,5 +11,25 @@ log
 *.bak
 *-bak
 *bak.py
+
+# Ignore Mac system files
+.DS_store
+
+# Ignore node_modules folder
+node_modules
+
+# ignore .nfs
+
+.nfs*
+
+# Ignore all text files
+*.txt
+
+# Ignore files related to API keys
+.env
+
+# Ignore SASS config files
+.sass-cache
+
 *.param
 *.bin
diff --git a/egs/mgb2/ASR/README.md b/egs/mgb2/ASR/README.md
new file mode 100644
index 000000000..2bc4b000b
--- /dev/null
+++ b/egs/mgb2/ASR/README.md
@@ -0,0 +1,43 @@
+# MGB2
+
+The Multi-Dialect Broadcast News Arabic Speech Recognition (MGB-2):
+The second edition of the Multi-Genre Broadcast (MGB-2) Challenge is
+an evaluation of speech recognition and lightly supervised alignment
+using TV recordings in Arabic. The speech data is broad and multi-genre,
+spanning the whole range of TV output, and represents a challenging task for
+speech technology. In 2016, the challenge featured two new Arabic tracks based
+on TV data from Aljazeera. It was an official challenge at the 2016 IEEE
+Workshop on Spoken Language Technology. The 1,200 hours MGB-2: from Aljazeera
+TV programs have been manually captioned with no timing information.
+QCRI Arabic ASR system has been used to recognize all programs. The ASR output
+was used to align the manual captioning and produce speech segments for
+training speech recognition. More than 20 hours from 2015 programs have been
+transcribed verbatim and manually segmented. This data is split into a
+development set of 10 hours, and a similar evaluation set of 10 hours.
+Both the development and evaluation data have been released in the 2016 MGB
+challenge
+
+Official reference:
+
+Ali, Ahmed, et al. "The MGB-2 challenge: Arabic multi-dialect broadcast media recognition." 
+2016 IEEE Spoken Language Technology Workshop (SLT). IEEE, 2016.
+
+IEEE link: https://ieeexplore.ieee.org/abstract/document/7846277
+
+## Stateless Pruned Transducer Performance Record (after 30 epochs)
+
+|                                    |     dev    |    test    | comment                                  |
+|------------------------------------|------------|------------|------------------------------------------|
+|          greedy search             | 15.52      | 15.28      | --epoch 18, --avg 5, --max-duration 200  |
+| modified beam search               | 13.88      | 13.7       | --epoch 18, --avg 5, --max-duration 200  |
+| fast beam search                   | 14.62      | 14.36      | --epoch 18, --avg 5, --max-duration 200  |
+
+## Conformer-CTC Performance Record (after 40 epochs)
+
+| Decoding method           | dev WER    | test WER |
+|---------------------------|------------|---------|
+| attention-decoder         | 15.62      |  15.01  |
+| whole-lattice-rescoring   | 15.89      |  15.08  |
+
+
+See [RESULTS](/egs/mgb2/ASR/RESULTS.md) for details.
diff --git a/egs/mgb2/ASR/RESULTS.md b/egs/mgb2/ASR/RESULTS.md
new file mode 100644
index 000000000..2a7ea7664
--- /dev/null
+++ b/egs/mgb2/ASR/RESULTS.md
@@ -0,0 +1,236 @@
+# Results
+
+
+### MGB2 all data BPE training results (Stateless Pruned Transducer)
+
+#### 2022-09-07
+
+The WERs are
+
+|                                    |     dev    |    test    | comment                                  |
+|------------------------------------|------------|------------|------------------------------------------|
+|          greedy search             | 15.52      | 15.28      | --epoch 18, --avg 5, --max-duration 200 |
+| modified beam search               | 13.88      | 13.7       | --epoch 18, --avg 5, --max-duration 200 |
+| fast beam search                   | 14.62      | 14.36      | --epoch 18, --avg 5, --max-duration 200|
+
+The training command for reproducing is given below:
+
+```
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+
+  
+./pruned_transducer_stateless5/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --exp-dir pruned_transducer_stateless5/exp \
+  --max-duration 300 \
+  --num-buckets 50
+```
+
+The tensorboard training log can be found at
+https://tensorboard.dev/experiment/YyNv45pfQ0GqWzZ898WOlw/#scalars
+
+The decoding command is:
+```
+epoch=18
+avg=5
+for method in greedy_search modified_beam_search fast_beam_search; do
+  ./pruned_transducer_stateless5/decode.py \
+    --epoch $epoch \
+	--beam-size 10 \
+    --avg $avg \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 200 \
+    --decoding-method $method \
+    --max-sym-per-frame 1 \
+    --num-encoder-layers 12 \
+    --dim-feedforward 2048 \
+    --nhead 8 \
+    --encoder-dim 512 \
+    --decoder-dim 512 \
+    --joiner-dim 512 \
+    --use-averaged-model True
+done
+```
+
+### MGB2 all data BPE training results (Conformer-CTC) (after 40 epochs)
+
+#### 2022-06-04
+
+You can find a pretrained model, training logs, decoding logs, and decoding results at:
+https://huggingface.co/AmirHussein/icefall-asr-mgb2-conformer_ctc-2022-27-06
+
+The best WER, as of 2022-06-04, for the MGB2 test dataset is below
+
+Using whole lattice HLG decoding + n-gram LM rescoring 
+
+|     | dev        | test       |
+|-----|------------|------------|
+| WER | 15.62      |  15.01     |
+
+Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are:
+| ngram_lm_scale | attention_scale |
+|----------------|-----------------|
+| 0.1            | -            |
+
+
+Using n-best (n=0.5) attention decoder rescoring
+
+|     | dev        | test       |
+|-----|------------|------------|
+| WER |    15.89   |  15.08     |
+
+Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are:
+| ngram_lm_scale | attention_scale |
+|----------------|-----------------|
+| 0.01           | 0.5             |
+
+
+To reproduce the above result, use the following commands for training:
+
+# Note: the model was trained on V-100 32GB GPU
+
+```
+cd egs/mgb2/ASR
+. ./path.sh
+./prepare.sh
+export CUDA_VISIBLE_DEVICES="0,1"
+./conformer_ctc/train.py \
+  --lang-dir data/lang_bpe_5000 \
+  --att-rate 0.8 \
+  --lr-factor 10 \
+  --max-duration  \
+  --concatenate-cuts 0 \
+  --world-size 2 \
+  --bucketing-sampler 1 \
+  --max-duration 100 \ 
+  --start-epoch 0 \
+  --num-epochs 40
+  
+```
+
+and the following command for nbest decoding
+
+```
+./conformer_ctc/decode.py \
+  --lang-dir data/lang_bpe_5000 \
+  --max-duration 30 \
+  --concatenate-cuts 0 \
+  --bucketing-sampler 1 \
+  --num-paths 1000 \
+  --epoch 40 \
+  --avg 5 \
+  --method attention-decoder \
+  --nbest-scale 0.5
+```
+
+and the following command for whole-lattice decoding
+
+```
+./conformer_ctc/decode.py \
+  --epoch 40 \
+  --avg 5 \
+  --exp-dir conformer_ctc/exp_5000_att0.8 \
+  --lang-dir data/lang_bpe_5000 \
+  --max-duration 30 \
+  --concatenate-cuts 0 \
+  --bucketing-sampler 1 \
+  --num-paths 1000 \
+  --method  whole-lattice-rescoring
+```
+
+
+The tensorboard log for training is available at
+https://tensorboard.dev/experiment/QYNzOi52RwOX8yvtpl3hMw/#scalars
+
+
+### MGB2 100h BPE training results (Conformer-CTC) (after 33 epochs)
+
+#### 2022-06-04
+
+The best WER, as of 2022-06-04, for the MGB2 test dataset is below
+
+Using whole lattice HLG decoding + n-gram LM rescoring 
+
+|     | dev        | test       |
+|-----|------------|------------|
+| WER | 25.32      |  23.53     |
+
+Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are:
+| ngram_lm_scale | attention_scale |
+|----------------|-----------------|
+| 0.1            | -            |
+
+
+Using n-best (n=0.5) HLG decoding + n-gram LM rescoring + attention decoder rescoring:
+
+|     | dev        | test       |
+|-----|------------|------------|
+| WER |    27.87   |  26.12     |
+
+Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are:
+| ngram_lm_scale | attention_scale |
+|----------------|-----------------|
+| 0.01           | 0.3             |
+
+
+To reproduce the above result, use the following commands for training:
+
+# Note: the model was trained on V-100 32GB GPU
+
+```
+cd egs/mgb2/ASR
+. ./path.sh
+./prepare.sh
+export CUDA_VISIBLE_DEVICES="0,1"
+./conformer_ctc/train.py \
+  --lang-dir data/lang_bpe_5000 \
+  --att-rate 0.8 \
+  --lr-factor 10 \
+  --max-duration  \
+  --concatenate-cuts 0 \
+  --world-size 2 \
+  --bucketing-sampler 1 \
+  --max-duration 100 \ 
+  --start-epoch 0 \
+  --num-epochs 40
+  
+```
+
+and the following command for nbest decoding
+
+```
+./conformer_ctc/decode.py \
+  --lang-dir data/lang_bpe_5000 \
+  --max-duration 30 \
+  --concatenate-cuts 0 \
+  --bucketing-sampler 1 \
+  --num-paths 1000 \
+  --epoch 40 \
+  --avg 5 \
+  --method attention-decoder \
+  --nbest-scale 0.5
+```
+
+and the following command for whole-lattice decoding
+
+```
+./conformer_ctc/decode.py \
+  --lang-dir data/lang_bpe_5000 \
+  --max-duration 30 \
+  --concatenate-cuts 0 \
+  --bucketing-sampler 1 \
+  --num-paths 1000 \
+  --epoch 40 \
+  --avg 5 \
+  --method  whole-lattice-rescoring
+```
+
+The tensorboard log for training is available at
+
+
+
+
+
diff --git a/egs/mgb2/ASR/conformer_ctc/__init__.py b/egs/mgb2/ASR/conformer_ctc/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/mgb2/ASR/conformer_ctc/ali.py b/egs/mgb2/ASR/conformer_ctc/ali.py
new file mode 100755
index 000000000..aea962dcd
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/ali.py
@@ -0,0 +1,395 @@
+#!/usr/bin/env python3
+# Copyright    2021  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Usage:
+    ./conformer_ctc/ali.py \
+            --exp-dir ./conformer_ctc/exp \
+            --lang-dir ./data/lang_bpe_500 \
+            --epoch 20 \
+            --avg 10 \
+            --max-duration 300 \
+            --dataset train-clean-100 \
+            --out-dir data/ali
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import k2
+import numpy as np
+import torch
+from asr_datamodule import LibriSpeechAsrDataModule
+from conformer import Conformer
+from lhotse import CutSet
+from lhotse.features.io import FeaturesWriter, NumpyHdf5Writer
+
+from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
+from icefall.checkpoint import average_checkpoints, load_checkpoint
+from icefall.decode import one_best_decoding
+from icefall.env import get_env_info
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+    AttributeDict,
+    encode_supervisions,
+    get_alignments,
+    setup_logger,
+)
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=34,
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
+    )
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=20,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        default="data/lang_bpe_500",
+        help="The lang dir",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="conformer_ctc/exp",
+        help="The experiment dir",
+    )
+
+    parser.add_argument(
+        "--out-dir",
+        type=str,
+        required=True,
+        help="""Output directory.
+        It contains 3 generated files:
+
+        - labels_xxx.h5
+        - aux_labels_xxx.h5
+        - cuts_xxx.json.gz
+
+        where xxx is the value of `--dataset`. For instance, if
+        `--dataset` is `train-clean-100`, it will contain 3 files:
+
+        - `labels_train-clean-100.h5`
+        - `aux_labels_train-clean-100.h5`
+        - `cuts_train-clean-100.json.gz`
+
+        Note: Both labels_xxx.h5 and aux_labels_xxx.h5 contain framewise
+        alignment. The difference is that labels_xxx.h5 contains repeats.
+        """,
+    )
+
+    parser.add_argument(
+        "--dataset",
+        type=str,
+        required=True,
+        help="""The name of the dataset to compute alignments for.
+        Possible values are:
+            - test-clean.
+            - test-other
+            - train-clean-100
+            - train-clean-360
+            - train-other-500
+            - dev-clean
+            - dev-other
+        """,
+    )
+    return parser
+
+
+def get_params() -> AttributeDict:
+    params = AttributeDict(
+        {
+            "lm_dir": Path("data/lm"),
+            "feature_dim": 80,
+            "nhead": 8,
+            "attention_dim": 512,
+            "subsampling_factor": 4,
+            # Set it to 0 since attention decoder
+            # is not used for computing alignments
+            "num_decoder_layers": 0,
+            "vgg_frontend": False,
+            "use_feat_batchnorm": True,
+            "output_beam": 10,
+            "use_double_scores": True,
+            "env_info": get_env_info(),
+        }
+    )
+    return params
+
+
+def compute_alignments(
+    model: torch.nn.Module,
+    dl: torch.utils.data.DataLoader,
+    labels_writer: FeaturesWriter,
+    aux_labels_writer: FeaturesWriter,
+    params: AttributeDict,
+    graph_compiler: BpeCtcTrainingGraphCompiler,
+) -> CutSet:
+    """Compute the framewise alignments of a dataset.
+
+    Args:
+      model:
+        The neural network model.
+      dl:
+        Dataloader containing the dataset.
+      params:
+        Parameters for computing alignments.
+      graph_compiler:
+        It converts token IDs to decoding graphs.
+    Returns:
+      Return a CutSet. Each cut has two custom fields: labels_alignment
+      and aux_labels_alignment, containing framewise alignments information.
+      Both are of type `lhotse.array.TemporalArray`. The difference between
+      the two alignments is that `labels_alignment` contain repeats.
+    """
+    try:
+        num_batches = len(dl)
+    except TypeError:
+        num_batches = "?"
+    num_cuts = 0
+
+    device = graph_compiler.device
+    cuts = []
+    for batch_idx, batch in enumerate(dl):
+        feature = batch["inputs"]
+
+        # at entry, feature is [N, T, C]
+        assert feature.ndim == 3
+        feature = feature.to(device)
+
+        supervisions = batch["supervisions"]
+        cut_list = supervisions["cut"]
+
+        for cut in cut_list:
+            assert len(cut.supervisions) == 1, f"{len(cut.supervisions)}"
+
+        nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
+        # nnet_output is [N, T, C]
+        supervision_segments, texts = encode_supervisions(
+            supervisions, subsampling_factor=params.subsampling_factor
+        )
+        # we need also to sort cut_ids as encode_supervisions()
+        # reorders "texts".
+        # In general, new2old is an identity map since lhotse sorts the returned
+        # cuts by duration in descending order
+        new2old = supervision_segments[:, 0].tolist()
+
+        cut_list = [cut_list[i] for i in new2old]
+
+        token_ids = graph_compiler.texts_to_ids(texts)
+        decoding_graph = graph_compiler.compile(token_ids)
+
+        dense_fsa_vec = k2.DenseFsaVec(
+            nnet_output,
+            supervision_segments,
+            allow_truncate=params.subsampling_factor - 1,
+        )
+
+        lattice = k2.intersect_dense(
+            decoding_graph,
+            dense_fsa_vec,
+            params.output_beam,
+        )
+
+        best_path = one_best_decoding(
+            lattice=lattice,
+            use_double_scores=params.use_double_scores,
+        )
+
+        labels_ali = get_alignments(best_path, kind="labels")
+        aux_labels_ali = get_alignments(best_path, kind="aux_labels")
+        assert len(labels_ali) == len(aux_labels_ali) == len(cut_list)
+        for cut, labels, aux_labels in zip(cut_list, labels_ali, aux_labels_ali):
+            cut.labels_alignment = labels_writer.store_array(
+                key=cut.id,
+                value=np.asarray(labels, dtype=np.int32),
+                # frame shift is 0.01s, subsampling_factor is 4
+                frame_shift=0.04,
+                temporal_dim=0,
+                start=0,
+            )
+            cut.aux_labels_alignment = aux_labels_writer.store_array(
+                key=cut.id,
+                value=np.asarray(aux_labels, dtype=np.int32),
+                # frame shift is 0.01s, subsampling_factor is 4
+                frame_shift=0.04,
+                temporal_dim=0,
+                start=0,
+            )
+
+        cuts += cut_list
+
+        num_cuts += len(cut_list)
+
+        if batch_idx % 100 == 0:
+            batch_str = f"{batch_idx}/{num_batches}"
+
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+
+    return CutSet.from_cuts(cuts)
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    LibriSpeechAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+
+    args.enable_spec_aug = False
+    args.enable_musan = False
+    args.return_cuts = True
+    args.concatenate_cuts = False
+
+    params = get_params()
+    params.update(vars(args))
+
+    setup_logger(f"{params.exp_dir}/log-ali")
+
+    logging.info(f"Computing alignments for {params.dataset} - started")
+    logging.info(params)
+
+    out_dir = Path(params.out_dir)
+    out_dir.mkdir(exist_ok=True)
+
+    out_labels_ali_filename = out_dir / f"labels_{params.dataset}.h5"
+    out_aux_labels_ali_filename = out_dir / f"aux_labels_{params.dataset}.h5"
+    out_manifest_filename = out_dir / f"cuts_{params.dataset}.json.gz"
+
+    for f in (
+        out_labels_ali_filename,
+        out_aux_labels_ali_filename,
+        out_manifest_filename,
+    ):
+        if f.exists():
+            logging.info(f"{f} exists - skipping")
+            return
+
+    lexicon = Lexicon(params.lang_dir)
+    max_token_id = max(lexicon.tokens)
+    num_classes = max_token_id + 1  # +1 for the blank
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+    logging.info(f"device: {device}")
+
+    graph_compiler = BpeCtcTrainingGraphCompiler(
+        params.lang_dir,
+        device=device,
+        sos_token="",
+        eos_token="",
+    )
+
+    logging.info("About to create model")
+    model = Conformer(
+        num_features=params.feature_dim,
+        nhead=params.nhead,
+        d_model=params.attention_dim,
+        num_classes=num_classes,
+        subsampling_factor=params.subsampling_factor,
+        num_decoder_layers=params.num_decoder_layers,
+        vgg_frontend=params.vgg_frontend,
+        use_feat_batchnorm=params.use_feat_batchnorm,
+    )
+    model.to(device)
+
+    if params.avg == 1:
+        load_checkpoint(
+            f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=False
+        )
+    else:
+        start = params.epoch - params.avg + 1
+        filenames = []
+        for i in range(start, params.epoch + 1):
+            if start >= 0:
+                filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+        logging.info(f"averaging {filenames}")
+        model.load_state_dict(
+            average_checkpoints(filenames, device=device), strict=False
+        )
+
+    model.eval()
+
+    librispeech = LibriSpeechAsrDataModule(args)
+    if params.dataset == "test-clean":
+        test_clean_cuts = librispeech.test_clean_cuts()
+        dl = librispeech.test_dataloaders(test_clean_cuts)
+    elif params.dataset == "test-other":
+        test_other_cuts = librispeech.test_other_cuts()
+        dl = librispeech.test_dataloaders(test_other_cuts)
+    elif params.dataset == "train-clean-100":
+        train_clean_100_cuts = librispeech.train_clean_100_cuts()
+        dl = librispeech.train_dataloaders(train_clean_100_cuts)
+    elif params.dataset == "train-clean-360":
+        train_clean_360_cuts = librispeech.train_clean_360_cuts()
+        dl = librispeech.train_dataloaders(train_clean_360_cuts)
+    elif params.dataset == "train-other-500":
+        train_other_500_cuts = librispeech.train_other_500_cuts()
+        dl = librispeech.train_dataloaders(train_other_500_cuts)
+    elif params.dataset == "dev-clean":
+        dev_clean_cuts = librispeech.dev_clean_cuts()
+        dl = librispeech.valid_dataloaders(dev_clean_cuts)
+    else:
+        assert params.dataset == "dev-other", f"{params.dataset}"
+        dev_other_cuts = librispeech.dev_other_cuts()
+        dl = librispeech.valid_dataloaders(dev_other_cuts)
+
+    logging.info(f"Processing {params.dataset}")
+    with NumpyHdf5Writer(out_labels_ali_filename) as labels_writer:
+        with NumpyHdf5Writer(out_aux_labels_ali_filename) as aux_labels_writer:
+            cut_set = compute_alignments(
+                model=model,
+                dl=dl,
+                labels_writer=labels_writer,
+                aux_labels_writer=aux_labels_writer,
+                params=params,
+                graph_compiler=graph_compiler,
+            )
+
+    cut_set.to_file(out_manifest_filename)
+
+    logging.info(
+        f"For dataset {params.dataset}, its alignments with repeats are "
+        f"saved to {out_labels_ali_filename}, the alignments without repeats "
+        f"are saved to {out_aux_labels_ali_filename}, and the cut manifest "
+        f"file is {out_manifest_filename}. Number of cuts: {len(cut_set)}"
+    )
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/mgb2/ASR/conformer_ctc/asr_datamodule.py b/egs/mgb2/ASR/conformer_ctc/asr_datamodule.py
new file mode 100644
index 000000000..8242e986d
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/asr_datamodule.py
@@ -0,0 +1,372 @@
+# Copyright 2022 Johns Hopkins University  (Amir Hussein)
+# Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
+
+
+import argparse
+import inspect
+import logging
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+import torch
+from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
+from lhotse.dataset import (
+    CutConcatenate,
+    CutMix,
+    DynamicBucketingSampler,
+    K2SpeechRecognitionDataset,
+    PrecomputedFeatures,
+    SingleCutSampler,
+    SpecAugment,
+)
+from lhotse.dataset.input_strategies import OnTheFlyFeatures
+from lhotse.utils import fix_random_seed
+from torch.utils.data import DataLoader
+
+from icefall.utils import str2bool
+
+
+class _SeedWorkers:
+    def __init__(self, seed: int):
+        self.seed = seed
+
+    def __call__(self, worker_id: int):
+        fix_random_seed(self.seed + worker_id)
+
+
+class MGB2AsrDataModule:
+
+    """
+    DataModule for k2 ASR experiments.
+    It assumes there is always one train and valid dataloader,
+    but there can be multiple test dataloaders
+
+    It contains all the common data pipeline modules used in ASR
+    experiments, e.g.:
+    - dynamic batch size,
+    - bucketing samplers,
+    - cut concatenation,
+    - augmentation,
+    - on-the-fly feature extraction
+
+    This class should be derived for specific corpora used in ASR tasks.
+    """
+
+    def __init__(self, args: argparse.Namespace):
+        self.args = args
+
+    @classmethod
+    def add_arguments(cls, parser: argparse.ArgumentParser):
+        group = parser.add_argument_group(
+            title="ASR data related options",
+            description="These options are used for the preparation of "
+            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+            "effective batch sizes, sampling strategies, applied data "
+            "augmentations, etc.",
+        )
+        group.add_argument(
+            "--manifest-dir",
+            type=Path,
+            default=Path("data/fbank"),
+            help="Path to directory with train/valid/test cuts.",
+        )
+        group.add_argument(
+            "--max-duration",
+            type=int,
+            default=200.0,
+            help="Maximum pooled recordings duration (seconds) in a "
+            "single batch. You can reduce it if it causes CUDA OOM.",
+        )
+        group.add_argument(
+            "--bucketing-sampler",
+            type=str2bool,
+            default=True,
+            help="When enabled, the batches will come from buckets of "
+            "similar duration (saves padding frames).",
+        )
+        group.add_argument(
+            "--num-buckets",
+            type=int,
+            default=30,
+            help="The number of buckets for the DynamicBucketingSampler"
+            "(you might want to increase it for larger datasets).",
+        )
+        group.add_argument(
+            "--concatenate-cuts",
+            type=str2bool,
+            default=False,
+            help="When enabled, utterances (cuts) will be concatenated "
+            "to minimize the amount of padding.",
+        )
+        group.add_argument(
+            "--duration-factor",
+            type=float,
+            default=1.0,
+            help="Determines the maximum duration of a concatenated cut "
+            "relative to the duration of the longest cut in a batch.",
+        )
+        group.add_argument(
+            "--gap",
+            type=float,
+            default=1.0,
+            help="The amount of padding (in seconds) inserted between "
+            "concatenated cuts. This padding is filled with noise when "
+            "noise augmentation is used.",
+        )
+        group.add_argument(
+            "--on-the-fly-feats",
+            type=str2bool,
+            default=False,
+            help="When enabled, use on-the-fly cut mixing and feature "
+            "extraction. Will drop existing precomputed feature manifests "
+            "if available.",
+        )
+        group.add_argument(
+            "--shuffle",
+            type=str2bool,
+            default=True,
+            help="When enabled (=default), the examples will be "
+            "shuffled for each epoch.",
+        )
+        group.add_argument(
+            "--drop-last",
+            type=str2bool,
+            default=True,
+            help="Whether to drop last batch. Used by sampler.",
+        )
+        group.add_argument(
+            "--return-cuts",
+            type=str2bool,
+            default=True,
+            help="When enabled, each batch will have the "
+            "field: batch['supervisions']['cut'] with the cuts that "
+            "were used to construct it.",
+        )
+
+        group.add_argument(
+            "--num-workers",
+            type=int,
+            default=1,
+            help="The number of training dataloader workers that "
+            "collect the batches.",
+        )
+
+        group.add_argument(
+            "--enable-spec-aug",
+            type=str2bool,
+            default=True,
+            help="When enabled, use SpecAugment for training dataset.",
+        )
+
+        group.add_argument(
+            "--spec-aug-time-warp-factor",
+            type=int,
+            default=80,
+            help="Used only when --enable-spec-aug is True. "
+            "It specifies the factor for time warping in SpecAugment. "
+            "Larger values mean more warping. "
+            "A value less than 1 means to disable time warp.",
+        )
+
+        group.add_argument(
+            "--enable-musan",
+            type=str2bool,
+            default=True,
+            help="When enabled, select noise from MUSAN and mix it"
+            "with training dataset. ",
+        )
+
+    def train_dataloaders(
+        self,
+        cuts_train: CutSet,
+        sampler_state_dict: Optional[Dict[str, Any]] = None,
+    ) -> DataLoader:
+
+        transforms = []
+        if self.args.enable_musan:
+            logging.info("Enable MUSAN")
+            logging.info("About to get Musan cuts")
+            cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.jsonl.gz")
+
+            transforms.append(
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+            )
+        else:
+            logging.info("Disable MUSAN")
+
+        if self.args.concatenate_cuts:
+            logging.info(
+                f"Using cut concatenation with duration factor "
+                f"{self.args.duration_factor} and gap {self.args.gap}."
+            )
+            # Cut concatenation should be the first transform in the list,
+            # so that if we e.g. mix noise in, it will fill the gaps between
+            # different utterances.
+            transforms = [
+                CutConcatenate(
+                    duration_factor=self.args.duration_factor, gap=self.args.gap
+                )
+            ] + transforms
+
+        input_transforms = []
+        if self.args.enable_spec_aug:
+            logging.info("Enable SpecAugment")
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+            # Set the value of num_frame_masks according to Lhotse's version.
+            # In different Lhotse's versions, the default of num_frame_masks is
+            # different.
+            num_frame_masks = 10
+            num_frame_masks_parameter = inspect.signature(
+                SpecAugment.__init__
+            ).parameters["num_frame_masks"]
+            if num_frame_masks_parameter.default == 1:
+                num_frame_masks = 2
+            logging.info(f"Num frame mask: {num_frame_masks}")
+            input_transforms.append(
+                SpecAugment(
+                    time_warp_factor=self.args.spec_aug_time_warp_factor,
+                    num_frame_masks=num_frame_masks,
+                    features_mask_size=27,
+                    num_feature_masks=2,
+                    frames_mask_size=100,
+                )
+            )
+        else:
+            logging.info("Disable SpecAugment")
+
+        logging.info("About to create train dataset")
+        train = K2SpeechRecognitionDataset(
+            cut_transforms=transforms,
+            input_transforms=input_transforms,
+            return_cuts=self.args.return_cuts,
+        )
+
+        if self.args.on_the_fly_feats:
+            # NOTE: the PerturbSpeed transform should be added only if we
+            # remove it from data prep stage.
+            # Add on-the-fly speed perturbation; since originally it would
+            # have increased epoch size by 3, we will apply prob 2/3 and use
+            # 3x more epochs.
+            # Speed perturbation probably should come first before
+            # concatenation, but in principle the transforms order doesn't have
+            # to be strict (e.g. could be randomized)
+            # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms   # noqa
+            # Drop feats to be on the safe side.
+            train = K2SpeechRecognitionDataset(
+                cut_transforms=transforms,
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_transforms=input_transforms,
+                return_cuts=self.args.return_cuts,
+            )
+
+        if self.args.bucketing_sampler:
+            logging.info("Using DynamicBucketingSampler.")
+            train_sampler = DynamicBucketingSampler(
+                cuts_train,
+                max_duration=self.args.max_duration,
+                shuffle=self.args.shuffle,
+                num_buckets=self.args.num_buckets,
+                drop_last=self.args.drop_last,
+            )
+        else:
+            logging.info("Using SingleCutSampler.")
+            train_sampler = SingleCutSampler(
+                cuts_train,
+                max_duration=self.args.max_duration,
+                shuffle=self.args.shuffle,
+            )
+        logging.info("About to create train dataloader")
+
+        if sampler_state_dict is not None:
+            logging.info("Loading sampler state dict")
+            train_sampler.load_state_dict(sampler_state_dict)
+        # 'seed' is derived from the current random state, which will have
+        # previously been set in the main process.
+        seed = torch.randint(0, 100000, ()).item()
+        worker_init_fn = _SeedWorkers(seed)
+
+        train_dl = DataLoader(
+            train,
+            sampler=train_sampler,
+            batch_size=None,
+            num_workers=self.args.num_workers,
+            persistent_workers=False,
+            worker_init_fn=worker_init_fn,
+        )
+
+        return train_dl
+
+    def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+        transforms = []
+        if self.args.concatenate_cuts:
+            transforms = [
+                CutConcatenate(
+                    duration_factor=self.args.duration_factor, gap=self.args.gap
+                )
+            ] + transforms
+
+        logging.info("About to create dev dataset")
+        if self.args.on_the_fly_feats:
+            validate = K2SpeechRecognitionDataset(
+                cut_transforms=transforms,
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                return_cuts=self.args.return_cuts,
+            )
+        else:
+            validate = K2SpeechRecognitionDataset(
+                cut_transforms=transforms,
+                return_cuts=self.args.return_cuts,
+            )
+        valid_sampler = DynamicBucketingSampler(
+            cuts_valid,
+            max_duration=self.args.max_duration,
+            shuffle=False,
+        )
+        logging.info("About to create dev dataloader")
+        valid_dl = DataLoader(
+            validate,
+            sampler=valid_sampler,
+            batch_size=None,
+            num_workers=2,
+            persistent_workers=False,
+        )
+
+        return valid_dl
+
+    def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+        logging.debug("About to create test dataset")
+        test = K2SpeechRecognitionDataset(
+            input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
+            if self.args.on_the_fly_feats
+            else PrecomputedFeatures(),
+            return_cuts=self.args.return_cuts,
+        )
+        sampler = DynamicBucketingSampler(
+            cuts, max_duration=self.args.max_duration, shuffle=False
+        )
+        logging.debug("About to create test dataloader")
+        test_dl = DataLoader(
+            test,
+            batch_size=None,
+            sampler=sampler,
+            num_workers=self.args.num_workers,
+        )
+        return test_dl
+
+    @lru_cache()
+    def train_cuts(self) -> CutSet:
+        logging.info("About to get train cuts")
+        return load_manifest_lazy(self.args.manifest_dir / "cuts_train_shuf.jsonl.gz")
+
+    @lru_cache()
+    def dev_cuts(self) -> CutSet:
+        logging.info("About to get dev cuts")
+
+        return load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz")
+
+    @lru_cache()
+    def test_cuts(self) -> CutSet:
+        logging.info("About to get test cuts")
+
+        return load_manifest_lazy(self.args.manifest_dir / "cuts_test.jsonl.gz")
diff --git a/egs/mgb2/ASR/conformer_ctc/compile_hlg.py b/egs/mgb2/ASR/conformer_ctc/compile_hlg.py
new file mode 120000
index 000000000..471aa7fb4
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/compile_hlg.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compile_hlg.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/conformer_ctc/compute_fbank_musan.py b/egs/mgb2/ASR/conformer_ctc/compute_fbank_musan.py
new file mode 120000
index 000000000..5833f2484
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/compute_fbank_musan.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compute_fbank_musan.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/conformer_ctc/conformer.py b/egs/mgb2/ASR/conformer_ctc/conformer.py
new file mode 120000
index 000000000..d1f4209d7
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/conformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/conformer_ctc/conformer.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/conformer_ctc/convert_transcript_words_to_tokens.py b/egs/mgb2/ASR/conformer_ctc/convert_transcript_words_to_tokens.py
new file mode 120000
index 000000000..2ce13fd69
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/convert_transcript_words_to_tokens.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/convert_transcript_words_to_tokens.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/conformer_ctc/decode.py b/egs/mgb2/ASR/conformer_ctc/decode.py
new file mode 100755
index 000000000..f771d7f1e
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/decode.py
@@ -0,0 +1,695 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import logging
+import pdb
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import MGB2AsrDataModule
+from conformer import Conformer
+
+from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
+from icefall.checkpoint import average_checkpoints, load_checkpoint
+from icefall.decode import (
+    get_lattice,
+    nbest_decoding,
+    nbest_oracle,
+    one_best_decoding,
+    rescore_with_attention_decoder,
+    rescore_with_n_best_list,
+    rescore_with_whole_lattice,
+)
+from icefall.env import get_env_info
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+    AttributeDict,
+    get_texts,
+    setup_logger,
+    store_transcripts,
+    write_error_stats,
+)
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=50,
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
+    )
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=5,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
+    )
+
+    parser.add_argument(
+        "--method",
+        type=str,
+        default="attention-decoder",
+        help="""Decoding method.
+        Supported values are:
+            - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece
+              model, i.e., lang_dir/bpe.model, to convert word pieces to words.
+              It needs neither a lexicon nor an n-gram LM.
+            - (1) 1best. Extract the best path from the decoding lattice as the
+              decoding result.
+            - (2) nbest. Extract n paths from the decoding lattice; the path
+              with the highest score is the decoding result.
+            - (3) nbest-rescoring. Extract n paths from the decoding lattice,
+              rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
+              the highest score is the decoding result.
+            - (4) whole-lattice-rescoring. Rescore the decoding lattice with an
+              n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
+              is the decoding result.
+            - (5) attention-decoder. Extract n paths from the LM rescored
+              lattice, the path with the highest score is the decoding result.
+            - (6) nbest-oracle. Its WER is the lower bound of any n-best
+              rescoring method can achieve. Useful for debugging n-best
+              rescoring method.
+        """,
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=20,
+        help="""Number of paths for n-best based decoding method.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, attention-decoder, and nbest-oracle
+        """,
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=0.5,
+        help="""The scale to be applied to `lattice.scores`.
+        It's needed if you use any kinds of n-best based rescoring.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, attention-decoder, and nbest-oracle
+        A smaller value results in more unique paths.
+        """,
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="conformer_ctc/exp",
+        help="The experiment dir",
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        default="data/lang_bpe_500",
+        help="The lang dir",
+    )
+
+    parser.add_argument(
+        "--lm-dir",
+        type=str,
+        default="data/lm",
+        help="""The LM dir.
+        It should contain either G_4_gram.pt or G_4_gram.fst.txt
+        """,
+    )
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    params = AttributeDict(
+        {
+            # parameters for conformer
+            "subsampling_factor": 4,
+            "vgg_frontend": False,
+            "use_feat_batchnorm": True,
+            "feature_dim": 80,
+            "nhead": 8,
+            "attention_dim": 512,
+            "num_decoder_layers": 6,
+            # parameters for decoding
+            "search_beam": 20,
+            "output_beam": 8,
+            "min_active_states": 30,
+            "max_active_states": 10000,
+            "use_double_scores": True,
+            "env_info": get_env_info(),
+        }
+    )
+    return params
+
+
+def decode_one_batch(
+    params: AttributeDict,
+    model: nn.Module,
+    HLG: Optional[k2.Fsa],
+    H: Optional[k2.Fsa],
+    bpe_model: Optional[spm.SentencePieceProcessor],
+    batch: dict,
+    word_table: k2.SymbolTable,
+    sos_id: int,
+    eos_id: int,
+    G: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+    """Decode one batch and return the result in a dict. The dict has the
+    following format:
+
+        - key: It indicates the setting used for decoding. For example,
+               if no rescoring is used, the key is the string `no_rescore`.
+               If LM rescoring is used, the key is the string `lm_scale_xxx`,
+               where `xxx` is the value of `lm_scale`. An example key is
+               `lm_scale_0.7`
+        - value: It contains the decoding result. `len(value)` equals to
+                 batch size. `value[i]` is the decoding result for the i-th
+                 utterance in the given batch.
+    Args:
+      params:
+        It's the return value of :func:`get_params`.
+
+        - params.method is "1best", it uses 1best decoding without LM rescoring.
+        - params.method is "nbest", it uses nbest decoding without LM rescoring.
+        - params.method is "nbest-rescoring", it uses nbest LM rescoring.
+        - params.method is "whole-lattice-rescoring", it uses whole lattice LM
+          rescoring.
+
+      model:
+        The neural model.
+      HLG:
+        The decoding graph. Used only when params.method is NOT ctc-decoding.
+      H:
+        The ctc topo. Used only when params.method is ctc-decoding.
+      bpe_model:
+        The BPE model. Used only when params.method is ctc-decoding.
+      batch:
+        It is the return value from iterating
+        `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+        for the format of the `batch`.
+      word_table:
+        The word symbol table.
+      sos_id:
+        The token ID of the SOS.
+      eos_id:
+        The token ID of the EOS.
+      G:
+        An LM. It is not None when params.method is "nbest-rescoring"
+        or "whole-lattice-rescoring". In general, the G in HLG
+        is a 3-gram LM, while this G is a 4-gram LM.
+    Returns:
+      Return the decoding result. See above description for the format of
+      the returned dict. Note: If it decodes to nothing, then return None.
+    """
+    if HLG is not None:
+        device = HLG.device
+    else:
+        device = H.device
+    feature = batch["inputs"]
+    assert feature.ndim == 3
+    feature = feature.to(device)
+    # at entry, feature is (N, T, C)
+
+    supervisions = batch["supervisions"]
+
+    nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
+    # nnet_output is (N, T, C)
+
+    supervision_segments = torch.stack(
+        (
+            supervisions["sequence_idx"],
+            supervisions["start_frame"] // params.subsampling_factor,
+            supervisions["num_frames"] // params.subsampling_factor,
+        ),
+        1,
+    ).to(torch.int32)
+
+    if H is None:
+        assert HLG is not None
+        decoding_graph = HLG
+    else:
+        assert HLG is None
+        assert bpe_model is not None
+        decoding_graph = H
+
+    lattice = get_lattice(
+        nnet_output=nnet_output,
+        decoding_graph=decoding_graph,
+        supervision_segments=supervision_segments,
+        search_beam=params.search_beam,
+        output_beam=params.output_beam,
+        min_active_states=params.min_active_states,
+        max_active_states=params.max_active_states,
+        subsampling_factor=params.subsampling_factor,
+    )
+
+    if params.method == "ctc-decoding":
+        best_path = one_best_decoding(
+            lattice=lattice, use_double_scores=params.use_double_scores
+        )
+        # Note: `best_path.aux_labels` contains token IDs, not word IDs
+        # since we are using H, not HLG here.
+        #
+        # token_ids is a lit-of-list of IDs
+        token_ids = get_texts(best_path)
+
+        # hyps is a list of str, e.g., ['xxx yyy zzz', ...]
+        hyps = bpe_model.decode(token_ids)
+
+        # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
+        hyps = [s.split() for s in hyps]
+        key = "ctc-decoding"
+        return {key: hyps}
+
+    if params.method == "nbest-oracle":
+        # Note: You can also pass rescored lattices to it.
+        # We choose the HLG decoded lattice for speed reasons
+        # as HLG decoding is faster and the oracle WER
+        # is only slightly worse than that of rescored lattices.
+        best_path = nbest_oracle(
+            lattice=lattice,
+            num_paths=params.num_paths,
+            ref_texts=supervisions["text"],
+            word_table=word_table,
+            nbest_scale=params.nbest_scale,
+            oov="",
+        )
+        hyps = get_texts(best_path)
+        hyps = [[word_table[i] for i in ids] for ids in hyps]
+        key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}"  # noqa
+        return {key: hyps}
+
+    if params.method in ["1best", "nbest"]:
+        if params.method == "1best":
+            best_path = one_best_decoding(
+                lattice=lattice, use_double_scores=params.use_double_scores
+            )
+            key = "no_rescore"
+        else:
+            best_path = nbest_decoding(
+                lattice=lattice,
+                num_paths=params.num_paths,
+                use_double_scores=params.use_double_scores,
+                nbest_scale=params.nbest_scale,
+            )
+            key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}"  # noqa
+
+        hyps = get_texts(best_path)
+        hyps = [[word_table[i] for i in ids] for ids in hyps]
+        return {key: hyps}
+
+    assert params.method in [
+        "nbest-rescoring",
+        "whole-lattice-rescoring",
+        "attention-decoder",
+    ]
+
+    lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
+    lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
+    lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
+
+    if params.method == "nbest-rescoring":
+        best_path_dict = rescore_with_n_best_list(
+            lattice=lattice,
+            G=G,
+            num_paths=params.num_paths,
+            lm_scale_list=lm_scale_list,
+            nbest_scale=params.nbest_scale,
+        )
+    elif params.method == "whole-lattice-rescoring":
+        best_path_dict = rescore_with_whole_lattice(
+            lattice=lattice,
+            G_with_epsilon_loops=G,
+            lm_scale_list=lm_scale_list,
+        )
+    elif params.method == "attention-decoder":
+        # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
+        rescored_lattice = rescore_with_whole_lattice(
+            lattice=lattice,
+            G_with_epsilon_loops=G,
+            lm_scale_list=None,
+        )
+        # TODO: pass `lattice` instead of `rescored_lattice` to
+        # `rescore_with_attention_decoder`
+
+        best_path_dict = rescore_with_attention_decoder(
+            lattice=rescored_lattice,
+            num_paths=params.num_paths,
+            model=model,
+            memory=memory,
+            memory_key_padding_mask=memory_key_padding_mask,
+            sos_id=sos_id,
+            eos_id=eos_id,
+            nbest_scale=params.nbest_scale,
+        )
+    else:
+        assert False, f"Unsupported decoding method: {params.method}"
+
+    ans = dict()
+    if best_path_dict is not None:
+        for lm_scale_str, best_path in best_path_dict.items():
+            hyps = get_texts(best_path)
+            hyps = [[word_table[i] for i in ids] for ids in hyps]
+            ans[lm_scale_str] = hyps
+    else:
+        ans = None
+    return ans
+
+
+def decode_dataset(
+    dl: torch.utils.data.DataLoader,
+    params: AttributeDict,
+    model: nn.Module,
+    HLG: Optional[k2.Fsa],
+    H: Optional[k2.Fsa],
+    bpe_model: Optional[spm.SentencePieceProcessor],
+    word_table: k2.SymbolTable,
+    sos_id: int,
+    eos_id: int,
+    G: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[List[str], List[str]]]]:
+    """Decode dataset.
+
+    Args:
+      dl:
+        PyTorch's dataloader containing the dataset to decode.
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The neural model.
+      HLG:
+        The decoding graph. Used only when params.method is NOT ctc-decoding.
+      H:
+        The ctc topo. Used only when params.method is ctc-decoding.
+      bpe_model:
+        The BPE model. Used only when params.method is ctc-decoding.
+      word_table:
+        It is the word symbol table.
+      sos_id:
+        The token ID for SOS.
+      eos_id:
+        The token ID for EOS.
+      G:
+        An LM. It is not None when params.method is "nbest-rescoring"
+        or "whole-lattice-rescoring". In general, the G in HLG
+        is a 3-gram LM, while this G is a 4-gram LM.
+    Returns:
+      Return a dict, whose key may be "no-rescore" if no LM rescoring
+      is used, or it may be "lm_scale_0.7" if LM rescoring is used.
+      Its value is a list of tuples. Each tuple contains two elements:
+      The first is the reference transcript, and the second is the
+      predicted result.
+    """
+    num_cuts = 0
+
+    try:
+        num_batches = len(dl)
+    except TypeError:
+        num_batches = "?"
+
+    results = defaultdict(list)
+    for batch_idx, batch in enumerate(dl):
+        # pdb.set_trace()
+        texts = batch["supervisions"]["text"]
+
+        hyps_dict = decode_one_batch(
+            params=params,
+            model=model,
+            HLG=HLG,
+            H=H,
+            bpe_model=bpe_model,
+            batch=batch,
+            word_table=word_table,
+            G=G,
+            sos_id=sos_id,
+            eos_id=eos_id,
+        )
+
+        if hyps_dict is not None:
+            for lm_scale, hyps in hyps_dict.items():
+                this_batch = []
+                assert len(hyps) == len(texts)
+                for hyp_words, ref_text in zip(hyps, texts):
+                    ref_words = ref_text.split()
+                    this_batch.append((ref_words, hyp_words))
+
+                results[lm_scale].extend(this_batch)
+        else:
+            assert len(results) > 0, "It should not decode to empty in the first batch!"
+            this_batch = []
+            hyp_words = []
+            for ref_text in texts:
+                ref_words = ref_text.split()
+                this_batch.append((ref_words, hyp_words))
+
+            for lm_scale in results.keys():
+                results[lm_scale].extend(this_batch)
+
+        num_cuts += len(texts)
+
+        if batch_idx % 100 == 0:
+            batch_str = f"{batch_idx}/{num_batches}"
+
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+    return results
+
+
+def save_results(
+    params: AttributeDict,
+    test_set_name: str,
+    results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
+):
+    if params.method == "attention-decoder":
+        # Set it to False since there are too many logs.
+        enable_log = False
+    else:
+        enable_log = True
+    test_set_wers = dict()
+    for key, results in results_dict.items():
+        recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
+        store_transcripts(filename=recog_path, texts=results)
+        if enable_log:
+            logging.info(f"The transcripts are stored in {recog_path}")
+
+        # The following prints out WERs, per-word error statistics and aligned
+        # ref/hyp pairs.
+        errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
+        with open(errs_filename, "w") as f:
+            wer = write_error_stats(
+                f, f"{test_set_name}-{key}", results, enable_log=enable_log
+            )
+            test_set_wers[key] = wer
+
+        if enable_log:
+            logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+    test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+    errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
+    with open(errs_info, "w") as f:
+        print("settings\tWER", file=f)
+        for key, val in test_set_wers:
+            print("{}\t{}".format(key, val), file=f)
+
+    s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+    note = "\tbest for {}".format(test_set_name)
+    for key, val in test_set_wers:
+        s += "{}\t{}{}\n".format(key, val, note)
+        note = ""
+    logging.info(s)
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    MGB2AsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+    args.lang_dir = Path(args.lang_dir)
+    args.lm_dir = Path(args.lm_dir)
+
+    params = get_params()
+    params.update(vars(args))
+
+    setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
+    logging.info("Decoding started")
+    logging.info(params)
+
+    lexicon = Lexicon(params.lang_dir)
+    max_token_id = max(lexicon.tokens)
+    num_classes = max_token_id + 1  # +1 for the blank
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    graph_compiler = BpeCtcTrainingGraphCompiler(
+        params.lang_dir,
+        device=device,
+        sos_token="",
+        eos_token="",
+    )
+    sos_id = graph_compiler.sos_id
+    eos_id = graph_compiler.eos_id
+
+    if params.method == "ctc-decoding":
+        HLG = None
+        H = k2.ctc_topo(
+            max_token=max_token_id,
+            modified=False,
+            device=device,
+        )
+        bpe_model = spm.SentencePieceProcessor()
+        bpe_model.load(str(params.lang_dir / "bpe.model"))
+    else:
+        H = None
+        bpe_model = None
+        HLG = k2.Fsa.from_dict(
+            torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
+        )
+        assert HLG.requires_grad is False
+
+        if not hasattr(HLG, "lm_scores"):
+            HLG.lm_scores = HLG.scores.clone()
+
+    if params.method in (
+        "nbest-rescoring",
+        "whole-lattice-rescoring",
+        "attention-decoder",
+    ):
+        if not (params.lm_dir / "G_4_gram.pt").is_file():
+            logging.info("Loading G_4_gram.fst.txt")
+            logging.warning("It may take 8 minutes.")
+            with open(params.lm_dir / "G_4_gram.fst.txt") as f:
+                first_word_disambig_id = lexicon.word_table["#0"]
+
+                G = k2.Fsa.from_openfst(f.read(), acceptor=False)
+                # G.aux_labels is not needed in later computations, so
+                # remove it here.
+                del G.aux_labels
+                # CAUTION: The following line is crucial.
+                # Arcs entering the back-off state have label equal to #0.
+                # We have to change it to 0 here.
+                G.labels[G.labels >= first_word_disambig_id] = 0
+                # See https://github.com/k2-fsa/k2/issues/874
+                # for why we need to set G.properties to None
+                G.__dict__["_properties"] = None
+                G = k2.Fsa.from_fsas([G]).to(device)
+                G = k2.arc_sort(G)
+                # Save a dummy value so that it can be loaded in C++.
+                # See https://github.com/pytorch/pytorch/issues/67902
+                # for why we need to do this.
+                G.dummy = 1
+
+                torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
+        else:
+            logging.info("Loading pre-compiled G_4_gram.pt")
+            d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
+            G = k2.Fsa.from_dict(d)
+
+        if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
+            # Add epsilon self-loops to G as we will compose
+            # it with the whole lattice later
+            G = k2.add_epsilon_self_loops(G)
+            G = k2.arc_sort(G)
+            G = G.to(device)
+
+        # G.lm_scores is used to replace HLG.lm_scores during
+        # LM rescoring.
+        G.lm_scores = G.scores.clone()
+    else:
+        G = None
+
+    model = Conformer(
+        num_features=params.feature_dim,
+        nhead=params.nhead,
+        d_model=params.attention_dim,
+        num_classes=num_classes,
+        subsampling_factor=params.subsampling_factor,
+        num_decoder_layers=params.num_decoder_layers,
+        vgg_frontend=params.vgg_frontend,
+        use_feat_batchnorm=params.use_feat_batchnorm,
+    )
+
+    if params.avg == 1:
+        load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+    else:
+        start = params.epoch - params.avg + 1
+        filenames = []
+        for i in range(start, params.epoch + 1):
+            if start >= 0:
+                filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+        logging.info(f"averaging {filenames}")
+        model.to(device)
+        model.load_state_dict(average_checkpoints(filenames, device=device))
+
+    model.to(device)
+    model.eval()
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    MGB2 = MGB2AsrDataModule(args)
+
+    test_cuts = MGB2.test_cuts()
+    dev_cuts = MGB2.dev_cuts()
+
+    test_dl = MGB2.test_dataloaders(test_cuts)
+    dev_dl = MGB2.test_dataloaders(dev_cuts)
+
+    test_sets = ["test", "dev"]
+    test_all_dl = [test_dl, dev_dl]
+
+    for test_set, test_dl in zip(test_sets, test_all_dl):
+        results_dict = decode_dataset(
+            dl=test_dl,
+            params=params,
+            model=model,
+            HLG=HLG,
+            H=H,
+            bpe_model=bpe_model,
+            word_table=lexicon.word_table,
+            G=G,
+            sos_id=sos_id,
+            eos_id=eos_id,
+        )
+
+        save_results(params=params, test_set_name=test_set, results_dict=results_dict)
+
+    logging.info("Done!")
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/mgb2/ASR/conformer_ctc/download_lm.py b/egs/mgb2/ASR/conformer_ctc/download_lm.py
new file mode 120000
index 000000000..c9668bd2d
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/download_lm.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/download_lm.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/conformer_ctc/export.py b/egs/mgb2/ASR/conformer_ctc/export.py
new file mode 120000
index 000000000..60e314d9d
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/export.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/conformer_ctc/export.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/conformer_ctc/generate_unique_lexicon.py b/egs/mgb2/ASR/conformer_ctc/generate_unique_lexicon.py
new file mode 120000
index 000000000..c0aea1403
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/generate_unique_lexicon.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/generate_unique_lexicon.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/conformer_ctc/label_smoothing.py b/egs/mgb2/ASR/conformer_ctc/label_smoothing.py
new file mode 120000
index 000000000..e9d239fff
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/label_smoothing.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/conformer_ctc/label_smoothing.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/conformer_ctc/pretrained.py b/egs/mgb2/ASR/conformer_ctc/pretrained.py
new file mode 100755
index 000000000..d30ca98d8
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/pretrained.py
@@ -0,0 +1,430 @@
+#!/usr/bin/env python3
+# Copyright      2021  Xiaomi Corp.        (authors: Fangjun Kuang,
+#                                                    Mingshuang Luo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import logging
+import math
+from typing import List
+
+import k2
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from conformer import Conformer
+from torch.nn.utils.rnn import pad_sequence
+
+from icefall.decode import (
+    get_lattice,
+    one_best_decoding,
+    rescore_with_attention_decoder,
+    rescore_with_whole_lattice,
+)
+from icefall.utils import AttributeDict, get_texts
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--checkpoint",
+        type=str,
+        required=True,
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
+    )
+
+    parser.add_argument(
+        "--words-file",
+        type=str,
+        help="""Path to words.txt.
+        Used only when method is not ctc-decoding.
+        """,
+    )
+
+    parser.add_argument(
+        "--HLG",
+        type=str,
+        help="""Path to HLG.pt.
+        Used only when method is not ctc-decoding.
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        help="""Path to bpe.model.
+        Used only when method is ctc-decoding.
+        """,
+    )
+
+    parser.add_argument(
+        "--method",
+        type=str,
+        default="1best",
+        help="""Decoding method.
+        Possible values are:
+        (0) ctc-decoding - Use CTC decoding. It uses a sentence
+            piece model, i.e., lang_dir/bpe.model, to convert
+            word pieces to words. It needs neither a lexicon
+            nor an n-gram LM.
+        (1) 1best - Use the best path as decoding output. Only
+            the transformer encoder output is used for decoding.
+            We call it HLG decoding.
+        (2) whole-lattice-rescoring - Use an LM to rescore the
+            decoding lattice and then use 1best to decode the
+            rescored lattice.
+            We call it HLG decoding + n-gram LM rescoring.
+        (3) attention-decoder - Extract n paths from the rescored
+            lattice and use the transformer attention decoder for
+            rescoring.
+            We call it HLG decoding + n-gram LM rescoring + attention
+            decoder rescoring.
+        """,
+    )
+
+    parser.add_argument(
+        "--G",
+        type=str,
+        help="""An LM for rescoring.
+        Used only when method is
+        whole-lattice-rescoring or attention-decoder.
+        It's usually a 4-gram LM.
+        """,
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=100,
+        help="""
+        Used only when method is attention-decoder.
+        It specifies the size of n-best list.""",
+    )
+
+    parser.add_argument(
+        "--ngram-lm-scale",
+        type=float,
+        default=1.3,
+        help="""
+        Used only when method is whole-lattice-rescoring and attention-decoder.
+        It specifies the scale for n-gram LM scores.
+        (Note: You need to tune it on a dataset.)
+        """,
+    )
+
+    parser.add_argument(
+        "--attention-decoder-scale",
+        type=float,
+        default=1.2,
+        help="""
+        Used only when method is attention-decoder.
+        It specifies the scale for attention decoder scores.
+        (Note: You need to tune it on a dataset.)
+        """,
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=0.5,
+        help="""
+        Used only when method is attention-decoder.
+        It specifies the scale for lattice.scores when
+        extracting n-best lists. A smaller value results in
+        more unique number of paths with the risk of missing
+        the best path.
+        """,
+    )
+
+    parser.add_argument(
+        "--sos-id",
+        type=int,
+        default=1,
+        help="""
+        Used only when method is attention-decoder.
+        It specifies ID for the SOS token.
+        """,
+    )
+
+    parser.add_argument(
+        "--num-classes",
+        type=int,
+        default=500,
+        help="""
+        Vocab size in the BPE model.
+        """,
+    )
+
+    parser.add_argument(
+        "--eos-id",
+        type=int,
+        default=1,
+        help="""
+        Used only when method is attention-decoder.
+        It specifies ID for the EOS token.
+        """,
+    )
+
+    parser.add_argument(
+        "sound_files",
+        type=str,
+        nargs="+",
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
+    )
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    params = AttributeDict(
+        {
+            "sample_rate": 16000,
+            # parameters for conformer
+            "subsampling_factor": 4,
+            "vgg_frontend": False,
+            "use_feat_batchnorm": True,
+            "feature_dim": 80,
+            "nhead": 8,
+            "attention_dim": 512,
+            "num_decoder_layers": 6,
+            # parameters for decoding
+            "search_beam": 20,
+            "output_beam": 8,
+            "min_active_states": 30,
+            "max_active_states": 10000,
+            "use_double_scores": True,
+        }
+    )
+    return params
+
+
+def read_sound_files(
+    filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+    """Read a list of sound files into a list 1-D float32 torch tensors.
+    Args:
+      filenames:
+        A list of sound filenames.
+      expected_sample_rate:
+        The expected sample rate of the sound files.
+    Returns:
+      Return a list of 1-D float32 torch tensors.
+    """
+    ans = []
+    for f in filenames:
+        wave, sample_rate = torchaudio.load(f)
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
+        )
+        # We use only the first channel
+        ans.append(wave[0])
+    return ans
+
+
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+
+    params = get_params()
+    if args.method != "attention-decoder":
+        # to save memory as the attention decoder
+        # will not be used
+        params.num_decoder_layers = 0
+
+    params.update(vars(args))
+    logging.info(f"{params}")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    logging.info("Creating model")
+    model = Conformer(
+        num_features=params.feature_dim,
+        nhead=params.nhead,
+        d_model=params.attention_dim,
+        num_classes=params.num_classes,
+        subsampling_factor=params.subsampling_factor,
+        num_decoder_layers=params.num_decoder_layers,
+        vgg_frontend=params.vgg_frontend,
+        use_feat_batchnorm=params.use_feat_batchnorm,
+    )
+
+    checkpoint = torch.load(args.checkpoint, map_location="cpu")
+    model.load_state_dict(checkpoint["model"], strict=False)
+    model.to(device)
+    model.eval()
+
+    logging.info("Constructing Fbank computer")
+    opts = kaldifeat.FbankOptions()
+    opts.device = device
+    opts.frame_opts.dither = 0
+    opts.frame_opts.snip_edges = False
+    opts.frame_opts.samp_freq = params.sample_rate
+    opts.mel_opts.num_bins = params.feature_dim
+
+    fbank = kaldifeat.Fbank(opts)
+
+    logging.info(f"Reading sound files: {params.sound_files}")
+    waves = read_sound_files(
+        filenames=params.sound_files, expected_sample_rate=params.sample_rate
+    )
+    waves = [w.to(device) for w in waves]
+
+    logging.info("Decoding started")
+    features = fbank(waves)
+
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+
+    # Note: We don't use key padding mask for attention during decoding
+    with torch.no_grad():
+        nnet_output, memory, memory_key_padding_mask = model(features)
+
+    batch_size = nnet_output.shape[0]
+    supervision_segments = torch.tensor(
+        [[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
+        dtype=torch.int32,
+    )
+
+    if params.method == "ctc-decoding":
+        logging.info("Use CTC decoding")
+        bpe_model = spm.SentencePieceProcessor()
+        bpe_model.load(params.bpe_model)
+        max_token_id = params.num_classes - 1
+
+        H = k2.ctc_topo(
+            max_token=max_token_id,
+            modified=False,
+            device=device,
+        )
+
+        lattice = get_lattice(
+            nnet_output=nnet_output,
+            decoding_graph=H,
+            supervision_segments=supervision_segments,
+            search_beam=params.search_beam,
+            output_beam=params.output_beam,
+            min_active_states=params.min_active_states,
+            max_active_states=params.max_active_states,
+            subsampling_factor=params.subsampling_factor,
+        )
+
+        best_path = one_best_decoding(
+            lattice=lattice, use_double_scores=params.use_double_scores
+        )
+        token_ids = get_texts(best_path)
+        hyps = bpe_model.decode(token_ids)
+        hyps = [s.split() for s in hyps]
+    elif params.method in [
+        "1best",
+        "whole-lattice-rescoring",
+        "attention-decoder",
+    ]:
+        logging.info(f"Loading HLG from {params.HLG}")
+        HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
+        HLG = HLG.to(device)
+        if not hasattr(HLG, "lm_scores"):
+            # For whole-lattice-rescoring and attention-decoder
+            HLG.lm_scores = HLG.scores.clone()
+
+        if params.method in [
+            "whole-lattice-rescoring",
+            "attention-decoder",
+        ]:
+            logging.info(f"Loading G from {params.G}")
+            G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
+            # Add epsilon self-loops to G as we will compose
+            # it with the whole lattice later
+            G = G.to(device)
+            G = k2.add_epsilon_self_loops(G)
+            G = k2.arc_sort(G)
+            G.lm_scores = G.scores.clone()
+
+        lattice = get_lattice(
+            nnet_output=nnet_output,
+            decoding_graph=HLG,
+            supervision_segments=supervision_segments,
+            search_beam=params.search_beam,
+            output_beam=params.output_beam,
+            min_active_states=params.min_active_states,
+            max_active_states=params.max_active_states,
+            subsampling_factor=params.subsampling_factor,
+        )
+
+        if params.method == "1best":
+            logging.info("Use HLG decoding")
+            best_path = one_best_decoding(
+                lattice=lattice, use_double_scores=params.use_double_scores
+            )
+        elif params.method == "whole-lattice-rescoring":
+            logging.info("Use HLG decoding + LM rescoring")
+            best_path_dict = rescore_with_whole_lattice(
+                lattice=lattice,
+                G_with_epsilon_loops=G,
+                lm_scale_list=[params.ngram_lm_scale],
+            )
+            best_path = next(iter(best_path_dict.values()))
+        elif params.method == "attention-decoder":
+            logging.info("Use HLG + LM rescoring + attention decoder rescoring")
+            rescored_lattice = rescore_with_whole_lattice(
+                lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
+            )
+            best_path_dict = rescore_with_attention_decoder(
+                lattice=rescored_lattice,
+                num_paths=params.num_paths,
+                model=model,
+                memory=memory,
+                memory_key_padding_mask=memory_key_padding_mask,
+                sos_id=params.sos_id,
+                eos_id=params.eos_id,
+                nbest_scale=params.nbest_scale,
+                ngram_lm_scale=params.ngram_lm_scale,
+                attention_scale=params.attention_decoder_scale,
+            )
+            best_path = next(iter(best_path_dict.values()))
+
+        hyps = get_texts(best_path)
+        word_sym_table = k2.SymbolTable.from_file(params.words_file)
+        hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
+    else:
+        raise ValueError(f"Unsupported decoding method: {params.method}")
+
+    s = "\n"
+    for filename, hyp in zip(params.sound_files, hyps):
+        words = " ".join(hyp)
+        s += f"{filename}:\n{words}\n\n"
+    logging.info(s)
+
+    logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/mgb2/ASR/conformer_ctc/subsampling.py b/egs/mgb2/ASR/conformer_ctc/subsampling.py
new file mode 120000
index 000000000..16354dc73
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/subsampling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/conformer_ctc/subsampling.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/conformer_ctc/test_label_smoothing.py b/egs/mgb2/ASR/conformer_ctc/test_label_smoothing.py
new file mode 120000
index 000000000..04b959ecf
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/test_label_smoothing.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/conformer_ctc/test_label_smoothing.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/conformer_ctc/test_subsampling.py b/egs/mgb2/ASR/conformer_ctc/test_subsampling.py
new file mode 120000
index 000000000..98c3be3e6
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/test_subsampling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/conformer_ctc/test_subsampling.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/conformer_ctc/test_transformer.py b/egs/mgb2/ASR/conformer_ctc/test_transformer.py
new file mode 120000
index 000000000..8b0990ec6
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/test_transformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/conformer_ctc/test_transformer.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/conformer_ctc/train.py b/egs/mgb2/ASR/conformer_ctc/train.py
new file mode 100755
index 000000000..08ffee210
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/train.py
@@ -0,0 +1,766 @@
+#!/usr/bin/env python3
+# Copyright 2022 Johns Hopkins University  (Amir Hussein)
+# Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
+
+
+import argparse
+import logging
+from pathlib import Path
+from shutil import copyfile
+from typing import Optional, Tuple
+
+import k2
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import MGB2AsrDataModule
+from conformer import Conformer
+from lhotse.cut import Cut
+from lhotse.utils import fix_random_seed
+from torch import Tensor
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.nn.utils import clip_grad_norm_
+from torch.utils.tensorboard import SummaryWriter
+from transformer import Noam
+
+from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
+from icefall.checkpoint import load_checkpoint
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+    AttributeDict,
+    MetricsTracker,
+    encode_supervisions,
+    setup_logger,
+    str2bool,
+)
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--world-size",
+        type=int,
+        default=1,
+        help="Number of GPUs for DDP training.",
+    )
+
+    parser.add_argument(
+        "--master-port",
+        type=int,
+        default=12354,
+        help="Master port to use for DDP training.",
+    )
+
+    parser.add_argument(
+        "--tensorboard",
+        type=str2bool,
+        default=True,
+        help="Should various information be logged in tensorboard.",
+    )
+
+    parser.add_argument(
+        "--num-epochs",
+        type=int,
+        default=50,
+        help="Number of epochs to train.",
+    )
+
+    parser.add_argument(
+        "--start-epoch",
+        type=int,
+        default=0,
+        help="""Resume training from from this epoch.
+        If it is positive, it will load checkpoint from
+        conformer_ctc/exp/epoch-{start_epoch-1}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="conformer_ctc/exp",
+        help="""The experiment dir.
+        It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        default="data/lang_bpe_500",
+        help="""The lang dir
+        It contains language related input files such as
+        "lexicon.txt"
+        """,
+    )
+
+    parser.add_argument(
+        "--att-rate",
+        type=float,
+        default=0.8,
+        help="""The attention rate.
+        The total loss is (1 -  att_rate) * ctc_loss + att_rate * att_loss
+        """,
+    )
+
+    parser.add_argument(
+        "--num-decoder-layers",
+        type=int,
+        default=6,
+        help="""Number of decoder layer of transformer decoder.
+        Setting this to 0 will not create the decoder at all (pure CTC model)
+        """,
+    )
+
+    parser.add_argument(
+        "--lr-factor",
+        type=float,
+        default=5.0,
+        help="The lr_factor for Noam optimizer",
+    )
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    """Return a dict containing training parameters.
+
+    All training related parameters that are not passed from the commandline
+    are saved in the variable `params`.
+
+    Commandline options are merged into `params` after they are parsed, so
+    you can also access them via `params`.
+
+    Explanation of options saved in `params`:
+
+        - best_train_loss: Best training loss so far. It is used to select
+                           the model that has the lowest training loss. It is
+                           updated during the training.
+
+        - best_valid_loss: Best validation loss so far. It is used to select
+                           the model that has the lowest validation loss. It is
+                           updated during the training.
+
+        - best_train_epoch: It is the epoch that has the best training loss.
+
+        - best_valid_epoch: It is the epoch that has the best validation loss.
+
+        - batch_idx_train: Used to writing statistics to tensorboard. It
+                           contains number of batches trained so far across
+                           epochs.
+
+        - log_interval:  Print training loss if batch_idx % log_interval` is 0
+
+        - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+        - valid_interval:  Run validation if batch_idx % valid_interval is 0
+
+        - feature_dim: The model input dim. It has to match the one used
+                       in computing features.
+
+        - subsampling_factor:  The subsampling factor for the model.
+
+        - use_feat_batchnorm: Normalization for the input features, can be a
+                              boolean indicating whether to do batch
+                              normalization, or a float which means just scaling
+                              the input features with this float value.
+                              If given a float value, we will remove batchnorm
+                              layer in `ConvolutionModule` as well.
+
+        - attention_dim: Hidden dim for multi-head attention model.
+
+        - head: Number of heads of multi-head attention model.
+
+        - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+        - beam_size: It is used in k2.ctc_loss
+
+        - reduction: It is used in k2.ctc_loss
+
+        - use_double_scores: It is used in k2.ctc_loss
+
+        - weight_decay:  The weight_decay for the optimizer.
+
+        - warm_step: The warm_step for Noam optimizer.
+    """
+    params = AttributeDict(
+        {
+            "best_train_loss": float("inf"),
+            "best_valid_loss": float("inf"),
+            "best_train_epoch": -1,
+            "best_valid_epoch": -1,
+            "batch_idx_train": 0,
+            "log_interval": 50,
+            "reset_interval": 200,
+            "valid_interval": 3000,
+            # parameters for conformer
+            "feature_dim": 80,
+            "subsampling_factor": 4,
+            "use_feat_batchnorm": True,
+            "attention_dim": 512,
+            "nhead": 8,
+            "num_decoder_layers": 6,
+            # parameters for loss
+            "beam_size": 10,
+            "reduction": "sum",
+            "use_double_scores": True,
+            # parameters for Noam
+            "weight_decay": 1e-6,
+            "warm_step": 80000,
+            "env_info": get_env_info(),
+        }
+    )
+
+    return params
+
+
+def load_checkpoint_if_available(
+    params: AttributeDict,
+    model: nn.Module,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
+) -> None:
+    """Load checkpoint from file.
+
+    If params.start_epoch is positive, it will load the checkpoint from
+    `params.start_epoch - 1`. Otherwise, this function does nothing.
+
+    Apart from loading state dict for `model`, `optimizer` and `scheduler`,
+    it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+    and `best_valid_loss` in `params`.
+
+    Args:
+      params:
+        The return value of :func:`get_params`.
+      model:
+        The training model.
+      optimizer:
+        The optimizer that we are using.
+      scheduler:
+        The learning rate scheduler we are using.
+    Returns:
+      Return None.
+    """
+    if params.start_epoch <= 0:
+        return
+
+    filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+    saved_params = load_checkpoint(
+        filename,
+        model=model,
+        optimizer=optimizer,
+        scheduler=scheduler,
+    )
+
+    keys = [
+        "best_train_epoch",
+        "best_valid_epoch",
+        "batch_idx_train",
+        "best_train_loss",
+        "best_valid_loss",
+    ]
+    for k in keys:
+        params[k] = saved_params[k]
+
+    return saved_params
+
+
+def save_checkpoint(
+    params: AttributeDict,
+    model: nn.Module,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
+    rank: int = 0,
+) -> None:
+    """Save model, optimizer, scheduler and training stats to file.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The training model.
+    """
+    if rank != 0:
+        return
+    filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+    save_checkpoint_impl(
+        filename=filename,
+        model=model,
+        params=params,
+        optimizer=optimizer,
+        scheduler=scheduler,
+        rank=rank,
+    )
+
+    if params.best_train_epoch == params.cur_epoch:
+        best_train_filename = params.exp_dir / "best-train-loss.pt"
+        copyfile(src=filename, dst=best_train_filename)
+
+    if params.best_valid_epoch == params.cur_epoch:
+        best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+        copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+    params: AttributeDict,
+    model: nn.Module,
+    batch: dict,
+    graph_compiler: BpeCtcTrainingGraphCompiler,
+    is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+    """
+    Compute CTC loss given the model and its inputs.
+
+    Args:
+      params:
+        Parameters for training. See :func:`get_params`.
+      model:
+        The model for training. It is an instance of Conformer in our case.
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      graph_compiler:
+        It is used to build a decoding graph from a ctc topo and training
+        transcript. The training transcript is contained in the given `batch`,
+        while the ctc topo is built when this compiler is instantiated.
+      is_training:
+        True for training. False for validation. When it is True, this
+        function enables autograd during computation; when it is False, it
+        disables autograd.
+    """
+    device = graph_compiler.device
+    feature = batch["inputs"]
+    # at entry, feature is (N, T, C)
+    assert feature.ndim == 3
+    feature = feature.to(device)
+
+    supervisions = batch["supervisions"]
+    with torch.set_grad_enabled(is_training):
+        nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
+        # nnet_output is (N, T, C)
+
+    # NOTE: We need `encode_supervisions` to sort sequences with
+    # different duration in decreasing order, required by
+    # `k2.intersect_dense` called in `k2.ctc_loss`
+    supervision_segments, texts = encode_supervisions(
+        supervisions, subsampling_factor=params.subsampling_factor
+    )
+
+    token_ids = graph_compiler.texts_to_ids(texts)
+
+    decoding_graph = graph_compiler.compile(token_ids)
+
+    dense_fsa_vec = k2.DenseFsaVec(
+        nnet_output,
+        supervision_segments,
+        allow_truncate=params.subsampling_factor - 1,
+    )
+
+    ctc_loss = k2.ctc_loss(
+        decoding_graph=decoding_graph,
+        dense_fsa_vec=dense_fsa_vec,
+        output_beam=params.beam_size,
+        reduction="none",
+        use_double_scores=params.use_double_scores,
+    )
+    # filter inf from ctc_loss
+    ctc_loss = torch.sum(
+        torch.where(
+            ctc_loss != float("inf"),
+            ctc_loss,
+            torch.tensor(0, dtype=torch.float32).to(device),
+        )
+    )
+
+    if params.att_rate != 0.0:
+        with torch.set_grad_enabled(is_training):
+            mmodel = model.module if hasattr(model, "module") else model
+            # Note: We need to generate an unsorted version of token_ids
+            # `encode_supervisions()` called above sorts text, but
+            # encoder_memory and memory_mask are not sorted, so we
+            # use an unsorted version `supervisions["text"]` to regenerate
+            # the token_ids
+            #
+            # See https://github.com/k2-fsa/icefall/issues/97
+            # for more details
+            unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"])
+
+            att_loss = mmodel.decoder_forward(
+                encoder_memory,
+                memory_mask,
+                token_ids=unsorted_token_ids,
+                sos_id=graph_compiler.sos_id,
+                eos_id=graph_compiler.eos_id,
+            )
+        loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
+    else:
+        loss = ctc_loss
+        att_loss = torch.tensor([0])
+
+    assert loss.requires_grad == is_training
+
+    info = MetricsTracker()
+    info["frames"] = supervision_segments[:, 2].sum().item()
+    info["ctc_loss"] = ctc_loss.detach().cpu().item()
+    if params.att_rate != 0.0:
+        info["att_loss"] = att_loss.detach().cpu().item()
+
+    info["loss"] = loss.detach().cpu().item()
+
+    return loss, info
+
+
+def compute_validation_loss(
+    params: AttributeDict,
+    model: nn.Module,
+    graph_compiler: BpeCtcTrainingGraphCompiler,
+    valid_dl: torch.utils.data.DataLoader,
+    world_size: int = 1,
+) -> MetricsTracker:
+    """Run the validation process."""
+    model.eval()
+
+    tot_loss = MetricsTracker()
+
+    for batch_idx, batch in enumerate(valid_dl):
+        loss, loss_info = compute_loss(
+            params=params,
+            model=model,
+            batch=batch,
+            graph_compiler=graph_compiler,
+            is_training=False,
+        )
+        assert loss.requires_grad is False
+        tot_loss = tot_loss + loss_info
+
+    if world_size > 1:
+        tot_loss.reduce(loss.device)
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    if loss_value < params.best_valid_loss:
+        params.best_valid_epoch = params.cur_epoch
+        params.best_valid_loss = loss_value
+
+    return tot_loss
+
+
+def train_one_epoch(
+    params: AttributeDict,
+    model: nn.Module,
+    optimizer: torch.optim.Optimizer,
+    graph_compiler: BpeCtcTrainingGraphCompiler,
+    train_dl: torch.utils.data.DataLoader,
+    valid_dl: torch.utils.data.DataLoader,
+    tb_writer: Optional[SummaryWriter] = None,
+    world_size: int = 1,
+) -> None:
+    """Train the model for one epoch.
+
+    The training loss from the mean of all frames is saved in
+    `params.train_loss`. It runs the validation process every
+    `params.valid_interval` batches.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The model for training.
+      optimizer:
+        The optimizer we are using.
+      graph_compiler:
+        It is used to convert transcripts to FSAs.
+      train_dl:
+        Dataloader for the training dataset.
+      valid_dl:
+        Dataloader for the validation dataset.
+      tb_writer:
+        Writer to write log messages to tensorboard.
+      world_size:
+        Number of nodes in DDP training. If it is 1, DDP is disabled.
+    """
+
+    model.train()
+
+    tot_loss = MetricsTracker()
+
+    for batch_idx, batch in enumerate(train_dl):
+        if batch["inputs"].shape[0] == len(batch["supervisions"]["text"]):
+            params.batch_idx_train += 1
+            batch_size = len(batch["supervisions"]["text"])
+
+            loss, loss_info = compute_loss(
+                params=params,
+                model=model,
+                batch=batch,
+                graph_compiler=graph_compiler,
+                is_training=True,
+            )
+            # summary stats
+            tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+            # if tot_loss is None:
+            #     logging.warning("Batch mismatch. Skipping ...")
+            #     del batch
+            #     del tot_loss
+            #     continue;
+            # elif tot_loss.isinf() or tot_loss.isnan():
+            #     logging.warning("NaN or Inf loss. Skipping ...")
+            #     del batch
+            #     del tot_loss
+            #     continue;
+            # NOTE: We use reduction==sum and loss is computed over utterances
+            # in the batch and there is no normalization to it so far.
+
+            optimizer.zero_grad()
+            loss.backward()
+            clip_grad_norm_(model.parameters(), 5.0, 2.0)
+            optimizer.step()
+
+            if batch_idx % params.log_interval == 0:
+                logging.info(
+                    f"Epoch {params.cur_epoch}, "
+                    f"batch {batch_idx}, loss[{loss_info}], "
+                    f"tot_loss[{tot_loss}], batch size: {batch_size}"
+                )
+
+            if batch_idx % params.log_interval == 0:
+
+                if tb_writer is not None:
+                    loss_info.write_summary(
+                        tb_writer, "train/current_", params.batch_idx_train
+                    )
+                    tot_loss.write_summary(
+                        tb_writer, "train/tot_", params.batch_idx_train
+                    )
+
+            if batch_idx > 0 and batch_idx % params.valid_interval == 0:
+                logging.info("Computing validation loss")
+                valid_info = compute_validation_loss(
+                    params=params,
+                    model=model,
+                    graph_compiler=graph_compiler,
+                    valid_dl=valid_dl,
+                    world_size=world_size,
+                )
+                model.train()
+                logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+                if tb_writer is not None:
+                    valid_info.write_summary(
+                        tb_writer, "train/valid_", params.batch_idx_train
+                    )
+        else:
+            logging.warning(
+                f"Batch {batch_idx} mismatch in dimentions between the input and the output. Skipping ..."
+            )
+            continue
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    params.train_loss = loss_value
+    if params.train_loss < params.best_train_loss:
+        params.best_train_epoch = params.cur_epoch
+        params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+    """
+    Args:
+      rank:
+        It is a value between 0 and `world_size-1`, which is
+        passed automatically by `mp.spawn()` in :func:`main`.
+        The node with rank 0 is responsible for saving checkpoint.
+      world_size:
+        Number of GPUs for DDP training.
+      args:
+        The return value of get_parser().parse_args()
+    """
+    params = get_params()
+    params.update(vars(args))
+
+    fix_random_seed(42)
+    if world_size > 1:
+        setup_dist(rank, world_size, params.master_port)
+
+    setup_logger(f"{params.exp_dir}/log/log-train")
+    logging.info("Training started")
+    logging.info(params)
+
+    if args.tensorboard and rank == 0:
+        tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+    else:
+        tb_writer = None
+
+    lexicon = Lexicon(params.lang_dir)
+    max_token_id = max(lexicon.tokens)
+    num_classes = max_token_id + 1  # +1 for the blank
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", rank)
+
+    graph_compiler = BpeCtcTrainingGraphCompiler(
+        params.lang_dir,
+        device=device,
+        sos_token="",
+        eos_token="",
+    )
+
+    logging.info("About to create model")
+    model = Conformer(
+        num_features=params.feature_dim,
+        nhead=params.nhead,
+        d_model=params.attention_dim,
+        num_classes=num_classes,
+        subsampling_factor=params.subsampling_factor,
+        num_decoder_layers=params.num_decoder_layers,
+        vgg_frontend=False,
+        use_feat_batchnorm=params.use_feat_batchnorm,
+    )
+
+    checkpoints = load_checkpoint_if_available(params=params, model=model)
+
+    model.to(device)
+    if world_size > 1:
+        model = DDP(model, device_ids=[rank])
+
+    optimizer = Noam(
+        model.parameters(),
+        model_size=params.attention_dim,
+        factor=params.lr_factor,
+        warm_step=params.warm_step,
+        weight_decay=params.weight_decay,
+    )
+
+    if checkpoints:
+        optimizer.load_state_dict(checkpoints["optimizer"])
+
+    MGB2 = MGB2AsrDataModule(args)
+
+    train_cuts = MGB2.train_cuts()
+
+    def remove_short_and_long_utt(c: Cut):
+        # Keep only utterances with duration between 1 second and 20 seconds
+        #
+        # Caution: There is a reason to select 20.0 here. Please see
+        # ../local/display_manifest_statistics.py
+        #
+        # You should use ../local/display_manifest_statistics.py to get
+        # an utterance duration distribution for your dataset to select
+        # the threshold
+        return 0.5 <= c.duration <= 30.0
+
+    train_cuts = train_cuts.filter(remove_short_and_long_utt)
+    train_dl = MGB2.train_dataloaders(train_cuts)
+
+    valid_cuts = MGB2.dev_cuts()
+    valid_dl = MGB2.test_dataloaders(valid_cuts)
+
+    scan_pessimistic_batches_for_oom(
+        model=model,
+        train_dl=train_dl,
+        optimizer=optimizer,
+        graph_compiler=graph_compiler,
+        params=params,
+    )
+
+    for epoch in range(params.start_epoch, params.num_epochs):
+        train_dl.sampler.set_epoch(epoch)
+
+        cur_lr = optimizer._rate
+        if tb_writer is not None:
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
+            tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+        if rank == 0:
+            logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
+
+        params.cur_epoch = epoch
+
+        train_one_epoch(
+            params=params,
+            model=model,
+            optimizer=optimizer,
+            graph_compiler=graph_compiler,
+            train_dl=train_dl,
+            valid_dl=valid_dl,
+            tb_writer=tb_writer,
+            world_size=world_size,
+        )
+
+        save_checkpoint(
+            params=params,
+            model=model,
+            optimizer=optimizer,
+            rank=rank,
+        )
+
+    logging.info("Done!")
+
+    if world_size > 1:
+        torch.distributed.barrier()
+        cleanup_dist()
+
+
+def scan_pessimistic_batches_for_oom(
+    model: nn.Module,
+    train_dl: torch.utils.data.DataLoader,
+    optimizer: torch.optim.Optimizer,
+    graph_compiler: BpeCtcTrainingGraphCompiler,
+    params: AttributeDict,
+):
+    from lhotse.dataset import find_pessimistic_batches
+
+    logging.info(
+        "Sanity check -- see if any of the batches in epoch 0 would cause OOM."
+    )
+    batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+    for criterion, cuts in batches.items():
+        batch = train_dl.dataset[cuts]
+        try:
+            optimizer.zero_grad()
+            loss, _ = compute_loss(
+                params=params,
+                model=model,
+                batch=batch,
+                graph_compiler=graph_compiler,
+                is_training=True,
+            )
+            loss.backward()
+            clip_grad_norm_(model.parameters(), 5.0, 2.0)
+            optimizer.step()
+        except RuntimeError as e:
+            if "CUDA out of memory" in str(e):
+                logging.error(
+                    "Your GPU ran out of memory with the current "
+                    "max_duration setting. We recommend decreasing "
+                    "max_duration and trying again.\n"
+                    f"Failing criterion: {criterion} "
+                    f"(={crit_values[criterion]}) ..."
+                )
+            raise
+
+
+def main():
+    parser = get_parser()
+    MGB2AsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+    args.lang_dir = Path(args.lang_dir)
+
+    world_size = args.world_size
+    assert world_size >= 1
+    if world_size > 1:
+        mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+    else:
+        run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/mgb2/ASR/conformer_ctc/transformer.py b/egs/mgb2/ASR/conformer_ctc/transformer.py
new file mode 120000
index 000000000..1c3f43fcf
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/transformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/conformer_ctc/transformer.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/local/__init__.py b/egs/mgb2/ASR/local/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/mgb2/ASR/local/compile_hlg.py b/egs/mgb2/ASR/local/compile_hlg.py
new file mode 120000
index 000000000..471aa7fb4
--- /dev/null
+++ b/egs/mgb2/ASR/local/compile_hlg.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compile_hlg.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/local/compute_fbank_mgb2.py b/egs/mgb2/ASR/local/compute_fbank_mgb2.py
new file mode 100755
index 000000000..6cae69e41
--- /dev/null
+++ b/egs/mgb2/ASR/local/compute_fbank_mgb2.py
@@ -0,0 +1,101 @@
+#!/usr/bin/env python3
+# Copyright 2022 Johns Hopkins University  (Amir Hussein)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of the MGB2 dataset.
+It looks for manifests in the directory data/manifests.
+
+The generated fbank features are saved in data/fbank.
+"""
+
+import logging
+import os
+from pathlib import Path
+
+import torch
+from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
+from lhotse.recipes.utils import read_manifests_if_cached
+
+from icefall.utils import get_executor
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def compute_fbank_mgb2():
+    src_dir = Path("data/manifests")
+    output_dir = Path("data/fbank")
+    num_jobs = min(15, os.cpu_count())
+    num_mel_bins = 80
+
+    dataset_parts = (
+        "train",
+        "test",
+        "dev",
+    )
+    manifests = read_manifests_if_cached(
+        prefix="mgb2", dataset_parts=dataset_parts, output_dir=src_dir
+    )
+    assert manifests is not None
+    assert len(manifests) == len(dataset_parts), (
+        len(manifests),
+        len(dataset_parts),
+        list(manifests.keys()),
+        dataset_parts,
+    )
+    extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
+
+    with get_executor() as ex:  # Initialize the executor only once.
+        for partition, m in manifests.items():
+            if (output_dir / f"cuts_{partition}.json.gz").is_file():
+                logging.info(f"{partition} already exists - skipping.")
+                continue
+            logging.info(f"Processing {partition}")
+            cut_set = CutSet.from_manifests(
+                recordings=m["recordings"],
+                supervisions=m["supervisions"],
+            )
+            if "train" in partition:
+                cut_set = (
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+                )
+            cut_set = cut_set.compute_and_store_features(
+                extractor=extractor,
+                storage_path=f"{output_dir}/feats_{partition}",
+                # when an executor is specified, make more partitions
+                num_jobs=num_jobs if ex is None else 80,
+                executor=ex,
+                storage_type=LilcomChunkyWriter,
+            )
+            logging.info("About to split cuts into smaller chunks.")
+            cut_set = cut_set.trim_to_supervisions(
+                keep_overlapping=False, min_duration=None
+            )
+            cut_set.to_file(output_dir / f"cuts_{partition}.jsonl.gz")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+
+    compute_fbank_mgb2()
diff --git a/egs/mgb2/ASR/local/compute_fbank_musan.py b/egs/mgb2/ASR/local/compute_fbank_musan.py
new file mode 100755
index 000000000..5d0d69a13
--- /dev/null
+++ b/egs/mgb2/ASR/local/compute_fbank_musan.py
@@ -0,0 +1,108 @@
+#!/usr/bin/env python3
+# Copyright    2021  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of the musan dataset.
+It looks for manifests in the directory data/manifests.
+The generated fbank features are saved in data/fbank.
+"""
+
+import logging
+import os
+from pathlib import Path
+
+import torch
+from lhotse import (
+    ChunkedLilcomHdf5Writer,
+    CutSet,
+    Fbank,
+    FbankConfig,
+    LilcomChunkyWriter,
+    combine,
+)
+from lhotse.recipes.utils import read_manifests_if_cached
+
+from icefall.utils import get_executor
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def compute_fbank_musan():
+    src_dir = Path("data/manifests")
+    output_dir = Path("data/fbank")
+    num_jobs = min(15, os.cpu_count())
+    num_mel_bins = 80
+
+    dataset_parts = (
+        "music",
+        "speech",
+        "noise",
+    )
+    prefix = "musan"
+    suffix = "jsonl.gz"
+    manifests = read_manifests_if_cached(
+        prefix=prefix,
+        dataset_parts=dataset_parts,
+        output_dir=src_dir,
+        suffix=suffix,
+    )
+    assert manifests is not None
+    assert len(manifests) == len(dataset_parts), (
+        len(manifests),
+        len(dataset_parts),
+    )
+
+    musan_cuts_path = output_dir / "cuts_musan.jsonl.gz"
+
+    if musan_cuts_path.is_file():
+        logging.info(f"{musan_cuts_path} already exists - skipping")
+        return
+
+    logging.info("Extracting features for Musan")
+
+    extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
+
+    with get_executor() as ex:  # Initialize the executor only once.
+        # create chunks of Musan with duration 5 - 10 seconds
+        musan_cuts = (
+            CutSet.from_manifests(
+                recordings=combine(part["recordings"] for part in manifests.values())
+            )
+            .cut_into_windows(10.0)
+            .filter(lambda c: c.duration > 5)
+            .compute_and_store_features(
+                extractor=extractor,
+                storage_path=f"{output_dir}/feats_musan",
+                num_jobs=num_jobs if ex is None else 80,
+                executor=ex,
+                storage_type=LilcomChunkyWriter,
+            )
+        )
+        musan_cuts.to_file(musan_cuts_path)
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    compute_fbank_musan()
diff --git a/egs/mgb2/ASR/local/convert_transcript_words_to_tokens.py b/egs/mgb2/ASR/local/convert_transcript_words_to_tokens.py
new file mode 100755
index 000000000..a8d5117c9
--- /dev/null
+++ b/egs/mgb2/ASR/local/convert_transcript_words_to_tokens.py
@@ -0,0 +1,103 @@
+#!/usr/bin/env python3
+
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+"""
+Convert a transcript file containing words to a corpus file containing tokens
+for LM training with the help of a lexicon.
+
+If the lexicon contains phones, the resulting LM will be a phone LM; If the
+lexicon contains word pieces, the resulting LM will be a word piece LM.
+
+If a word has multiple pronunciations, the one that appears first in the lexicon
+is kept; others are removed.
+
+If the input transcript is:
+
+    hello zoo world hello
+    world zoo
+    foo zoo world hellO
+
+and if the lexicon is
+
+     SPN
+    hello h e l l o 2
+    hello h e l l o
+    world w o r l d
+    zoo z o o
+
+Then the output is
+
+    h e l l o 2 z o o w o r l d h e l l o 2
+    w o r l d z o o
+    SPN z o o w o r l d SPN
+"""
+
+import argparse
+from pathlib import Path
+from typing import Dict, List
+
+from generate_unique_lexicon import filter_multiple_pronunications
+
+from icefall.lexicon import read_lexicon
+
+
+def get_args():
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--transcript",
+        type=str,
+        help="The input transcript file."
+        "We assume that the transcript file consists of "
+        "lines. Each line consists of space separated words.",
+    )
+    parser.add_argument("--lexicon", type=str, help="The input lexicon file.")
+    parser.add_argument("--oov", type=str, default="", help="The OOV word.")
+
+    return parser.parse_args()
+
+
+def process_line(lexicon: Dict[str, List[str]], line: str, oov_token: str) -> None:
+    """
+    Args:
+      lexicon:
+        A dict containing pronunciations. Its keys are words and values
+        are pronunciations (i.e., tokens).
+      line:
+        A line of transcript consisting of space(s) separated words.
+      oov_token:
+        The pronunciation of the oov word if a word in `line` is not present
+        in the lexicon.
+    Returns:
+      Return None.
+    """
+    s = ""
+    words = line.strip().split()
+    for i, w in enumerate(words):
+        tokens = lexicon.get(w, oov_token)
+        s += " ".join(tokens)
+        s += " "
+    print(s.strip())
+
+
+def main():
+    args = get_args()
+    assert Path(args.lexicon).is_file()
+    assert Path(args.transcript).is_file()
+    assert len(args.oov) > 0
+
+    # Only the first pronunciation of a word is kept
+    lexicon = filter_multiple_pronunications(read_lexicon(args.lexicon))
+
+    lexicon = dict(lexicon)
+
+    assert args.oov in lexicon
+
+    oov_token = lexicon[args.oov]
+
+    with open(args.transcript) as f:
+        for line in f:
+            process_line(lexicon=lexicon, line=line, oov_token=oov_token)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/mgb2/ASR/local/display_manifest_statistics.py b/egs/mgb2/ASR/local/display_manifest_statistics.py
new file mode 100755
index 000000000..d3e224905
--- /dev/null
+++ b/egs/mgb2/ASR/local/display_manifest_statistics.py
@@ -0,0 +1,97 @@
+#!/usr/bin/env python3
+# Copyright    2021  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This file displays duration statistics of utterances in a manifest.
+You can use the displayed value to choose minimum/maximum duration
+to remove short and long utterances during the training.
+
+See the function `remove_short_and_long_utt()` in transducer/train.py
+for usage.
+"""
+
+
+from lhotse import load_manifest
+
+
+def main():
+    # path = "./data/fbank/cuts_train.jsonl.gz"
+    path = "./data/fbank/cuts_dev.jsonl.gz"
+    # path = "./data/fbank/cuts_test.jsonl.gz"
+
+    cuts = load_manifest(path)
+    cuts.describe()
+
+
+if __name__ == "__main__":
+    main()
+
+"""
+# train
+
+Cuts count: 1125309
+Total duration (hours): 3403.9
+Speech duration (hours): 3403.9 (100.0%)
+***
+Duration statistics (seconds):
+mean    10.9
+std     10.1
+min     0.2
+25%     5.2
+50%     7.8
+75%     12.7
+99%     52.0
+99.5%   65.1
+99.9%   99.5
+max     228.9
+
+
+# test
+Cuts count: 5365
+Total duration (hours): 9.6
+Speech duration (hours): 9.6 (100.0%)
+***
+Duration statistics (seconds):
+mean    6.4
+std     1.5
+min     1.6
+25%     5.3
+50%     6.5
+75%     7.6
+99%     9.5
+99.5%   9.7
+99.9%   10.3
+max     12.4
+
+# dev
+Cuts count: 5002
+Total duration (hours): 8.5
+Speech duration (hours): 8.5 (100.0%)
+***
+Duration statistics (seconds):
+mean    6.1
+std     1.7
+min     1.5
+25%     4.8
+50%     6.2
+75%     7.4
+99%     9.5
+99.5%   9.7
+99.9%   10.1
+max     20.3
+
+"""
diff --git a/egs/mgb2/ASR/local/generate_unique_lexicon.py b/egs/mgb2/ASR/local/generate_unique_lexicon.py
new file mode 120000
index 000000000..c0aea1403
--- /dev/null
+++ b/egs/mgb2/ASR/local/generate_unique_lexicon.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/generate_unique_lexicon.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/local/prep_mgb2_lexicon.sh b/egs/mgb2/ASR/local/prep_mgb2_lexicon.sh
new file mode 100755
index 000000000..3b673db6f
--- /dev/null
+++ b/egs/mgb2/ASR/local/prep_mgb2_lexicon.sh
@@ -0,0 +1,30 @@
+#!/usr/bin/env bash
+
+# Copyright 2022 QCRI (author: Amir Hussein)
+# Apache 2.0
+# This script prepares the graphemic lexicon.
+
+dir=data/local/dict
+lexicon_url1="https://arabicspeech.org/arabicspeech-portal-resources/lexicon/ar-ar_grapheme_lexicon_20160209.bz2";
+lexicon_url2="https://arabicspeech.org/arabicspeech-portal-resources/lexicon/ar-ar_phoneme_lexicon_20140317.bz2";
+stage=0
+lang_dir=download/lm
+mkdir -p $lang_dir
+
+if [ $stage -le 0 ]; then
+  echo "$0: Downloading text for lexicon... $(date)."
+  wget --no-check-certificate -P $lang_dir $lexicon_url1
+  wget --no-check-certificate -P $lang_dir $lexicon_url2
+  bzcat $lang_dir/ar-ar_grapheme_lexicon_20160209.bz2 | sed '1,3d' | awk '{print $1}'  >  $lang_dir/grapheme_lexicon
+  bzcat $lang_dir/ar-ar_phoneme_lexicon_20140317.bz2 | sed '1,3d' | awk '{print $1}' >>  $lang_dir/phoneme_lexicon
+  cat download/lm/train/text | cut -d ' ' -f 2- | tr -s " " "\n" | sort -u >> $lang_dir/uniq_words
+fi
+
+
+if [ $stage -le 0 ]; then
+  echo "$0: processing lexicon text and creating lexicon... $(date)."
+  # remove vowels and  rare alef wasla
+  cat $lang_dir/uniq_words |  sed -e 's:[FNKaui\~o\`]::g' -e 's:{:}:g' | sed -r '/^\s*$/d' | sort -u > $lang_dir/grapheme_lexicon.txt
+fi
+
+echo "$0: Lexicon preparation succeeded"
diff --git a/egs/mgb2/ASR/local/prepare_lang.py b/egs/mgb2/ASR/local/prepare_lang.py
new file mode 120000
index 000000000..747f2ab39
--- /dev/null
+++ b/egs/mgb2/ASR/local/prepare_lang.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/prepare_lang.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/local/prepare_lang_bpe.py b/egs/mgb2/ASR/local/prepare_lang_bpe.py
new file mode 120000
index 000000000..36b40e7fc
--- /dev/null
+++ b/egs/mgb2/ASR/local/prepare_lang_bpe.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/prepare_lang_bpe.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/local/prepare_mgb2_lexicon.py b/egs/mgb2/ASR/local/prepare_mgb2_lexicon.py
new file mode 100755
index 000000000..99e1fa34d
--- /dev/null
+++ b/egs/mgb2/ASR/local/prepare_mgb2_lexicon.py
@@ -0,0 +1,37 @@
+#!/usr/bin/env python3
+
+# Copyright      2022  Amir Hussein
+# Apache 2.0
+
+# This script prepares givel a column of words lexicon.
+
+import argparse
+
+
+def get_args():
+    parser = argparse.ArgumentParser(
+        description="""Creates the list of characters and words in lexicon"""
+    )
+    parser.add_argument("input", type=str, help="""Input list of words file""")
+    parser.add_argument("output", type=str, help="""output graphemic lexicon""")
+    args = parser.parse_args()
+    return args
+
+
+def main():
+    lex = {}
+    args = get_args()
+    with open(args.input, "r", encoding="utf-8") as f:
+        for line in f:
+            line = line.strip()
+            characters = list(line)
+            characters = " ".join(["V" if char == "*" else char for char in characters])
+            lex[line] = characters
+
+    with open(args.output, "w", encoding="utf-8") as fp:
+        for key in sorted(lex):
+            fp.write(key + "  " + lex[key] + "\n")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/mgb2/ASR/local/test_prepare_lang.py b/egs/mgb2/ASR/local/test_prepare_lang.py
new file mode 120000
index 000000000..f0f864998
--- /dev/null
+++ b/egs/mgb2/ASR/local/test_prepare_lang.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/test_prepare_lang.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/prepare.sh b/egs/mgb2/ASR/prepare.sh
new file mode 100755
index 000000000..899d15d97
--- /dev/null
+++ b/egs/mgb2/ASR/prepare.sh
@@ -0,0 +1,234 @@
+#!/usr/bin/env bash
+# Copyright 2022 Johns Hopkins University  (Amir Hussein)
+# Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
+
+set -eou pipefail
+nj=30
+stage=7
+stop_stage=1000
+
+# We assume dl_dir (download dir) contains the following
+# directories and files. 
+#
+#  - $dl_dir/mgb2
+#      
+#      You can download the data from 
+#
+#
+#  - $dl_dir/musan
+#      This directory contains the following directories downloaded from
+#       http://www.openslr.org/17/
+#
+#     - music
+#     - noise
+#     - speech
+#
+# Note: MGB2 is not available for direct 
+# download, however you can fill out the form and  
+# download it from https://arabicspeech.org/mgb2 
+
+dl_dir=$PWD/download
+
+. shared/parse_options.sh || exit 1
+
+# vocab size for sentence piece models.
+# It will generate data/lang_bpe_xxx,
+# data/lang_bpe_yyy if the array contains xxx, yyy
+vocab_sizes=(
+  5000
+)
+
+# All files generated by this script are saved in "data".
+# You can safely remove "data" and rerun this script to regenerate it.
+mkdir -p data
+
+log() {
+  # This function is from espnet
+  local fname=${BASH_SOURCE[1]##*/}
+  echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "dl_dir: $dl_dir"
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+  log "Stage 0: Download data"
+
+  # If you have pre-downloaded it to /path/to/MGB2,
+  # you can create a symlink
+  #
+  #   ln -sfv /path/to/mgb2 $dl_dir/MGB2
+
+  # If you have pre-downloaded it to /path/to/musan,
+  # you can create a symlink
+  #
+  #   ln -sfv /path/to/musan $dl_dir/
+  #
+  if [ ! -d $dl_dir/musan ]; then
+    lhotse download musan $dl_dir
+  fi
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+  log "Stage 1: Prepare mgb2 manifest"
+  # We assume that you have downloaded the mgb2 corpus
+  # to $dl_dir/mgb2
+  mkdir -p data/manifests
+
+  lhotse prepare mgb2 $dl_dir/mgb2 data/manifests
+  
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+  log "Stage 2: Prepare musan manifest"
+  # We assume that you have downloaded the musan corpus
+  # to data/musan
+  mkdir -p data/manifests
+  lhotse prepare musan $dl_dir/musan data/manifests
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+  log "Stage 3: Compute fbank for mgb2"
+  mkdir -p data/fbank
+  ./local/compute_fbank_mgb2.py
+   # shufling the data
+  gunzip -c data/fbank/cuts_train.jsonl.gz | shuf | gzip -c > data/fbank/cuts_train_shuf.jsonl.gz
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+  log "Stage 4: Compute fbank for musan"
+  mkdir -p data/fbank
+  ./local/compute_fbank_musan.py
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+  log "Stage 5: Prepare phone based lang"
+  if [[ ! -e download/lm/train/text ]]; then 
+  # export train text file to build grapheme lexicon 
+  lhotse kaldi export \
+    data/manifests/mgb2_recordings_train.jsonl.gz \
+    data/manifests/mgb2_supervisions_train.jsonl.gz  \
+    download/lm/train
+  fi
+
+  lang_dir=data/lang_phone
+  mkdir -p $lang_dir
+  ./local/prep_mgb2_lexicon.sh 
+  python local/prepare_mgb2_lexicon.py  $dl_dir/lm/grapheme_lexicon.txt  $dl_dir/lm/lexicon.txt
+  (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) |
+    cat - $dl_dir/lm/lexicon.txt |
+    sort | uniq > $lang_dir/lexicon.txt
+
+  if [ ! -f $lang_dir/L_disambig.pt ]; then
+    ./local/prepare_lang.py --lang-dir $lang_dir
+  fi
+fi
+
+
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+  log "Stage 6: Prepare BPE based lang"
+
+  for vocab_size in ${vocab_sizes[@]}; do
+    lang_dir=data/lang_bpe_${vocab_size}
+    mkdir -p $lang_dir
+    # We reuse words.txt from phone based lexicon
+    # so that the two can share G.pt later.
+    cp data/lang_phone/words.txt $lang_dir
+
+    if [ ! -f $lang_dir/transcript_words.txt ]; then
+      log "Generate data for BPE training"
+      files=$(
+        find "$dl_dir/lm/train" -name "text"
+      )
+      for f in ${files[@]}; do
+        cat $f | cut -d " " -f 2- | sed -r '/^\s*$/d'
+      done > $lang_dir/transcript_words.txt
+    fi
+
+    ./local/train_bpe_model.py \
+      --lang-dir $lang_dir \
+      --vocab-size $vocab_size \
+      --transcript $lang_dir/transcript_words.txt
+
+    if [ ! -f $lang_dir/L_disambig.pt ]; then
+      ./local/prepare_lang_bpe.py --lang-dir $lang_dir
+    fi
+  done
+fi
+
+if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
+  log "Stage 7: Prepare bigram P"
+
+  for vocab_size in ${vocab_sizes[@]}; do
+    lang_dir=data/lang_bpe_${vocab_size}
+
+    if [ ! -f $lang_dir/transcript_tokens.txt ]; then
+      ./local/convert_transcript_words_to_tokens.py \
+        --lexicon $lang_dir/lexicon.txt \
+        --transcript $lang_dir/transcript_words.txt \
+        --oov "" \
+        > $lang_dir/transcript_tokens.txt
+    fi
+
+    if [ ! -f $lang_dir/P.arpa ]; then
+      ./shared/make_kn_lm.py \
+        -ngram-order 2 \
+        -text $lang_dir/transcript_tokens.txt \
+        -lm $lang_dir/P.arpa
+    fi
+
+    if [ ! -f $lang_dir/P.fst.txt ]; then
+      python3 -m kaldilm \
+        --read-symbol-table="$lang_dir/tokens.txt" \
+        --disambig-symbol='#0' \
+        --max-order=2 \
+        $lang_dir/P.arpa > $lang_dir/P.fst.txt
+    fi
+  done
+fi
+
+if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
+  log "Stage 8: Prepare G"
+  # We assume you have install kaldilm, if not, please install
+  # it using: pip install kaldilm
+  for vocab_size in ${vocab_sizes[@]}; do
+    lang_dir=data/lang_bpe_${vocab_size}
+    mkdir -p data/lm
+    if [ ! -f data/lm/G_3_gram.fst.txt ]; then
+      # It is used in building HLG
+      ./shared/make_kn_lm.py \
+          -ngram-order 3 \
+          -text $lang_dir/transcript_words.txt \
+          -lm $lang_dir/G.arpa
+
+      python3 -m kaldilm \
+        --read-symbol-table="data/lang_phone/words.txt" \
+        --disambig-symbol='#0' \
+        --max-order=3 \
+        $lang_dir/G.arpa > data/lm/G_3_gram.fst.txt
+    fi
+
+    if [ ! -f data/lm/G_4_gram.fst.txt ]; then
+      # It is used for LM rescoring
+      ./shared/make_kn_lm.py \
+          -ngram-order 4 \
+          -text $lang_dir/transcript_words.txt \
+          -lm $lang_dir/4-gram.arpa
+
+      python3 -m kaldilm \
+        --read-symbol-table="data/lang_phone/words.txt" \
+        --disambig-symbol='#0' \
+        --max-order=4 \
+        $lang_dir/4-gram.arpa > data/lm/G_4_gram.fst.txt
+    fi
+  done
+fi
+
+if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
+  log "Stage 9: Compile HLG"
+  ./local/compile_hlg.py --lang-dir data/lang_phone
+
+  for vocab_size in ${vocab_sizes[@]}; do
+    lang_dir=data/lang_bpe_${vocab_size}
+    ./local/compile_hlg.py --lang-dir $lang_dir
+  done
+fi
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/__init__.py b/egs/mgb2/ASR/pruned_transducer_stateless5/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/mgb2/ASR/pruned_transducer_stateless5/asr_datamodule.py
new file mode 120000
index 000000000..a73848de9
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/asr_datamodule.py
@@ -0,0 +1 @@
+../conformer_ctc/asr_datamodule.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/beam_search.py b/egs/mgb2/ASR/pruned_transducer_stateless5/beam_search.py
new file mode 120000
index 000000000..02d01b343
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/beam_search.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/conformer.py b/egs/mgb2/ASR/pruned_transducer_stateless5/conformer.py
new file mode 120000
index 000000000..c7c1a4b6e
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/conformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/conformer.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/decode.py b/egs/mgb2/ASR/pruned_transducer_stateless5/decode.py
new file mode 100755
index 000000000..1463f8f67
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/decode.py
@@ -0,0 +1,625 @@
+#!/usr/bin/env python3
+# Copyright    2022  Johns Hopkins        (authors: Amir Hussein)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./pruned_transducer_stateless5/decode.py \
+    --epoch 18 \
+    --avg 5 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 200 \
+    --decoding-method greedy_search
+
+(2) beam search (not recommended)
+./pruned_transducer_stateless5/decode.py \
+    --epoch 18 \
+    --avg 5 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 200 \
+    --decoding-method beam_search \
+    --beam-size 10
+
+(3) modified beam search
+./pruned_transducer_stateless5/decode.py \
+    --epoch 18 \
+    --avg 5 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 600 \
+    --decoding-method modified_beam_search \
+    --beam-size 10
+
+(4) fast beam search
+./pruned_transducer_stateless5/decode.py \
+    --epoch 18 \
+    --avg 5 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 200 \
+    --decoding-method fast_beam_search \
+    --beam-size 10 \
+    --max-contexts 4 \
+    --max-states 8
+"""
+
+
+import argparse
+import logging
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import MGB2AsrDataModule
+from beam_search import (
+    beam_search,
+    fast_beam_search_one_best,
+    greedy_search,
+    greedy_search_batch,
+    modified_beam_search,
+)
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.utils import (
+    AttributeDict,
+    setup_logger,
+    store_transcripts,
+    str2bool,
+    write_error_stats,
+)
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=30,
+        help="""It specifies the checkpoint to use for decoding.
+        Note: Epoch counts from 1.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=15,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=False,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless5/exp",
+        help="The experiment dir",
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_2000/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--decoding-method",
+        type=str,
+        default="greedy_search",
+        help="""Possible values are:
+          - greedy_search
+          - beam_search
+          - modified_beam_search
+          - fast_beam_search
+        """,
+    )
+
+    parser.add_argument(
+        "--beam-size",
+        type=int,
+        default=4,
+        help="""An integer indicating how many candidates we will keep for each
+        frame. Used only when --decoding-method is beam_search or
+        modified_beam_search.""",
+    )
+
+    parser.add_argument(
+        "--beam",
+        type=float,
+        default=4,
+        help="""A floating point value to calculate the cutoff score during beam
+        search (i.e., `cutoff = max-score - beam`), which is the same as the
+        `beam` in Kaldi.
+        Used only when --decoding-method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--max-contexts",
+        type=int,
+        default=4,
+        help="""Used only when --decoding-method is
+        fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--max-states",
+        type=int,
+        default=8,
+        help="""Used only when --decoding-method is
+        fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+    )
+    parser.add_argument(
+        "--max-sym-per-frame",
+        type=int,
+        default=1,
+        help="""Maximum number of symbols per frame.
+        Used only when --decoding_method is greedy_search""",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def decode_one_batch(
+    params: AttributeDict,
+    model: nn.Module,
+    sp: spm.SentencePieceProcessor,
+    batch: dict,
+    decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+    """Decode one batch and return the result in a dict. The dict has the
+    following format:
+
+        - key: It indicates the setting used for decoding. For example,
+               if greedy_search is used, it would be "greedy_search"
+               If beam search with a beam size of 7 is used, it would be
+               "beam_7"
+        - value: It contains the decoding result. `len(value)` equals to
+                 batch size. `value[i]` is the decoding result for the i-th
+                 utterance in the given batch.
+    Args:
+      params:
+        It's the return value of :func:`get_params`.
+      model:
+        The neural model.
+      sp:
+        The BPE model.
+      batch:
+        It is the return value from iterating
+        `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+        for the format of the `batch`.
+      decoding_graph:
+        The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+        only when --decoding_method is fast_beam_search.
+    Returns:
+      Return the decoding result. See above description for the format of
+      the returned dict.
+    """
+    device = next(model.parameters()).device
+    feature = batch["inputs"]
+    assert feature.ndim == 3
+
+    feature = feature.to(device)
+    # at entry, feature is (N, T, C)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    hyps = []
+
+    if params.decoding_method == "fast_beam_search":
+        hyp_tokens = fast_beam_search_one_best(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+
+        hyp_tokens = greedy_search_batch(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+        )
+
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "modified_beam_search":
+        hyp_tokens = modified_beam_search(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam_size,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    else:
+        batch_size = encoder_out.size(0)
+
+        for i in range(batch_size):
+            # fmt: off
+            encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+            # fmt: on
+            if params.decoding_method == "greedy_search":
+                hyp = greedy_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    max_sym_per_frame=params.max_sym_per_frame,
+                )
+            elif params.decoding_method == "beam_search":
+                hyp = beam_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    beam=params.beam_size,
+                )
+            else:
+                raise ValueError(
+                    f"Unsupported decoding method: {params.decoding_method}"
+                )
+            hyps.append(sp.decode(hyp).split())
+
+    if params.decoding_method == "greedy_search":
+        return {"greedy_search": hyps}
+    elif params.decoding_method == "fast_beam_search":
+        return {
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
+        }
+    else:
+        return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+    dl: torch.utils.data.DataLoader,
+    params: AttributeDict,
+    model: nn.Module,
+    sp: spm.SentencePieceProcessor,
+    decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[List[str], List[str]]]]:
+    """Decode dataset.
+
+    Args:
+      dl:
+        PyTorch's dataloader containing the dataset to decode.
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The neural model.
+      sp:
+        The BPE model.
+      decoding_graph:
+        The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+        only when --decoding_method is fast_beam_search.
+    Returns:
+      Return a dict, whose key may be "greedy_search" if greedy search
+      is used, or it may be "beam_7" if beam size of 7 is used.
+      Its value is a list of tuples. Each tuple contains two elements:
+      The first is the reference transcript, and the second is the
+      predicted result.
+    """
+    num_cuts = 0
+
+    try:
+        num_batches = len(dl)
+    except TypeError:
+        num_batches = "?"
+
+    if params.decoding_method == "greedy_search":
+        log_interval = 50
+    else:
+        log_interval = 20
+
+    results = defaultdict(list)
+    for batch_idx, batch in enumerate(dl):
+        texts = batch["supervisions"]["text"]
+
+        hyps_dict = decode_one_batch(
+            params=params,
+            model=model,
+            sp=sp,
+            decoding_graph=decoding_graph,
+            batch=batch,
+        )
+
+        for name, hyps in hyps_dict.items():
+            this_batch = []
+            assert len(hyps) == len(texts)
+            for hyp_words, ref_text in zip(hyps, texts):
+
+                ref_words = ref_text.split()
+                this_batch.append((ref_words, hyp_words))
+
+            results[name].extend(this_batch)
+
+        num_cuts += len(texts)
+
+        if batch_idx % log_interval == 0:
+            batch_str = f"{batch_idx}/{num_batches}"
+
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+    return results
+
+
+def save_results(
+    params: AttributeDict,
+    test_set_name: str,
+    results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
+):
+    test_set_wers = dict()
+    for key, results in results_dict.items():
+        recog_path = (
+            params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+        )
+        store_transcripts(filename=recog_path, texts=results)
+        logging.info(f"The transcripts are stored in {recog_path}")
+
+        # The following prints out WERs, per-word error statistics and aligned
+        # ref/hyp pairs.
+        errs_filename = (
+            params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+        )
+        with open(errs_filename, "w") as f:
+            wer = write_error_stats(
+                f, f"{test_set_name}-{key}", results, enable_log=True
+            )
+            test_set_wers[key] = wer
+
+        logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+    test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+    errs_info = (
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+    )
+    with open(errs_info, "w") as f:
+        print("settings\tWER", file=f)
+        for key, val in test_set_wers:
+            print("{}\t{}".format(key, val), file=f)
+
+    s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+    note = "\tbest for {}".format(test_set_name)
+    for key, val in test_set_wers:
+        s += "{}\t{}{}\n".format(key, val, note)
+        note = ""
+    logging.info(s)
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    MGB2AsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    params = get_params()
+    params.update(vars(args))
+
+    assert params.decoding_method in (
+        "greedy_search",
+        "beam_search",
+        "fast_beam_search",
+        "modified_beam_search",
+    )
+    params.res_dir = params.exp_dir / params.decoding_method
+
+    if params.iter > 0:
+        params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+    else:
+        params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+    if "fast_beam_search" in params.decoding_method:
+        params.suffix += f"-beam-{params.beam}"
+        params.suffix += f"-max-contexts-{params.max_contexts}"
+        params.suffix += f"-max-states-{params.max_states}"
+    elif "beam_search" in params.decoding_method:
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+    else:
+        params.suffix += f"-context-{params.context_size}"
+        params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+    if params.use_averaged_model:
+        params.suffix += "-use-averaged-model"
+
+    setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+    logging.info("Decoding started")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"Device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  and  are defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.unk_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to(device)
+    model.eval()
+
+    if params.decoding_method == "fast_beam_search":
+        decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+    else:
+        decoding_graph = None
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    MGB2 = MGB2AsrDataModule(args)
+
+    test_cuts = MGB2.test_cuts()
+    dev_cuts = MGB2.dev_cuts()
+
+    test_dl = MGB2.test_dataloaders(test_cuts)
+    dev_dl = MGB2.test_dataloaders(dev_cuts)
+
+    test_sets = ["test", "dev"]
+    test_all_dl = [test_dl, dev_dl]
+
+    for test_set, test_dl in zip(test_sets, test_all_dl):
+        results_dict = decode_dataset(
+            dl=test_dl,
+            params=params,
+            model=model,
+            sp=sp,
+            decoding_graph=decoding_graph,
+        )
+
+        save_results(
+            params=params,
+            test_set_name=test_set,
+            results_dict=results_dict,
+        )
+
+    logging.info("Done!")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/decoder.py b/egs/mgb2/ASR/pruned_transducer_stateless5/decoder.py
new file mode 120000
index 000000000..6775ee67e
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/decoder.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/decoder.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/encoder_interface.py b/egs/mgb2/ASR/pruned_transducer_stateless5/encoder_interface.py
new file mode 120000
index 000000000..972e44ca4
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/encoder_interface.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/encoder_interface.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/export.py b/egs/mgb2/ASR/pruned_transducer_stateless5/export.py
new file mode 100755
index 000000000..7a5d7f680
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/export.py
@@ -0,0 +1,272 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script converts several saved checkpoints
+# to a single one using model averaging.
+"""
+Usage:
+./pruned_transducer_stateless5/export.py \
+  --exp-dir ./pruned_transducer_stateless5/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 20 \
+  --avg 10
+
+It will generate a file exp_dir/pretrained.pt
+
+To use the generated file with `pruned_transducer_stateless5/decode.py`,
+you can do:
+
+    cd /path/to/exp_dir
+    ln -s pretrained.pt epoch-9999.pt
+
+    cd /path/to/egs/librispeech/ASR
+    ./pruned_transducer_stateless5/decode.py \
+        --exp-dir ./pruned_transducer_stateless5/exp \
+        --epoch 9999 \
+        --avg 1 \
+        --max-duration 600 \
+        --decoding-method greedy_search \
+        --bpe-model data/lang_bpe_500/bpe.model
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import sentencepiece as spm
+import torch
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.utils import str2bool
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=28,
+        help="""It specifies the checkpoint to use for averaging.
+        Note: Epoch counts from 1.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=15,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=False,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless5/exp",
+        help="""It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_500/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--jit",
+        type=str2bool,
+        default=False,
+        help="""True to save a model after applying torch.jit.script.
+        """,
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def main():
+    args = get_parser().parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    assert args.jit is False, "Support torchscript will be added later"
+
+    params = get_params()
+    params.update(vars(args))
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.eval()
+
+    model.to("cpu")
+    model.eval()
+
+    if params.jit:
+        logging.info("Using torch.jit.script")
+        model = torch.jit.script(model)
+        filename = params.exp_dir / "cpu_jit.pt"
+        model.save(str(filename))
+        logging.info(f"Saved to {filename}")
+    else:
+        logging.info("Not using torch.jit.script")
+        # Save it using a format so that it can be loaded
+        # by :func:`load_checkpoint`
+        filename = params.exp_dir / "pretrained.pt"
+        torch.save({"model": model.state_dict()}, str(filename))
+        logging.info(f"Saved to {filename}")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/joiner.py b/egs/mgb2/ASR/pruned_transducer_stateless5/joiner.py
new file mode 120000
index 000000000..f5279e151
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/joiner.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/joiner.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/model.py b/egs/mgb2/ASR/pruned_transducer_stateless5/model.py
new file mode 120000
index 000000000..7b417fd89
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/model.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/optim.py b/egs/mgb2/ASR/pruned_transducer_stateless5/optim.py
new file mode 120000
index 000000000..210374f22
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/optim.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/pretrained.py b/egs/mgb2/ASR/pruned_transducer_stateless5/pretrained.py
new file mode 100755
index 000000000..77ba0873b
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/pretrained.py
@@ -0,0 +1,344 @@
+#!/usr/bin/env python3
+# Copyright      2021  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+(1) greedy search
+./pruned_transducer_stateless5/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method greedy_search \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+(2) beam search
+./pruned_transducer_stateless5/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method beam_search \
+    --beam-size 4 \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+(3) modified beam search
+./pruned_transducer_stateless5/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method modified_beam_search \
+    --beam-size 4 \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+(4) fast beam search
+./pruned_transducer_stateless5/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method fast_beam_search \
+    --beam-size 4 \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+You can also use `./pruned_transducer_stateless5/exp/epoch-xx.pt`.
+
+Note: ./pruned_transducer_stateless5/exp/pretrained.pt is generated by
+./pruned_transducer_stateless5/export.py
+"""
+
+
+import argparse
+import logging
+import math
+from typing import List
+
+import k2
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from beam_search import (
+    beam_search,
+    fast_beam_search_one_best,
+    greedy_search,
+    greedy_search_batch,
+    modified_beam_search,
+)
+from torch.nn.utils.rnn import pad_sequence
+from train import add_model_arguments, get_params, get_transducer_model
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--checkpoint",
+        type=str,
+        required=True,
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        help="""Path to bpe.model.""",
+    )
+
+    parser.add_argument(
+        "--method",
+        type=str,
+        default="greedy_search",
+        help="""Possible values are:
+          - greedy_search
+          - beam_search
+          - modified_beam_search
+          - fast_beam_search
+        """,
+    )
+
+    parser.add_argument(
+        "sound_files",
+        type=str,
+        nargs="+",
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
+    )
+
+    parser.add_argument(
+        "--sample-rate",
+        type=int,
+        default=16000,
+        help="The sample rate of the input sound file",
+    )
+
+    parser.add_argument(
+        "--beam-size",
+        type=int,
+        default=4,
+        help="""An integer indicating how many candidates we will keep for each
+        frame. Used only when --method is beam_search or
+        modified_beam_search.""",
+    )
+
+    parser.add_argument(
+        "--beam",
+        type=float,
+        default=4,
+        help="""A floating point value to calculate the cutoff score during beam
+        search (i.e., `cutoff = max-score - beam`), which is the same as the
+        `beam` in Kaldi.
+        Used only when --method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--max-contexts",
+        type=int,
+        default=4,
+        help="""Used only when --method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--max-states",
+        type=int,
+        default=8,
+        help="""Used only when --method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+    )
+    parser.add_argument(
+        "--max-sym-per-frame",
+        type=int,
+        default=1,
+        help="""Maximum number of symbols per frame. Used only when
+        --method is greedy_search.
+        """,
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def read_sound_files(
+    filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+    """Read a list of sound files into a list 1-D float32 torch tensors.
+    Args:
+      filenames:
+        A list of sound filenames.
+      expected_sample_rate:
+        The expected sample rate of the sound files.
+    Returns:
+      Return a list of 1-D float32 torch tensors.
+    """
+    ans = []
+    for f in filenames:
+        wave, sample_rate = torchaudio.load(f)
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
+        )
+        # We use only the first channel
+        ans.append(wave[0])
+    return ans
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+
+    params = get_params()
+
+    params.update(vars(args))
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.unk_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(f"{params}")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    logging.info("Creating model")
+    model = get_transducer_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    checkpoint = torch.load(args.checkpoint, map_location="cpu")
+    model.load_state_dict(checkpoint["model"], strict=False)
+    model.to(device)
+    model.eval()
+    model.device = device
+
+    logging.info("Constructing Fbank computer")
+    opts = kaldifeat.FbankOptions()
+    opts.device = device
+    opts.frame_opts.dither = 0
+    opts.frame_opts.snip_edges = False
+    opts.frame_opts.samp_freq = params.sample_rate
+    opts.mel_opts.num_bins = params.feature_dim
+
+    fbank = kaldifeat.Fbank(opts)
+
+    logging.info(f"Reading sound files: {params.sound_files}")
+    waves = read_sound_files(
+        filenames=params.sound_files, expected_sample_rate=params.sample_rate
+    )
+    waves = [w.to(device) for w in waves]
+
+    logging.info("Decoding started")
+    features = fbank(waves)
+    feature_lengths = [f.size(0) for f in features]
+
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+
+    feature_lengths = torch.tensor(feature_lengths, device=device)
+
+    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
+
+    num_waves = encoder_out.size(0)
+    hyps = []
+    msg = f"Using {params.method}"
+    if params.method == "beam_search":
+        msg += f" with beam size {params.beam_size}"
+    logging.info(msg)
+
+    if params.method == "fast_beam_search":
+        decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+        hyp_tokens = fast_beam_search_one_best(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.method == "modified_beam_search":
+        hyp_tokens = modified_beam_search(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam_size,
+        )
+
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
+        hyp_tokens = greedy_search_batch(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    else:
+        for i in range(num_waves):
+            # fmt: off
+            encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+            # fmt: on
+            if params.method == "greedy_search":
+                hyp = greedy_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    max_sym_per_frame=params.max_sym_per_frame,
+                )
+            elif params.method == "beam_search":
+                hyp = beam_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    beam=params.beam_size,
+                )
+            else:
+                raise ValueError(f"Unsupported method: {params.method}")
+
+            hyps.append(sp.decode(hyp).split())
+
+    s = "\n"
+    for filename, hyp in zip(params.sound_files, hyps):
+        words = " ".join(hyp)
+        s += f"{filename}:\n{words}\n\n"
+    logging.info(s)
+
+    logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/scaling.py b/egs/mgb2/ASR/pruned_transducer_stateless5/scaling.py
new file mode 120000
index 000000000..ff7bfeda9
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/scaling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/scaling.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/test_model.py b/egs/mgb2/ASR/pruned_transducer_stateless5/test_model.py
new file mode 120000
index 000000000..b71d7bb81
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/test_model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/test_model.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/train.py b/egs/mgb2/ASR/pruned_transducer_stateless5/train.py
new file mode 100755
index 000000000..e1b623353
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/train.py
@@ -0,0 +1,1176 @@
+#!/usr/bin/env python3
+# Copyright    2022  Johns Hopkins        (authors: Amir Hussein)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless5/train.py \
+  --world-size 2 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --exp-dir pruned_transducer_stateless5/exp \
+  --max-duration 200 \
+  --num-buckets 50
+
+# For mix precision training:
+
+./pruned_transducer_stateless5/train.py \
+  --world-size 2 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --use-fp16 1 \
+  --exp-dir pruned_transducer_stateless5/exp \
+  --max-duration 200	\
+  --num-buckets 50
+
+"""
+
+# xxx
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import nvidia_smi
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import MGB2AsrDataModule
+from conformer import Conformer
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import Transducer
+from optim import Eden, Eve
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.nn.utils import clip_grad_norm_
+from torch.utils.tensorboard import SummaryWriter
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+    save_checkpoint_with_global_batch_idx,
+    update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+    parser.add_argument(
+        "--num-encoder-layers",
+        type=int,
+        default=12,
+        help="Number of conformer encoder layers..",
+    )
+
+    parser.add_argument(
+        "--dim-feedforward",
+        type=int,
+        default=2048,
+        help="Feedforward dimension of the conformer encoder layer.",
+    )
+
+    parser.add_argument(
+        "--nhead",
+        type=int,
+        default=8,
+        help="Number of attention heads in the conformer encoder layer.",
+    )
+
+    parser.add_argument(
+        "--encoder-dim",
+        type=int,
+        default=512,
+        help="Attention dimension in the conformer encoder layer.",
+    )
+
+    parser.add_argument(
+        "--decoder-dim",
+        type=int,
+        default=512,
+        help="Embedding dimension in the decoder model.",
+    )
+
+    parser.add_argument(
+        "--joiner-dim",
+        type=int,
+        default=512,
+        help="""Dimension used in the joiner model.
+        Outputs from the encoder and decoder model are projected
+        to this dimension before adding.
+        """,
+    )
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--world-size",
+        type=int,
+        default=1,
+        help="Number of GPUs for DDP training.",
+    )
+
+    parser.add_argument(
+        "--master-port",
+        type=int,
+        default=12354,
+        help="Master port to use for DDP training.",
+    )
+
+    parser.add_argument(
+        "--tensorboard",
+        type=str2bool,
+        default=True,
+        help="Should various information be logged in tensorboard.",
+    )
+
+    parser.add_argument(
+        "--num-epochs",
+        type=int,
+        default=30,
+        help="Number of epochs to train.",
+    )
+
+    parser.add_argument(
+        "--start-epoch",
+        type=int,
+        default=1,
+        help="""Resume training from this epoch. It should be positive.
+        If larger than 1, it will load checkpoint from
+        exp-dir/epoch-{start_epoch-1}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--start-batch",
+        type=int,
+        default=0,
+        help="""If positive, --start-epoch is ignored and
+        it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless5/exp",
+        help="""The experiment dir.
+        It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_2000/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--initial-lr",
+        type=float,
+        default=0.003,
+        help="The initial learning rate.  This value should not need " "to be changed.",
+    )
+
+    parser.add_argument(
+        "--lr-batches",
+        type=float,
+        default=5000,
+        help="""Number of steps that affects how rapidly the learning rate
+        decreases. We suggest not to change this.""",
+    )
+
+    parser.add_argument(
+        "--lr-epochs",
+        type=float,
+        default=6,
+        help="""Number of epochs that affects how rapidly the learning rate decreases.
+        """,
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+    )
+
+    parser.add_argument(
+        "--prune-range",
+        type=int,
+        default=5,
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
+    )
+
+    parser.add_argument(
+        "--lm-scale",
+        type=float,
+        default=0.25,
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
+    )
+
+    parser.add_argument(
+        "--am-scale",
+        type=float,
+        default=0.0,
+        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+    )
+
+    parser.add_argument(
+        "--simple-loss-scale",
+        type=float,
+        default=0.5,
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
+    )
+
+    parser.add_argument(
+        "--seed",
+        type=int,
+        default=42,
+        help="The seed for random generators intended for reproducibility",
+    )
+
+    parser.add_argument(
+        "--print-diagnostics",
+        type=str2bool,
+        default=False,
+        help="Accumulate stats on activations, print them and exit.",
+    )
+
+    parser.add_argument(
+        "--save-every-n",
+        type=int,
+        default=8000,
+        help="""Save checkpoint after processing this number of batches"
+        periodically. We save checkpoint to exp-dir/ whenever
+        params.batch_idx_train % save_every_n == 0. The checkpoint filename
+        has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+        Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+        end of each epoch where `xxx` is the epoch number counting from 0.
+        """,
+    )
+
+    parser.add_argument(
+        "--keep-last-k",
+        type=int,
+        default=10,
+        help="""Only keep this number of checkpoints on disk.
+        For instance, if it is 3, there are only 3 checkpoints
+        in the exp-dir with filenames `checkpoint-xxx.pt`.
+        It does not affect checkpoints with name `epoch-xxx.pt`.
+        """,
+    )
+
+    parser.add_argument(
+        "--average-period",
+        type=int,
+        default=100,
+        help="""Update the averaged model, namely `model_avg`, after processing
+        this number of batches. `model_avg` is a separate version of model,
+        in which each floating-point parameter is the average of all the
+        parameters from the start of training. Each time we take the average,
+        we do: `model_avg = model * (average_period / batch_idx_train) +
+            model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+        """,
+    )
+
+    parser.add_argument(
+        "--use-fp16",
+        type=str2bool,
+        default=True,
+        help="Whether to use half precision training.",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    """Return a dict containing training parameters.
+
+    All training related parameters that are not passed from the commandline
+    are saved in the variable `params`.
+
+    Commandline options are merged into `params` after they are parsed, so
+    you can also access them via `params`.
+
+    Explanation of options saved in `params`:
+
+        - best_train_loss: Best training loss so far. It is used to select
+                           the model that has the lowest training loss. It is
+                           updated during the training.
+
+        - best_valid_loss: Best validation loss so far. It is used to select
+                           the model that has the lowest validation loss. It is
+                           updated during the training.
+
+        - best_train_epoch: It is the epoch that has the best training loss.
+
+        - best_valid_epoch: It is the epoch that has the best validation loss.
+
+        - batch_idx_train: Used to writing statistics to tensorboard. It
+                           contains number of batches trained so far across
+                           epochs.
+
+        - log_interval:  Print training loss if batch_idx % log_interval` is 0
+
+        - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+        - valid_interval:  Run validation if batch_idx % valid_interval is 0
+
+        - feature_dim: The model input dim. It has to match the one used
+                       in computing features.
+
+        - subsampling_factor:  The subsampling factor for the model.
+
+        - encoder_dim: Hidden dim for multi-head attention model.
+
+        - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+        - warm_step: The warm_step for Noam optimizer.
+    """
+    params = AttributeDict(
+        {
+            "best_train_loss": float("inf"),
+            "best_valid_loss": float("inf"),
+            "best_train_epoch": -1,
+            "best_valid_epoch": -1,
+            "batch_idx_train": 0,
+            "log_interval": 50,
+            "reset_interval": 200,
+            "valid_interval": 3000,  # For the 100h subset, use 800
+            # parameters for conformer
+            "feature_dim": 80,
+            "subsampling_factor": 4,
+            # parameters for Noam
+            "model_warm_step": 80000,  # arg given to model, not for lrate
+            "env_info": get_env_info(),
+        }
+    )
+
+    return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+    # TODO: We can add an option to switch between Conformer and Transformer
+    encoder = Conformer(
+        num_features=params.feature_dim,
+        subsampling_factor=params.subsampling_factor,
+        d_model=params.encoder_dim,
+        nhead=params.nhead,
+        dim_feedforward=params.dim_feedforward,
+        num_encoder_layers=params.num_encoder_layers,
+    )
+    return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+    decoder = Decoder(
+        vocab_size=params.vocab_size,
+        decoder_dim=params.decoder_dim,
+        blank_id=params.blank_id,
+        context_size=params.context_size,
+    )
+    return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+    joiner = Joiner(
+        encoder_dim=params.encoder_dim,
+        decoder_dim=params.decoder_dim,
+        joiner_dim=params.joiner_dim,
+        vocab_size=params.vocab_size,
+    )
+    return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+    encoder = get_encoder_model(params)
+    decoder = get_decoder_model(params)
+    joiner = get_joiner_model(params)
+
+    model = Transducer(
+        encoder=encoder,
+        decoder=decoder,
+        joiner=joiner,
+        encoder_dim=params.encoder_dim,
+        decoder_dim=params.decoder_dim,
+        joiner_dim=params.joiner_dim,
+        vocab_size=params.vocab_size,
+    )
+    return model
+
+
+def load_checkpoint_if_available(
+    params: AttributeDict,
+    model: nn.Module,
+    model_avg: nn.Module = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+    """Load checkpoint from file.
+
+    If params.start_batch is positive, it will load the checkpoint from
+    `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+    params.start_epoch is larger than 1, it will load the checkpoint from
+    `params.start_epoch - 1`.
+
+    Apart from loading state dict for `model` and `optimizer` it also updates
+    `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+    and `best_valid_loss` in `params`.
+
+    Args:
+      params:
+        The return value of :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer that we are using.
+      scheduler:
+        The scheduler that we are using.
+    Returns:
+      Return a dict containing previously saved training info.
+    """
+    if params.start_batch > 0:
+        filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+    elif params.start_epoch > 1:
+        filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+    else:
+        return None
+
+    assert filename.is_file(), f"{filename} does not exist!"
+
+    saved_params = load_checkpoint(
+        filename,
+        model=model,
+        model_avg=model_avg,
+        optimizer=optimizer,
+        scheduler=scheduler,
+    )
+
+    keys = [
+        "best_train_epoch",
+        "best_valid_epoch",
+        "batch_idx_train",
+        "best_train_loss",
+        "best_valid_loss",
+    ]
+    for k in keys:
+        params[k] = saved_params[k]
+
+    if params.start_batch > 0:
+        if "cur_epoch" in saved_params:
+            params["start_epoch"] = saved_params["cur_epoch"]
+
+        if "cur_batch_idx" in saved_params:
+            params["cur_batch_idx"] = saved_params["cur_batch_idx"]
+
+    return saved_params
+
+
+def save_checkpoint(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    model_avg: Optional[nn.Module] = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+    sampler: Optional[CutSampler] = None,
+    scaler: Optional[GradScaler] = None,
+    rank: int = 0,
+) -> None:
+    """Save model, optimizer, scheduler and training stats to file.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer used in the training.
+      sampler:
+       The sampler for the training dataset.
+      scaler:
+        The scaler used for mix precision training.
+    """
+    if rank != 0:
+        return
+    filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+    save_checkpoint_impl(
+        filename=filename,
+        model=model,
+        model_avg=model_avg,
+        params=params,
+        optimizer=optimizer,
+        scheduler=scheduler,
+        sampler=sampler,
+        scaler=scaler,
+        rank=rank,
+    )
+
+    if params.best_train_epoch == params.cur_epoch:
+        best_train_filename = params.exp_dir / "best-train-loss.pt"
+        copyfile(src=filename, dst=best_train_filename)
+
+    if params.best_valid_epoch == params.cur_epoch:
+        best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+        copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    sp: spm.SentencePieceProcessor,
+    batch: dict,
+    is_training: bool,
+    warmup: float = 1.0,
+    reduction="none",
+) -> Tuple[Tensor, MetricsTracker]:
+    """
+    Compute CTC loss given the model and its inputs.
+
+    Args:
+      params:
+        Parameters for training. See :func:`get_params`.
+      model:
+        The model for training. It is an instance of Conformer in our case.
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      is_training:
+        True for training. False for validation. When it is True, this
+        function enables autograd during computation; when it is False, it
+        disables autograd.
+     warmup: a floating point value which increases throughout training;
+        values >= 1.0 are fully warmed up and have all modules present.
+    """
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+    feature = batch["inputs"]
+    # at entry, feature is (N, T, C)
+    assert feature.ndim == 3
+    feature = feature.to(device)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    texts = batch["supervisions"]["text"]
+    y = sp.encode(texts, out_type=int)
+    y = k2.RaggedTensor(y).to(device)
+
+    with torch.set_grad_enabled(is_training):
+        simple_loss, pruned_loss = model(
+            x=feature,
+            x_lens=feature_lens,
+            y=y,
+            prune_range=params.prune_range,
+            am_scale=params.am_scale,
+            lm_scale=params.lm_scale,
+            warmup=warmup,
+            reduction="none",
+        )
+        simple_loss_is_finite = torch.isfinite(simple_loss)
+        pruned_loss_is_finite = torch.isfinite(pruned_loss)
+        is_finite = simple_loss_is_finite & pruned_loss_is_finite
+        inf_flag = False
+        if not torch.all(is_finite):
+            inf_flag = True
+            logging.info(
+                "Not all losses are finite!\n"
+                f"simple_loss: {simple_loss}\n"
+                f"pruned_loss: {pruned_loss}"
+            )
+            display_and_save_batch(batch, params=params, sp=sp)
+            simple_loss = simple_loss[simple_loss_is_finite]
+            pruned_loss = pruned_loss[pruned_loss_is_finite]
+
+        simple_loss = simple_loss.sum()
+        pruned_loss = pruned_loss.sum()
+
+        # after the main warmup step, we keep pruned_loss_scale small
+        # for the same amount of time (model_warm_step), to avoid
+        # overwhelming the simple_loss and causing it to diverge,
+        # in case it had not fully learned the alignment yet.
+        pruned_loss_scale = (
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+        )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+    assert loss.requires_grad == is_training
+
+    info = MetricsTracker()
+    with warnings.catch_warnings():
+        warnings.simplefilter("ignore")
+
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+    # # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances`  # noqa
+    # info["utterances"] = feature.size(0)
+    # # averaged input duration in frames over utterances
+    # info["utt_duration"] = feature_lens.sum().item()
+    # # averaged padding proportion over utterances
+    # info["utt_pad_proportion"] = (
+    #     ((feature.size(1) - feature_lens) / feature.size(1)).sum().item()
+    # )
+
+    # Note: We use reduction=sum while computing the loss.
+    info["loss"] = loss.detach().cpu().item()
+    info["simple_loss"] = simple_loss.detach().cpu().item()
+    info["pruned_loss"] = pruned_loss.detach().cpu().item()
+
+    return loss, info, inf_flag
+
+
+def compute_validation_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    sp: spm.SentencePieceProcessor,
+    valid_dl: torch.utils.data.DataLoader,
+    world_size: int = 1,
+) -> MetricsTracker:
+    """Run the validation process."""
+    model.eval()
+
+    tot_loss = MetricsTracker()
+    with torch.no_grad():
+        for batch_idx, batch in enumerate(valid_dl):
+            loss, loss_info, inf_flag = compute_loss(
+                params=params,
+                model=model,
+                sp=sp,
+                batch=batch,
+                is_training=False,
+            )
+            assert loss.requires_grad is False
+            tot_loss = tot_loss + loss_info
+
+        if world_size > 1:
+            tot_loss.reduce(loss.device)
+
+        loss_value = tot_loss["loss"] / tot_loss["frames"]
+        if loss_value < params.best_valid_loss:
+            params.best_valid_epoch = params.cur_epoch
+            params.best_valid_loss = loss_value
+
+    return tot_loss
+
+
+def train_one_epoch(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    optimizer: torch.optim.Optimizer,
+    scheduler: LRSchedulerType,
+    sp: spm.SentencePieceProcessor,
+    train_dl: torch.utils.data.DataLoader,
+    valid_dl: torch.utils.data.DataLoader,
+    scaler: GradScaler,
+    model_avg: Optional[nn.Module] = None,
+    tb_writer: Optional[SummaryWriter] = None,
+    world_size: int = 1,
+    rank: int = 0,
+) -> None:
+    """Train the model for one epoch.
+
+    The training loss from the mean of all frames is saved in
+    `params.train_loss`. It runs the validation process every
+    `params.valid_interval` batches.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The model for training.
+      optimizer:
+        The optimizer we are using.
+      scheduler:
+        The learning rate scheduler, we call step() every step.
+      train_dl:
+        Dataloader for the training dataset.
+      valid_dl:
+        Dataloader for the validation dataset.
+      scaler:
+        The scaler used for mix precision training.
+      model_avg:
+        The stored model averaged from the start of training.
+      tb_writer:
+        Writer to write log messages to tensorboard.
+      world_size:
+        Number of nodes in DDP training. If it is 1, DDP is disabled.
+      rank:
+        The rank of the node in DDP training. If no DDP is used, it should
+        be set to 0.
+    """
+    model.train()
+
+    tot_loss = MetricsTracker()
+
+    cur_batch_idx = params.get("cur_batch_idx", 0)
+
+    for batch_idx, batch in enumerate(train_dl):
+
+        if batch["inputs"].shape[0] == len(batch["supervisions"]["text"]):
+            if batch_idx < cur_batch_idx:
+                continue
+            cur_batch_idx = batch_idx
+
+            params.batch_idx_train += 1
+            batch_size = len(batch["supervisions"]["text"])
+
+            try:
+                with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                    loss, loss_info, inf_flag = compute_loss(
+                        params=params,
+                        model=model,
+                        sp=sp,
+                        batch=batch,
+                        is_training=True,
+                        warmup=(params.batch_idx_train / params.model_warm_step),
+                    )
+                # summary stats
+                tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+                # NOTE: We use reduction==sum and loss is computed over utterances
+                # in the batch and there is no normalization to it so far.
+                if not inf_flag:
+                    scaler.scale(loss).backward()
+                    scheduler.step_batch(params.batch_idx_train)
+                    scaler.step(optimizer)
+                    scaler.update()
+                    optimizer.zero_grad()
+                else:
+                    continue
+            except:  # noqa
+                display_and_save_batch(batch, params=params, sp=sp)
+                raise
+
+            if params.print_diagnostics and batch_idx == 5:
+                return
+
+            if (
+                rank == 0
+                and params.batch_idx_train > 0
+                and params.batch_idx_train % params.average_period == 0
+            ):
+                update_averaged_model(
+                    params=params,
+                    model_cur=model,
+                    model_avg=model_avg,
+                )
+
+            if (
+                params.batch_idx_train > 0
+                and params.batch_idx_train % params.save_every_n == 0
+            ):
+                params.cur_batch_idx = batch_idx
+                save_checkpoint_with_global_batch_idx(
+                    out_dir=params.exp_dir,
+                    global_batch_idx=params.batch_idx_train,
+                    model=model,
+                    model_avg=model_avg,
+                    params=params,
+                    optimizer=optimizer,
+                    scheduler=scheduler,
+                    sampler=train_dl.sampler,
+                    scaler=scaler,
+                    rank=rank,
+                )
+                del params.cur_batch_idx
+                remove_checkpoints(
+                    out_dir=params.exp_dir,
+                    topk=params.keep_last_k,
+                    rank=rank,
+                )
+
+            if batch_idx % params.log_interval == 0:
+                cur_lr = scheduler.get_last_lr()[0]
+                # https://silpara.medium.com/check-gpu-memory-usage-from-python-ccca503322ea
+                memory_debugging()
+                logging.info(
+                    f"Epoch {params.cur_epoch}, "
+                    f"batch {batch_idx}, loss[{loss_info}], "
+                    f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+                    f"lr: {cur_lr:.2e}"
+                )
+
+                if tb_writer is not None:
+                    tb_writer.add_scalar(
+                        "train/learning_rate", cur_lr, params.batch_idx_train
+                    )
+
+                    loss_info.write_summary(
+                        tb_writer, "train/current_", params.batch_idx_train
+                    )
+                    tot_loss.write_summary(
+                        tb_writer, "train/tot_", params.batch_idx_train
+                    )
+
+            if batch_idx > 0 and batch_idx % params.valid_interval == 0:
+                logging.info("Computing validation loss")
+                valid_info = compute_validation_loss(
+                    params=params,
+                    model=model,
+                    sp=sp,
+                    valid_dl=valid_dl,
+                    world_size=world_size,
+                )
+                model.train()
+                logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+                if tb_writer is not None:
+                    valid_info.write_summary(
+                        tb_writer, "train/valid_", params.batch_idx_train
+                    )
+        else:
+            logging.warning(
+                f"Batch {batch_idx} mismatch in dimentions between the input and the output. Skipping ..."
+            )
+            continue
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    params.train_loss = loss_value
+    if params.train_loss < params.best_train_loss:
+        params.best_train_epoch = params.cur_epoch
+        params.best_train_loss = params.train_loss
+
+
+def memory_debugging():
+    # memory nvidia debugging
+    nvidia_smi.nvmlInit()
+
+    deviceCount = nvidia_smi.nvmlDeviceGetCount()
+    for i in range(deviceCount):
+        handle = nvidia_smi.nvmlDeviceGetHandleByIndex(i)
+        info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
+        logging.info(
+            "Device {}: {}, Memory : ({:.2f}% free): {}(total), {} (free), {} (used)".format(
+                i,
+                nvidia_smi.nvmlDeviceGetName(handle),
+                100 * info.free / info.total,
+                info.total,
+                info.free,
+                info.used,
+            )
+        )
+
+    nvidia_smi.nvmlShutdown()
+
+
+def run(rank, world_size, args):
+    """
+    Args:
+      rank:
+        It is a value between 0 and `world_size-1`, which is
+        passed automatically by `mp.spawn()` in :func:`main`.
+        The node with rank 0 is responsible for saving checkpoint.
+      world_size:
+        Number of GPUs for DDP training.
+      args:
+        The return value of get_parser().parse_args()
+    """
+    params = get_params()
+    params.update(vars(args))
+
+    fix_random_seed(params.seed)
+    if world_size > 1:
+        setup_dist(rank, world_size, params.master_port)
+
+    setup_logger(f"{params.exp_dir}/log/log-train")
+    logging.info("Training started")
+
+    if args.tensorboard and rank == 0:
+        tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+    else:
+        tb_writer = None
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", rank)
+    logging.info(f"Device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    assert params.save_every_n >= params.average_period
+    model_avg: Optional[nn.Module] = None
+    if rank == 0:
+        # model_avg is only used with rank 0
+        model_avg = copy.deepcopy(model)
+
+    checkpoints = load_checkpoint_if_available(
+        params=params, model=model, model_avg=model_avg
+    )
+
+    model.to(device)
+    if world_size > 1:
+        logging.info("Using DDP")
+        model = DDP(model, device_ids=[rank])
+
+    optimizer = Eve(model.parameters(), lr=params.initial_lr)
+
+    scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+    if checkpoints and "optimizer" in checkpoints:
+        logging.info("Loading optimizer state dict")
+        optimizer.load_state_dict(checkpoints["optimizer"])
+
+    if (
+        checkpoints
+        and "scheduler" in checkpoints
+        and checkpoints["scheduler"] is not None
+    ):
+        logging.info("Loading scheduler state dict")
+        scheduler.load_state_dict(checkpoints["scheduler"])
+
+    if params.print_diagnostics:
+        opts = diagnostics.TensorDiagnosticOptions(
+            2**22
+        )  # allow 4 megabytes per sub-module
+        diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+    MGB2 = MGB2AsrDataModule(args)
+    train_cuts = MGB2.train_cuts()
+
+    def remove_short_and_long_utt(c: Cut):
+        # Keep only utterances with duration between 1 second and 30 seconds
+        #
+        # Caution: There is a reason to select 20.0 here. Please see
+        # ../local/display_manifest_statistics.py
+        #
+        # You should use ../local/display_manifest_statistics.py to get
+        # an utterance duration distribution for your dataset to select
+        # the threshold
+        return 0.5 <= c.duration <= 30.0
+
+    def remove_short_and_long_text(c: Cut):
+        # Keep only text with charachters between 20 and 450
+
+        return 20 <= len(c.supervisions[0].text) <= 450
+
+    train_cuts = train_cuts.filter(remove_short_and_long_utt)
+    train_cuts = train_cuts.filter(remove_short_and_long_text)
+
+    if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+        # We only load the sampler's state dict when it loads a checkpoint
+        # saved in the middle of an epoch
+        sampler_state_dict = checkpoints["sampler"]
+    else:
+        sampler_state_dict = None
+
+    train_dl = MGB2.train_dataloaders(train_cuts, sampler_state_dict=sampler_state_dict)
+
+    valid_cuts = MGB2.dev_cuts()
+    valid_dl = MGB2.test_dataloaders(valid_cuts)
+
+    if not params.print_diagnostics:
+        scan_pessimistic_batches_for_oom(
+            model=model,
+            train_dl=train_dl,
+            optimizer=optimizer,
+            sp=sp,
+            params=params,
+        )
+
+    scaler = GradScaler(enabled=params.use_fp16)
+    if checkpoints and "grad_scaler" in checkpoints:
+        logging.info("Loading grad scaler state dict")
+        scaler.load_state_dict(checkpoints["grad_scaler"])
+
+    for epoch in range(params.start_epoch, params.num_epochs + 1):
+
+        scheduler.step_epoch(epoch - 1)
+        fix_random_seed(params.seed + epoch - 1)
+        train_dl.sampler.set_epoch(epoch - 1)
+
+        if tb_writer is not None:
+            tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+        params.cur_epoch = epoch
+
+        train_one_epoch(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sp=sp,
+            train_dl=train_dl,
+            valid_dl=valid_dl,
+            scaler=scaler,
+            tb_writer=tb_writer,
+            world_size=world_size,
+            rank=rank,
+        )
+
+        if params.print_diagnostics:
+            diagnostic.print_diagnostics()
+            break
+
+        save_checkpoint(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sampler=train_dl.sampler,
+            scaler=scaler,
+            rank=rank,
+        )
+
+    logging.info("Done!")
+
+    if world_size > 1:
+        torch.distributed.barrier()
+        cleanup_dist()
+
+
+def display_and_save_batch(
+    batch: dict,
+    params: AttributeDict,
+    sp: spm.SentencePieceProcessor,
+) -> None:
+    """Display the batch statistics and save the batch into disk.
+
+    Args:
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      params:
+        Parameters for training. See :func:`get_params`.
+      sp:
+        The BPE model.
+    """
+    from lhotse.utils import uuid4
+
+    filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+    logging.info(f"Saving batch to {filename}")
+    torch.save(batch, filename)
+
+    supervisions = batch["supervisions"]
+    features = batch["inputs"]
+
+    logging.info(f"features shape: {features.shape}")
+
+    y = sp.encode(supervisions["text"], out_type=int)
+    num_tokens = sum(len(i) for i in y)
+    logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+    model: Union[nn.Module, DDP],
+    train_dl: torch.utils.data.DataLoader,
+    optimizer: torch.optim.Optimizer,
+    sp: spm.SentencePieceProcessor,
+    params: AttributeDict,
+):
+    from lhotse.dataset import find_pessimistic_batches
+
+    logging.info(
+        "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+    )
+    batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+    for criterion, cuts in batches.items():
+        batch = train_dl.dataset[cuts]
+        try:
+            # warmup = 0.0 is so that the derivs for the pruned loss stay zero
+            # (i.e. are not remembered by the decaying-average in adam), because
+            # we want to avoid these params being subject to shrinkage in adam.
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+
+                loss, _, _ = compute_loss(
+                    params=params,
+                    model=model,
+                    sp=sp,
+                    batch=batch,
+                    is_training=True,
+                    warmup=0.0,
+                )
+            loss.backward()
+            # clip_grad_norm_(model.parameters(), 5.0, 2.0)
+            optimizer.step()
+            optimizer.zero_grad()
+        except Exception as e:
+            if "CUDA out of memory" in str(e):
+                logging.error(
+                    "Your GPU ran out of memory with the current "
+                    "max_duration setting. We recommend decreasing "
+                    "max_duration and trying again.\n"
+                    f"Failing criterion: {criterion} "
+                    f"(={crit_values[criterion]}) ..."
+                )
+            display_and_save_batch(batch, params=params, sp=sp)
+            raise
+
+
+def main():
+    parser = get_parser()
+    MGB2AsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    world_size = args.world_size
+    assert world_size >= 1
+    if world_size > 1:
+        mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+    else:
+        run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/mgb2/ASR/shared b/egs/mgb2/ASR/shared
new file mode 120000
index 000000000..4c5e91438
--- /dev/null
+++ b/egs/mgb2/ASR/shared
@@ -0,0 +1 @@
+../../../icefall/shared/
\ No newline at end of file
diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py
index 207c12bf1..6589579d1 100644
--- a/icefall/diagnostics.py
+++ b/icefall/diagnostics.py
@@ -263,7 +263,7 @@ class TensorDiagnostic(object):
                     ans += f", norm={norm:.2g}"
                 mean = stats.mean().item()
                 rms = (stats**2).mean().sqrt().item()
-                ans += f", mean={mean:.3g}, rms={rms:.3g}"
+                ans += f", mean={mean:.2g}, rms={rms:.2g}"
 
                 # OK, "ans" contains the actual stats, e.g.
                 # ans = "percentiles: [0.43 0.46 0.48 0.49 0.49 0.5 0.51 0.52 0.53 0.54 0.59], mean=0.5, rms=0.5"

From 7700ddcb38b5ba0d91334947e3cac44825f1cf7c Mon Sep 17 00:00:00 2001
From: Weiji Zhuang 
Date: Fri, 2 Dec 2022 17:40:42 +0800
Subject: [PATCH 058/120] update multidataset zipformer results (#728)

---
 egs/librispeech/ASR/RESULTS.md | 26 +++++++++++++++-----------
 1 file changed, 15 insertions(+), 11 deletions(-)

diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md
index c2ea3d050..0885fb9b6 100644
--- a/egs/librispeech/ASR/RESULTS.md
+++ b/egs/librispeech/ASR/RESULTS.md
@@ -108,21 +108,25 @@ See  for more details.
 [pruned_transducer_stateless8](./pruned_transducer_stateless8)
 
 The tensorboard log can be found at
-
+
 
 You can find a pretrained model, training logs, decoding logs, and decoding
 results at:
-
+
 
 You can use  to deploy it.
 
 Number of model parameters: 70369391, i.e., 70.37 M
 
-|                      | test-clean | test-other  | comment                                |
-|----------------------|------------|-------------|----------------------------------------|
-| greedy search        | 1.87       | 4.38        | --epoch 16 --avg 2 --max-duration 600  |
-| modified beam search | 1.81       | 4.34        | --epoch 16 --avg 2 --max-duration 600  |
-| fast beam search     | 1.91       | 4.33        | --epoch 16 --avg 2 --max-duration 600  |
+| decoding method      | test-clean | test-other | comment            |
+|----------------------|------------|------------|--------------------|
+| greedy_search        | 1.81       | 4.18       | --epoch 20 --avg 4 |
+| fast_beam_search     | 1.82       | 4.15       | --epoch 20 --avg 4 |
+| modified_beam_search | 1.78       | **4.08**   | --epoch 20 --avg 4 |
+| greedy_search        | 1.84       | 4.3        | --epoch 19 --avg 8 |
+| fast_beam_search     |**1.77**    | 4.25       | --epoch 19 --avg 8 |
+| modified_beam_search | 1.81       | 4.16       | --epoch 19 --avg 8 |
+
 
 The training commands are:
 ```bash
@@ -142,15 +146,15 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
 
 The decoding commands are:
 ```bash
-for m in greedy_search fast_beam_search modified_beam_search ; do
-  for epoch in 16; do
-    for avg in 2; do
+for m in greedy_search fast_beam_search modified_beam_search; do
+  for epoch in $(seq 20 -1 10); do
+    for avg in $(seq 9 -1 1); do
       ./pruned_transducer_stateless8/decode.py \
           --epoch $epoch \
           --avg $avg \
           --use-averaged-model 1 \
           --exp-dir ./pruned_transducer_stateless8/exp \
-          --feedforward-dims  "1024,1024,2048,2048,1024" \
+          --feedforward-dims "1024,1024,2048,2048,1024" \
           --max-duration 600 \
           --decoding-method $m
     done

From 8eb4b9d96da0432c1c27901f2964da954583d69a Mon Sep 17 00:00:00 2001
From: Zengwei Yao 
Date: Sat, 3 Dec 2022 19:01:10 +0800
Subject: [PATCH 059/120] Combining rnnt loss and k2-ctc loss for Dan's
 Zipformer (#683)

* init files

* add ctc as auxiliary loss and ctc_decode.py

* tuning the scalar of HLG score for 1best, nbest and nbest-oracle

* rename to pruned_transducer_stateless7_ctc

* fix doc

* fix bug, recover the hlg scores

* modify ctc_decode.py, move out the hlg scale

* fix hlg_scale

* add export.py and pretrained.py, and so on

* upload files, update README.md and RESULTS.md

* add CI test
---
 ...ed-transducer-stateless7-ctc-2022-12-01.sh |  147 ++
 ...-librispeech-2022-12-01-stateless7-ctc.yml |  163 +++
 egs/librispeech/ASR/README.md                 |    1 +
 egs/librispeech/ASR/RESULTS.md                |   79 ++
 .../ASR/conformer_ctc3/jit_pretrained.py      |   20 +-
 .../__init__.py                               |    0
 .../asr_datamodule.py                         |    1 +
 .../beam_search.py                            |    1 +
 .../ctc_decode.py                             |  818 +++++++++++
 .../decode.py                                 |  841 +++++++++++
 .../decoder.py                                |    1 +
 .../encoder_interface.py                      |    1 +
 .../export.py                                 |  320 +++++
 .../jit_pretrained.py                         |  271 ++++
 .../jit_pretrained_ctc.py                     |  423 ++++++
 .../joiner.py                                 |    1 +
 .../pruned_transducer_stateless7_ctc/model.py |  198 +++
 .../pruned_transducer_stateless7_ctc/optim.py |    1 +
 .../pretrained.py                             |  353 +++++
 .../pretrained_ctc.py                         |  441 ++++++
 .../scaling.py                                |    1 +
 .../scaling_converter.py                      |    1 +
 .../test_model.py                             |   56 +
 .../pruned_transducer_stateless7_ctc/train.py | 1252 +++++++++++++++++
 .../zipformer.py                              |    1 +
 icefall/utils.py                              |   18 +-
 26 files changed, 5396 insertions(+), 14 deletions(-)
 create mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh
 create mode 100644 .github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml
 create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/__init__.py
 create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/asr_datamodule.py
 create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/beam_search.py
 create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py
 create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py
 create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decoder.py
 create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/encoder_interface.py
 create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py
 create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained.py
 create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py
 create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/joiner.py
 create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py
 create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/optim.py
 create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py
 create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py
 create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/scaling.py
 create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/scaling_converter.py
 create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/test_model.py
 create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py
 create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/zipformer.py

diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh
new file mode 100755
index 000000000..6642d5f67
--- /dev/null
+++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh
@@ -0,0 +1,147 @@
+#!/usr/bin/env bash
+
+set -e
+
+log() {
+  # This function is from espnet
+  local fname=${BASH_SOURCE[1]##*/}
+  echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+cd egs/librispeech/ASR
+
+repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-ctc-2022-12-01
+
+log "Downloading pre-trained model from $repo_url"
+git lfs install
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+repo=$(basename $repo_url)
+
+log "Display test files"
+tree $repo/
+soxi $repo/test_wavs/*.wav
+ls -lh $repo/test_wavs/*.wav
+
+pushd $repo/exp
+git lfs pull --include "data/*"
+git lfs pull --include "exp/cpu_jit.pt"
+git lfs pull --include "exp/pretrained.pt"
+ln -s pretrained.pt epoch-99.pt
+ls -lh *.pt
+popd
+
+log "Export to torchscript model"
+./pruned_transducer_stateless7_ctc/export.py \
+  --exp-dir $repo/exp \
+  --use-averaged-model false \
+  --bpe-model $repo/data/lang_bpe_500/bpe.model \
+  --epoch 99 \
+  --avg 1 \
+  --jit 1
+
+ls -lh $repo/exp/*.pt
+
+log "Decode with models exported by torch.jit.script()"
+
+./pruned_transducer_stateless7_ctc/jit_pretrained.py \
+  --bpe-model $repo/data/lang_bpe_500/bpe.model \
+  --nn-model-filename $repo/exp/cpu_jit.pt \
+  $repo/test_wavs/1089-134686-0001.wav \
+  $repo/test_wavs/1221-135766-0001.wav \
+  $repo/test_wavs/1221-135766-0002.wav
+
+for m in ctc-decoding 1best; do
+  ./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \
+    --model-filename $repo/exp/cpu_jit.pt \
+    --words-file $repo/data/lang_bpe_500/words.txt  \
+    --HLG $repo/data/lang_bpe_500/HLG.pt \
+    --bpe-model $repo/data/lang_bpe_500/bpe.model \
+    --G $repo/data/lm/G_4_gram.pt \
+    --method $m \
+    --sample-rate 16000 \
+    $repo/test_wavs/1089-134686-0001.wav \
+    $repo/test_wavs/1221-135766-0001.wav \
+    $repo/test_wavs/1221-135766-0002.wav
+done
+
+for sym in 1 2 3; do
+  log "Greedy search with --max-sym-per-frame $sym"
+
+  ./pruned_transducer_stateless7_ctc/pretrained.py \
+    --method greedy_search \
+    --max-sym-per-frame $sym \
+    --checkpoint $repo/exp/pretrained.pt \
+    --bpe-model $repo/data/lang_bpe_500/bpe.model \
+    $repo/test_wavs/1089-134686-0001.wav \
+    $repo/test_wavs/1221-135766-0001.wav \
+    $repo/test_wavs/1221-135766-0002.wav
+done
+
+for method in modified_beam_search beam_search fast_beam_search; do
+  log "$method"
+
+  ./pruned_transducer_stateless7_ctc/pretrained.py \
+    --method $method \
+    --beam-size 4 \
+    --checkpoint $repo/exp/pretrained.pt \
+    --bpe-model $repo/data/lang_bpe_500/bpe.model \
+    $repo/test_wavs/1089-134686-0001.wav \
+    $repo/test_wavs/1221-135766-0001.wav \
+    $repo/test_wavs/1221-135766-0002.wav
+done
+
+for m in ctc-decoding 1best; do
+  ./pruned_transducer_stateless7_ctc/pretrained_ctc.py \
+    --checkpoint $repo/exp/pretrained.pt \
+    --words-file $repo/data/lang_bpe_500/words.txt  \
+    --HLG $repo/data/lang_bpe_500/HLG.pt \
+    --bpe-model $repo/data/lang_bpe_500/bpe.model \
+    --G $repo/data/lm/G_4_gram.pt \
+    --method $m \
+    --sample-rate 16000 \
+    $repo/test_wavs/1089-134686-0001.wav \
+    $repo/test_wavs/1221-135766-0001.wav \
+    $repo/test_wavs/1221-135766-0002.wav
+done
+
+echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
+echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
+if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode"  ]]; then
+  mkdir -p pruned_transducer_stateless7_ctc/exp
+  ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless7_ctc/exp/epoch-999.pt
+  ln -s $PWD/$repo/data/lang_bpe_500 data/
+
+  ls -lh data
+  ls -lh pruned_transducer_stateless7_ctc/exp
+
+  log "Decoding test-clean and test-other"
+
+  # use a small value for decoding with CPU
+  max_duration=100
+
+  for method in greedy_search fast_beam_search modified_beam_search; do
+    log "Decoding with $method"
+
+    ./pruned_transducer_stateless7_ctc/decode.py \
+      --decoding-method $method \
+      --epoch 999 \
+      --avg 1 \
+      --use-averaged-model 0 \
+      --max-duration $max_duration \
+      --exp-dir pruned_transducer_stateless7_ctc/exp
+  done
+
+  for m in ctc-decoding 1best; do
+    ./pruned_transducer_stateless7_ctc/ctc_decode.py \
+        --epoch 999 \
+        --avg 1 \
+        --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+        --max-duration $max_duration \
+        --use-averaged-model 0 \
+        --decoding-method $m \
+        --hlg-scale 0.6 \
+        --lm-dir data/lm
+  done
+
+  rm pruned_transducer_stateless7_ctc/exp/*.pt
+fi
diff --git a/.github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml b/.github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml
new file mode 100644
index 000000000..ccd8d50d0
--- /dev/null
+++ b/.github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml
@@ -0,0 +1,163 @@
+# Copyright      2022  Fangjun Kuang (csukuangfj@gmail.com)
+
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: run-librispeech-2022-12-01-stateless7-ctc
+# zipformer
+
+on:
+  push:
+    branches:
+      - master
+  pull_request:
+    types: [labeled]
+
+  schedule:
+    # minute (0-59)
+    # hour (0-23)
+    # day of the month (1-31)
+    # month (1-12)
+    # day of the week (0-6)
+    # nightly build at 15:50 UTC time every day
+    - cron: "50 15 * * *"
+
+jobs:
+  run_librispeech_2022_11_11_zipformer:
+    if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
+    runs-on: ${{ matrix.os }}
+    strategy:
+      matrix:
+        os: [ubuntu-latest]
+        python-version: [3.8]
+
+      fail-fast: false
+
+    steps:
+      - uses: actions/checkout@v2
+        with:
+          fetch-depth: 0
+
+      - name: Setup Python ${{ matrix.python-version }}
+        uses: actions/setup-python@v2
+        with:
+          python-version: ${{ matrix.python-version }}
+          cache: 'pip'
+          cache-dependency-path: '**/requirements-ci.txt'
+
+      - name: Install Python dependencies
+        run: |
+          grep -v '^#' ./requirements-ci.txt  | xargs -n 1 -L 1 pip install
+          pip uninstall -y protobuf
+          pip install --no-binary protobuf protobuf
+
+      - name: Cache kaldifeat
+        id: my-cache
+        uses: actions/cache@v2
+        with:
+          path: |
+            ~/tmp/kaldifeat
+          key: cache-tmp-${{ matrix.python-version }}-2022-09-25
+
+      - name: Install kaldifeat
+        if: steps.my-cache.outputs.cache-hit != 'true'
+        shell: bash
+        run: |
+          .github/scripts/install-kaldifeat.sh
+
+      - name: Cache LibriSpeech test-clean and test-other datasets
+        id: libri-test-clean-and-test-other-data
+        uses: actions/cache@v2
+        with:
+          path: |
+            ~/tmp/download
+          key: cache-libri-test-clean-and-test-other
+
+      - name: Download LibriSpeech test-clean and test-other
+        if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
+        shell: bash
+        run: |
+          .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
+
+      - name: Prepare manifests for LibriSpeech test-clean and test-other
+        shell: bash
+        run: |
+          .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
+
+      - name: Cache LibriSpeech test-clean and test-other fbank features
+        id: libri-test-clean-and-test-other-fbank
+        uses: actions/cache@v2
+        with:
+          path: |
+            ~/tmp/fbank-libri
+          key: cache-libri-fbank-test-clean-and-test-other-v2
+
+      - name: Compute fbank for LibriSpeech test-clean and test-other
+        if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
+        shell: bash
+        run: |
+          .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
+
+      - name: Inference with pre-trained model
+        shell: bash
+        env:
+          GITHUB_EVENT_NAME: ${{ github.event_name }}
+          GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
+        run: |
+          mkdir -p egs/librispeech/ASR/data
+          ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
+          ls -lh egs/librispeech/ASR/data/*
+
+          sudo apt-get -qq install git-lfs tree sox
+          export PYTHONPATH=$PWD:$PYTHONPATH
+          export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
+          export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
+
+          .github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh
+
+      - name: Display decoding results for librispeech pruned_transducer_stateless7_ctc
+        if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+        shell: bash
+        run: |
+          cd egs/librispeech/ASR/
+          tree ./pruned_transducer_stateless7_ctc/exp
+
+          cd pruned_transducer_stateless7_ctc
+          echo "results for pruned_transducer_stateless7_ctc"
+          echo "===greedy search==="
+          find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+          find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+          echo "===fast_beam_search==="
+          find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+          find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+          echo "===modified beam search==="
+          find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+          find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+          echo "===ctc decoding==="
+          find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+          find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+          echo "===1best==="
+          find exp/1best -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+          find exp/1best -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+      - name: Upload decoding results for librispeech pruned_transducer_stateless7_ctc
+        uses: actions/upload-artifact@v2
+        if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+        with:
+          name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless7-ctc-2022-12-01
+          path: egs/librispeech/ASR/pruned_transducer_stateless7_ctc/exp/
diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md
index e737d68bd..caa23a49f 100644
--- a/egs/librispeech/ASR/README.md
+++ b/egs/librispeech/ASR/README.md
@@ -23,6 +23,7 @@ The following table lists the differences among them.
 | `pruned_transducer_stateless5`        | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless4 + more layers + random combiner|
 | `pruned_transducer_stateless6`        | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless4 + distillation with hubert|
 | `pruned_transducer_stateless7`        | Zipformer | Embedding + Conv1d | First experiment with Zipformer from Dan|
+| `pruned_transducer_stateless7_ctc`    | Zipformer | Embedding + Conv1d | Same as pruned_transducer_stateless7, but with extra CTC head|
 | `pruned_transducer_stateless8`        | Zipformer | Embedding + Conv1d | Same as pruned_transducer_stateless7, but using extra data from GigaSpeech|
 | `pruned_stateless_emformer_rnnt2`     | Emformer(from torchaudio) | Embedding + Conv1d | Using Emformer from torchaudio for streaming ASR|
 | `conv_emformer_transducer_stateless`  | ConvEmformer | Embedding + Conv1d | Using ConvEmformer for streaming ASR + mechanisms in reworked model |
diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md
index 0885fb9b6..9e5669f6d 100644
--- a/egs/librispeech/ASR/RESULTS.md
+++ b/egs/librispeech/ASR/RESULTS.md
@@ -1,5 +1,84 @@
 ## Results
 
+### pruned_transducer_stateless7_ctc (zipformer with transducer loss and ctc loss)
+
+See  for more details.
+
+[pruned_transducer_stateless7_ctc](./pruned_transducer_stateless7_ctc)
+
+The tensorboard log can be found at
+
+
+You can find a pretrained model, training logs, decoding logs, and decoding
+results at:
+
+
+Number of model parameters: 70561891, i.e., 70.56 M
+
+|                          | test-clean | test-other  | comment            |
+|--------------------------|------------|-------------|--------------------|
+| greedy search            | 2.23       | 5.19        | --epoch 30 --avg 8 |
+| modified beam search     | 2.21       | 5.12        | --epoch 30 --avg 8 |
+| fast beam search         | 2.23       | 5.18        | --epoch 30 --avg 8 |
+| ctc decoding             | 2.48       | 5.82        | --epoch 30 --avg 9 |
+| 1best                    | 2.43       | 5.22        | --epoch 30 --avg 9 |
+| nbest                    | 2.43       | 5.22        | --epoch 30 --avg 9 |
+| nbest rescoring          | 2.34       | 5.05        | --epoch 30 --avg 9 |
+| whole lattice rescoring  | 2.34       | 5.04        | --epoch 30 --avg 9 |
+
+The training commands are:
+```bash
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless7_ctc/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --full-libri 1 \
+  --use-fp16 1 \
+  --max-duration 750 \
+  --exp-dir pruned_transducer_stateless7_ctc/exp \
+  --feedforward-dims  "1024,1024,2048,2048,1024" \
+  --ctc-loss-scale 0.2 \
+  --master-port 12535
+```
+
+The decoding commands for the transducer branch are:
+```bash
+for m in greedy_search fast_beam_search modified_beam_search ; do
+  for epoch in 30; do
+    for avg in 8; do
+      ./pruned_transducer_stateless7_ctc/decode.py \
+          --epoch $epoch \
+          --avg $avg \
+          --use-averaged-model 1 \
+          --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+          --feedforward-dims  "1024,1024,2048,2048,1024" \
+          --max-duration 600 \
+          --decoding-method $m
+    done
+  done
+done
+```
+
+The decoding commands for the ctc branch are:
+```bash
+for m in ctc-decoding nbest nbest-rescoring whole-lattice-rescoring; do
+  for epoch in 30; do
+    for avg in 9; do
+      ./pruned_transducer_stateless7_ctc/ctc_decode.py \
+          --epoch $epoch \
+          --avg $avg \
+          --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+          --max-duration 100 \
+          --decoding-method $m \
+          --hlg-scale 0.6 \
+          --lm-dir data/lm
+    done
+  done
+done
+```
+
+
 ### LibriSpeech BPE training results (Conformer CTC, supporting delay penalty)
 
 #### [conformer_ctc3](./conformer_ctc3)
diff --git a/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py b/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py
index c96defd23..5be898e37 100755
--- a/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py
+++ b/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py
@@ -23,40 +23,44 @@ Usage (for non-streaming mode):
 
 (1) ctc-decoding
 ./conformer_ctc3/pretrained.py \
-  --checkpoint conformer_ctc3/exp/pretrained.pt \
+  --nn-model-filename ./conformer_ctc3/exp/cpu_jit.pt \
   --bpe-model data/lang_bpe_500/bpe.model \
   --method ctc-decoding \
   --sample-rate 16000 \
-  test_wavs/1089-134686-0001.wav
+  /path/to/foo.wav \
+  /path/to/bar.wav
 
 (2) 1best
 ./conformer_ctc3/pretrained.py \
-  --checkpoint conformer_ctc3/exp/pretrained.pt \
+  --nn-model-filename ./conformer_ctc3/exp/cpu_jit.pt \
   --HLG data/lang_bpe_500/HLG.pt \
   --words-file data/lang_bpe_500/words.txt  \
   --method 1best \
   --sample-rate 16000 \
-  test_wavs/1089-134686-0001.wav
+  /path/to/foo.wav \
+  /path/to/bar.wav
 
 (3) nbest-rescoring
 ./conformer_ctc3/pretrained.py \
-  --checkpoint conformer_ctc3/exp/pretrained.pt \
+  --nn-model-filename ./conformer_ctc3/exp/cpu_jit.pt \
   --HLG data/lang_bpe_500/HLG.pt \
   --words-file data/lang_bpe_500/words.txt  \
   --G data/lm/G_4_gram.pt \
   --method nbest-rescoring \
   --sample-rate 16000 \
-  test_wavs/1089-134686-0001.wav
+  /path/to/foo.wav \
+  /path/to/bar.wav
 
 (4) whole-lattice-rescoring
 ./conformer_ctc3/pretrained.py \
-  --checkpoint conformer_ctc3/exp/pretrained.pt \
+  --nn-model-filename ./conformer_ctc3/exp/cpu_jit.pt \
   --HLG data/lang_bpe_500/HLG.pt \
   --words-file data/lang_bpe_500/words.txt  \
   --G data/lm/G_4_gram.pt \
   --method whole-lattice-rescoring \
   --sample-rate 16000 \
-  test_wavs/1089-134686-0001.wav
+  /path/to/foo.wav \
+  /path/to/bar.wav
 """
 
 
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/__init__.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/asr_datamodule.py
new file mode 120000
index 000000000..a074d6085
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/asr_datamodule.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/asr_datamodule.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/beam_search.py
new file mode 120000
index 000000000..8554e44cc
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/beam_search.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/beam_search.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py
new file mode 100755
index 000000000..9c23e7d66
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py
@@ -0,0 +1,818 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
+#                                                 Liyong Guo,
+#                                                 Quandong Wang,
+#                                                 Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) ctc-decoding
+./pruned_transducer_stateless7_ctc/ctc_decode.py \
+    --epoch 30 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+    --max-duration 600 \
+    --decoding-method ctc-decoding
+
+(2) 1best
+./pruned_transducer_stateless7_ctc/ctc_decode.py \
+    --epoch 30 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+    --max-duration 600 \
+    --hlg-scale 0.8 \
+    --decoding-method 1best
+
+(3) nbest
+./pruned_transducer_stateless7_ctc/ctc_decode.py \
+    --epoch 30 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+    --max-duration 600 \
+    --hlg-scale 0.8 \
+    --decoding-method 1best
+
+(4) nbest-rescoring
+./pruned_transducer_stateless7_ctc/ctc_decode.py \
+    --epoch 30 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+    --max-duration 600 \
+    --hlg-scale 0.8 \
+    --lm-dir data/lm \
+    --decoding-method nbest-rescoring
+
+(5) whole-lattice-rescoring
+./pruned_transducer_stateless7_ctc/ctc_decode.py \
+    --epoch 30 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+    --max-duration 600 \
+    --hlg-scale 0.8 \
+    --lm-dir data/lm \
+    --decoding-method whole-lattice-rescoring
+"""
+
+
+import argparse
+import logging
+import math
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.decode import (
+    get_lattice,
+    nbest_decoding,
+    nbest_oracle,
+    one_best_decoding,
+    rescore_with_n_best_list,
+    rescore_with_whole_lattice,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+    AttributeDict,
+    get_texts,
+    setup_logger,
+    store_transcripts,
+    str2bool,
+    write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=30,
+        help="""It specifies the checkpoint to use for decoding.
+        Note: Epoch counts from 1.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=15,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=True,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless7_ctc/exp",
+        help="The experiment dir",
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_500/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=Path,
+        default="data/lang_bpe_500",
+        help="The lang dir containing word table and LG graph",
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+
+    parser.add_argument(
+        "--decoding-method",
+        type=str,
+        default="ctc-decoding",
+        help="""Decoding method.
+        Supported values are:
+        - (1) ctc-decoding. Use CTC decoding. It uses a sentence piece
+          model, i.e., lang_dir/bpe.model, to convert word pieces to words.
+          It needs neither a lexicon nor an n-gram LM.
+        - (2) 1best. Extract the best path from the decoding lattice as the
+          decoding result.
+        - (3) nbest. Extract n paths from the decoding lattice; the path
+          with the highest score is the decoding result.
+        - (4) nbest-rescoring. Extract n paths from the decoding lattice,
+          rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
+          the highest score is the decoding result.
+        - (5) whole-lattice-rescoring. Rescore the decoding lattice with an
+          n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
+          is the decoding result.
+          you have trained an RNN LM using ./rnn_lm/train.py
+        - (6) nbest-oracle. Its WER is the lower bound of any n-best
+          rescoring method can achieve. Useful for debugging n-best
+          rescoring method.
+        """,
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=100,
+        help="""Number of paths for n-best based decoding method.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, and nbest-oracle
+        """,
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=0.5,
+        help="""The scale to be applied to `lattice.scores`.
+        It's needed if you use any kinds of n-best based rescoring.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, and nbest-oracle
+        A smaller value results in more unique paths.
+        """,
+    )
+
+    parser.add_argument(
+        "--hlg-scale",
+        type=float,
+        default=0.8,
+        help="""The scale to be applied to `hlg.scores`.
+        """,
+    )
+
+    parser.add_argument(
+        "--lm-dir",
+        type=str,
+        default="data/lm",
+        help="""The n-gram LM dir.
+        It should contain either G_4_gram.pt or G_4_gram.fst.txt
+        """,
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def get_decoding_params() -> AttributeDict:
+    """Parameters for decoding."""
+    params = AttributeDict(
+        {
+            "frame_shift_ms": 10,
+            "search_beam": 20,
+            "output_beam": 8,
+            "min_active_states": 30,
+            "max_active_states": 10000,
+            "use_double_scores": True,
+        }
+    )
+    return params
+
+
+def decode_one_batch(
+    params: AttributeDict,
+    model: nn.Module,
+    HLG: Optional[k2.Fsa],
+    H: Optional[k2.Fsa],
+    bpe_model: Optional[spm.SentencePieceProcessor],
+    batch: dict,
+    word_table: k2.SymbolTable,
+    G: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+    """Decode one batch and return the result in a dict. The dict has the
+    following format:
+    - key: It indicates the setting used for decoding. For example,
+           if no rescoring is used, the key is the string `no_rescore`.
+           If LM rescoring is used, the key is the string `lm_scale_xxx`,
+           where `xxx` is the value of `lm_scale`. An example key is
+           `lm_scale_0.7`
+    - value: It contains the decoding result. `len(value)` equals to
+             batch size. `value[i]` is the decoding result for the i-th
+             utterance in the given batch.
+
+    Args:
+      params:
+        It's the return value of :func:`get_params`.
+
+        - params.decoding_method is "1best", it uses 1best decoding without LM rescoring.
+        - params.decoding_method is "nbest", it uses nbest decoding without LM rescoring.
+        - params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring.
+        - params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM
+          rescoring.
+
+      model:
+        The neural model.
+      HLG:
+        The decoding graph. Used only when params.decoding_method is NOT ctc-decoding.
+      H:
+        The ctc topo. Used only when params.decoding_method is ctc-decoding.
+      bpe_model:
+        The BPE model. Used only when params.decoding_method is ctc-decoding.
+      batch:
+        It is the return value from iterating
+        `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+        for the format of the `batch`.
+      word_table:
+        The word symbol table.
+      G:
+        An LM. It is not None when params.decoding_method is "nbest-rescoring"
+        or "whole-lattice-rescoring". In general, the G in HLG
+        is a 3-gram LM, while this G is a 4-gram LM.
+    Returns:
+      Return the decoding result. See above description for the format of
+      the returned dict. Note: If it decodes to nothing, then return None.
+    """
+    if HLG is not None:
+        device = HLG.device
+    else:
+        device = H.device
+    feature = batch["inputs"]
+    assert feature.ndim == 3
+    feature = feature.to(device)
+    # at entry, feature is (N, T, C)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    encoder_out, encoder_out_lens = model.encoder(feature, feature_lens)
+    nnet_output = model.ctc_output(encoder_out)
+    # nnet_output is (N, T, C)
+
+    supervision_segments = torch.stack(
+        (
+            supervisions["sequence_idx"],
+            supervisions["start_frame"] // params.subsampling_factor,
+            supervisions["num_frames"] // params.subsampling_factor,
+        ),
+        1,
+    ).to(torch.int32)
+
+    if H is None:
+        assert HLG is not None
+        decoding_graph = HLG
+    else:
+        assert HLG is None
+        assert bpe_model is not None
+        decoding_graph = H
+
+    lattice = get_lattice(
+        nnet_output=nnet_output,
+        decoding_graph=decoding_graph,
+        supervision_segments=supervision_segments,
+        search_beam=params.search_beam,
+        output_beam=params.output_beam,
+        min_active_states=params.min_active_states,
+        max_active_states=params.max_active_states,
+        subsampling_factor=params.subsampling_factor,
+    )
+
+    if params.decoding_method == "ctc-decoding":
+        best_path = one_best_decoding(
+            lattice=lattice, use_double_scores=params.use_double_scores
+        )
+        # Note: `best_path.aux_labels` contains token IDs, not word IDs
+        # since we are using H, not HLG here.
+        #
+        # token_ids is a lit-of-list of IDs
+        token_ids = get_texts(best_path)
+
+        # hyps is a list of str, e.g., ['xxx yyy zzz', ...]
+        hyps = bpe_model.decode(token_ids)
+
+        # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
+        hyps = [s.split() for s in hyps]
+        key = "ctc-decoding"
+        return {key: hyps}
+
+    if params.decoding_method == "nbest-oracle":
+        # Note: You can also pass rescored lattices to it.
+        # We choose the HLG decoded lattice for speed reasons
+        # as HLG decoding is faster and the oracle WER
+        # is only slightly worse than that of rescored lattices.
+        best_path = nbest_oracle(
+            lattice=lattice,
+            num_paths=params.num_paths,
+            ref_texts=supervisions["text"],
+            word_table=word_table,
+            nbest_scale=params.nbest_scale,
+            oov="",
+        )
+        hyps = get_texts(best_path)
+        hyps = [[word_table[i] for i in ids] for ids in hyps]
+        key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}"  # noqa
+        return {key: hyps}
+
+    if params.decoding_method in ["1best", "nbest"]:
+        if params.decoding_method == "1best":
+            best_path = one_best_decoding(
+                lattice=lattice, use_double_scores=params.use_double_scores
+            )
+            key = "no_rescore"
+        else:
+            best_path = nbest_decoding(
+                lattice=lattice,
+                num_paths=params.num_paths,
+                use_double_scores=params.use_double_scores,
+                nbest_scale=params.nbest_scale,
+            )
+            key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}"  # noqa
+
+        hyps = get_texts(best_path)
+        hyps = [[word_table[i] for i in ids] for ids in hyps]
+        return {key: hyps}
+
+    assert params.decoding_method in [
+        "nbest-rescoring",
+        "whole-lattice-rescoring",
+    ]
+
+    lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
+    lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
+    lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
+
+    if params.decoding_method == "nbest-rescoring":
+        best_path_dict = rescore_with_n_best_list(
+            lattice=lattice,
+            G=G,
+            num_paths=params.num_paths,
+            lm_scale_list=lm_scale_list,
+            nbest_scale=params.nbest_scale,
+        )
+    elif params.decoding_method == "whole-lattice-rescoring":
+        best_path_dict = rescore_with_whole_lattice(
+            lattice=lattice,
+            G_with_epsilon_loops=G,
+            lm_scale_list=lm_scale_list,
+        )
+    else:
+        assert False, f"Unsupported decoding method: {params.decoding_method}"
+
+    ans = dict()
+    if best_path_dict is not None:
+        for lm_scale_str, best_path in best_path_dict.items():
+            hyps = get_texts(best_path)
+            hyps = [[word_table[i] for i in ids] for ids in hyps]
+            ans[lm_scale_str] = hyps
+    else:
+        ans = None
+    return ans
+
+
+def decode_dataset(
+    dl: torch.utils.data.DataLoader,
+    params: AttributeDict,
+    model: nn.Module,
+    HLG: Optional[k2.Fsa],
+    H: Optional[k2.Fsa],
+    bpe_model: Optional[spm.SentencePieceProcessor],
+    word_table: k2.SymbolTable,
+    G: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+    """Decode dataset.
+
+    Args:
+      dl:
+        PyTorch's dataloader containing the dataset to decode.
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The neural model.
+      HLG:
+        The decoding graph. Used only when params.decoding_method is NOT ctc-decoding.
+      H:
+        The ctc topo. Used only when params.decoding_method is ctc-decoding.
+      bpe_model:
+        The BPE model. Used only when params.decoding_method is ctc-decoding.
+      word_table:
+        It is the word symbol table.
+      G:
+        An LM. It is not None when params.decoding_method is "nbest-rescoring"
+        or "whole-lattice-rescoring". In general, the G in HLG
+        is a 3-gram LM, while this G is a 4-gram LM.
+    Returns:
+      Return a dict, whose key may be "no-rescore" if no LM rescoring
+      is used, or it may be "lm_scale_0.7" if LM rescoring is used.
+      Its value is a list of tuples. Each tuple contains two elements:
+      The first is the reference transcript, and the second is the
+      predicted result.
+    """
+    num_cuts = 0
+
+    try:
+        num_batches = len(dl)
+    except TypeError:
+        num_batches = "?"
+
+    results = defaultdict(list)
+    for batch_idx, batch in enumerate(dl):
+        texts = batch["supervisions"]["text"]
+        cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+        hyps_dict = decode_one_batch(
+            params=params,
+            model=model,
+            HLG=HLG,
+            H=H,
+            bpe_model=bpe_model,
+            batch=batch,
+            word_table=word_table,
+            G=G,
+        )
+
+        for name, hyps in hyps_dict.items():
+            this_batch = []
+            assert len(hyps) == len(texts)
+            for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+                ref_words = ref_text.split()
+                this_batch.append((cut_id, ref_words, hyp_words))
+
+            results[name].extend(this_batch)
+
+        num_cuts += len(texts)
+
+        if batch_idx % 100 == 0:
+            batch_str = f"{batch_idx}/{num_batches}"
+
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+    return results
+
+
+def save_results(
+    params: AttributeDict,
+    test_set_name: str,
+    results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+    test_set_wers = dict()
+    for key, results in results_dict.items():
+        recog_path = (
+            params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+        )
+        results = sorted(results)
+        store_transcripts(filename=recog_path, texts=results)
+        logging.info(f"The transcripts are stored in {recog_path}")
+
+        # The following prints out WERs, per-word error statistics and aligned
+        # ref/hyp pairs.
+        errs_filename = (
+            params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+        )
+        with open(errs_filename, "w") as f:
+            wer = write_error_stats(f, f"{test_set_name}-{key}", results)
+            test_set_wers[key] = wer
+
+        logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+    test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+    errs_info = (
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+    )
+    with open(errs_info, "w") as f:
+        print("settings\tWER", file=f)
+        for key, val in test_set_wers:
+            print("{}\t{}".format(key, val), file=f)
+
+    s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+    note = "\tbest for {}".format(test_set_name)
+    for key, val in test_set_wers:
+        s += "{}\t{}{}\n".format(key, val, note)
+        note = ""
+    logging.info(s)
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    LibriSpeechAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+    args.lang_dir = Path(args.lang_dir)
+    args.lm_dir = Path(args.lm_dir)
+
+    params = get_params()
+    # add decoding params
+    params.update(get_decoding_params())
+    params.update(vars(args))
+
+    assert params.decoding_method in (
+        "ctc-decoding",
+        "1best",
+        "nbest",
+        "nbest-rescoring",
+        "whole-lattice-rescoring",
+        "nbest-oracle",
+    )
+    params.res_dir = params.exp_dir / params.decoding_method
+
+    if params.iter > 0:
+        params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+    else:
+        params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+    if params.use_averaged_model:
+        params.suffix += "-use-averaged-model"
+
+    setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+    logging.info("Decoding started")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"Device: {device}")
+    logging.info(params)
+
+    lexicon = Lexicon(params.lang_dir)
+    max_token_id = max(lexicon.tokens)
+    num_classes = max_token_id + 1  # +1 for the blank
+
+    params.vocab_size = num_classes
+    #  and  are defined in local/train_bpe_model.py
+    params.blank_id = 0
+
+    if params.decoding_method == "ctc-decoding":
+        HLG = None
+        H = k2.ctc_topo(
+            max_token=max_token_id,
+            modified=False,
+            device=device,
+        )
+        bpe_model = spm.SentencePieceProcessor()
+        bpe_model.load(str(params.lang_dir / "bpe.model"))
+    else:
+        H = None
+        bpe_model = None
+        HLG = k2.Fsa.from_dict(
+            torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
+        )
+        assert HLG.requires_grad is False
+
+        HLG.scores *= params.hlg_scale
+        if not hasattr(HLG, "lm_scores"):
+            HLG.lm_scores = HLG.scores.clone()
+
+    if params.decoding_method in (
+        "nbest-rescoring",
+        "whole-lattice-rescoring",
+    ):
+        if not (params.lm_dir / "G_4_gram.pt").is_file():
+            logging.info("Loading G_4_gram.fst.txt")
+            logging.warning("It may take 8 minutes.")
+            with open(params.lm_dir / "G_4_gram.fst.txt") as f:
+                first_word_disambig_id = lexicon.word_table["#0"]
+
+                G = k2.Fsa.from_openfst(f.read(), acceptor=False)
+                # G.aux_labels is not needed in later computations, so
+                # remove it here.
+                del G.aux_labels
+                # CAUTION: The following line is crucial.
+                # Arcs entering the back-off state have label equal to #0.
+                # We have to change it to 0 here.
+                G.labels[G.labels >= first_word_disambig_id] = 0
+                # See https://github.com/k2-fsa/k2/issues/874
+                # for why we need to set G.properties to None
+                G.__dict__["_properties"] = None
+                G = k2.Fsa.from_fsas([G]).to(device)
+                G = k2.arc_sort(G)
+                # Save a dummy value so that it can be loaded in C++.
+                # See https://github.com/pytorch/pytorch/issues/67902
+                # for why we need to do this.
+                G.dummy = 1
+
+                torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
+        else:
+            logging.info("Loading pre-compiled G_4_gram.pt")
+            d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
+            G = k2.Fsa.from_dict(d)
+
+        if params.decoding_method == "whole-lattice-rescoring":
+            # Add epsilon self-loops to G as we will compose
+            # it with the whole lattice later
+            G = k2.add_epsilon_self_loops(G)
+            G = k2.arc_sort(G)
+            G = G.to(device)
+
+        # G.lm_scores is used to replace HLG.lm_scores during
+        # LM rescoring.
+        G.lm_scores = G.scores.clone()
+    else:
+        G = None
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to(device)
+    model.eval()
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    # we need cut ids to display recognition results.
+    args.return_cuts = True
+    librispeech = LibriSpeechAsrDataModule(args)
+
+    test_clean_cuts = librispeech.test_clean_cuts()
+    test_other_cuts = librispeech.test_other_cuts()
+
+    test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
+    test_other_dl = librispeech.test_dataloaders(test_other_cuts)
+
+    test_sets = ["test-clean", "test-other"]
+    test_dl = [test_clean_dl, test_other_dl]
+
+    for test_set, test_dl in zip(test_sets, test_dl):
+        results_dict = decode_dataset(
+            dl=test_dl,
+            params=params,
+            model=model,
+            HLG=HLG,
+            H=H,
+            bpe_model=bpe_model,
+            word_table=lexicon.word_table,
+            G=G,
+        )
+
+        save_results(
+            params=params,
+            test_set_name=test_set,
+            results_dict=results_dict,
+        )
+
+    logging.info("Done!")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py
new file mode 100755
index 000000000..32a9b6bb2
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py
@@ -0,0 +1,841 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
+#                                                 Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./pruned_transducer_stateless7_ctc/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+    --max-duration 600 \
+    --decoding-method greedy_search
+
+(2) beam search (not recommended)
+./pruned_transducer_stateless7_ctc/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+    --max-duration 600 \
+    --decoding-method beam_search \
+    --beam-size 4
+
+(3) modified beam search
+./pruned_transducer_stateless7_ctc/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+    --max-duration 600 \
+    --decoding-method modified_beam_search \
+    --beam-size 4
+
+(4) fast beam search (one best)
+./pruned_transducer_stateless7_ctc/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+    --max-duration 600 \
+    --decoding-method fast_beam_search \
+    --beam 20.0 \
+    --max-contexts 8 \
+    --max-states 64
+
+(5) fast beam search (nbest)
+./pruned_transducer_stateless7_ctc/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+    --max-duration 600 \
+    --decoding-method fast_beam_search_nbest \
+    --beam 20.0 \
+    --max-contexts 8 \
+    --max-states 64 \
+    --num-paths 200 \
+    --nbest-scale 0.5
+
+(6) fast beam search (nbest oracle WER)
+./pruned_transducer_stateless7_ctc/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+    --max-duration 600 \
+    --decoding-method fast_beam_search_nbest_oracle \
+    --beam 20.0 \
+    --max-contexts 8 \
+    --max-states 64 \
+    --num-paths 200 \
+    --nbest-scale 0.5
+
+(7) fast beam search (with LG)
+./pruned_transducer_stateless7_ctc/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+    --max-duration 600 \
+    --decoding-method fast_beam_search_nbest_LG \
+    --beam 20.0 \
+    --max-contexts 8 \
+    --max-states 64
+"""
+
+
+import argparse
+import logging
+import math
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from beam_search import (
+    beam_search,
+    fast_beam_search_nbest,
+    fast_beam_search_nbest_LG,
+    fast_beam_search_nbest_oracle,
+    fast_beam_search_one_best,
+    greedy_search,
+    greedy_search_batch,
+    modified_beam_search,
+)
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+    AttributeDict,
+    setup_logger,
+    store_transcripts,
+    str2bool,
+    write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=30,
+        help="""It specifies the checkpoint to use for decoding.
+        Note: Epoch counts from 1.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=9,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=True,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless7_ctc/exp",
+        help="The experiment dir",
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_500/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=Path,
+        default="data/lang_bpe_500",
+        help="The lang dir containing word table and LG graph",
+    )
+
+    parser.add_argument(
+        "--decoding-method",
+        type=str,
+        default="greedy_search",
+        help="""Possible values are:
+          - greedy_search
+          - beam_search
+          - modified_beam_search
+          - fast_beam_search
+          - fast_beam_search_nbest
+          - fast_beam_search_nbest_oracle
+          - fast_beam_search_nbest_LG
+        If you use fast_beam_search_nbest_LG, you have to specify
+        `--lang-dir`, which should contain `LG.pt`.
+        """,
+    )
+
+    parser.add_argument(
+        "--beam-size",
+        type=int,
+        default=4,
+        help="""An integer indicating how many candidates we will keep for each
+        frame. Used only when --decoding-method is beam_search or
+        modified_beam_search.""",
+    )
+
+    parser.add_argument(
+        "--beam",
+        type=float,
+        default=20.0,
+        help="""A floating point value to calculate the cutoff score during beam
+        search (i.e., `cutoff = max-score - beam`), which is the same as the
+        `beam` in Kaldi.
+        Used only when --decoding-method is fast_beam_search,
+        fast_beam_search_nbest, fast_beam_search_nbest_LG,
+        and fast_beam_search_nbest_oracle
+        """,
+    )
+
+    parser.add_argument(
+        "--ngram-lm-scale",
+        type=float,
+        default=0.01,
+        help="""
+        Used only when --decoding_method is fast_beam_search_nbest_LG.
+        It specifies the scale for n-gram LM scores.
+        """,
+    )
+
+    parser.add_argument(
+        "--max-contexts",
+        type=int,
+        default=8,
+        help="""Used only when --decoding-method is
+        fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+        and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--max-states",
+        type=int,
+        default=64,
+        help="""Used only when --decoding-method is
+        fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+        and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+    parser.add_argument(
+        "--max-sym-per-frame",
+        type=int,
+        default=1,
+        help="""Maximum number of symbols per frame.
+        Used only when --decoding_method is greedy_search""",
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=200,
+        help="""Number of paths for nbest decoding.
+        Used only when the decoding method is fast_beam_search_nbest,
+        fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=0.5,
+        help="""Scale applied to lattice scores when computing nbest paths.
+        Used only when the decoding method is fast_beam_search_nbest,
+        fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--simulate-streaming",
+        type=str2bool,
+        default=False,
+        help="""Whether to simulate streaming in decoding, this is a good way to
+        test a streaming model.
+        """,
+    )
+
+    parser.add_argument(
+        "--decode-chunk-size",
+        type=int,
+        default=16,
+        help="The chunk size for decoding (in frames after subsampling)",
+    )
+
+    parser.add_argument(
+        "--left-context",
+        type=int,
+        default=64,
+        help="left context can be seen during decoding (in frames after subsampling)",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def decode_one_batch(
+    params: AttributeDict,
+    model: nn.Module,
+    sp: spm.SentencePieceProcessor,
+    batch: dict,
+    word_table: Optional[k2.SymbolTable] = None,
+    decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+    """Decode one batch and return the result in a dict. The dict has the
+    following format:
+
+        - key: It indicates the setting used for decoding. For example,
+               if greedy_search is used, it would be "greedy_search"
+               If beam search with a beam size of 7 is used, it would be
+               "beam_7"
+        - value: It contains the decoding result. `len(value)` equals to
+                 batch size. `value[i]` is the decoding result for the i-th
+                 utterance in the given batch.
+    Args:
+      params:
+        It's the return value of :func:`get_params`.
+      model:
+        The neural model.
+      sp:
+        The BPE model.
+      batch:
+        It is the return value from iterating
+        `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+        for the format of the `batch`.
+      word_table:
+        The word symbol table.
+      decoding_graph:
+        The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+        only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
+        fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+    Returns:
+      Return the decoding result. See above description for the format of
+      the returned dict.
+    """
+    device = next(model.parameters()).device
+    feature = batch["inputs"]
+    assert feature.ndim == 3
+
+    feature = feature.to(device)
+    # at entry, feature is (N, T, C)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    if params.simulate_streaming:
+        feature_lens += params.left_context
+        feature = torch.nn.functional.pad(
+            feature,
+            pad=(0, 0, 0, params.left_context),
+            value=LOG_EPS,
+        )
+        encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
+            x=feature,
+            x_lens=feature_lens,
+            chunk_size=params.decode_chunk_size,
+            left_context=params.left_context,
+            simulate_streaming=True,
+        )
+    else:
+        encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+
+    hyps = []
+
+    if params.decoding_method == "fast_beam_search":
+        hyp_tokens = fast_beam_search_one_best(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "fast_beam_search_nbest_LG":
+        hyp_tokens = fast_beam_search_nbest_LG(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+            num_paths=params.num_paths,
+            nbest_scale=params.nbest_scale,
+        )
+        for hyp in hyp_tokens:
+            hyps.append([word_table[i] for i in hyp])
+    elif params.decoding_method == "fast_beam_search_nbest":
+        hyp_tokens = fast_beam_search_nbest(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+            num_paths=params.num_paths,
+            nbest_scale=params.nbest_scale,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "fast_beam_search_nbest_oracle":
+        hyp_tokens = fast_beam_search_nbest_oracle(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+            num_paths=params.num_paths,
+            ref_texts=sp.encode(supervisions["text"]),
+            nbest_scale=params.nbest_scale,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+        hyp_tokens = greedy_search_batch(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "modified_beam_search":
+        hyp_tokens = modified_beam_search(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam_size,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    else:
+        batch_size = encoder_out.size(0)
+
+        for i in range(batch_size):
+            # fmt: off
+            encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+            # fmt: on
+            if params.decoding_method == "greedy_search":
+                hyp = greedy_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    max_sym_per_frame=params.max_sym_per_frame,
+                )
+            elif params.decoding_method == "beam_search":
+                hyp = beam_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    beam=params.beam_size,
+                )
+            else:
+                raise ValueError(
+                    f"Unsupported decoding method: {params.decoding_method}"
+                )
+            hyps.append(sp.decode(hyp).split())
+
+    if params.decoding_method == "greedy_search":
+        return {"greedy_search": hyps}
+    elif "fast_beam_search" in params.decoding_method:
+        key = f"beam_{params.beam}_"
+        key += f"max_contexts_{params.max_contexts}_"
+        key += f"max_states_{params.max_states}"
+        if "nbest" in params.decoding_method:
+            key += f"_num_paths_{params.num_paths}_"
+            key += f"nbest_scale_{params.nbest_scale}"
+            if "LG" in params.decoding_method:
+                key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+
+        return {key: hyps}
+    else:
+        return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+    dl: torch.utils.data.DataLoader,
+    params: AttributeDict,
+    model: nn.Module,
+    sp: spm.SentencePieceProcessor,
+    word_table: Optional[k2.SymbolTable] = None,
+    decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+    """Decode dataset.
+
+    Args:
+      dl:
+        PyTorch's dataloader containing the dataset to decode.
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The neural model.
+      sp:
+        The BPE model.
+      word_table:
+        The word symbol table.
+      decoding_graph:
+        The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+        only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
+        fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+    Returns:
+      Return a dict, whose key may be "greedy_search" if greedy search
+      is used, or it may be "beam_7" if beam size of 7 is used.
+      Its value is a list of tuples. Each tuple contains two elements:
+      The first is the reference transcript, and the second is the
+      predicted result.
+    """
+    num_cuts = 0
+
+    try:
+        num_batches = len(dl)
+    except TypeError:
+        num_batches = "?"
+
+    if params.decoding_method == "greedy_search":
+        log_interval = 50
+    else:
+        log_interval = 20
+
+    results = defaultdict(list)
+    for batch_idx, batch in enumerate(dl):
+        texts = batch["supervisions"]["text"]
+        cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+        hyps_dict = decode_one_batch(
+            params=params,
+            model=model,
+            sp=sp,
+            decoding_graph=decoding_graph,
+            word_table=word_table,
+            batch=batch,
+        )
+
+        for name, hyps in hyps_dict.items():
+            this_batch = []
+            assert len(hyps) == len(texts)
+            for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+                ref_words = ref_text.split()
+                this_batch.append((cut_id, ref_words, hyp_words))
+
+            results[name].extend(this_batch)
+
+        num_cuts += len(texts)
+
+        if batch_idx % log_interval == 0:
+            batch_str = f"{batch_idx}/{num_batches}"
+
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+    return results
+
+
+def save_results(
+    params: AttributeDict,
+    test_set_name: str,
+    results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+    test_set_wers = dict()
+    for key, results in results_dict.items():
+        recog_path = (
+            params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+        )
+        results = sorted(results)
+        store_transcripts(filename=recog_path, texts=results)
+        logging.info(f"The transcripts are stored in {recog_path}")
+
+        # The following prints out WERs, per-word error statistics and aligned
+        # ref/hyp pairs.
+        errs_filename = (
+            params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+        )
+        with open(errs_filename, "w") as f:
+            wer = write_error_stats(
+                f, f"{test_set_name}-{key}", results, enable_log=True
+            )
+            test_set_wers[key] = wer
+
+        logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+    test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+    errs_info = (
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+    )
+    with open(errs_info, "w") as f:
+        print("settings\tWER", file=f)
+        for key, val in test_set_wers:
+            print("{}\t{}".format(key, val), file=f)
+
+    s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+    note = "\tbest for {}".format(test_set_name)
+    for key, val in test_set_wers:
+        s += "{}\t{}{}\n".format(key, val, note)
+        note = ""
+    logging.info(s)
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    LibriSpeechAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    params = get_params()
+    params.update(vars(args))
+
+    assert params.decoding_method in (
+        "greedy_search",
+        "beam_search",
+        "fast_beam_search",
+        "fast_beam_search_nbest",
+        "fast_beam_search_nbest_LG",
+        "fast_beam_search_nbest_oracle",
+        "modified_beam_search",
+    )
+    params.res_dir = params.exp_dir / params.decoding_method
+
+    if params.iter > 0:
+        params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+    else:
+        params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+    if params.simulate_streaming:
+        params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
+        params.suffix += f"-left-context-{params.left_context}"
+
+    if "fast_beam_search" in params.decoding_method:
+        params.suffix += f"-beam-{params.beam}"
+        params.suffix += f"-max-contexts-{params.max_contexts}"
+        params.suffix += f"-max-states-{params.max_states}"
+        if "nbest" in params.decoding_method:
+            params.suffix += f"-nbest-scale-{params.nbest_scale}"
+            params.suffix += f"-num-paths-{params.num_paths}"
+            if "LG" in params.decoding_method:
+                params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+    elif "beam_search" in params.decoding_method:
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+    else:
+        params.suffix += f"-context-{params.context_size}"
+        params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+    if params.use_averaged_model:
+        params.suffix += "-use-averaged-model"
+
+    setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+    logging.info("Decoding started")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"Device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  and  are defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.unk_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    if params.simulate_streaming:
+        assert (
+            params.causal_convolution
+        ), "Decoding in streaming requires causal convolution"
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to(device)
+    model.eval()
+
+    if "fast_beam_search" in params.decoding_method:
+        if params.decoding_method == "fast_beam_search_nbest_LG":
+            lexicon = Lexicon(params.lang_dir)
+            word_table = lexicon.word_table
+            lg_filename = params.lang_dir / "LG.pt"
+            logging.info(f"Loading {lg_filename}")
+            decoding_graph = k2.Fsa.from_dict(
+                torch.load(lg_filename, map_location=device)
+            )
+            decoding_graph.scores *= params.ngram_lm_scale
+        else:
+            word_table = None
+            decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+    else:
+        decoding_graph = None
+        word_table = None
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    # we need cut ids to display recognition results.
+    args.return_cuts = True
+    librispeech = LibriSpeechAsrDataModule(args)
+
+    test_clean_cuts = librispeech.test_clean_cuts()
+    test_other_cuts = librispeech.test_other_cuts()
+
+    test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
+    test_other_dl = librispeech.test_dataloaders(test_other_cuts)
+
+    test_sets = ["test-clean", "test-other"]
+    test_dl = [test_clean_dl, test_other_dl]
+
+    for test_set, test_dl in zip(test_sets, test_dl):
+        results_dict = decode_dataset(
+            dl=test_dl,
+            params=params,
+            model=model,
+            sp=sp,
+            word_table=word_table,
+            decoding_graph=decoding_graph,
+        )
+
+        save_results(
+            params=params,
+            test_set_name=test_set,
+            results_dict=results_dict,
+        )
+
+    logging.info("Done!")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decoder.py
new file mode 120000
index 000000000..33944d0d2
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decoder.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7/decoder.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/encoder_interface.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/encoder_interface.py
new file mode 120000
index 000000000..b9aa0ae08
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/encoder_interface.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/encoder_interface.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py
new file mode 100755
index 000000000..59a393739
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py
@@ -0,0 +1,320 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script converts several saved checkpoints
+# to a single one using model averaging.
+"""
+
+Usage:
+
+(1) Export to torchscript model using torch.jit.script()
+
+./pruned_transducer_stateless7_ctc/export.py \
+  --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 30 \
+  --avg 9 \
+  --jit 1
+
+It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
+load it by `torch.jit.load("cpu_jit.pt")`.
+
+Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
+are on CPU. You can use `to("cuda")` to move them to a CUDA device.
+
+Check
+https://github.com/k2-fsa/sherpa
+for how to use the exported models outside of icefall.
+
+(2) Export `model.state_dict()`
+
+./pruned_transducer_stateless7_ctc/export.py \
+  --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 20 \
+  --avg 10
+
+It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
+load it by `icefall.checkpoint.load_checkpoint()`.
+
+To use the generated file with `pruned_transducer_stateless7_ctc/decode.py`,
+you can do:
+
+    cd /path/to/exp_dir
+    ln -s pretrained.pt epoch-9999.pt
+
+    cd /path/to/egs/librispeech/ASR
+    ./pruned_transducer_stateless7_ctc/decode.py \
+        --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+        --epoch 9999 \
+        --avg 1 \
+        --max-duration 600 \
+        --decoding-method greedy_search \
+        --bpe-model data/lang_bpe_500/bpe.model
+
+Check ./pretrained.py for its usage.
+
+Note: If you don't want to train a model from scratch, we have
+provided one for you. You can get it at
+
+https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
+
+with the following commands:
+
+    sudo apt-get install git-lfs
+    git lfs install
+    git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
+    # You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import sentencepiece as spm
+import torch
+from scaling_converter import convert_scaled_to_non_scaled
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.utils import str2bool
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=30,
+        help="""It specifies the checkpoint to use for decoding.
+        Note: Epoch counts from 1.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=9,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=True,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless7/exp",
+        help="""It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_500/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--jit",
+        type=str2bool,
+        default=False,
+        help="""True to save a model after applying torch.jit.script.
+        It will generate a file named cpu_jit.pt
+
+        Check ./jit_pretrained.py for how to use it.
+        """,
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+@torch.no_grad()
+def main():
+    args = get_parser().parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    params = get_params()
+    params.update(vars(args))
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    model.to(device)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to("cpu")
+    model.eval()
+
+    if params.jit is True:
+        convert_scaled_to_non_scaled(model, inplace=True)
+        logging.info("Using torch.jit.script()")
+        # We won't use the forward() method of the model in C++, so just ignore
+        # it here.
+        # Otherwise, one of its arguments is a ragged tensor and is not
+        # torch scriptabe.
+        model.__class__.forward = torch.jit.ignore(model.__class__.forward)
+        logging.info("Using torch.jit.script")
+        model = torch.jit.script(model)
+        filename = params.exp_dir / "cpu_jit.pt"
+        model.save(str(filename))
+        logging.info(f"Saved to {filename}")
+    else:
+        logging.info("Not using torchscript. Export model.state_dict()")
+        # Save it using a format so that it can be loaded
+        # by :func:`load_checkpoint`
+        filename = params.exp_dir / "pretrained.pt"
+        torch.save({"model": model.state_dict()}, str(filename))
+        logging.info(f"Saved to {filename}")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained.py
new file mode 100755
index 000000000..280b95984
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained.py
@@ -0,0 +1,271 @@
+#!/usr/bin/env python3
+# Copyright      2022  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script loads torchscript models, exported by `torch.jit.script()`
+and uses them to decode waves.
+You can use the following command to get the exported models:
+
+./pruned_transducer_stateless7_ctc/export.py \
+  --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 20 \
+  --avg 10 \
+  --jit 1
+
+Usage of this script:
+
+./pruned_transducer_stateless7_ctc/jit_pretrained.py \
+  --nn-model-filename ./pruned_transducer_stateless7_ctc/exp/cpu_jit.pt \
+  /path/to/foo.wav \
+  /path/to/bar.wav
+"""
+
+import argparse
+import logging
+import math
+from typing import List
+
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from torch.nn.utils.rnn import pad_sequence
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--nn-model-filename",
+        type=str,
+        required=True,
+        help="Path to the torchscript model cpu_jit.pt",
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        help="""Path to bpe.model.""",
+    )
+
+    parser.add_argument(
+        "sound_files",
+        type=str,
+        nargs="+",
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
+    )
+
+    return parser
+
+
+def read_sound_files(
+    filenames: List[str], expected_sample_rate: float = 16000
+) -> List[torch.Tensor]:
+    """Read a list of sound files into a list 1-D float32 torch tensors.
+    Args:
+      filenames:
+        A list of sound filenames.
+      expected_sample_rate:
+        The expected sample rate of the sound files.
+    Returns:
+      Return a list of 1-D float32 torch tensors.
+    """
+    ans = []
+    for f in filenames:
+        wave, sample_rate = torchaudio.load(f)
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        # We use only the first channel
+        ans.append(wave[0])
+    return ans
+
+
+def greedy_search(
+    model: torch.jit.ScriptModule,
+    encoder_out: torch.Tensor,
+    encoder_out_lens: torch.Tensor,
+) -> List[List[int]]:
+    """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
+    Args:
+      model:
+        The transducer model.
+      encoder_out:
+        A 3-D tensor of shape (N, T, C)
+      encoder_out_lens:
+        A 1-D tensor of shape (N,).
+    Returns:
+      Return the decoded results for each utterance.
+    """
+    assert encoder_out.ndim == 3
+    assert encoder_out.size(0) >= 1, encoder_out.size(0)
+
+    packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
+        input=encoder_out,
+        lengths=encoder_out_lens.cpu(),
+        batch_first=True,
+        enforce_sorted=False,
+    )
+
+    device = encoder_out.device
+    blank_id = 0  # hard-code to 0
+
+    batch_size_list = packed_encoder_out.batch_sizes.tolist()
+    N = encoder_out.size(0)
+
+    assert torch.all(encoder_out_lens > 0), encoder_out_lens
+    assert N == batch_size_list[0], (N, batch_size_list)
+
+    context_size = model.decoder.context_size
+    hyps = [[blank_id] * context_size for _ in range(N)]
+
+    decoder_input = torch.tensor(
+        hyps,
+        device=device,
+        dtype=torch.int64,
+    )  # (N, context_size)
+
+    decoder_out = model.decoder(
+        decoder_input,
+        need_pad=torch.tensor([False]),
+    ).squeeze(1)
+
+    offset = 0
+    for batch_size in batch_size_list:
+        start = offset
+        end = offset + batch_size
+        current_encoder_out = packed_encoder_out.data[start:end]
+        current_encoder_out = current_encoder_out
+        # current_encoder_out's shape: (batch_size, encoder_out_dim)
+        offset = end
+
+        decoder_out = decoder_out[:batch_size]
+
+        logits = model.joiner(
+            current_encoder_out,
+            decoder_out,
+        )
+        # logits'shape (batch_size, vocab_size)
+
+        assert logits.ndim == 2, logits.shape
+        y = logits.argmax(dim=1).tolist()
+        emitted = False
+        for i, v in enumerate(y):
+            if v != blank_id:
+                hyps[i].append(v)
+                emitted = True
+        if emitted:
+            # update decoder output
+            decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
+            decoder_input = torch.tensor(
+                decoder_input,
+                device=device,
+                dtype=torch.int64,
+            )
+            decoder_out = model.decoder(
+                decoder_input,
+                need_pad=torch.tensor([False]),
+            )
+            decoder_out = decoder_out.squeeze(1)
+
+    sorted_ans = [h[context_size:] for h in hyps]
+    ans = []
+    unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
+    for i in range(N):
+        ans.append(sorted_ans[unsorted_indices[i]])
+
+    return ans
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+    logging.info(vars(args))
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    model = torch.jit.load(args.nn_model_filename)
+
+    model.eval()
+
+    model.to(device)
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(args.bpe_model)
+
+    logging.info("Constructing Fbank computer")
+    opts = kaldifeat.FbankOptions()
+    opts.device = device
+    opts.frame_opts.dither = 0
+    opts.frame_opts.snip_edges = False
+    opts.frame_opts.samp_freq = 16000
+    opts.mel_opts.num_bins = 80
+
+    fbank = kaldifeat.Fbank(opts)
+
+    logging.info(f"Reading sound files: {args.sound_files}")
+    waves = read_sound_files(
+        filenames=args.sound_files,
+    )
+    waves = [w.to(device) for w in waves]
+
+    logging.info("Decoding started")
+    features = fbank(waves)
+    feature_lengths = [f.size(0) for f in features]
+
+    features = pad_sequence(
+        features,
+        batch_first=True,
+        padding_value=math.log(1e-10),
+    )
+
+    feature_lengths = torch.tensor(feature_lengths, device=device)
+
+    encoder_out, encoder_out_lens = model.encoder(
+        x=features,
+        x_lens=feature_lengths,
+    )
+
+    hyps = greedy_search(
+        model=model,
+        encoder_out=encoder_out,
+        encoder_out_lens=encoder_out_lens,
+    )
+    s = "\n"
+    for filename, hyp in zip(args.sound_files, hyps):
+        words = sp.decode(hyp)
+        s += f"{filename}:\n{words}\n\n"
+    logging.info(s)
+
+    logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py
new file mode 100755
index 000000000..d3343d34a
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py
@@ -0,0 +1,423 @@
+#!/usr/bin/env python3
+# Copyright      2022  Xiaomi Corp.        (authors: Fangjun Kuang,
+#                                                    Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script loads torchscript models, exported by `torch.jit.script()`
+and uses them to decode waves.
+You can use the following command to get the exported models:
+
+./pruned_transducer_stateless7_ctc/export.py \
+  --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 20 \
+  --avg 10 \
+  --jit 1
+
+Usage of this script:
+
+(1) ctc-decoding
+./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \
+  --nn-model-filename ./pruned_transducer_stateless7_ctc/exp/cpu_jit.pt \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --method ctc-decoding \
+  --sample-rate 16000 \
+  /path/to/foo.wav \
+  /path/to/bar.wav
+
+(2) 1best
+./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \
+  --nn-model-filename ./pruned_transducer_stateless7_ctc/exp/cpu_jit.pt \
+  --HLG data/lang_bpe_500/HLG.pt \
+  --words-file data/lang_bpe_500/words.txt  \
+  --method 1best \
+  --sample-rate 16000 \
+  /path/to/foo.wav \
+  /path/to/bar.wav
+
+
+(3) nbest-rescoring
+./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \
+  --nn-model-filename ./pruned_transducer_stateless7_ctc/exp/cpu_jit.pt \
+  --HLG data/lang_bpe_500/HLG.pt \
+  --words-file data/lang_bpe_500/words.txt  \
+  --G data/lm/G_4_gram.pt \
+  --method nbest-rescoring \
+  --sample-rate 16000 \
+  /path/to/foo.wav \
+  /path/to/bar.wav
+
+
+(4) whole-lattice-rescoring
+./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \
+  --nn-model-filename ./pruned_transducer_stateless7_ctc/exp/cpu_jit.pt \
+  --HLG data/lang_bpe_500/HLG.pt \
+  --words-file data/lang_bpe_500/words.txt  \
+  --G data/lm/G_4_gram.pt \
+  --method whole-lattice-rescoring \
+  --sample-rate 16000 \
+  /path/to/foo.wav \
+  /path/to/bar.wav
+"""
+
+import argparse
+import logging
+import math
+from typing import List
+
+import k2
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from ctc_decode import get_decoding_params
+from torch.nn.utils.rnn import pad_sequence
+from train import get_params
+
+from icefall.decode import (
+    get_lattice,
+    one_best_decoding,
+    rescore_with_n_best_list,
+    rescore_with_whole_lattice,
+)
+from icefall.utils import get_texts
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--model-filename",
+        type=str,
+        required=True,
+        help="Path to the torchscript model.",
+    )
+
+    parser.add_argument(
+        "--words-file",
+        type=str,
+        help="""Path to words.txt.
+        Used only when method is not ctc-decoding.
+        """,
+    )
+
+    parser.add_argument(
+        "--HLG",
+        type=str,
+        help="""Path to HLG.pt.
+        Used only when method is not ctc-decoding.
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        help="""Path to bpe.model.
+        Used only when method is ctc-decoding.
+        """,
+    )
+
+    parser.add_argument(
+        "--method",
+        type=str,
+        default="1best",
+        help="""Decoding method.
+        Possible values are:
+        (0) ctc-decoding - Use CTC decoding. It uses a sentence
+            piece model, i.e., lang_dir/bpe.model, to convert
+            word pieces to words. It needs neither a lexicon
+            nor an n-gram LM.
+        (1) 1best - Use the best path as decoding output. Only
+            the transformer encoder output is used for decoding.
+            We call it HLG decoding.
+        (2) nbest-rescoring. Extract n paths from the decoding lattice,
+            rescore them with an LM, the path with
+            the highest score is the decoding result.
+            We call it HLG decoding + n-gram LM rescoring.
+        (3) whole-lattice-rescoring - Use an LM to rescore the
+            decoding lattice and then use 1best to decode the
+            rescored lattice.
+            We call it HLG decoding + n-gram LM rescoring.
+        """,
+    )
+
+    parser.add_argument(
+        "--G",
+        type=str,
+        help="""An LM for rescoring.
+        Used only when method is
+        whole-lattice-rescoring or nbest-rescoring.
+        It's usually a 4-gram LM.
+        """,
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=100,
+        help="""
+        Used only when method is attention-decoder.
+        It specifies the size of n-best list.""",
+    )
+
+    parser.add_argument(
+        "--ngram-lm-scale",
+        type=float,
+        default=1.3,
+        help="""
+        Used only when method is whole-lattice-rescoring and nbest-rescoring.
+        It specifies the scale for n-gram LM scores.
+        (Note: You need to tune it on a dataset.)
+        """,
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=0.5,
+        help="""
+        Used only when method is nbest-rescoring.
+        It specifies the scale for lattice.scores when
+        extracting n-best lists. A smaller value results in
+        more unique number of paths with the risk of missing
+        the best path.
+        """,
+    )
+
+    parser.add_argument(
+        "--num-classes",
+        type=int,
+        default=500,
+        help="""
+        Vocab size in the BPE model.
+        """,
+    )
+
+    parser.add_argument(
+        "--sample-rate",
+        type=int,
+        default=16000,
+        help="The sample rate of the input sound file",
+    )
+
+    parser.add_argument(
+        "sound_files",
+        type=str,
+        nargs="+",
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
+    )
+
+    return parser
+
+
+def read_sound_files(
+    filenames: List[str], expected_sample_rate: float = 16000
+) -> List[torch.Tensor]:
+    """Read a list of sound files into a list 1-D float32 torch tensors.
+    Args:
+      filenames:
+        A list of sound filenames.
+      expected_sample_rate:
+        The expected sample rate of the sound files.
+    Returns:
+      Return a list of 1-D float32 torch tensors.
+    """
+    ans = []
+    for f in filenames:
+        wave, sample_rate = torchaudio.load(f)
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        # We use only the first channel
+        ans.append(wave[0])
+    return ans
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+
+    params = get_params()
+    # add decoding params
+    params.update(get_decoding_params())
+    params.update(vars(args))
+
+    logging.info(f"{params}")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    model = torch.jit.load(args.model_filename)
+    model.to(device)
+    model.eval()
+
+    logging.info("Constructing Fbank computer")
+    opts = kaldifeat.FbankOptions()
+    opts.device = device
+    opts.frame_opts.dither = 0
+    opts.frame_opts.snip_edges = False
+    opts.frame_opts.samp_freq = params.sample_rate
+    opts.mel_opts.num_bins = params.feature_dim
+
+    fbank = kaldifeat.Fbank(opts)
+
+    logging.info(f"Reading sound files: {params.sound_files}")
+    waves = read_sound_files(
+        filenames=params.sound_files, expected_sample_rate=params.sample_rate
+    )
+    waves = [w.to(device) for w in waves]
+
+    logging.info("Decoding started")
+    features = fbank(waves)
+    feature_lengths = [f.size(0) for f in features]
+
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    feature_lengths = torch.tensor(feature_lengths, device=device)
+
+    encoder_out, encoder_out_lens = model.encoder(
+        x=features,
+        x_lens=feature_lengths,
+    )
+    nnet_output = model.ctc_output(encoder_out)
+
+    batch_size = nnet_output.shape[0]
+    supervision_segments = torch.tensor(
+        [[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
+        dtype=torch.int32,
+    )
+
+    if params.method == "ctc-decoding":
+        logging.info("Use CTC decoding")
+        bpe_model = spm.SentencePieceProcessor()
+        bpe_model.load(params.bpe_model)
+        max_token_id = params.num_classes - 1
+
+        H = k2.ctc_topo(
+            max_token=max_token_id,
+            modified=False,
+            device=device,
+        )
+
+        lattice = get_lattice(
+            nnet_output=nnet_output,
+            decoding_graph=H,
+            supervision_segments=supervision_segments,
+            search_beam=params.search_beam,
+            output_beam=params.output_beam,
+            min_active_states=params.min_active_states,
+            max_active_states=params.max_active_states,
+            subsampling_factor=params.subsampling_factor,
+        )
+
+        best_path = one_best_decoding(
+            lattice=lattice, use_double_scores=params.use_double_scores
+        )
+        token_ids = get_texts(best_path)
+        hyps = bpe_model.decode(token_ids)
+        hyps = [s.split() for s in hyps]
+    elif params.method in [
+        "1best",
+        "nbest-rescoring",
+        "whole-lattice-rescoring",
+    ]:
+        logging.info(f"Loading HLG from {params.HLG}")
+        HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
+        HLG = HLG.to(device)
+        if not hasattr(HLG, "lm_scores"):
+            # For whole-lattice-rescoring and attention-decoder
+            HLG.lm_scores = HLG.scores.clone()
+
+        if params.method in [
+            "nbest-rescoring",
+            "whole-lattice-rescoring",
+        ]:
+            logging.info(f"Loading G from {params.G}")
+            G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
+            G = G.to(device)
+            if params.method == "whole-lattice-rescoring":
+                # Add epsilon self-loops to G as we will compose
+                # it with the whole lattice later
+                G = k2.add_epsilon_self_loops(G)
+                G = k2.arc_sort(G)
+
+            # G.lm_scores is used to replace HLG.lm_scores during
+            # LM rescoring.
+            G.lm_scores = G.scores.clone()
+
+        lattice = get_lattice(
+            nnet_output=nnet_output,
+            decoding_graph=HLG,
+            supervision_segments=supervision_segments,
+            search_beam=params.search_beam,
+            output_beam=params.output_beam,
+            min_active_states=params.min_active_states,
+            max_active_states=params.max_active_states,
+            subsampling_factor=params.subsampling_factor,
+        )
+
+        if params.method == "1best":
+            logging.info("Use HLG decoding")
+            best_path = one_best_decoding(
+                lattice=lattice, use_double_scores=params.use_double_scores
+            )
+        if params.method == "nbest-rescoring":
+            logging.info("Use HLG decoding + LM rescoring")
+            best_path_dict = rescore_with_n_best_list(
+                lattice=lattice,
+                G=G,
+                num_paths=params.num_paths,
+                lm_scale_list=[params.ngram_lm_scale],
+                nbest_scale=params.nbest_scale,
+            )
+            best_path = next(iter(best_path_dict.values()))
+        elif params.method == "whole-lattice-rescoring":
+            logging.info("Use HLG decoding + LM rescoring")
+            best_path_dict = rescore_with_whole_lattice(
+                lattice=lattice,
+                G_with_epsilon_loops=G,
+                lm_scale_list=[params.ngram_lm_scale],
+            )
+            best_path = next(iter(best_path_dict.values()))
+
+        hyps = get_texts(best_path)
+        word_sym_table = k2.SymbolTable.from_file(params.words_file)
+        hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
+    else:
+        raise ValueError(f"Unsupported decoding method: {params.method}")
+
+    s = "\n"
+    for filename, hyp in zip(params.sound_files, hyps):
+        words = " ".join(hyp)
+        s += f"{filename}:\n{words}\n\n"
+    logging.info(s)
+
+    logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/joiner.py
new file mode 120000
index 000000000..ecfb6dd8a
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/joiner.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7/joiner.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py
new file mode 100644
index 000000000..a6e919e2f
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py
@@ -0,0 +1,198 @@
+# Copyright    2021  Xiaomi Corp.        (authors: Fangjun Kuang, Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import Tuple
+
+import k2
+import torch
+import torch.nn as nn
+from encoder_interface import EncoderInterface
+
+from icefall.utils import add_sos
+
+
+class Transducer(nn.Module):
+    """It implements https://arxiv.org/pdf/1211.3711.pdf
+    "Sequence Transduction with Recurrent Neural Networks"
+    """
+
+    def __init__(
+        self,
+        encoder: EncoderInterface,
+        decoder: nn.Module,
+        joiner: nn.Module,
+        encoder_dim: int,
+        decoder_dim: int,
+        joiner_dim: int,
+        vocab_size: int,
+    ):
+        """
+        Args:
+          encoder:
+            It is the transcription network in the paper. Its accepts
+            two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
+            It returns two tensors: `logits` of shape (N, T, encoder_dm) and
+            `logit_lens` of shape (N,).
+          decoder:
+            It is the prediction network in the paper. Its input shape
+            is (N, U) and its output shape is (N, U, decoder_dim).
+            It should contain one attribute: `blank_id`.
+          joiner:
+            It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
+            Its output shape is (N, T, U, vocab_size). Note that its output contains
+            unnormalized probs, i.e., not processed by log-softmax.
+        """
+        super().__init__()
+        assert isinstance(encoder, EncoderInterface), type(encoder)
+        assert hasattr(decoder, "blank_id")
+
+        self.encoder = encoder
+        self.decoder = decoder
+        self.joiner = joiner
+
+        self.simple_am_proj = nn.Linear(
+            encoder_dim,
+            vocab_size,
+        )
+        self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size)
+
+        self.ctc_output = nn.Sequential(
+            nn.Dropout(p=0.1),
+            nn.Linear(encoder_dim, vocab_size),
+            nn.LogSoftmax(dim=-1),
+        )
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        x_lens: torch.Tensor,
+        y: k2.RaggedTensor,
+        prune_range: int = 5,
+        am_scale: float = 0.0,
+        lm_scale: float = 0.0,
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        """
+        Args:
+          x:
+            A 3-D tensor of shape (N, T, C).
+          x_lens:
+            A 1-D tensor of shape (N,). It contains the number of frames in `x`
+            before padding.
+          y:
+            A ragged tensor with 2 axes [utt][label]. It contains labels of each
+            utterance.
+          prune_range:
+            The prune range for rnnt loss, it means how many symbols(context)
+            we are considering for each frame to compute the loss.
+          am_scale:
+            The scale to smooth the loss with am (output of encoder network)
+            part
+          lm_scale:
+            The scale to smooth the loss with lm (output of predictor network)
+            part
+        Returns:
+          Return a tuple containing simple loss, pruned loss, and ctc-output.
+
+        Note:
+           Regarding am_scale & lm_scale, it will make the loss-function one of
+           the form:
+              lm_scale * lm_probs + am_scale * am_probs +
+              (1-lm_scale-am_scale) * combined_probs
+        """
+        assert x.ndim == 3, x.shape
+        assert x_lens.ndim == 1, x_lens.shape
+        assert y.num_axes == 2, y.num_axes
+
+        assert x.size(0) == x_lens.size(0) == y.dim0
+
+        encoder_out, x_lens = self.encoder(x, x_lens)
+        assert torch.all(x_lens > 0)
+
+        # compute ctc log-probs
+        ctc_output = self.ctc_output(encoder_out)
+
+        # Now for the decoder, i.e., the prediction network
+        row_splits = y.shape.row_splits(1)
+        y_lens = row_splits[1:] - row_splits[:-1]
+
+        blank_id = self.decoder.blank_id
+        sos_y = add_sos(y, sos_id=blank_id)
+
+        # sos_y_padded: [B, S + 1], start with SOS.
+        sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
+
+        # decoder_out: [B, S + 1, decoder_dim]
+        decoder_out = self.decoder(sos_y_padded)
+
+        # Note: y does not start with SOS
+        # y_padded : [B, S]
+        y_padded = y.pad(mode="constant", padding_value=0)
+
+        y_padded = y_padded.to(torch.int64)
+        boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
+        boundary[:, 2] = y_lens
+        boundary[:, 3] = x_lens
+
+        lm = self.simple_lm_proj(decoder_out)
+        am = self.simple_am_proj(encoder_out)
+
+        with torch.cuda.amp.autocast(enabled=False):
+            simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
+                lm=lm.float(),
+                am=am.float(),
+                symbols=y_padded,
+                termination_symbol=blank_id,
+                lm_only_scale=lm_scale,
+                am_only_scale=am_scale,
+                boundary=boundary,
+                reduction="sum",
+                return_grad=True,
+            )
+
+        # ranges : [B, T, prune_range]
+        ranges = k2.get_rnnt_prune_ranges(
+            px_grad=px_grad,
+            py_grad=py_grad,
+            boundary=boundary,
+            s_range=prune_range,
+        )
+
+        # am_pruned : [B, T, prune_range, encoder_dim]
+        # lm_pruned : [B, T, prune_range, decoder_dim]
+        am_pruned, lm_pruned = k2.do_rnnt_pruning(
+            am=self.joiner.encoder_proj(encoder_out),
+            lm=self.joiner.decoder_proj(decoder_out),
+            ranges=ranges,
+        )
+
+        # logits : [B, T, prune_range, vocab_size]
+
+        # project_input=False since we applied the decoder's input projections
+        # prior to do_rnnt_pruning (this is an optimization for speed).
+        logits = self.joiner(am_pruned, lm_pruned, project_input=False)
+
+        with torch.cuda.amp.autocast(enabled=False):
+            pruned_loss = k2.rnnt_loss_pruned(
+                logits=logits.float(),
+                symbols=y_padded,
+                ranges=ranges,
+                termination_symbol=blank_id,
+                boundary=boundary,
+                reduction="sum",
+            )
+
+        return (simple_loss, pruned_loss, ctc_output)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/optim.py
new file mode 120000
index 000000000..81ac4a89a
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/optim.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7/optim.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py
new file mode 100755
index 000000000..2f1b1a49f
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py
@@ -0,0 +1,353 @@
+#!/usr/bin/env python3
+# Copyright      2021  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script loads a checkpoint and uses it to decode waves.
+You can generate the checkpoint with the following command:
+
+./pruned_transducer_stateless7_ctc/export.py \
+  --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 20 \
+  --avg 10
+
+Usage of this script:
+
+(1) greedy search
+./pruned_transducer_stateless7_ctc/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method greedy_search \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+(2) beam search
+./pruned_transducer_stateless7_ctc/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method beam_search \
+    --beam-size 4 \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+(3) modified beam search
+./pruned_transducer_stateless7_ctc/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method modified_beam_search \
+    --beam-size 4 \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+(4) fast beam search
+./pruned_transducer_stateless7_ctc/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method fast_beam_search \
+    --beam-size 4 \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+You can also use `./pruned_transducer_stateless7_ctc/exp/epoch-xx.pt`.
+
+Note: ./pruned_transducer_stateless7_ctc/exp/pretrained.pt is generated by
+./pruned_transducer_stateless7_ctc/export.py
+"""
+
+
+import argparse
+import logging
+import math
+from typing import List
+
+import k2
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from beam_search import (
+    beam_search,
+    fast_beam_search_one_best,
+    greedy_search,
+    greedy_search_batch,
+    modified_beam_search,
+)
+from torch.nn.utils.rnn import pad_sequence
+from train import add_model_arguments, get_params, get_transducer_model
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--checkpoint",
+        type=str,
+        required=True,
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        help="""Path to bpe.model.""",
+    )
+
+    parser.add_argument(
+        "--method",
+        type=str,
+        default="greedy_search",
+        help="""Possible values are:
+          - greedy_search
+          - beam_search
+          - modified_beam_search
+          - fast_beam_search
+        """,
+    )
+
+    parser.add_argument(
+        "sound_files",
+        type=str,
+        nargs="+",
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
+    )
+
+    parser.add_argument(
+        "--sample-rate",
+        type=int,
+        default=16000,
+        help="The sample rate of the input sound file",
+    )
+
+    parser.add_argument(
+        "--beam-size",
+        type=int,
+        default=4,
+        help="""An integer indicating how many candidates we will keep for each
+        frame. Used only when --method is beam_search or
+        modified_beam_search.""",
+    )
+
+    parser.add_argument(
+        "--beam",
+        type=float,
+        default=4,
+        help="""A floating point value to calculate the cutoff score during beam
+        search (i.e., `cutoff = max-score - beam`), which is the same as the
+        `beam` in Kaldi.
+        Used only when --method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--max-contexts",
+        type=int,
+        default=4,
+        help="""Used only when --method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--max-states",
+        type=int,
+        default=8,
+        help="""Used only when --method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+    parser.add_argument(
+        "--max-sym-per-frame",
+        type=int,
+        default=1,
+        help="""Maximum number of symbols per frame. Used only when
+        --method is greedy_search.
+        """,
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def read_sound_files(
+    filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+    """Read a list of sound files into a list 1-D float32 torch tensors.
+    Args:
+      filenames:
+        A list of sound filenames.
+      expected_sample_rate:
+        The expected sample rate of the sound files.
+    Returns:
+      Return a list of 1-D float32 torch tensors.
+    """
+    ans = []
+    for f in filenames:
+        wave, sample_rate = torchaudio.load(f)
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        # We use only the first channel
+        ans.append(wave[0])
+    return ans
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+
+    params = get_params()
+
+    params.update(vars(args))
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.unk_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(f"{params}")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    logging.info("Creating model")
+    model = get_transducer_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    checkpoint = torch.load(args.checkpoint, map_location="cpu")
+    model.load_state_dict(checkpoint["model"], strict=False)
+    model.to(device)
+    model.eval()
+    model.device = device
+
+    logging.info("Constructing Fbank computer")
+    opts = kaldifeat.FbankOptions()
+    opts.device = device
+    opts.frame_opts.dither = 0
+    opts.frame_opts.snip_edges = False
+    opts.frame_opts.samp_freq = params.sample_rate
+    opts.mel_opts.num_bins = params.feature_dim
+
+    fbank = kaldifeat.Fbank(opts)
+
+    logging.info(f"Reading sound files: {params.sound_files}")
+    waves = read_sound_files(
+        filenames=params.sound_files, expected_sample_rate=params.sample_rate
+    )
+    waves = [w.to(device) for w in waves]
+
+    logging.info("Decoding started")
+    features = fbank(waves)
+    feature_lengths = [f.size(0) for f in features]
+
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+
+    feature_lengths = torch.tensor(feature_lengths, device=device)
+
+    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
+
+    num_waves = encoder_out.size(0)
+    hyps = []
+    msg = f"Using {params.method}"
+    if params.method == "beam_search":
+        msg += f" with beam size {params.beam_size}"
+    logging.info(msg)
+
+    if params.method == "fast_beam_search":
+        decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+        hyp_tokens = fast_beam_search_one_best(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.method == "modified_beam_search":
+        hyp_tokens = modified_beam_search(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam_size,
+        )
+
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
+        hyp_tokens = greedy_search_batch(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    else:
+        for i in range(num_waves):
+            # fmt: off
+            encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+            # fmt: on
+            if params.method == "greedy_search":
+                hyp = greedy_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    max_sym_per_frame=params.max_sym_per_frame,
+                )
+            elif params.method == "beam_search":
+                hyp = beam_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    beam=params.beam_size,
+                )
+            else:
+                raise ValueError(f"Unsupported method: {params.method}")
+
+            hyps.append(sp.decode(hyp).split())
+
+    s = "\n"
+    for filename, hyp in zip(params.sound_files, hyps):
+        words = " ".join(hyp)
+        s += f"{filename}:\n{words}\n\n"
+    logging.info(s)
+
+    logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py
new file mode 100755
index 000000000..74aef1bc7
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py
@@ -0,0 +1,441 @@
+#!/usr/bin/env python3
+# Copyright      2022  Xiaomi Corp.        (authors: Fangjun Kuang,
+#                                                    Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script loads torchscript models, exported by `torch.jit.script()`
+and uses them to decode waves.
+You can use the following command to get the exported models:
+
+./pruned_transducer_stateless7_ctc/export.py \
+  --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 20 \
+  --avg 10
+
+Usage of this script:
+
+(1) ctc-decoding
+./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \
+  --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --method ctc-decoding \
+  --sample-rate 16000 \
+  /path/to/foo.wav \
+  /path/to/bar.wav
+
+(2) 1best
+./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \
+  --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
+  --HLG data/lang_bpe_500/HLG.pt \
+  --words-file data/lang_bpe_500/words.txt  \
+  --method 1best \
+  --sample-rate 16000 \
+  /path/to/foo.wav \
+  /path/to/bar.wav
+
+(3) nbest-rescoring
+./bruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \
+  --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
+  --HLG data/lang_bpe_500/HLG.pt \
+  --words-file data/lang_bpe_500/words.txt  \
+  --G data/lm/G_4_gram.pt \
+  --method nbest-rescoring \
+  --sample-rate 16000 \
+  /path/to/foo.wav \
+  /path/to/bar.wav
+
+
+(4) whole-lattice-rescoring
+./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \
+  --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
+  --HLG data/lang_bpe_500/HLG.pt \
+  --words-file data/lang_bpe_500/words.txt  \
+  --G data/lm/G_4_gram.pt \
+  --method whole-lattice-rescoring \
+  --sample-rate 16000 \
+  /path/to/foo.wav \
+  /path/to/bar.wav
+"""
+
+import argparse
+import logging
+import math
+from typing import List
+
+import k2
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from ctc_decode import get_decoding_params
+from torch.nn.utils.rnn import pad_sequence
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.decode import (
+    get_lattice,
+    one_best_decoding,
+    rescore_with_n_best_list,
+    rescore_with_whole_lattice,
+)
+from icefall.utils import get_texts
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--checkpoint",
+        type=str,
+        required=True,
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+    )
+
+    parser.add_argument(
+        "--words-file",
+        type=str,
+        help="""Path to words.txt.
+        Used only when method is not ctc-decoding.
+        """,
+    )
+
+    parser.add_argument(
+        "--HLG",
+        type=str,
+        help="""Path to HLG.pt.
+        Used only when method is not ctc-decoding.
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        help="""Path to bpe.model.
+        Used only when method is ctc-decoding.
+        """,
+    )
+
+    parser.add_argument(
+        "--method",
+        type=str,
+        default="1best",
+        help="""Decoding method.
+        Possible values are:
+        (0) ctc-decoding - Use CTC decoding. It uses a sentence
+            piece model, i.e., lang_dir/bpe.model, to convert
+            word pieces to words. It needs neither a lexicon
+            nor an n-gram LM.
+        (1) 1best - Use the best path as decoding output. Only
+            the transformer encoder output is used for decoding.
+            We call it HLG decoding.
+        (2) nbest-rescoring. Extract n paths from the decoding lattice,
+            rescore them with an LM, the path with
+            the highest score is the decoding result.
+            We call it HLG decoding + n-gram LM rescoring.
+        (3) whole-lattice-rescoring - Use an LM to rescore the
+            decoding lattice and then use 1best to decode the
+            rescored lattice.
+            We call it HLG decoding + n-gram LM rescoring.
+        """,
+    )
+
+    parser.add_argument(
+        "--G",
+        type=str,
+        help="""An LM for rescoring.
+        Used only when method is
+        whole-lattice-rescoring or nbest-rescoring.
+        It's usually a 4-gram LM.
+        """,
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=100,
+        help="""
+        Used only when method is attention-decoder.
+        It specifies the size of n-best list.""",
+    )
+
+    parser.add_argument(
+        "--ngram-lm-scale",
+        type=float,
+        default=1.3,
+        help="""
+        Used only when method is whole-lattice-rescoring and nbest-rescoring.
+        It specifies the scale for n-gram LM scores.
+        (Note: You need to tune it on a dataset.)
+        """,
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=0.5,
+        help="""
+        Used only when method is nbest-rescoring.
+        It specifies the scale for lattice.scores when
+        extracting n-best lists. A smaller value results in
+        more unique number of paths with the risk of missing
+        the best path.
+        """,
+    )
+
+    parser.add_argument(
+        "--num-classes",
+        type=int,
+        default=500,
+        help="""
+        Vocab size in the BPE model.
+        """,
+    )
+
+    parser.add_argument(
+        "--sample-rate",
+        type=int,
+        default=16000,
+        help="The sample rate of the input sound file",
+    )
+
+    parser.add_argument(
+        "sound_files",
+        type=str,
+        nargs="+",
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def read_sound_files(
+    filenames: List[str], expected_sample_rate: float = 16000
+) -> List[torch.Tensor]:
+    """Read a list of sound files into a list 1-D float32 torch tensors.
+    Args:
+      filenames:
+        A list of sound filenames.
+      expected_sample_rate:
+        The expected sample rate of the sound files.
+    Returns:
+      Return a list of 1-D float32 torch tensors.
+    """
+    ans = []
+    for f in filenames:
+        wave, sample_rate = torchaudio.load(f)
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
+        )
+        # We use only the first channel
+        ans.append(wave[0])
+    return ans
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+
+    params = get_params()
+    # add decoding params
+    params.update(get_decoding_params())
+    params.update(vars(args))
+    params.vocab_size = params.num_classes
+    params.blank_id = 0
+
+    logging.info(f"{params}")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    logging.info("Creating model")
+    model = get_transducer_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    checkpoint = torch.load(args.checkpoint, map_location="cpu")
+    model.load_state_dict(checkpoint["model"], strict=False)
+    model.to(device)
+    model.eval()
+
+    logging.info("Constructing Fbank computer")
+    opts = kaldifeat.FbankOptions()
+    opts.device = device
+    opts.frame_opts.dither = 0
+    opts.frame_opts.snip_edges = False
+    opts.frame_opts.samp_freq = params.sample_rate
+    opts.mel_opts.num_bins = params.feature_dim
+
+    fbank = kaldifeat.Fbank(opts)
+
+    logging.info(f"Reading sound files: {params.sound_files}")
+    waves = read_sound_files(
+        filenames=params.sound_files, expected_sample_rate=params.sample_rate
+    )
+    waves = [w.to(device) for w in waves]
+
+    logging.info("Decoding started")
+    features = fbank(waves)
+    feature_lengths = [f.size(0) for f in features]
+
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    feature_lengths = torch.tensor(feature_lengths, device=device)
+
+    encoder_out, encoder_out_lens = model.encoder(
+        x=features,
+        x_lens=feature_lengths,
+    )
+    nnet_output = model.ctc_output(encoder_out)
+
+    batch_size = nnet_output.shape[0]
+    supervision_segments = torch.tensor(
+        [[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
+        dtype=torch.int32,
+    )
+
+    if params.method == "ctc-decoding":
+        logging.info("Use CTC decoding")
+        bpe_model = spm.SentencePieceProcessor()
+        bpe_model.load(params.bpe_model)
+        max_token_id = params.num_classes - 1
+
+        H = k2.ctc_topo(
+            max_token=max_token_id,
+            modified=False,
+            device=device,
+        )
+
+        lattice = get_lattice(
+            nnet_output=nnet_output,
+            decoding_graph=H,
+            supervision_segments=supervision_segments,
+            search_beam=params.search_beam,
+            output_beam=params.output_beam,
+            min_active_states=params.min_active_states,
+            max_active_states=params.max_active_states,
+            subsampling_factor=params.subsampling_factor,
+        )
+
+        best_path = one_best_decoding(
+            lattice=lattice, use_double_scores=params.use_double_scores
+        )
+        token_ids = get_texts(best_path)
+        hyps = bpe_model.decode(token_ids)
+        hyps = [s.split() for s in hyps]
+    elif params.method in [
+        "1best",
+        "nbest-rescoring",
+        "whole-lattice-rescoring",
+    ]:
+        logging.info(f"Loading HLG from {params.HLG}")
+        HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
+        HLG = HLG.to(device)
+        if not hasattr(HLG, "lm_scores"):
+            # For whole-lattice-rescoring and attention-decoder
+            HLG.lm_scores = HLG.scores.clone()
+
+        if params.method in [
+            "nbest-rescoring",
+            "whole-lattice-rescoring",
+        ]:
+            logging.info(f"Loading G from {params.G}")
+            G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
+            G = G.to(device)
+            if params.method == "whole-lattice-rescoring":
+                # Add epsilon self-loops to G as we will compose
+                # it with the whole lattice later
+                G = k2.add_epsilon_self_loops(G)
+                G = k2.arc_sort(G)
+
+            # G.lm_scores is used to replace HLG.lm_scores during
+            # LM rescoring.
+            G.lm_scores = G.scores.clone()
+
+        lattice = get_lattice(
+            nnet_output=nnet_output,
+            decoding_graph=HLG,
+            supervision_segments=supervision_segments,
+            search_beam=params.search_beam,
+            output_beam=params.output_beam,
+            min_active_states=params.min_active_states,
+            max_active_states=params.max_active_states,
+            subsampling_factor=params.subsampling_factor,
+        )
+
+        if params.method == "1best":
+            logging.info("Use HLG decoding")
+            best_path = one_best_decoding(
+                lattice=lattice, use_double_scores=params.use_double_scores
+            )
+        if params.method == "nbest-rescoring":
+            logging.info("Use HLG decoding + LM rescoring")
+            best_path_dict = rescore_with_n_best_list(
+                lattice=lattice,
+                G=G,
+                num_paths=params.num_paths,
+                lm_scale_list=[params.ngram_lm_scale],
+                nbest_scale=params.nbest_scale,
+            )
+            best_path = next(iter(best_path_dict.values()))
+        elif params.method == "whole-lattice-rescoring":
+            logging.info("Use HLG decoding + LM rescoring")
+            best_path_dict = rescore_with_whole_lattice(
+                lattice=lattice,
+                G_with_epsilon_loops=G,
+                lm_scale_list=[params.ngram_lm_scale],
+            )
+            best_path = next(iter(best_path_dict.values()))
+
+        hyps = get_texts(best_path)
+        word_sym_table = k2.SymbolTable.from_file(params.words_file)
+        hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
+    else:
+        raise ValueError(f"Unsupported decoding method: {params.method}")
+
+    s = "\n"
+    for filename, hyp in zip(params.sound_files, hyps):
+        words = " ".join(hyp)
+        s += f"{filename}:\n{words}\n\n"
+    logging.info(s)
+
+    logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/scaling.py
new file mode 120000
index 000000000..2428b74b9
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/scaling.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7/scaling.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/scaling_converter.py
new file mode 120000
index 000000000..b8b8ba432
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/scaling_converter.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7/scaling_converter.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/test_model.py
new file mode 100755
index 000000000..e482d2040
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/test_model.py
@@ -0,0 +1,56 @@
+#!/usr/bin/env python3
+# Copyright    2022  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+To run this file, do:
+
+    cd icefall/egs/librispeech/ASR
+    python ./pruned_transducer_stateless7_ctc/test_model.py
+"""
+
+from train import get_params, get_transducer_model
+
+
+def test_model_1():
+    params = get_params()
+    params.vocab_size = 500
+    params.blank_id = 0
+    params.context_size = 2
+    params.num_encoder_layers = "2,4,3,2,4"
+    #  params.feedforward_dims = "1024,1024,1536,1536,1024"
+    params.feedforward_dims = "1024,1024,2048,2048,1024"
+    params.nhead = "8,8,8,8,8"
+    params.encoder_dims = "384,384,384,384,384"
+    params.attention_dims = "192,192,192,192,192"
+    params.encoder_unmasked_dims = "256,256,256,256,256"
+    params.zipformer_downsampling_factors = "1,2,4,8,2"
+    params.cnn_module_kernels = "31,31,31,31,31"
+    params.decoder_dim = 512
+    params.joiner_dim = 512
+    model = get_transducer_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    print(f"Number of model parameters: {num_param}")
+
+
+def main():
+    test_model_1()
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py
new file mode 100755
index 000000000..abfd56e5a
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py
@@ -0,0 +1,1252 @@
+#!/usr/bin/env python3
+# Copyright    2021-2022  Xiaomi Corp.        (authors: Fangjun Kuang,
+#                                                       Wei Kang,
+#                                                       Mingshuang Luo,)
+#                                                       Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless7_ctc/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --exp-dir pruned_transducer_stateless7_ctc/exp \
+  --full-libri 1 \
+  --max-duration 300
+
+# For mix precision training:
+
+./pruned_transducer_stateless7_ctc/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --use-fp16 1 \
+  --exp-dir pruned_transducer_stateless7_ctc/exp \
+  --full-libri 1 \
+  --max-duration 550
+
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import Transducer
+from optim import Eden, ScaledAdam
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+    save_checkpoint_with_global_batch_idx,
+    update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import (
+    AttributeDict,
+    MetricsTracker,
+    encode_supervisions,
+    setup_logger,
+    str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+    if isinstance(model, DDP):
+        # get underlying nn.Module
+        model = model.module
+    for module in model.modules():
+        if hasattr(module, "batch_count"):
+            module.batch_count = batch_count
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+    parser.add_argument(
+        "--num-encoder-layers",
+        type=str,
+        default="2,4,3,2,4",
+        help="Number of zipformer encoder layers, comma separated.",
+    )
+
+    parser.add_argument(
+        "--feedforward-dims",
+        type=str,
+        default="1024,1024,2048,2048,1024",
+        help="Feedforward dimension of the zipformer encoder layers, comma separated.",
+    )
+
+    parser.add_argument(
+        "--nhead",
+        type=str,
+        default="8,8,8,8,8",
+        help="Number of attention heads in the zipformer encoder layers.",
+    )
+
+    parser.add_argument(
+        "--encoder-dims",
+        type=str,
+        default="384,384,384,384,384",
+        help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated",
+    )
+
+    parser.add_argument(
+        "--attention-dims",
+        type=str,
+        default="192,192,192,192,192",
+        help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
+        not the same as embedding dimension.""",
+    )
+
+    parser.add_argument(
+        "--encoder-unmasked-dims",
+        type=str,
+        default="256,256,256,256,256",
+        help="Unmasked dimensions in the encoders, relates to augmentation during training.  "
+        "Must be <= each of encoder_dims.  Empirically, less than 256 seems to make performance "
+        " worse.",
+    )
+
+    parser.add_argument(
+        "--zipformer-downsampling-factors",
+        type=str,
+        default="1,2,4,8,2",
+        help="Downsampling factor for each stack of encoder layers.",
+    )
+
+    parser.add_argument(
+        "--cnn-module-kernels",
+        type=str,
+        default="31,31,31,31,31",
+        help="Sizes of kernels in convolution modules",
+    )
+
+    parser.add_argument(
+        "--decoder-dim",
+        type=int,
+        default=512,
+        help="Embedding dimension in the decoder model.",
+    )
+
+    parser.add_argument(
+        "--joiner-dim",
+        type=int,
+        default=512,
+        help="""Dimension used in the joiner model.
+        Outputs from the encoder and decoder model are projected
+        to this dimension before adding.
+        """,
+    )
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--world-size",
+        type=int,
+        default=1,
+        help="Number of GPUs for DDP training.",
+    )
+
+    parser.add_argument(
+        "--master-port",
+        type=int,
+        default=12354,
+        help="Master port to use for DDP training.",
+    )
+
+    parser.add_argument(
+        "--tensorboard",
+        type=str2bool,
+        default=True,
+        help="Should various information be logged in tensorboard.",
+    )
+
+    parser.add_argument(
+        "--num-epochs",
+        type=int,
+        default=30,
+        help="Number of epochs to train.",
+    )
+
+    parser.add_argument(
+        "--start-epoch",
+        type=int,
+        default=1,
+        help="""Resume training from this epoch. It should be positive.
+        If larger than 1, it will load checkpoint from
+        exp-dir/epoch-{start_epoch-1}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--start-batch",
+        type=int,
+        default=0,
+        help="""If positive, --start-epoch is ignored and
+        it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless7_ctc/exp",
+        help="""The experiment dir.
+        It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_500/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--base-lr", type=float, default=0.05, help="The base learning rate."
+    )
+
+    parser.add_argument(
+        "--lr-batches",
+        type=float,
+        default=5000,
+        help="""Number of steps that affects how rapidly the learning rate
+        decreases. We suggest not to change this.""",
+    )
+
+    parser.add_argument(
+        "--lr-epochs",
+        type=float,
+        default=3.5,
+        help="""Number of epochs that affects how rapidly the learning rate decreases.
+        """,
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+
+    parser.add_argument(
+        "--prune-range",
+        type=int,
+        default=5,
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
+    )
+
+    parser.add_argument(
+        "--lm-scale",
+        type=float,
+        default=0.25,
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
+    )
+
+    parser.add_argument(
+        "--am-scale",
+        type=float,
+        default=0.0,
+        help="The scale to smooth the loss with am (output of encoder network) part.",
+    )
+
+    parser.add_argument(
+        "--simple-loss-scale",
+        type=float,
+        default=0.5,
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
+    )
+
+    parser.add_argument(
+        "--ctc-loss-scale",
+        type=float,
+        default=0.2,
+        help="Scale for CTC loss.",
+    )
+
+    parser.add_argument(
+        "--seed",
+        type=int,
+        default=42,
+        help="The seed for random generators intended for reproducibility",
+    )
+
+    parser.add_argument(
+        "--print-diagnostics",
+        type=str2bool,
+        default=False,
+        help="Accumulate stats on activations, print them and exit.",
+    )
+
+    parser.add_argument(
+        "--inf-check",
+        type=str2bool,
+        default=False,
+        help="Add hooks to check for infinite module outputs and gradients.",
+    )
+
+    parser.add_argument(
+        "--save-every-n",
+        type=int,
+        default=2000,
+        help="""Save checkpoint after processing this number of batches"
+        periodically. We save checkpoint to exp-dir/ whenever
+        params.batch_idx_train % save_every_n == 0. The checkpoint filename
+        has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+        Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+        end of each epoch where `xxx` is the epoch number counting from 0.
+        """,
+    )
+
+    parser.add_argument(
+        "--keep-last-k",
+        type=int,
+        default=30,
+        help="""Only keep this number of checkpoints on disk.
+        For instance, if it is 3, there are only 3 checkpoints
+        in the exp-dir with filenames `checkpoint-xxx.pt`.
+        It does not affect checkpoints with name `epoch-xxx.pt`.
+        """,
+    )
+
+    parser.add_argument(
+        "--average-period",
+        type=int,
+        default=200,
+        help="""Update the averaged model, namely `model_avg`, after processing
+        this number of batches. `model_avg` is a separate version of model,
+        in which each floating-point parameter is the average of all the
+        parameters from the start of training. Each time we take the average,
+        we do: `model_avg = model * (average_period / batch_idx_train) +
+            model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+        """,
+    )
+
+    parser.add_argument(
+        "--use-fp16",
+        type=str2bool,
+        default=False,
+        help="Whether to use half precision training.",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    """Return a dict containing training parameters.
+
+    All training related parameters that are not passed from the commandline
+    are saved in the variable `params`.
+
+    Commandline options are merged into `params` after they are parsed, so
+    you can also access them via `params`.
+
+    Explanation of options saved in `params`:
+
+        - best_train_loss: Best training loss so far. It is used to select
+                           the model that has the lowest training loss. It is
+                           updated during the training.
+
+        - best_valid_loss: Best validation loss so far. It is used to select
+                           the model that has the lowest validation loss. It is
+                           updated during the training.
+
+        - best_train_epoch: It is the epoch that has the best training loss.
+
+        - best_valid_epoch: It is the epoch that has the best validation loss.
+
+        - batch_idx_train: Used to writing statistics to tensorboard. It
+                           contains number of batches trained so far across
+                           epochs.
+
+        - log_interval:  Print training loss if batch_idx % log_interval` is 0
+
+        - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+        - valid_interval:  Run validation if batch_idx % valid_interval is 0
+
+        - feature_dim: The model input dim. It has to match the one used
+                       in computing features.
+
+        - subsampling_factor:  The subsampling factor for the model.
+
+        - encoder_dim: Hidden dim for multi-head attention model.
+
+        - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+        - warm_step: The warmup period that dictates the decay of the
+              scale on "simple" (un-pruned) loss.
+    """
+    params = AttributeDict(
+        {
+            "best_train_loss": float("inf"),
+            "best_valid_loss": float("inf"),
+            "best_train_epoch": -1,
+            "best_valid_epoch": -1,
+            "batch_idx_train": 0,
+            "log_interval": 50,
+            "reset_interval": 200,
+            "valid_interval": 3000,  # For the 100h subset, use 800
+            # parameters for zipformer
+            "feature_dim": 80,
+            "subsampling_factor": 4,  # not passed in, this is fixed.
+            # parameters for ctc loss
+            "beam_size": 10,
+            "use_double_scores": True,
+            "warm_step": 2000,
+            "env_info": get_env_info(),
+        }
+    )
+
+    return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+    # TODO: We can add an option to switch between Zipformer and Transformer
+    def to_int_tuple(s: str):
+        return tuple(map(int, s.split(",")))
+
+    encoder = Zipformer(
+        num_features=params.feature_dim,
+        output_downsampling_factor=2,
+        zipformer_downsampling_factors=to_int_tuple(
+            params.zipformer_downsampling_factors
+        ),
+        encoder_dims=to_int_tuple(params.encoder_dims),
+        attention_dim=to_int_tuple(params.attention_dims),
+        encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
+        nhead=to_int_tuple(params.nhead),
+        feedforward_dim=to_int_tuple(params.feedforward_dims),
+        cnn_module_kernels=to_int_tuple(params.cnn_module_kernels),
+        num_encoder_layers=to_int_tuple(params.num_encoder_layers),
+    )
+    return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+    decoder = Decoder(
+        vocab_size=params.vocab_size,
+        decoder_dim=params.decoder_dim,
+        blank_id=params.blank_id,
+        context_size=params.context_size,
+    )
+    return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+    joiner = Joiner(
+        encoder_dim=int(params.encoder_dims.split(",")[-1]),
+        decoder_dim=params.decoder_dim,
+        joiner_dim=params.joiner_dim,
+        vocab_size=params.vocab_size,
+    )
+    return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+    encoder = get_encoder_model(params)
+    decoder = get_decoder_model(params)
+    joiner = get_joiner_model(params)
+
+    model = Transducer(
+        encoder=encoder,
+        decoder=decoder,
+        joiner=joiner,
+        encoder_dim=int(params.encoder_dims.split(",")[-1]),
+        decoder_dim=params.decoder_dim,
+        joiner_dim=params.joiner_dim,
+        vocab_size=params.vocab_size,
+    )
+    return model
+
+
+def load_checkpoint_if_available(
+    params: AttributeDict,
+    model: nn.Module,
+    model_avg: nn.Module = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+    """Load checkpoint from file.
+
+    If params.start_batch is positive, it will load the checkpoint from
+    `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+    params.start_epoch is larger than 1, it will load the checkpoint from
+    `params.start_epoch - 1`.
+
+    Apart from loading state dict for `model` and `optimizer` it also updates
+    `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+    and `best_valid_loss` in `params`.
+
+    Args:
+      params:
+        The return value of :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer that we are using.
+      scheduler:
+        The scheduler that we are using.
+    Returns:
+      Return a dict containing previously saved training info.
+    """
+    if params.start_batch > 0:
+        filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+    elif params.start_epoch > 1:
+        filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+    else:
+        return None
+
+    assert filename.is_file(), f"{filename} does not exist!"
+
+    saved_params = load_checkpoint(
+        filename,
+        model=model,
+        model_avg=model_avg,
+        optimizer=optimizer,
+        scheduler=scheduler,
+    )
+
+    keys = [
+        "best_train_epoch",
+        "best_valid_epoch",
+        "batch_idx_train",
+        "best_train_loss",
+        "best_valid_loss",
+    ]
+    for k in keys:
+        params[k] = saved_params[k]
+
+    if params.start_batch > 0:
+        if "cur_epoch" in saved_params:
+            params["start_epoch"] = saved_params["cur_epoch"]
+
+        if "cur_batch_idx" in saved_params:
+            params["cur_batch_idx"] = saved_params["cur_batch_idx"]
+
+    return saved_params
+
+
+def save_checkpoint(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    model_avg: Optional[nn.Module] = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+    sampler: Optional[CutSampler] = None,
+    scaler: Optional[GradScaler] = None,
+    rank: int = 0,
+) -> None:
+    """Save model, optimizer, scheduler and training stats to file.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer used in the training.
+      sampler:
+       The sampler for the training dataset.
+      scaler:
+        The scaler used for mix precision training.
+    """
+    if rank != 0:
+        return
+    filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+    save_checkpoint_impl(
+        filename=filename,
+        model=model,
+        model_avg=model_avg,
+        params=params,
+        optimizer=optimizer,
+        scheduler=scheduler,
+        sampler=sampler,
+        scaler=scaler,
+        rank=rank,
+    )
+
+    if params.best_train_epoch == params.cur_epoch:
+        best_train_filename = params.exp_dir / "best-train-loss.pt"
+        copyfile(src=filename, dst=best_train_filename)
+
+    if params.best_valid_epoch == params.cur_epoch:
+        best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+        copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    sp: spm.SentencePieceProcessor,
+    batch: dict,
+    is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+    """
+    Compute transducer loss given the model and its inputs.
+
+    Args:
+      params:
+        Parameters for training. See :func:`get_params`.
+      model:
+        The model for training. It is an instance of Zipformer in our case.
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      is_training:
+        True for training. False for validation. When it is True, this
+        function enables autograd during computation; when it is False, it
+        disables autograd.
+     warmup: a floating point value which increases throughout training;
+        values >= 1.0 are fully warmed up and have all modules present.
+    """
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+    feature = batch["inputs"]
+    # at entry, feature is (N, T, C)
+    assert feature.ndim == 3
+    feature = feature.to(device)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    batch_idx_train = params.batch_idx_train
+    warm_step = params.warm_step
+
+    texts = batch["supervisions"]["text"]
+    token_ids = sp.encode(texts, out_type=int)
+    y = k2.RaggedTensor(token_ids).to(device)
+
+    with torch.set_grad_enabled(is_training):
+        simple_loss, pruned_loss, ctc_output = model(
+            x=feature,
+            x_lens=feature_lens,
+            y=y,
+            prune_range=params.prune_range,
+            am_scale=params.am_scale,
+            lm_scale=params.lm_scale,
+        )
+
+        s = params.simple_loss_scale
+        # take down the scale on the simple loss from 1.0 at the start
+        # to params.simple_loss scale by warm_step.
+        simple_loss_scale = (
+            s
+            if batch_idx_train >= warm_step
+            else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+        )
+        pruned_loss_scale = (
+            1.0
+            if batch_idx_train >= warm_step
+            else 0.1 + 0.9 * (batch_idx_train / warm_step)
+        )
+
+        loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+    # Compute ctc loss
+
+    # NOTE: We need `encode_supervisions` to sort sequences with
+    # different duration in decreasing order, required by
+    # `k2.intersect_dense` called in `k2.ctc_loss`
+    with warnings.catch_warnings():
+        warnings.simplefilter("ignore")
+        supervision_segments, token_ids = encode_supervisions(
+            supervisions,
+            subsampling_factor=params.subsampling_factor,
+            token_ids=token_ids,
+        )
+
+    # Works with a BPE model
+    decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device)
+    dense_fsa_vec = k2.DenseFsaVec(
+        ctc_output,
+        supervision_segments,
+        allow_truncate=params.subsampling_factor - 1,
+    )
+
+    ctc_loss = k2.ctc_loss(
+        decoding_graph=decoding_graph,
+        dense_fsa_vec=dense_fsa_vec,
+        output_beam=params.beam_size,
+        reduction="sum",
+        use_double_scores=params.use_double_scores,
+    )
+    assert ctc_loss.requires_grad == is_training
+    loss += params.ctc_loss_scale * ctc_loss
+
+    assert loss.requires_grad == is_training
+
+    info = MetricsTracker()
+    with warnings.catch_warnings():
+        warnings.simplefilter("ignore")
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+    # Note: We use reduction=sum while computing the loss.
+    info["loss"] = loss.detach().cpu().item()
+    info["simple_loss"] = simple_loss.detach().cpu().item()
+    info["pruned_loss"] = pruned_loss.detach().cpu().item()
+    info["ctc_loss"] = ctc_loss.detach().cpu().item()
+
+    return loss, info
+
+
+def compute_validation_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    sp: spm.SentencePieceProcessor,
+    valid_dl: torch.utils.data.DataLoader,
+    world_size: int = 1,
+) -> MetricsTracker:
+    """Run the validation process."""
+    model.eval()
+
+    tot_loss = MetricsTracker()
+
+    for batch_idx, batch in enumerate(valid_dl):
+        loss, loss_info = compute_loss(
+            params=params,
+            model=model,
+            sp=sp,
+            batch=batch,
+            is_training=False,
+        )
+        assert loss.requires_grad is False
+        tot_loss = tot_loss + loss_info
+
+    if world_size > 1:
+        tot_loss.reduce(loss.device)
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    if loss_value < params.best_valid_loss:
+        params.best_valid_epoch = params.cur_epoch
+        params.best_valid_loss = loss_value
+
+    return tot_loss
+
+
+def train_one_epoch(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    optimizer: torch.optim.Optimizer,
+    scheduler: LRSchedulerType,
+    sp: spm.SentencePieceProcessor,
+    train_dl: torch.utils.data.DataLoader,
+    valid_dl: torch.utils.data.DataLoader,
+    scaler: GradScaler,
+    model_avg: Optional[nn.Module] = None,
+    tb_writer: Optional[SummaryWriter] = None,
+    world_size: int = 1,
+    rank: int = 0,
+) -> None:
+    """Train the model for one epoch.
+
+    The training loss from the mean of all frames is saved in
+    `params.train_loss`. It runs the validation process every
+    `params.valid_interval` batches.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The model for training.
+      optimizer:
+        The optimizer we are using.
+      scheduler:
+        The learning rate scheduler, we call step() every step.
+      train_dl:
+        Dataloader for the training dataset.
+      valid_dl:
+        Dataloader for the validation dataset.
+      scaler:
+        The scaler used for mix precision training.
+      model_avg:
+        The stored model averaged from the start of training.
+      tb_writer:
+        Writer to write log messages to tensorboard.
+      world_size:
+        Number of nodes in DDP training. If it is 1, DDP is disabled.
+      rank:
+        The rank of the node in DDP training. If no DDP is used, it should
+        be set to 0.
+    """
+    model.train()
+
+    tot_loss = MetricsTracker()
+
+    cur_batch_idx = params.get("cur_batch_idx", 0)
+
+    for batch_idx, batch in enumerate(train_dl):
+        if batch_idx < cur_batch_idx:
+            continue
+        cur_batch_idx = batch_idx
+
+        params.batch_idx_train += 1
+        batch_size = len(batch["supervisions"]["text"])
+
+        try:
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, loss_info = compute_loss(
+                    params=params,
+                    model=model,
+                    sp=sp,
+                    batch=batch,
+                    is_training=True,
+                )
+            # summary stats
+            tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+            # NOTE: We use reduction==sum and loss is computed over utterances
+            # in the batch and there is no normalization to it so far.
+            scaler.scale(loss).backward()
+            set_batch_count(model, params.batch_idx_train)
+            scheduler.step_batch(params.batch_idx_train)
+
+            scaler.step(optimizer)
+            scaler.update()
+            optimizer.zero_grad()
+        except:  # noqa
+            display_and_save_batch(batch, params=params, sp=sp)
+            raise
+
+        if params.print_diagnostics and batch_idx == 5:
+            return
+
+        if (
+            rank == 0
+            and params.batch_idx_train > 0
+            and params.batch_idx_train % params.average_period == 0
+        ):
+            update_averaged_model(
+                params=params,
+                model_cur=model,
+                model_avg=model_avg,
+            )
+
+        if (
+            params.batch_idx_train > 0
+            and params.batch_idx_train % params.save_every_n == 0
+        ):
+            params.cur_batch_idx = batch_idx
+            save_checkpoint_with_global_batch_idx(
+                out_dir=params.exp_dir,
+                global_batch_idx=params.batch_idx_train,
+                model=model,
+                model_avg=model_avg,
+                params=params,
+                optimizer=optimizer,
+                scheduler=scheduler,
+                sampler=train_dl.sampler,
+                scaler=scaler,
+                rank=rank,
+            )
+            del params.cur_batch_idx
+            remove_checkpoints(
+                out_dir=params.exp_dir,
+                topk=params.keep_last_k,
+                rank=rank,
+            )
+
+        if batch_idx % 100 == 0 and params.use_fp16:
+            # If the grad scale was less than 1, try increasing it.    The _growth_interval
+            # of the grad scaler is configurable, but we can't configure it to have different
+            # behavior depending on the current grad scale.
+            cur_grad_scale = scaler._scale.item()
+            if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
+                scaler.update(cur_grad_scale * 2.0)
+            if cur_grad_scale < 0.01:
+                logging.warning(f"Grad scale is small: {cur_grad_scale}")
+            if cur_grad_scale < 1.0e-05:
+                raise RuntimeError(
+                    f"grad_scale is too small, exiting: {cur_grad_scale}"
+                )
+
+        if batch_idx % params.log_interval == 0:
+            cur_lr = scheduler.get_last_lr()[0]
+            cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+            logging.info(
+                f"Epoch {params.cur_epoch}, "
+                f"batch {batch_idx}, loss[{loss_info}], "
+                f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+                f"lr: {cur_lr:.2e}, "
+                + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+            )
+
+            if tb_writer is not None:
+                tb_writer.add_scalar(
+                    "train/learning_rate", cur_lr, params.batch_idx_train
+                )
+
+                loss_info.write_summary(
+                    tb_writer, "train/current_", params.batch_idx_train
+                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                if params.use_fp16:
+                    tb_writer.add_scalar(
+                        "train/grad_scale",
+                        cur_grad_scale,
+                        params.batch_idx_train,
+                    )
+
+        if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+            logging.info("Computing validation loss")
+            valid_info = compute_validation_loss(
+                params=params,
+                model=model,
+                sp=sp,
+                valid_dl=valid_dl,
+                world_size=world_size,
+            )
+            model.train()
+            logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+            logging.info(
+                f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+            )
+            if tb_writer is not None:
+                valid_info.write_summary(
+                    tb_writer, "train/valid_", params.batch_idx_train
+                )
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    params.train_loss = loss_value
+    if params.train_loss < params.best_train_loss:
+        params.best_train_epoch = params.cur_epoch
+        params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+    """
+    Args:
+      rank:
+        It is a value between 0 and `world_size-1`, which is
+        passed automatically by `mp.spawn()` in :func:`main`.
+        The node with rank 0 is responsible for saving checkpoint.
+      world_size:
+        Number of GPUs for DDP training.
+      args:
+        The return value of get_parser().parse_args()
+    """
+    params = get_params()
+    params.update(vars(args))
+    if params.full_libri is False:
+        params.valid_interval = 1600
+
+    fix_random_seed(params.seed)
+    if world_size > 1:
+        setup_dist(rank, world_size, params.master_port)
+
+    setup_logger(f"{params.exp_dir}/log/log-train")
+    logging.info("Training started")
+
+    if args.tensorboard and rank == 0:
+        tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+    else:
+        tb_writer = None
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", rank)
+    logging.info(f"Device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    assert params.save_every_n >= params.average_period
+    model_avg: Optional[nn.Module] = None
+    if rank == 0:
+        # model_avg is only used with rank 0
+        model_avg = copy.deepcopy(model).to(torch.float64)
+
+    assert params.start_epoch > 0, params.start_epoch
+    checkpoints = load_checkpoint_if_available(
+        params=params, model=model, model_avg=model_avg
+    )
+
+    model.to(device)
+    if world_size > 1:
+        logging.info("Using DDP")
+        model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+    optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0)
+
+    scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+    if checkpoints and "optimizer" in checkpoints:
+        logging.info("Loading optimizer state dict")
+        optimizer.load_state_dict(checkpoints["optimizer"])
+
+    if (
+        checkpoints
+        and "scheduler" in checkpoints
+        and checkpoints["scheduler"] is not None
+    ):
+        logging.info("Loading scheduler state dict")
+        scheduler.load_state_dict(checkpoints["scheduler"])
+
+    if params.print_diagnostics:
+        opts = diagnostics.TensorDiagnosticOptions(
+            2**22
+        )  # allow 4 megabytes per sub-module
+        diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+    if params.inf_check:
+        register_inf_check_hooks(model)
+
+    librispeech = LibriSpeechAsrDataModule(args)
+
+    train_cuts = librispeech.train_clean_100_cuts()
+    if params.full_libri:
+        train_cuts += librispeech.train_clean_360_cuts()
+        train_cuts += librispeech.train_other_500_cuts()
+
+    def remove_short_and_long_utt(c: Cut):
+        # Keep only utterances with duration between 1 second and 20 seconds
+        #
+        # Caution: There is a reason to select 20.0 here. Please see
+        # ../local/display_manifest_statistics.py
+        #
+        # You should use ../local/display_manifest_statistics.py to get
+        # an utterance duration distribution for your dataset to select
+        # the threshold
+        return 1.0 <= c.duration <= 20.0
+
+    train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+    if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+        # We only load the sampler's state dict when it loads a checkpoint
+        # saved in the middle of an epoch
+        sampler_state_dict = checkpoints["sampler"]
+    else:
+        sampler_state_dict = None
+
+    train_dl = librispeech.train_dataloaders(
+        train_cuts, sampler_state_dict=sampler_state_dict
+    )
+
+    valid_cuts = librispeech.dev_clean_cuts()
+    valid_cuts += librispeech.dev_other_cuts()
+    valid_dl = librispeech.valid_dataloaders(valid_cuts)
+
+    if not params.print_diagnostics:
+        scan_pessimistic_batches_for_oom(
+            model=model,
+            train_dl=train_dl,
+            optimizer=optimizer,
+            sp=sp,
+            params=params,
+        )
+
+    scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+    if checkpoints and "grad_scaler" in checkpoints:
+        logging.info("Loading grad scaler state dict")
+        scaler.load_state_dict(checkpoints["grad_scaler"])
+
+    for epoch in range(params.start_epoch, params.num_epochs + 1):
+        scheduler.step_epoch(epoch - 1)
+        fix_random_seed(params.seed + epoch - 1)
+        train_dl.sampler.set_epoch(epoch - 1)
+
+        if tb_writer is not None:
+            tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+        params.cur_epoch = epoch
+
+        train_one_epoch(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sp=sp,
+            train_dl=train_dl,
+            valid_dl=valid_dl,
+            scaler=scaler,
+            tb_writer=tb_writer,
+            world_size=world_size,
+            rank=rank,
+        )
+
+        if params.print_diagnostics:
+            diagnostic.print_diagnostics()
+            break
+
+        save_checkpoint(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sampler=train_dl.sampler,
+            scaler=scaler,
+            rank=rank,
+        )
+
+    logging.info("Done!")
+
+    if world_size > 1:
+        torch.distributed.barrier()
+        cleanup_dist()
+
+
+def display_and_save_batch(
+    batch: dict,
+    params: AttributeDict,
+    sp: spm.SentencePieceProcessor,
+) -> None:
+    """Display the batch statistics and save the batch into disk.
+
+    Args:
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      params:
+        Parameters for training. See :func:`get_params`.
+      sp:
+        The BPE model.
+    """
+    from lhotse.utils import uuid4
+
+    filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+    logging.info(f"Saving batch to {filename}")
+    torch.save(batch, filename)
+
+    supervisions = batch["supervisions"]
+    features = batch["inputs"]
+
+    logging.info(f"features shape: {features.shape}")
+
+    y = sp.encode(supervisions["text"], out_type=int)
+    num_tokens = sum(len(i) for i in y)
+    logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+    model: Union[nn.Module, DDP],
+    train_dl: torch.utils.data.DataLoader,
+    optimizer: torch.optim.Optimizer,
+    sp: spm.SentencePieceProcessor,
+    params: AttributeDict,
+):
+    from lhotse.dataset import find_pessimistic_batches
+
+    logging.info(
+        "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+    )
+    batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+    for criterion, cuts in batches.items():
+        batch = train_dl.dataset[cuts]
+        try:
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, _ = compute_loss(
+                    params=params,
+                    model=model,
+                    sp=sp,
+                    batch=batch,
+                    is_training=True,
+                )
+            loss.backward()
+            optimizer.zero_grad()
+        except Exception as e:
+            if "CUDA out of memory" in str(e):
+                logging.error(
+                    "Your GPU ran out of memory with the current "
+                    "max_duration setting. We recommend decreasing "
+                    "max_duration and trying again.\n"
+                    f"Failing criterion: {criterion} "
+                    f"(={crit_values[criterion]}) ..."
+                )
+            display_and_save_batch(batch, params=params, sp=sp)
+            raise
+        logging.info(
+            f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+        )
+
+
+def main():
+    parser = get_parser()
+    LibriSpeechAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    world_size = args.world_size
+    assert world_size >= 1
+    if world_size > 1:
+        mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+    else:
+        run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/zipformer.py
new file mode 120000
index 000000000..79b076556
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/zipformer.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7/zipformer.py
\ No newline at end of file
diff --git a/icefall/utils.py b/icefall/utils.py
index d852491c8..99e51a2a9 100644
--- a/icefall/utils.py
+++ b/icefall/utils.py
@@ -175,11 +175,13 @@ class AttributeDict(dict):
 
 
 def encode_supervisions(
-    supervisions: dict, subsampling_factor: int
-) -> Tuple[torch.Tensor, List[str]]:
+    supervisions: dict,
+    subsampling_factor: int,
+    token_ids: Optional[List[List[int]]] = None,
+) -> Tuple[torch.Tensor, Union[List[str], List[List[int]]]]:
     """
     Encodes Lhotse's ``batch["supervisions"]`` dict into
-    a pair of torch Tensor, and a list of transcription strings.
+    a pair of torch Tensor, and a list of transcription strings or token indexes
 
     The supervision tensor has shape ``(batch_size, 3)``.
     Its second dimension contains information about sequence index [0],
@@ -208,10 +210,14 @@ def encode_supervisions(
 
     indices = torch.argsort(supervision_segments[:, 2], descending=True)
     supervision_segments = supervision_segments[indices]
-    texts = supervisions["text"]
-    texts = [texts[idx] for idx in indices]
 
-    return supervision_segments, texts
+    if token_ids is None:
+        texts = supervisions["text"]
+        res = [texts[idx] for idx in indices]
+    else:
+        res = [token_ids[idx] for idx in indices]
+
+    return supervision_segments, res
 
 
 def get_texts(

From e6a67270128f607f49c81327190aca63bb3bb4eb Mon Sep 17 00:00:00 2001
From: Senyan Li <1149593720@qq.com>
Date: Sat, 3 Dec 2022 23:50:49 +0800
Subject: [PATCH 060/120] Add Tibetan Amdo dialect xbmu_amdo31 in egs (#706)

* add egs/xbmu_amdo31

* fix xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py

* fix xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py

* fix xbmu_amdo31/ASR/prepare.sh

* add RESULTS.md and README.md

* dix pruned_transducer_stateless5 decode.py

* add transducer stateless7

* fix transducer_stateless7

* fix RESULTS.md error

* Add pruned_transducer_stateless7 validation set results
---
 egs/xbmu_amdo31/ASR/README.md                 |   16 +
 egs/xbmu_amdo31/ASR/RESULTS.md                |   92 ++
 egs/xbmu_amdo31/ASR/local/compile_hlg.py      |    1 +
 egs/xbmu_amdo31/ASR/local/compile_lg.py       |    1 +
 .../ASR/local/compute_fbank_musan.py          |    1 +
 .../ASR/local/compute_fbank_xbmu_amdo31.py    |  130 ++
 .../convert_transcript_words_to_tokens.py     |    1 +
 egs/xbmu_amdo31/ASR/local/filter_cuts.py      |    1 +
 .../ASR/local/generate_unique_lexicon.py      |    1 +
 egs/xbmu_amdo31/ASR/local/prepare_lang.py     |    1 +
 egs/xbmu_amdo31/ASR/local/prepare_lang_bpe.py |    1 +
 .../ASR/local/prepare_lm_training_data.py     |    1 +
 .../ASR/local/sort_lm_training_data.py        |    1 +
 egs/xbmu_amdo31/ASR/local/train_bpe_model.py  |    1 +
 .../ASR/local/validate_bpe_lexicon.py         |    1 +
 egs/xbmu_amdo31/ASR/prepare.sh                |  357 +++++
 .../pruned_transducer_stateless5/__init__.py  |    0
 .../asr_datamodule.py                         |  408 ++++++
 .../beam_search.py                            |    1 +
 .../pruned_transducer_stateless5/conformer.py |    1 +
 .../pruned_transducer_stateless5/decode.py    |  970 +++++++++++++
 .../decode_stream.py                          |    1 +
 .../pruned_transducer_stateless5/decoder.py   |    1 +
 .../encoder_interface.py                      |    1 +
 .../pruned_transducer_stateless5/export.py    |  287 ++++
 .../pruned_transducer_stateless5/joiner.py    |    1 +
 .../ASR/pruned_transducer_stateless5/lstmp.py |    1 +
 .../ASR/pruned_transducer_stateless5/model.py |    1 +
 .../ASR/pruned_transducer_stateless5/optim.py |    1 +
 .../pretrained.py                             |  344 +++++
 .../pruned_transducer_stateless5/scaling.py   |    1 +
 .../scaling_converter.py                      |    1 +
 .../streaming_beam_search.py                  |    1 +
 .../streaming_decode.py                       |    1 +
 .../test_model.py                             |   65 +
 .../ASR/pruned_transducer_stateless5/train.py | 1187 ++++++++++++++++
 .../pruned_transducer_stateless7/__init__.py  |    0
 .../asr_datamodule.py                         |    1 +
 .../beam_search.py                            |    1 +
 .../pruned_transducer_stateless7/decode.py    |  843 ++++++++++++
 .../pruned_transducer_stateless7/decoder.py   |    1 +
 .../encoder_interface.py                      |    1 +
 .../pruned_transducer_stateless7/export.py    |    1 +
 .../jit_pretrained.py                         |    1 +
 .../pruned_transducer_stateless7/joiner.py    |    1 +
 .../ASR/pruned_transducer_stateless7/model.py |    1 +
 .../ASR/pruned_transducer_stateless7/optim.py |    1 +
 .../pretrained.py                             |  355 +++++
 .../pruned_transducer_stateless7/scaling.py   |    1 +
 .../scaling_converter.py                      |    1 +
 .../test_model.py                             |    1 +
 .../ASR/pruned_transducer_stateless7/train.py | 1224 +++++++++++++++++
 .../pruned_transducer_stateless7/zipformer.py |    1 +
 egs/xbmu_amdo31/ASR/shared                    |    1 +
 54 files changed, 6317 insertions(+)
 create mode 100644 egs/xbmu_amdo31/ASR/README.md
 create mode 100644 egs/xbmu_amdo31/ASR/RESULTS.md
 create mode 120000 egs/xbmu_amdo31/ASR/local/compile_hlg.py
 create mode 120000 egs/xbmu_amdo31/ASR/local/compile_lg.py
 create mode 120000 egs/xbmu_amdo31/ASR/local/compute_fbank_musan.py
 create mode 100755 egs/xbmu_amdo31/ASR/local/compute_fbank_xbmu_amdo31.py
 create mode 120000 egs/xbmu_amdo31/ASR/local/convert_transcript_words_to_tokens.py
 create mode 120000 egs/xbmu_amdo31/ASR/local/filter_cuts.py
 create mode 120000 egs/xbmu_amdo31/ASR/local/generate_unique_lexicon.py
 create mode 120000 egs/xbmu_amdo31/ASR/local/prepare_lang.py
 create mode 120000 egs/xbmu_amdo31/ASR/local/prepare_lang_bpe.py
 create mode 120000 egs/xbmu_amdo31/ASR/local/prepare_lm_training_data.py
 create mode 120000 egs/xbmu_amdo31/ASR/local/sort_lm_training_data.py
 create mode 120000 egs/xbmu_amdo31/ASR/local/train_bpe_model.py
 create mode 120000 egs/xbmu_amdo31/ASR/local/validate_bpe_lexicon.py
 create mode 100755 egs/xbmu_amdo31/ASR/prepare.sh
 create mode 100644 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/__init__.py
 create mode 100644 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/beam_search.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/conformer.py
 create mode 100755 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode_stream.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decoder.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/encoder_interface.py
 create mode 100755 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/export.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/joiner.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/lstmp.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/model.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/optim.py
 create mode 100755 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/pretrained.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/scaling.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/scaling_converter.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/streaming_beam_search.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/streaming_decode.py
 create mode 100755 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/test_model.py
 create mode 100755 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py
 create mode 100644 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/__init__.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/asr_datamodule.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/beam_search.py
 create mode 100755 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decoder.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/encoder_interface.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/export.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/jit_pretrained.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/joiner.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/model.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/optim.py
 create mode 100755 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/pretrained.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/scaling.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/scaling_converter.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/test_model.py
 create mode 100755 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/zipformer.py
 create mode 120000 egs/xbmu_amdo31/ASR/shared

diff --git a/egs/xbmu_amdo31/ASR/README.md b/egs/xbmu_amdo31/ASR/README.md
new file mode 100644
index 000000000..0a441d070
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/README.md
@@ -0,0 +1,16 @@
+# Introduction
+About the XBMU-AMDO31 corpus
+XBMU-AMDO31 is an open-source Amdo Tibetan speech corpus published by Northwest Minzu University.
+publicly available on https://huggingface.co/datasets/syzym/xbmu_amdo31
+
+XBMU-AMDO31 dataset is a speech recognition corpus of Amdo Tibetan dialect. 
+The open source corpus contains 31 hours of speech data and resources related 
+to build speech recognition systems,including transcribed texts and a Tibetan 
+pronunciation lexicon.
+(The lexicon is a Tibetan lexicon of the Lhasa dialect, which has been reused 
+for the Amdo dialect because of the uniformity of the Tibetan language)
+The dataset can be used to train a model for Amdo Tibetan Automatic Speech Recognition (ASR). 
+
+This recipe includes some different ASR models trained with XBMU-AMDO31.
+
+[./RESULTS.md](./RESULTS.md) contains the latest results.
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/RESULTS.md b/egs/xbmu_amdo31/ASR/RESULTS.md
new file mode 100644
index 000000000..1bd9b2e2b
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/RESULTS.md
@@ -0,0 +1,92 @@
+## Results
+
+### XBMU-AMDO31 BPE training result (Stateless Transducer)
+
+#### Pruned transducer stateless 5
+
+[./pruned_transducer_stateless5](./pruned_transducer_stateless5)
+
+It uses pruned RNN-T.
+
+A pre-trained model and decoding logs can be found at 
+
+You can use  to deploy it.
+
+Number of model parameters: 87801200, i.e., 87.8 M
+
+|                        | test | dev  | comment                               |
+|------------------------|------|------|---------------------------------------|
+| greedy search          | 11.06| 11.73| --epoch 28 --avg 23 --max-duration 600|
+| beam search            | 10.64| 11.42| --epoch 28 --avg 23 --max-duration 600|
+| modified beam search   | 10.57| 11.24| --epoch 28 --avg 23 --max-duration 600|
+
+
+Training command is:
+
+```bash
+cd egs/xbmu_amdo31/ASR
+./prepare.sh
+
+export CUDA_VISIBLE_DEVICES="0"
+
+./pruned_transducer_stateless5/train.py
+```
+
+**Caution**: It uses `--context-size=1`.
+
+
+The decoding command is:
+```bash
+for method in greedy_search beam_search modified_beam_search;
+do
+./pruned_transducer_stateless5/decode.py \
+    --epoch 28 \
+    --avg 23 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 600 \
+    --decoding-method $method
+done
+```
+
+### pruned_transducer_stateless7 (zipformer)
+
+See  for more details.
+
+[pruned_transducer_stateless7](./pruned_transducer_stateless7)
+
+You can find a pretrained model, training logs, decoding logs, and decoding
+results at:
+
+
+You can use  to deploy it.
+
+Number of model parameters: 70369391, i.e., 70.37 M
+
+|                      | test | dev  | comment                                |
+|----------------------|------|------|----------------------------------------|
+| greedy search        | 10.06| 10.59| --epoch 23 --avg 11 --max-duration 600 |
+| beam search          | 9.77 | 10.11| --epoch 23 --avg 11 --max-duration 600 |
+| modified beam search | 9.7  | 10.12| --epoch 23 --avg 11 --max-duration 600 |
+
+The training commands are:
+```bash
+export CUDA_VISIBLE_DEVICES="0"
+
+./pruned_transducer_stateless7/train.py
+```
+
+The decoding commands are:
+```bash
+for m in greedy_search beam_search modified_beam_search; do
+  for epoch in 23; do
+    for avg in 11; do
+      ./pruned_transducer_stateless7/decode.py \
+          --epoch $epoch \
+          --avg $avg \
+          --exp-dir ./pruned_transducer_stateless7/exp \
+          --max-duration 600 \
+          --decoding-method $m
+    done
+  done
+done
+```
diff --git a/egs/xbmu_amdo31/ASR/local/compile_hlg.py b/egs/xbmu_amdo31/ASR/local/compile_hlg.py
new file mode 120000
index 000000000..471aa7fb4
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/local/compile_hlg.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compile_hlg.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/local/compile_lg.py b/egs/xbmu_amdo31/ASR/local/compile_lg.py
new file mode 120000
index 000000000..462d6d3fb
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/local/compile_lg.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compile_lg.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/local/compute_fbank_musan.py b/egs/xbmu_amdo31/ASR/local/compute_fbank_musan.py
new file mode 120000
index 000000000..5833f2484
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/local/compute_fbank_musan.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compute_fbank_musan.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/local/compute_fbank_xbmu_amdo31.py b/egs/xbmu_amdo31/ASR/local/compute_fbank_xbmu_amdo31.py
new file mode 100755
index 000000000..a593e7be3
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/local/compute_fbank_xbmu_amdo31.py
@@ -0,0 +1,130 @@
+#!/usr/bin/env python3
+# Copyright    2021  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of the XBMU-AMDO31 dataset.
+It looks for manifests in the directory data/manifests.
+
+The generated fbank features are saved in data/fbank.
+"""
+
+import argparse
+import logging
+import os
+from pathlib import Path
+from typing import Optional
+
+import sentencepiece as spm
+import torch
+from filter_cuts import filter_cuts
+from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
+from lhotse.recipes.utils import read_manifests_if_cached
+
+from icefall.utils import get_executor
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def get_args():
+    parser = argparse.ArgumentParser()
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        help="""Path to the bpe.model. If not None, we will remove short and
+        long utterances before extracting features""",
+    )
+    return parser.parse_args()
+
+
+def compute_fbank_xbmu_amdo31(bpe_model: Optional[str] = None):
+    src_dir = Path("data/manifests")
+    output_dir = Path("data/fbank")
+    num_jobs = min(15, os.cpu_count())
+    num_mel_bins = 80
+
+    if bpe_model:
+        logging.info(f"Loading {bpe_model}")
+        sp = spm.SentencePieceProcessor()
+        sp.load(bpe_model)
+
+    dataset_parts = (
+        "train",
+        "dev",
+        "test",
+    )
+    prefix = "xbmu_amdo31"
+    suffix = "jsonl.gz"
+    manifests = read_manifests_if_cached(
+        dataset_parts=dataset_parts,
+        output_dir=src_dir,
+        prefix=prefix,
+        suffix=suffix,
+    )
+    assert manifests is not None
+
+    assert len(manifests) == len(dataset_parts), (
+        len(manifests),
+        len(dataset_parts),
+        list(manifests.keys()),
+        dataset_parts,
+    )
+
+    extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
+
+    with get_executor() as ex:  # Initialize the executor only once.
+        for partition, m in manifests.items():
+            cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
+            if (output_dir / cuts_filename).is_file():
+                logging.info(f"{partition} already exists - skipping.")
+                continue
+            logging.info(f"Processing {partition}")
+            cut_set = CutSet.from_manifests(
+                recordings=m["recordings"],
+                supervisions=m["supervisions"],
+            )
+            if bpe_model:
+                cut_set = filter_cuts(cut_set, sp)
+
+            if "train" in partition:
+                cut_set = (
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+                )
+            cut_set = cut_set.compute_and_store_features(
+                extractor=extractor,
+                storage_path=f"{output_dir}/{prefix}_feats_{partition}",
+                # when an executor is specified, make more partitions
+                num_jobs=num_jobs if ex is None else 80,
+                executor=ex,
+                storage_type=LilcomChunkyWriter,
+            )
+            cut_set.to_file(output_dir / cuts_filename)
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    args = get_args()
+    logging.info(vars(args))
+    compute_fbank_xbmu_amdo31(bpe_model=args.bpe_model)
diff --git a/egs/xbmu_amdo31/ASR/local/convert_transcript_words_to_tokens.py b/egs/xbmu_amdo31/ASR/local/convert_transcript_words_to_tokens.py
new file mode 120000
index 000000000..2ce13fd69
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/local/convert_transcript_words_to_tokens.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/convert_transcript_words_to_tokens.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/local/filter_cuts.py b/egs/xbmu_amdo31/ASR/local/filter_cuts.py
new file mode 120000
index 000000000..27aca1729
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/local/filter_cuts.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/filter_cuts.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/local/generate_unique_lexicon.py b/egs/xbmu_amdo31/ASR/local/generate_unique_lexicon.py
new file mode 120000
index 000000000..c0aea1403
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/local/generate_unique_lexicon.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/generate_unique_lexicon.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/local/prepare_lang.py b/egs/xbmu_amdo31/ASR/local/prepare_lang.py
new file mode 120000
index 000000000..747f2ab39
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/local/prepare_lang.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/prepare_lang.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/local/prepare_lang_bpe.py b/egs/xbmu_amdo31/ASR/local/prepare_lang_bpe.py
new file mode 120000
index 000000000..36b40e7fc
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/local/prepare_lang_bpe.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/prepare_lang_bpe.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/local/prepare_lm_training_data.py b/egs/xbmu_amdo31/ASR/local/prepare_lm_training_data.py
new file mode 120000
index 000000000..abc00d421
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/local/prepare_lm_training_data.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/prepare_lm_training_data.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/local/sort_lm_training_data.py b/egs/xbmu_amdo31/ASR/local/sort_lm_training_data.py
new file mode 120000
index 000000000..1d6ccbe33
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/local/sort_lm_training_data.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/sort_lm_training_data.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/local/train_bpe_model.py b/egs/xbmu_amdo31/ASR/local/train_bpe_model.py
new file mode 120000
index 000000000..6fad36421
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/local/train_bpe_model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/train_bpe_model.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/local/validate_bpe_lexicon.py b/egs/xbmu_amdo31/ASR/local/validate_bpe_lexicon.py
new file mode 120000
index 000000000..721bb48e7
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/local/validate_bpe_lexicon.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/validate_bpe_lexicon.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/prepare.sh b/egs/xbmu_amdo31/ASR/prepare.sh
new file mode 100755
index 000000000..32ae440f7
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/prepare.sh
@@ -0,0 +1,357 @@
+#!/usr/bin/env bash
+
+set -eou pipefail
+
+nj=15
+stage=-1
+stop_stage=100
+
+# We assume dl_dir (download dir) contains the following
+# directories and files. If not, they will be downloaded
+# by this script automatically.
+#
+#  - $dl_dir/xbmu_amdo31
+#      You can find data, resource, etc, inside it.
+#      You can download them from https://huggingface.co/datasets/syzym/xbmu_amdo31
+#
+#  - $dl_dir/lm
+#      This directory contains the following files downloaded from
+#       git lfs install
+#       https://huggingface.co/syzym/xbmu_amdo31_lm
+#
+#        - tibetan.3-gram.arpa
+#        - tibetan.4-gram.arpa
+#
+#  - $dl_dir/musan
+#      This directory contains the following directories downloaded from
+#       http://www.openslr.org/17/
+#
+#     - music
+#     - noise
+#     - speech
+
+dl_dir=$PWD/download
+
+. shared/parse_options.sh || exit 1
+
+# vocab size for sentence piece models.
+# It will generate data/lang_bpe_xxx,
+# data/lang_bpe_yyy if the array contains xxx, yyy
+vocab_sizes=(
+  1000
+  500
+)
+
+# All files generated by this script are saved in "data".
+# You can safely remove "data" and rerun this script to regenerate it.
+mkdir -p data
+
+log() {
+  # This function is from espnet
+  local fname=${BASH_SOURCE[1]##*/}
+  echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "dl_dir: $dl_dir"
+
+if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
+  log "stage -1: Download LM"
+  # We assume that you have installed the git-lfs, if not, you could install it
+  # using: `sudo apt-get install git-lfs && git-lfs install`
+  git lfs 1>/dev/null 2>&1 || (echo "please install git-lfs, consider using: sudo apt-get install git-lfs && git-lfs install" && exit 1)
+
+  if [ ! -f $dl_dir/lm/3-gram.unpruned.arpa ]; then
+    git clone https://huggingface.co/syzym/xbmu_amdo31_lm $dl_dir/lm
+    pushd $dl_dir/lm
+    git lfs pull --include "tibetan.3-gram.arpa"
+    git lfs pull --include "tibetan.4-gram.arpa"
+    popd
+  fi
+fi
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+  log "Stage 0: Download data"
+
+  # If you have pre-downloaded it to /path/to/xbmu_amdo31,
+  # you can create a symlink
+  #
+  #   ln -sfv /path/to/xbmu_amdo31 $dl_dir/xbmu_amdo31
+  #
+  
+  if [ ! -f $dl_dir/xbmu_amdo31 ]; then
+    git lfs 1>/dev/null 2>&1 || (echo "please install git-lfs, consider using: sudo apt-get install git-lfs && git-lfs install" && exit 1)
+    lhotse download xbmu-amdo31 $dl_dir
+  fi
+
+  # If you have pre-downloaded it to /path/to/musan,
+  # you can create a symlink
+  #
+  #   ln -sfv /path/to/musan $dl_dir/
+  #
+  if [ ! -d $dl_dir/musan ]; then
+    lhotse download musan $dl_dir
+  fi
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+  log "Stage 1: Prepare xbmu_amdo31 manifest"
+  # We assume that you have downloaded the xbmu_amdo31 corpus
+  # to $dl_dir/xbmu_amdo31
+  if [ ! -f data/manifests/.xbmu_amdo31_manifests.done ]; then
+    mkdir -p data/manifests
+    lhotse prepare xbmu-amdo31 $dl_dir/xbmu_amdo31 data/manifests
+    touch data/manifests/.xbmu_amdo31_manifests.done
+  fi
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+  log "Stage 2: Prepare musan manifest"
+  # We assume that you have downloaded the musan corpus
+  # to data/musan
+  if [ ! -f data/manifests/.musan_manifests.done ]; then
+    log "It may take 6 minutes"
+    mkdir -p data/manifests
+    lhotse prepare musan $dl_dir/musan data/manifests
+    touch data/manifests/.musan_manifests.done
+  fi
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+  log "Stage 3: Compute fbank for xbmu_amdo31"
+  if [ ! -f data/fbank/.xbmu_amdo31.done ]; then
+    mkdir -p data/fbank
+    ./local/compute_fbank_xbmu_amdo31.py
+    touch data/fbank/.xbmu_amdo31.done
+  fi
+fi
+
+
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+  log "Stage 4: Compute fbank for musan"
+  if [ ! -f data/fbank/.msuan.done ]; then
+    mkdir -p data/fbank
+    ./local/compute_fbank_musan.py
+    touch data/fbank/.msuan.done
+  fi
+fi
+
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+  log "Stage 5: Prepare phone based lang"
+  lang_dir=data/lang_phone
+  mkdir -p $lang_dir
+
+  (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) |
+    cat - $dl_dir/xbmu_amdo31/resource/lexicon.txt |
+    sort | uniq > $lang_dir/lexicon.txt
+
+  ./local/generate_unique_lexicon.py --lang-dir $lang_dir
+
+  if [ ! -f $lang_dir/L_disambig.pt ]; then
+    ./local/prepare_lang.py --lang-dir $lang_dir
+  fi
+fi
+
+
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+  log "Stage 6: Prepare BPE based lang"
+
+  for vocab_size in ${vocab_sizes[@]}; do
+    lang_dir=data/lang_bpe_${vocab_size}
+    mkdir -p $lang_dir
+    # We reuse words.txt from phone based lexicon
+    # so that the two can share G.pt later.
+    cp data/lang_phone/words.txt $lang_dir
+
+  if [ ! -f $lang_dir/transcript_words.txt ]; then
+    log "Generate data to train phone based bigram P"
+    xbmu_amdo31_text=$dl_dir/xbmu_amdo31/data/transcript/transcript_clean.txt
+    xbmu_amdo31_train_uid=$dl_dir/xbmu_amdo31/data/transcript/xbmu_amdo31_train_uid
+    find $dl_dir/xbmu_amdo31/data/wav/train -name "*.wav" | sed 's/\.wav//g' | awk -F '-' '{print $NF}' > $xbmu_amdo31_train_uid
+    awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $xbmu_amdo31_train_uid $xbmu_amdo31_text |
+	    cut -d " " -f 2- > $lang_dir/transcript_words.txt
+  fi
+
+    if [ ! -f $lang_dir/bpe.model ]; then
+      ./local/train_bpe_model.py \
+        --lang-dir $lang_dir \
+        --vocab-size $vocab_size \
+        --transcript $lang_dir/transcript_words.txt
+    fi
+
+    if [ ! -f $lang_dir/L_disambig.pt ]; then
+      ./local/prepare_lang_bpe.py --lang-dir $lang_dir
+
+      log "Validating $lang_dir/lexicon.txt"
+      ./local/validate_bpe_lexicon.py \
+        --lexicon $lang_dir/lexicon.txt \
+        --bpe-model $lang_dir/bpe.model
+    fi
+  done
+fi
+
+if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
+  log "Stage 7: Prepare bigram P"
+
+  for vocab_size in ${vocab_sizes[@]}; do
+    lang_dir=data/lang_bpe_${vocab_size}
+
+    if [ ! -f $lang_dir/transcript_tokens.txt ]; then
+      ./local/convert_transcript_words_to_tokens.py \
+        --lexicon $lang_dir/lexicon.txt \
+        --transcript $lang_dir/transcript_words.txt \
+        --oov "" \
+        > $lang_dir/transcript_tokens.txt
+    fi
+
+    if [ ! -f $lang_dir/P.arpa ]; then
+      ./shared/make_kn_lm.py \
+        -ngram-order 2 \
+        -text $lang_dir/transcript_tokens.txt \
+        -lm $lang_dir/P.arpa
+    fi
+
+    if [ ! -f $lang_dir/P.fst.txt ]; then
+      python3 -m kaldilm \
+        --read-symbol-table="$lang_dir/tokens.txt" \
+        --disambig-symbol='#0' \
+        --max-order=2 \
+        $lang_dir/P.arpa > $lang_dir/P.fst.txt
+    fi
+  done
+fi
+
+if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
+  log "Stage 8: Prepare G"
+  # We assume you have install kaldilm, if not, please install
+  # it using: pip install kaldilm
+
+  mkdir -p data/lm
+  if [ ! -f data/lm/G_3_gram.fst.txt ]; then
+    # It is used in building HLG
+    python3 -m kaldilm \
+      --read-symbol-table="data/lang_phone/words.txt" \
+      --disambig-symbol='#0' \
+      --max-order=3 \
+      $dl_dir/lm/tibetan.3-gram.arpa > data/lm/G_3_gram.fst.txt
+  fi
+
+  if [ ! -f data/lm/G_4_gram.fst.txt ]; then
+    # It is used for LM rescoring
+    python3 -m kaldilm \
+      --read-symbol-table="data/lang_phone/words.txt" \
+      --disambig-symbol='#0' \
+      --max-order=4 \
+      $dl_dir/lm/tibetan.4-gram.arpa > data/lm/G_4_gram.fst.txt
+  fi
+fi
+
+if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
+  log "Stage 9: Compile HLG"
+  ./local/compile_hlg.py --lang-dir data/lang_phone
+
+  for vocab_size in ${vocab_sizes[@]}; do
+    lang_dir=data/lang_bpe_${vocab_size}
+    ./local/compile_hlg.py --lang-dir $lang_dir
+  done
+fi
+
+# Compile LG for RNN-T fast_beam_search decoding
+if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
+  log "Stage 10: Compile LG"
+  ./local/compile_lg.py --lang-dir data/lang_phone
+
+  for vocab_size in ${vocab_sizes[@]}; do
+    lang_dir=data/lang_bpe_${vocab_size}
+    ./local/compile_lg.py --lang-dir $lang_dir
+  done
+fi
+
+if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
+  log "Stage 11: Generate LM training data"
+
+  for vocab_size in ${vocab_sizes[@]}; do
+    log "Processing vocab_size == ${vocab_size}"
+    lang_dir=data/lang_bpe_${vocab_size}
+    out_dir=data/lm_training_bpe_${vocab_size}
+    mkdir -p $out_dir
+
+    ./local/prepare_lm_training_data.py \
+      --bpe-model $lang_dir/bpe.model \
+      --lm-data $dl_dir/lm/lm_train.txt \
+      --lm-archive $out_dir/lm_data.pt
+  done
+fi
+
+if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
+  log "Stage 12: Generate LM validation data"
+
+  for vocab_size in ${vocab_sizes[@]}; do
+    log "Processing vocab_size == ${vocab_size}"
+    out_dir=data/lm_training_bpe_${vocab_size}
+    mkdir -p $out_dir
+
+    if [ ! -f $out_dir/valid.txt ]; then
+      files=$dl_dir/xbmu_amdo31/data/transcript/dev_text
+      for f in ${files[@]}; do
+        cat $f | cut -d " " -f 2-
+      done > $out_dir/valid.txt
+    fi
+
+    lang_dir=data/lang_bpe_${vocab_size}
+    ./local/prepare_lm_training_data.py \
+      --bpe-model $lang_dir/bpe.model \
+      --lm-data $out_dir/valid.txt \
+      --lm-archive $out_dir/lm_data-valid.pt
+  done
+fi
+
+if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then
+  log "Stage 13: Generate LM test data"
+
+  for vocab_size in ${vocab_sizes[@]}; do
+    log "Processing vocab_size == ${vocab_size}"
+    out_dir=data/lm_training_bpe_${vocab_size}
+    mkdir -p $out_dir
+
+    if [ ! -f $out_dir/test.txt ]; then
+        files=$dl_dir/xbmu_amdo31/data/transcript/test_text
+        cat $f | cut -d " " -f 2- > $out_dir/test.txt
+    fi
+
+    lang_dir=data/lang_bpe_${vocab_size}
+    ./local/prepare_lm_training_data.py \
+      --bpe-model $lang_dir/bpe.model \
+      --lm-data $out_dir/test.txt \
+      --lm-archive $out_dir/lm_data-test.pt
+  done
+fi
+
+if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
+  log "Stage 14: Sort LM training data"
+  # Sort LM training data by sentence length in descending order
+  # for ease of training.
+  #
+  # Sentence length equals to the number of BPE tokens
+  # in a sentence.
+
+  for vocab_size in ${vocab_sizes[@]}; do
+    out_dir=data/lm_training_bpe_${vocab_size}
+    mkdir -p $out_dir
+    ./local/sort_lm_training_data.py \
+      --in-lm-data $out_dir/lm_data.pt \
+      --out-lm-data $out_dir/sorted_lm_data.pt \
+      --out-statistics $out_dir/statistics.txt
+
+    ./local/sort_lm_training_data.py \
+      --in-lm-data $out_dir/lm_data-valid.pt \
+      --out-lm-data $out_dir/sorted_lm_data-valid.pt \
+      --out-statistics $out_dir/statistics-valid.txt
+
+    ./local/sort_lm_training_data.py \
+      --in-lm-data $out_dir/lm_data-test.pt \
+      --out-lm-data $out_dir/sorted_lm_data-test.pt \
+      --out-statistics $out_dir/statistics-test.txt
+  done
+fi
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/__init__.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py
new file mode 100644
index 000000000..55d5f4636
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py
@@ -0,0 +1,408 @@
+# Copyright      2021  Piotr Żelasko
+# Copyright      2022  Xiaomi Corporation     (Author: Mingshuang Luo)
+# Copyright      2022  Northwest Minzu University     (Author: Senyan Li)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import inspect
+import logging
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+import torch
+from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
+from lhotse.dataset import CutConcatenate  # noqa F401 for PrecomputedFeatures
+from lhotse.dataset import (
+    CutMix,
+    DynamicBucketingSampler,
+    K2SpeechRecognitionDataset,
+    PrecomputedFeatures,
+    SingleCutSampler,
+    SpecAugment,
+)
+from lhotse.dataset.input_strategies import AudioSamples  # noqa F401 For AudioSamples
+from lhotse.dataset.input_strategies import OnTheFlyFeatures
+from lhotse.utils import fix_random_seed
+from torch.utils.data import DataLoader
+
+from icefall.utils import str2bool
+
+
+class _SeedWorkers:
+    def __init__(self, seed: int):
+        self.seed = seed
+
+    def __call__(self, worker_id: int):
+        fix_random_seed(self.seed + worker_id)
+
+
+class Xbmu_AmdoAsrDataModule:
+    """
+    DataModule for k2 ASR experiments.
+    It assumes there is always one train and valid dataloader,
+    but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
+    and test-other).
+
+    It contains all the common data pipeline modules used in ASR
+    experiments, e.g.:
+    - dynamic batch size,
+    - bucketing samplers,
+    - cut concatenation,
+    - augmentation,
+    - on-the-fly feature extraction
+
+    This class should be derived for specific corpora used in ASR tasks.
+    """
+
+    def __init__(self, args: argparse.Namespace):
+        self.args = args
+
+    @classmethod
+    def add_arguments(cls, parser: argparse.ArgumentParser):
+        group = parser.add_argument_group(
+            title="ASR data related options",
+            description="These options are used for the preparation of "
+            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+            "effective batch sizes, sampling strategies, applied data "
+            "augmentations, etc.",
+        )
+        group.add_argument(
+            "--manifest-dir",
+            type=Path,
+            default=Path("data/fbank"),
+            help="Path to directory with train/valid/test cuts.",
+        )
+        group.add_argument(
+            "--max-duration",
+            type=int,
+            default=200.0,
+            help="Maximum pooled recordings duration (seconds) in a "
+            "single batch. You can reduce it if it causes CUDA OOM.",
+        )
+        group.add_argument(
+            "--bucketing-sampler",
+            type=str2bool,
+            default=True,
+            help="When enabled, the batches will come from buckets of "
+            "similar duration (saves padding frames).",
+        )
+        group.add_argument(
+            "--num-buckets",
+            type=int,
+            default=30,
+            help="The number of buckets for the DynamicBucketingSampler"
+            "(you might want to increase it for larger datasets).",
+        )
+        group.add_argument(
+            "--concatenate-cuts",
+            type=str2bool,
+            default=False,
+            help="When enabled, utterances (cuts) will be concatenated "
+            "to minimize the amount of padding.",
+        )
+        group.add_argument(
+            "--duration-factor",
+            type=float,
+            default=1.0,
+            help="Determines the maximum duration of a concatenated cut "
+            "relative to the duration of the longest cut in a batch.",
+        )
+        group.add_argument(
+            "--gap",
+            type=float,
+            default=1.0,
+            help="The amount of padding (in seconds) inserted between "
+            "concatenated cuts. This padding is filled with noise when "
+            "noise augmentation is used.",
+        )
+        group.add_argument(
+            "--on-the-fly-feats",
+            type=str2bool,
+            default=False,
+            help="When enabled, use on-the-fly cut mixing and feature "
+            "extraction. Will drop existing precomputed feature manifests "
+            "if available.",
+        )
+        group.add_argument(
+            "--shuffle",
+            type=str2bool,
+            default=True,
+            help="When enabled (=default), the examples will be "
+            "shuffled for each epoch.",
+        )
+        group.add_argument(
+            "--drop-last",
+            type=str2bool,
+            default=True,
+            help="Whether to drop last batch. Used by sampler.",
+        )
+        group.add_argument(
+            "--return-cuts",
+            type=str2bool,
+            default=True,
+            help="When enabled, each batch will have the "
+            "field: batch['supervisions']['cut'] with the cuts that "
+            "were used to construct it.",
+        )
+
+        group.add_argument(
+            "--num-workers",
+            type=int,
+            default=2,
+            help="The number of training dataloader workers that "
+            "collect the batches.",
+        )
+
+        group.add_argument(
+            "--enable-spec-aug",
+            type=str2bool,
+            default=True,
+            help="When enabled, use SpecAugment for training dataset.",
+        )
+
+        group.add_argument(
+            "--spec-aug-time-warp-factor",
+            type=int,
+            default=80,
+            help="Used only when --enable-spec-aug is True. "
+            "It specifies the factor for time warping in SpecAugment. "
+            "Larger values mean more warping. "
+            "A value less than 1 means to disable time warp.",
+        )
+
+        group.add_argument(
+            "--enable-musan",
+            type=str2bool,
+            default=True,
+            help="When enabled, select noise from MUSAN and mix it"
+            "with training dataset. ",
+        )
+
+        group.add_argument(
+            "--input-strategy",
+            type=str,
+            default="PrecomputedFeatures",
+            help="AudioSamples or PrecomputedFeatures",
+        )
+
+    def train_dataloaders(
+        self,
+        cuts_train: CutSet,
+        sampler_state_dict: Optional[Dict[str, Any]] = None,
+    ) -> DataLoader:
+        """
+        Args:
+          cuts_train:
+            CutSet for training.
+          sampler_state_dict:
+            The state dict for the training sampler.
+        """
+        transforms = []
+        if self.args.enable_musan:
+            logging.info("Enable MUSAN")
+            logging.info("About to get Musan cuts")
+            cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+            transforms.append(
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+            )
+        else:
+            logging.info("Disable MUSAN")
+
+        if self.args.concatenate_cuts:
+            logging.info(
+                f"Using cut concatenation with duration factor "
+                f"{self.args.duration_factor} and gap {self.args.gap}."
+            )
+            # Cut concatenation should be the first transform in the list,
+            # so that if we e.g. mix noise in, it will fill the gaps between
+            # different utterances.
+            transforms = [
+                CutConcatenate(
+                    duration_factor=self.args.duration_factor, gap=self.args.gap
+                )
+            ] + transforms
+
+        input_transforms = []
+        if self.args.enable_spec_aug:
+            logging.info("Enable SpecAugment")
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+            # Set the value of num_frame_masks according to Lhotse's version.
+            # In different Lhotse's versions, the default of num_frame_masks is
+            # different.
+            num_frame_masks = 10
+            num_frame_masks_parameter = inspect.signature(
+                SpecAugment.__init__
+            ).parameters["num_frame_masks"]
+            if num_frame_masks_parameter.default == 1:
+                num_frame_masks = 2
+            logging.info(f"Num frame mask: {num_frame_masks}")
+            input_transforms.append(
+                SpecAugment(
+                    time_warp_factor=self.args.spec_aug_time_warp_factor,
+                    num_frame_masks=num_frame_masks,
+                    features_mask_size=27,
+                    num_feature_masks=2,
+                    frames_mask_size=100,
+                )
+            )
+        else:
+            logging.info("Disable SpecAugment")
+
+        logging.info("About to create train dataset")
+        train = K2SpeechRecognitionDataset(
+            input_strategy=eval(self.args.input_strategy)(),
+            cut_transforms=transforms,
+            input_transforms=input_transforms,
+            return_cuts=self.args.return_cuts,
+        )
+
+        if self.args.on_the_fly_feats:
+            # NOTE: the PerturbSpeed transform should be added only if we
+            # remove it from data prep stage.
+            # Add on-the-fly speed perturbation; since originally it would
+            # have increased epoch size by 3, we will apply prob 2/3 and use
+            # 3x more epochs.
+            # Speed perturbation probably should come first before
+            # concatenation, but in principle the transforms order doesn't have
+            # to be strict (e.g. could be randomized)
+            # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms   # noqa
+            # Drop feats to be on the safe side.
+            train = K2SpeechRecognitionDataset(
+                cut_transforms=transforms,
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_transforms=input_transforms,
+                return_cuts=self.args.return_cuts,
+            )
+
+        if self.args.bucketing_sampler:
+            logging.info("Using DynamicBucketingSampler.")
+            train_sampler = DynamicBucketingSampler(
+                cuts_train,
+                max_duration=self.args.max_duration,
+                shuffle=self.args.shuffle,
+                num_buckets=self.args.num_buckets,
+                drop_last=self.args.drop_last,
+            )
+        else:
+            logging.info("Using SingleCutSampler.")
+            train_sampler = SingleCutSampler(
+                cuts_train,
+                max_duration=self.args.max_duration,
+                shuffle=self.args.shuffle,
+            )
+        logging.info("About to create train dataloader")
+
+        if sampler_state_dict is not None:
+            logging.info("Loading sampler state dict")
+            train_sampler.load_state_dict(sampler_state_dict)
+
+        # 'seed' is derived from the current random state, which will have
+        # previously been set in the main process.
+        seed = torch.randint(0, 100000, ()).item()
+        worker_init_fn = _SeedWorkers(seed)
+
+        train_dl = DataLoader(
+            train,
+            sampler=train_sampler,
+            batch_size=None,
+            num_workers=self.args.num_workers,
+            persistent_workers=False,
+            worker_init_fn=worker_init_fn,
+        )
+
+        return train_dl
+
+    def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+        transforms = []
+        if self.args.concatenate_cuts:
+            transforms = [
+                CutConcatenate(
+                    duration_factor=self.args.duration_factor, gap=self.args.gap
+                )
+            ] + transforms
+
+        logging.info("About to create dev dataset")
+        if self.args.on_the_fly_feats:
+            validate = K2SpeechRecognitionDataset(
+                cut_transforms=transforms,
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                return_cuts=self.args.return_cuts,
+            )
+        else:
+            validate = K2SpeechRecognitionDataset(
+                cut_transforms=transforms,
+                return_cuts=self.args.return_cuts,
+            )
+        valid_sampler = DynamicBucketingSampler(
+            cuts_valid,
+            max_duration=self.args.max_duration,
+            shuffle=False,
+        )
+        logging.info("About to create dev dataloader")
+        valid_dl = DataLoader(
+            validate,
+            sampler=valid_sampler,
+            batch_size=None,
+            num_workers=2,
+            persistent_workers=False,
+        )
+
+        return valid_dl
+
+    def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+        logging.debug("About to create test dataset")
+        test = K2SpeechRecognitionDataset(
+            input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
+            if self.args.on_the_fly_feats
+            else eval(self.args.input_strategy)(),
+            return_cuts=self.args.return_cuts,
+        )
+        sampler = DynamicBucketingSampler(
+            cuts,
+            max_duration=self.args.max_duration,
+            shuffle=False,
+        )
+        logging.debug("About to create test dataloader")
+        test_dl = DataLoader(
+            test,
+            batch_size=None,
+            sampler=sampler,
+            num_workers=self.args.num_workers,
+        )
+        return test_dl
+
+    @lru_cache()
+    def train_cuts(self) -> CutSet:
+        f = self.args.manifest_dir / "xbmu_amdo31_cuts_train.jsonl.gz"
+        logging.info(f"About to get train cuts from {f}")
+        cuts_train = load_manifest_lazy(f)
+        return cuts_train
+
+    @lru_cache()
+    def valid_cuts(self) -> CutSet:
+        f = self.args.manifest_dir / "xbmu_amdo31_cuts_dev.jsonl.gz"
+        logging.info(f"About to get valid cuts from {f}")
+        cuts_valid = load_manifest_lazy(f)
+        return cuts_valid
+
+    @lru_cache()
+    def test_cuts(self) -> CutSet:
+        f = self.args.manifest_dir / "xbmu_amdo31_cuts_test.jsonl.gz"
+        logging.info(f"About to get test cuts from {f}")
+        cuts_test = load_manifest_lazy(f)
+        return cuts_test
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/beam_search.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/beam_search.py
new file mode 120000
index 000000000..e24eca39f
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/conformer.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/conformer.py
new file mode 120000
index 000000000..c7c1a4b6e
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/conformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/conformer.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py
new file mode 100755
index 000000000..6a67e26f8
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py
@@ -0,0 +1,970 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
+#                                                 Zengwei Yao,
+#                                                 Xiaoyu Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./pruned_transducer_stateless5/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 600 \
+    --decoding-method greedy_search
+(2) beam search (not recommended)
+./pruned_transducer_stateless5/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 600 \
+    --decoding-method beam_search \
+    --beam-size 4
+(3) modified beam search
+./pruned_transducer_stateless5/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 600 \
+    --decoding-method modified_beam_search \
+    --beam-size 4
+(4) fast beam search (one best)
+./pruned_transducer_stateless5/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 600 \
+    --decoding-method fast_beam_search \
+    --beam 20.0 \
+    --max-contexts 8 \
+    --max-states 64
+(5) fast beam search (nbest)
+./pruned_transducer_stateless5/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 600 \
+    --decoding-method fast_beam_search_nbest \
+    --beam 20.0 \
+    --max-contexts 8 \
+    --max-states 64 \
+    --num-paths 200 \
+    --nbest-scale 0.5
+(6) fast beam search (nbest oracle WER)
+./pruned_transducer_stateless5/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 600 \
+    --decoding-method fast_beam_search_nbest_oracle \
+    --beam 20.0 \
+    --max-contexts 8 \
+    --max-states 64 \
+    --num-paths 200 \
+    --nbest-scale 0.5
+(7) fast beam search (with LG)
+./pruned_transducer_stateless5/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 600 \
+    --decoding-method fast_beam_search_nbest_LG \
+    --beam 20.0 \
+    --max-contexts 8 \
+    --max-states 64
+
+(8) modified beam search with RNNLM shallow fusion (with LG)
+./pruned_transducer_stateless5/decode.py \
+    --epoch 35 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 600 \
+    --decoding-method fast_beam_search_nbest_LG \
+    --beam 4 \
+    --max-contexts 4 \
+    --rnn-lm-scale 0.4 \
+    --rnn-lm-exp-dir /path/to/RNNLM/exp \
+    --rnn-lm-epoch 99 \
+    --rnn-lm-avg 1 \
+    --rnn-lm-num-layers 3 \
+    --rnn-lm-tie-weights 1
+
+
+"""
+
+
+import argparse
+import logging
+import math
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import Xbmu_AmdoAsrDataModule
+from beam_search import (
+    beam_search,
+    fast_beam_search_nbest,
+    fast_beam_search_nbest_LG,
+    fast_beam_search_nbest_oracle,
+    fast_beam_search_one_best,
+    greedy_search,
+    greedy_search_batch,
+    modified_beam_search,
+    modified_beam_search_rnnlm_shallow_fusion,
+)
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.rnn_lm.model import RnnLmModel
+from icefall.utils import (
+    AttributeDict,
+    setup_logger,
+    store_transcripts,
+    str2bool,
+    write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=30,
+        help="""It specifies the checkpoint to use for decoding.
+        Note: Epoch counts from 1.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=15,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=True,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless5/exp",
+        help="The experiment dir",
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_500/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=Path,
+        default="data/lang_bpe_500",
+        help="The lang dir containing word table and LG graph",
+    )
+
+    parser.add_argument(
+        "--decoding-method",
+        type=str,
+        default="greedy_search",
+        help="""Possible values are:
+          - greedy_search
+          - beam_search
+          - modified_beam_search
+          - fast_beam_search
+          - fast_beam_search_LG
+          - fast_beam_search_nbest
+          - fast_beam_search_nbest_oracle
+          - fast_beam_search_nbest_LG
+          - modified_beam_search_rnnlm_shallow_fusion # for rnn lm shallow fusion
+        If you use fast_beam_search_nbest_LG, you have to specify
+        `--lang-dir`, which should contain `LG.pt`.
+        """,
+    )
+
+    parser.add_argument(
+        "--beam-size",
+        type=int,
+        default=4,
+        help="""An integer indicating how many candidates we will keep for each
+        frame. Used only when --decoding-method is beam_search or
+        modified_beam_search.""",
+    )
+
+    parser.add_argument(
+        "--beam",
+        type=float,
+        default=20.0,
+        help="""A floating point value to calculate the cutoff score during beam
+        search (i.e., `cutoff = max-score - beam`), which is the same as the
+        `beam` in Kaldi.
+        Used only when --decoding-method is fast_beam_search, fast_beam_search_LG,
+        fast_beam_search_nbest, fast_beam_search_nbest_LG,
+        and fast_beam_search_nbest_oracle
+        """,
+    )
+
+    parser.add_argument(
+        "--ngram-lm-scale",
+        type=float,
+        default=0.01,
+        help="""
+        Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG.
+        It specifies the scale for n-gram LM scores.
+        """,
+    )
+
+    parser.add_argument(
+        "--decode-chunk-size",
+        type=int,
+        default=16,
+        help="The chunk size for decoding (in frames after subsampling)",
+    )
+
+    parser.add_argument(
+        "--left-context",
+        type=int,
+        default=64,
+        help="left context can be seen during decoding (in frames after subsampling)",
+    )
+
+    parser.add_argument(
+        "--max-contexts",
+        type=int,
+        default=8,
+        help="""Used only when --decoding-method is fast_beam_search_LG,
+        fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+        and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--max-states",
+        type=int,
+        default=64,
+        help="""Used only when --decoding-method is fast_beam_search_LG,
+        fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+        and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+
+    parser.add_argument(
+        "--max-sym-per-frame",
+        type=int,
+        default=1,
+        help="""Maximum number of symbols per frame.
+        Used only when --decoding_method is greedy_search""",
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=200,
+        help="""Number of paths for nbest decoding.
+        Used only when the decoding method is fast_beam_search_nbest,
+        fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=0.5,
+        help="""Scale applied to lattice scores when computing nbest paths.
+        Used only when the decoding method is fast_beam_search_nbest,
+        fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--simulate-streaming",
+        type=str2bool,
+        default=False,
+        help="""Whether to simulate streaming in decoding, this is a good way to
+        test a streaming model.
+        """,
+    )
+
+    parser.add_argument(
+        "--rnn-lm-scale",
+        type=float,
+        default=0.0,
+        help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
+        It specifies the path to RNN LM exp dir.
+        """,
+    )
+
+    parser.add_argument(
+        "--rnn-lm-exp-dir",
+        type=str,
+        default="rnn_lm/exp",
+        help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
+        It specifies the path to RNN LM exp dir.
+        """,
+    )
+
+    parser.add_argument(
+        "--rnn-lm-epoch",
+        type=int,
+        default=7,
+        help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
+        It specifies the checkpoint to use.
+        """,
+    )
+
+    parser.add_argument(
+        "--rnn-lm-avg",
+        type=int,
+        default=2,
+        help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
+        It specifies the number of checkpoints to average.
+        """,
+    )
+
+    parser.add_argument(
+        "--rnn-lm-embedding-dim",
+        type=int,
+        default=2048,
+        help="Embedding dim of the model",
+    )
+
+    parser.add_argument(
+        "--rnn-lm-hidden-dim",
+        type=int,
+        default=2048,
+        help="Hidden dim of the model",
+    )
+
+    parser.add_argument(
+        "--rnn-lm-num-layers",
+        type=int,
+        default=4,
+        help="Number of RNN layers the model",
+    )
+    parser.add_argument(
+        "--rnn-lm-tie-weights",
+        type=str2bool,
+        default=False,
+        help="""True to share the weights between the input embedding layer and the
+        last output linear layer
+        """,
+    )
+    add_model_arguments(parser)
+
+    return parser
+
+
+def decode_one_batch(
+    params: AttributeDict,
+    model: nn.Module,
+    sp: spm.SentencePieceProcessor,
+    batch: dict,
+    word_table: Optional[k2.SymbolTable] = None,
+    decoding_graph: Optional[k2.Fsa] = None,
+    rnnlm: Optional[RnnLmModel] = None,
+    rnnlm_scale: float = 1.0,
+) -> Dict[str, List[List[str]]]:
+    """Decode one batch and return the result in a dict. The dict has the
+    following format:
+
+        - key: It indicates the setting used for decoding. For example,
+               if greedy_search is used, it would be "greedy_search"
+               If beam search with a beam size of 7 is used, it would be
+               "beam_7"
+        - value: It contains the decoding result. `len(value)` equals to
+                 batch size. `value[i]` is the decoding result for the i-th
+                 utterance in the given batch.
+    Args:
+      params:
+        It's the return value of :func:`get_params`.
+      model:
+        The neural model.
+      sp:
+        The BPE model.
+      batch:
+        It is the return value from iterating
+        `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+        for the format of the `batch`.
+      word_table:
+        The word symbol table.
+      decoding_graph:
+        The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
+        only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
+        fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+    Returns:
+      Return the decoding result. See above description for the format of
+      the returned dict.
+    """
+    device = next(model.parameters()).device
+    feature = batch["inputs"]
+    assert feature.ndim == 3
+
+    feature = feature.to(device)
+    # at entry, feature is (N, T, C)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    if params.simulate_streaming:
+        feature_lens += params.left_context
+        feature = torch.nn.functional.pad(
+            feature,
+            pad=(0, 0, 0, params.left_context),
+            value=LOG_EPS,
+        )
+        encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
+            x=feature,
+            x_lens=feature_lens,
+            chunk_size=params.decode_chunk_size,
+            left_context=params.left_context,
+            simulate_streaming=True,
+        )
+    else:
+        encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+
+    hyps = []
+
+    if (
+        params.decoding_method == "fast_beam_search"
+        or params.decoding_method == "fast_beam_search_LG"
+    ):
+        hyp_tokens = fast_beam_search_one_best(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+        )
+        if params.decoding_method == "fast_beam_search":
+            for hyp in sp.decode(hyp_tokens):
+                hyps.append(hyp.split())
+        else:
+            for hyp in hyp_tokens:
+                hyps.append([word_table[i] for i in hyp])
+    elif params.decoding_method == "fast_beam_search_nbest_LG":
+        hyp_tokens = fast_beam_search_nbest_LG(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+            num_paths=params.num_paths,
+            nbest_scale=params.nbest_scale,
+        )
+        for hyp in hyp_tokens:
+            hyps.append([word_table[i] for i in hyp])
+    elif params.decoding_method == "fast_beam_search_nbest":
+        hyp_tokens = fast_beam_search_nbest(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+            num_paths=params.num_paths,
+            nbest_scale=params.nbest_scale,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "fast_beam_search_nbest_oracle":
+        hyp_tokens = fast_beam_search_nbest_oracle(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+            num_paths=params.num_paths,
+            ref_texts=sp.encode(supervisions["text"]),
+            nbest_scale=params.nbest_scale,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+        hyp_tokens = greedy_search_batch(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "modified_beam_search":
+        hyp_tokens = modified_beam_search(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam_size,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion":
+        hyp_tokens = modified_beam_search_rnnlm_shallow_fusion(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam_size,
+            sp=sp,
+            rnnlm=rnnlm,
+            rnnlm_scale=rnnlm_scale,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    else:
+        batch_size = encoder_out.size(0)
+
+        for i in range(batch_size):
+            # fmt: off
+            encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+            # fmt: on
+            if params.decoding_method == "greedy_search":
+                hyp = greedy_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    max_sym_per_frame=params.max_sym_per_frame,
+                )
+            elif params.decoding_method == "beam_search":
+                hyp = beam_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    beam=params.beam_size,
+                )
+            else:
+                raise ValueError(
+                    f"Unsupported decoding method: {params.decoding_method}"
+                )
+            hyps.append(sp.decode(hyp).split())
+
+    if params.decoding_method == "greedy_search":
+        return {"greedy_search": hyps}
+    elif "fast_beam_search" in params.decoding_method:
+        key = f"beam_{params.beam}_"
+        key += f"max_contexts_{params.max_contexts}_"
+        key += f"max_states_{params.max_states}"
+        if "nbest" in params.decoding_method:
+            key += f"_num_paths_{params.num_paths}_"
+            key += f"nbest_scale_{params.nbest_scale}"
+        if "LG" in params.decoding_method:
+            key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+
+        return {key: hyps}
+    else:
+        return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+    dl: torch.utils.data.DataLoader,
+    params: AttributeDict,
+    model: nn.Module,
+    sp: spm.SentencePieceProcessor,
+    word_table: Optional[k2.SymbolTable] = None,
+    decoding_graph: Optional[k2.Fsa] = None,
+    rnnlm: Optional[RnnLmModel] = None,
+    rnnlm_scale: float = 1.0,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+    """Decode dataset.
+
+    Args:
+      dl:
+        PyTorch's dataloader containing the dataset to decode.
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The neural model.
+      sp:
+        The BPE model.
+      word_table:
+        The word symbol table.
+      decoding_graph:
+        The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
+        only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
+        fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+    Returns:
+      Return a dict, whose key may be "greedy_search" if greedy search
+      is used, or it may be "beam_7" if beam size of 7 is used.
+      Its value is a list of tuples. Each tuple contains two elements:
+      The first is the reference transcript, and the second is the
+      predicted result.
+    """
+    num_cuts = 0
+
+    try:
+        num_batches = len(dl)
+    except TypeError:
+        num_batches = "?"
+
+    if params.decoding_method == "greedy_search":
+        log_interval = 50
+    else:
+        log_interval = 20
+
+    results = defaultdict(list)
+    for batch_idx, batch in enumerate(dl):
+        texts = batch["supervisions"]["text"]
+        cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+        logging.info(f"Decoding {batch_idx}-th batch")
+
+        hyps_dict = decode_one_batch(
+            params=params,
+            model=model,
+            sp=sp,
+            decoding_graph=decoding_graph,
+            word_table=word_table,
+            batch=batch,
+            rnnlm=rnnlm,
+            rnnlm_scale=rnnlm_scale,
+        )
+
+        for name, hyps in hyps_dict.items():
+            this_batch = []
+            assert len(hyps) == len(texts)
+            for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+                ref_words = ref_text.split()
+                this_batch.append((cut_id, ref_words, hyp_words))
+
+            results[name].extend(this_batch)
+
+        num_cuts += len(texts)
+
+        if batch_idx % log_interval == 0:
+            batch_str = f"{batch_idx}/{num_batches}"
+
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+    return results
+
+
+def save_results(
+    params: AttributeDict,
+    test_set_name: str,
+    results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+    test_set_wers = dict()
+    for key, results in results_dict.items():
+        recog_path = (
+            params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+        )
+        results = sorted(results)
+        store_transcripts(filename=recog_path, texts=results)
+        logging.info(f"The transcripts are stored in {recog_path}")
+
+        # The following prints out WERs, per-word error statistics and aligned
+        # ref/hyp pairs.
+        errs_filename = (
+            params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+        )
+        with open(errs_filename, "w") as f:
+            wer = write_error_stats(
+                f, f"{test_set_name}-{key}", results, enable_log=True
+            )
+            test_set_wers[key] = wer
+
+        logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+    test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+    errs_info = (
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+    )
+    with open(errs_info, "w") as f:
+        print("settings\tWER", file=f)
+        for key, val in test_set_wers:
+            print("{}\t{}".format(key, val), file=f)
+
+    s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+    note = "\tbest for {}".format(test_set_name)
+    for key, val in test_set_wers:
+        s += "{}\t{}{}\n".format(key, val, note)
+        note = ""
+    logging.info(s)
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    Xbmu_AmdoAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    params = get_params()
+    params.update(vars(args))
+
+    assert params.decoding_method in (
+        "greedy_search",
+        "beam_search",
+        "fast_beam_search",
+        "fast_beam_search_LG",
+        "fast_beam_search_nbest",
+        "fast_beam_search_nbest_LG",
+        "fast_beam_search_nbest_oracle",
+        "modified_beam_search",
+        "modified_beam_search_rnnlm_shallow_fusion",
+    )
+    params.res_dir = params.exp_dir / params.decoding_method
+
+    if params.iter > 0:
+        params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+    else:
+        params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+    if params.simulate_streaming:
+        params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
+        params.suffix += f"-left-context-{params.left_context}"
+    if "fast_beam_search" in params.decoding_method:
+        params.suffix += f"-beam-{params.beam}"
+        params.suffix += f"-max-contexts-{params.max_contexts}"
+        params.suffix += f"-max-states-{params.max_states}"
+        if "nbest" in params.decoding_method:
+            params.suffix += f"-nbest-scale-{params.nbest_scale}"
+            params.suffix += f"-num-paths-{params.num_paths}"
+        if "LG" in params.decoding_method:
+            params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+    elif "beam_search" in params.decoding_method:
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+    else:
+        params.suffix += f"-context-{params.context_size}"
+        params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+    params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}"
+
+    if params.use_averaged_model:
+        params.suffix += "-use-averaged-model"
+
+    setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+    logging.info("Decoding started")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"Device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  and  are defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.unk_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    if params.simulate_streaming:
+        assert (
+            params.causal_convolution
+        ), "Decoding in streaming requires causal convolution"
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to(device)
+    model.eval()
+
+    rnn_lm_model = None
+    rnn_lm_scale = params.rnn_lm_scale
+    if params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion":
+        rnn_lm_model = RnnLmModel(
+            vocab_size=params.vocab_size,
+            embedding_dim=params.rnn_lm_embedding_dim,
+            hidden_dim=params.rnn_lm_hidden_dim,
+            num_layers=params.rnn_lm_num_layers,
+            tie_weights=params.rnn_lm_tie_weights,
+        )
+        assert params.rnn_lm_avg == 1
+
+        load_checkpoint(
+            f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
+            rnn_lm_model,
+        )
+        rnn_lm_model.to(device)
+        rnn_lm_model.eval()
+
+    if "fast_beam_search" in params.decoding_method:
+        if "LG" in params.decoding_method:
+            lexicon = Lexicon(params.lang_dir)
+            word_table = lexicon.word_table
+            lg_filename = params.lang_dir / "LG.pt"
+            logging.info(f"Loading {lg_filename}")
+            decoding_graph = k2.Fsa.from_dict(
+                torch.load(lg_filename, map_location=device)
+            )
+            decoding_graph.scores *= params.ngram_lm_scale
+        else:
+            word_table = None
+            decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+    else:
+        decoding_graph = None
+        word_table = None
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    # we need cut ids to display recognition results.
+    args.return_cuts = True
+    xbmu_amdo = Xbmu_AmdoAsrDataModule(args)
+
+    test_cuts = xbmu_amdo.test_cuts()
+
+    test_dl = xbmu_amdo.test_dataloaders(test_cuts)
+
+    test_sets = ["test"]
+    test_dl = [test_dl]
+
+    for test_set, test_dl in zip(test_sets, test_dl):
+        results_dict = decode_dataset(
+            dl=test_dl,
+            params=params,
+            model=model,
+            sp=sp,
+            word_table=word_table,
+            decoding_graph=decoding_graph,
+            rnnlm=rnn_lm_model,
+            rnnlm_scale=rnn_lm_scale,
+        )
+
+        save_results(
+            params=params,
+            test_set_name=test_set,
+            results_dict=results_dict,
+        )
+
+    logging.info("Done!")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode_stream.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode_stream.py
new file mode 120000
index 000000000..d59ef95f7
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode_stream.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/decode_stream.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decoder.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decoder.py
new file mode 120000
index 000000000..722e1c894
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decoder.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/decoder.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/encoder_interface.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/encoder_interface.py
new file mode 120000
index 000000000..f58253127
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/encoder_interface.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/export.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/export.py
new file mode 100755
index 000000000..54f656859
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/export.py
@@ -0,0 +1,287 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script converts several saved checkpoints
+# to a single one using model averaging.
+"""
+Usage:
+./pruned_transducer_stateless5/export.py \
+  --exp-dir ./pruned_transducer_stateless5/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 20 \
+  --avg 10
+
+It will generate a file exp_dir/pretrained.pt
+
+To use the generated file with `pruned_transducer_stateless5/decode.py`,
+you can do:
+
+    cd /path/to/exp_dir
+    ln -s pretrained.pt epoch-9999.pt
+
+    cd /path/to/egs/librispeech/ASR
+    ./pruned_transducer_stateless5/decode.py \
+        --exp-dir ./pruned_transducer_stateless5/exp \
+        --epoch 9999 \
+        --avg 1 \
+        --max-duration 600 \
+        --decoding-method greedy_search \
+        --bpe-model data/lang_bpe_500/bpe.model
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import sentencepiece as spm
+import torch
+from scaling_converter import convert_scaled_to_non_scaled
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.utils import str2bool
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=28,
+        help="""It specifies the checkpoint to use for averaging.
+        Note: Epoch counts from 1.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=15,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=True,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless5/exp",
+        help="""It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_500/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--jit",
+        type=str2bool,
+        default=False,
+        help="""True to save a model after applying torch.jit.script.
+        """,
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+
+    parser.add_argument(
+        "--streaming-model",
+        type=str2bool,
+        default=False,
+        help="""Whether to export a streaming model, if the models in exp-dir
+        are streaming model, this should be True.
+        """,
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def main():
+    args = get_parser().parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    params = get_params()
+    params.update(vars(args))
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    if params.streaming_model:
+        assert params.causal_convolution
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to("cpu")
+    model.eval()
+
+    if params.jit:
+        # We won't use the forward() method of the model in C++, so just ignore
+        # it here.
+        # Otherwise, one of its arguments is a ragged tensor and is not
+        # torch scriptabe.
+        convert_scaled_to_non_scaled(model, inplace=True)
+        model.__class__.forward = torch.jit.ignore(model.__class__.forward)
+        logging.info("Using torch.jit.script")
+        model = torch.jit.script(model)
+        filename = params.exp_dir / "cpu_jit.pt"
+        model.save(str(filename))
+        logging.info(f"Saved to {filename}")
+    else:
+        logging.info("Not using torch.jit.script")
+        # Save it using a format so that it can be loaded
+        # by :func:`load_checkpoint`
+        filename = params.exp_dir / "pretrained.pt"
+        torch.save({"model": model.state_dict()}, str(filename))
+        logging.info(f"Saved to {filename}")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/joiner.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/joiner.py
new file mode 120000
index 000000000..9052f3cbb
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/joiner.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/joiner.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/lstmp.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/lstmp.py
new file mode 120000
index 000000000..b82e115fc
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/lstmp.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/lstm_transducer_stateless2/lstmp.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/model.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/model.py
new file mode 120000
index 000000000..a99e74334
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/model.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/optim.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/optim.py
new file mode 120000
index 000000000..0a2f285aa
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/optim.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/pretrained.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/pretrained.py
new file mode 100755
index 000000000..74a2210c3
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/pretrained.py
@@ -0,0 +1,344 @@
+#!/usr/bin/env python3
+# Copyright      2021  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+(1) greedy search
+./pruned_transducer_stateless5/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method greedy_search \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+(2) beam search
+./pruned_transducer_stateless5/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method beam_search \
+    --beam-size 4 \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+(3) modified beam search
+./pruned_transducer_stateless5/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method modified_beam_search \
+    --beam-size 4 \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+(4) fast beam search
+./pruned_transducer_stateless5/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method fast_beam_search \
+    --beam-size 4 \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+You can also use `./pruned_transducer_stateless5/exp/epoch-xx.pt`.
+
+Note: ./pruned_transducer_stateless5/exp/pretrained.pt is generated by
+./pruned_transducer_stateless5/export.py
+"""
+
+
+import argparse
+import logging
+import math
+from typing import List
+
+import k2
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from beam_search import (
+    beam_search,
+    fast_beam_search_one_best,
+    greedy_search,
+    greedy_search_batch,
+    modified_beam_search,
+)
+from torch.nn.utils.rnn import pad_sequence
+from train import add_model_arguments, get_params, get_transducer_model
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--checkpoint",
+        type=str,
+        required=True,
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        help="""Path to bpe.model.""",
+    )
+
+    parser.add_argument(
+        "--method",
+        type=str,
+        default="greedy_search",
+        help="""Possible values are:
+          - greedy_search
+          - beam_search
+          - modified_beam_search
+          - fast_beam_search
+        """,
+    )
+
+    parser.add_argument(
+        "sound_files",
+        type=str,
+        nargs="+",
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
+    )
+
+    parser.add_argument(
+        "--sample-rate",
+        type=int,
+        default=16000,
+        help="The sample rate of the input sound file",
+    )
+
+    parser.add_argument(
+        "--beam-size",
+        type=int,
+        default=4,
+        help="""An integer indicating how many candidates we will keep for each
+        frame. Used only when --method is beam_search or
+        modified_beam_search.""",
+    )
+
+    parser.add_argument(
+        "--beam",
+        type=float,
+        default=4,
+        help="""A floating point value to calculate the cutoff score during beam
+        search (i.e., `cutoff = max-score - beam`), which is the same as the
+        `beam` in Kaldi.
+        Used only when --method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--max-contexts",
+        type=int,
+        default=4,
+        help="""Used only when --method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--max-states",
+        type=int,
+        default=8,
+        help="""Used only when --method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+    parser.add_argument(
+        "--max-sym-per-frame",
+        type=int,
+        default=1,
+        help="""Maximum number of symbols per frame. Used only when
+        --method is greedy_search.
+        """,
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def read_sound_files(
+    filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+    """Read a list of sound files into a list 1-D float32 torch tensors.
+    Args:
+      filenames:
+        A list of sound filenames.
+      expected_sample_rate:
+        The expected sample rate of the sound files.
+    Returns:
+      Return a list of 1-D float32 torch tensors.
+    """
+    ans = []
+    for f in filenames:
+        wave, sample_rate = torchaudio.load(f)
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        # We use only the first channel
+        ans.append(wave[0])
+    return ans
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+
+    params = get_params()
+
+    params.update(vars(args))
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.unk_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(f"{params}")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    logging.info("Creating model")
+    model = get_transducer_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    checkpoint = torch.load(args.checkpoint, map_location="cpu")
+    model.load_state_dict(checkpoint["model"], strict=False)
+    model.to(device)
+    model.eval()
+    model.device = device
+
+    logging.info("Constructing Fbank computer")
+    opts = kaldifeat.FbankOptions()
+    opts.device = device
+    opts.frame_opts.dither = 0
+    opts.frame_opts.snip_edges = False
+    opts.frame_opts.samp_freq = params.sample_rate
+    opts.mel_opts.num_bins = params.feature_dim
+
+    fbank = kaldifeat.Fbank(opts)
+
+    logging.info(f"Reading sound files: {params.sound_files}")
+    waves = read_sound_files(
+        filenames=params.sound_files, expected_sample_rate=params.sample_rate
+    )
+    waves = [w.to(device) for w in waves]
+
+    logging.info("Decoding started")
+    features = fbank(waves)
+    feature_lengths = [f.size(0) for f in features]
+
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+
+    feature_lengths = torch.tensor(feature_lengths, device=device)
+
+    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
+
+    num_waves = encoder_out.size(0)
+    hyps = []
+    msg = f"Using {params.method}"
+    if params.method == "beam_search":
+        msg += f" with beam size {params.beam_size}"
+    logging.info(msg)
+
+    if params.method == "fast_beam_search":
+        decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+        hyp_tokens = fast_beam_search_one_best(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.method == "modified_beam_search":
+        hyp_tokens = modified_beam_search(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam_size,
+        )
+
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
+        hyp_tokens = greedy_search_batch(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    else:
+        for i in range(num_waves):
+            # fmt: off
+            encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+            # fmt: on
+            if params.method == "greedy_search":
+                hyp = greedy_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    max_sym_per_frame=params.max_sym_per_frame,
+                )
+            elif params.method == "beam_search":
+                hyp = beam_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    beam=params.beam_size,
+                )
+            else:
+                raise ValueError(f"Unsupported method: {params.method}")
+
+            hyps.append(sp.decode(hyp).split())
+
+    s = "\n"
+    for filename, hyp in zip(params.sound_files, hyps):
+        words = " ".join(hyp)
+        s += f"{filename}:\n{words}\n\n"
+    logging.info(s)
+
+    logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/scaling.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/scaling.py
new file mode 120000
index 000000000..c10cdfe12
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/scaling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/scaling.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/scaling_converter.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/scaling_converter.py
new file mode 120000
index 000000000..db93d155b
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/scaling_converter.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/streaming_beam_search.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/streaming_beam_search.py
new file mode 120000
index 000000000..1199a61d6
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/streaming_beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/streaming_decode.py
new file mode 120000
index 000000000..f29284163
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/streaming_decode.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/test_model.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/test_model.py
new file mode 100755
index 000000000..9aad32014
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/test_model.py
@@ -0,0 +1,65 @@
+#!/usr/bin/env python3
+# Copyright    2022  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+To run this file, do:
+
+    cd icefall/egs/librispeech/ASR
+    python ./pruned_transducer_stateless4/test_model.py
+"""
+
+from train import get_params, get_transducer_model
+
+
+def test_model_1():
+    params = get_params()
+    params.vocab_size = 500
+    params.blank_id = 0
+    params.context_size = 2
+    params.num_encoder_layers = 24
+    params.dim_feedforward = 1536  # 384 * 4
+    params.encoder_dim = 384
+    model = get_transducer_model(params)
+    num_param = sum([p.numel() for p in model.parameters()])
+    print(f"Number of model parameters: {num_param}")
+
+
+# See Table 1 from https://arxiv.org/pdf/2005.08100.pdf
+def test_model_M():
+    params = get_params()
+    params.vocab_size = 500
+    params.blank_id = 0
+    params.context_size = 2
+    params.num_encoder_layers = 18
+    params.dim_feedforward = 1024
+    params.encoder_dim = 256
+    params.nhead = 4
+    params.decoder_dim = 512
+    params.joiner_dim = 512
+    model = get_transducer_model(params)
+    num_param = sum([p.numel() for p in model.parameters()])
+    print(f"Number of model parameters: {num_param}")
+
+
+def main():
+    #  test_model_1()
+    test_model_M()
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py
new file mode 100755
index 000000000..5b5ac17be
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py
@@ -0,0 +1,1187 @@
+#!/usr/bin/env python3
+# Copyright    2021-2022  Xiaomi Corp.        (authors: Fangjun Kuang,
+#                                                       Wei Kang,
+#                                                       Mingshuang Luo,)
+#                                                       Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless5/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --exp-dir pruned_transducer_stateless5/exp \
+  --full-libri 1 \
+  --max-duration 300
+
+# For mix precision training:
+
+./pruned_transducer_stateless5/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --use-fp16 1 \
+  --exp-dir pruned_transducer_stateless5/exp \
+  --full-libri 1 \
+  --max-duration 550
+
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import Xbmu_AmdoAsrDataModule
+from conformer import Conformer
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import Transducer
+from optim import Eden, Eve
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+    save_checkpoint_with_global_batch_idx,
+    update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.utils import (
+    AttributeDict,
+    MetricsTracker,
+    display_and_save_batch,
+    setup_logger,
+    str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+    parser.add_argument(
+        "--num-encoder-layers",
+        type=int,
+        default=24,
+        help="Number of conformer encoder layers..",
+    )
+
+    parser.add_argument(
+        "--dim-feedforward",
+        type=int,
+        default=1536,
+        help="Feedforward dimension of the conformer encoder layer.",
+    )
+
+    parser.add_argument(
+        "--nhead",
+        type=int,
+        default=8,
+        help="Number of attention heads in the conformer encoder layer.",
+    )
+
+    parser.add_argument(
+        "--encoder-dim",
+        type=int,
+        default=384,
+        help="Attention dimension in the conformer encoder layer.",
+    )
+
+    parser.add_argument(
+        "--decoder-dim",
+        type=int,
+        default=512,
+        help="Embedding dimension in the decoder model.",
+    )
+
+    parser.add_argument(
+        "--joiner-dim",
+        type=int,
+        default=512,
+        help="""Dimension used in the joiner model.
+        Outputs from the encoder and decoder model are projected
+        to this dimension before adding.
+        """,
+    )
+
+    parser.add_argument(
+        "--dynamic-chunk-training",
+        type=str2bool,
+        default=False,
+        help="""Whether to use dynamic_chunk_training, if you want a streaming
+        model, this requires to be True.
+        """,
+    )
+
+    parser.add_argument(
+        "--causal-convolution",
+        type=str2bool,
+        default=False,
+        help="""Whether to use causal convolution, this requires to be True when
+        using dynamic_chunk_training.
+        """,
+    )
+
+    parser.add_argument(
+        "--short-chunk-size",
+        type=int,
+        default=25,
+        help="""Chunk length of dynamic training, the chunk size would be either
+        max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
+        """,
+    )
+
+    parser.add_argument(
+        "--num-left-chunks",
+        type=int,
+        default=4,
+        help="How many left context can be seen in chunks when calculating attention.",
+    )
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--world-size",
+        type=int,
+        default=1,
+        help="Number of GPUs for DDP training.",
+    )
+
+    parser.add_argument(
+        "--master-port",
+        type=int,
+        default=12354,
+        help="Master port to use for DDP training.",
+    )
+
+    parser.add_argument(
+        "--tensorboard",
+        type=str2bool,
+        default=True,
+        help="Should various information be logged in tensorboard.",
+    )
+
+    parser.add_argument(
+        "--num-epochs",
+        type=int,
+        default=30,
+        help="Number of epochs to train.",
+    )
+
+    parser.add_argument(
+        "--start-epoch",
+        type=int,
+        default=1,
+        help="""Resume training from this epoch. It should be positive.
+        If larger than 1, it will load checkpoint from
+        exp-dir/epoch-{start_epoch-1}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--start-batch",
+        type=int,
+        default=0,
+        help="""If positive, --start-epoch is ignored and
+        it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless5/exp",
+        help="""The experiment dir.
+        It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_500/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--initial-lr",
+        type=float,
+        default=0.003,
+        help="The initial learning rate.  This value should not need to be changed.",
+    )
+
+    parser.add_argument(
+        "--lr-batches",
+        type=float,
+        default=5000,
+        help="""Number of steps that affects how rapidly the learning rate
+        decreases. We suggest not to change this.""",
+    )
+
+    parser.add_argument(
+        "--lr-epochs",
+        type=float,
+        default=6,
+        help="""Number of epochs that affects how rapidly the learning rate decreases.
+        """,
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+
+    parser.add_argument(
+        "--prune-range",
+        type=int,
+        default=5,
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
+    )
+
+    parser.add_argument(
+        "--lm-scale",
+        type=float,
+        default=0.25,
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
+    )
+
+    parser.add_argument(
+        "--am-scale",
+        type=float,
+        default=0.0,
+        help="The scale to smooth the loss with am (output of encoder network) part.",
+    )
+
+    parser.add_argument(
+        "--simple-loss-scale",
+        type=float,
+        default=0.5,
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
+    )
+
+    parser.add_argument(
+        "--seed",
+        type=int,
+        default=42,
+        help="The seed for random generators intended for reproducibility",
+    )
+
+    parser.add_argument(
+        "--print-diagnostics",
+        type=str2bool,
+        default=False,
+        help="Accumulate stats on activations, print them and exit.",
+    )
+
+    parser.add_argument(
+        "--save-every-n",
+        type=int,
+        default=4000,
+        help="""Save checkpoint after processing this number of batches"
+        periodically. We save checkpoint to exp-dir/ whenever
+        params.batch_idx_train % save_every_n == 0. The checkpoint filename
+        has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+        Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+        end of each epoch where `xxx` is the epoch number counting from 0.
+        """,
+    )
+
+    parser.add_argument(
+        "--keep-last-k",
+        type=int,
+        default=30,
+        help="""Only keep this number of checkpoints on disk.
+        For instance, if it is 3, there are only 3 checkpoints
+        in the exp-dir with filenames `checkpoint-xxx.pt`.
+        It does not affect checkpoints with name `epoch-xxx.pt`.
+        """,
+    )
+
+    parser.add_argument(
+        "--average-period",
+        type=int,
+        default=100,
+        help="""Update the averaged model, namely `model_avg`, after processing
+        this number of batches. `model_avg` is a separate version of model,
+        in which each floating-point parameter is the average of all the
+        parameters from the start of training. Each time we take the average,
+        we do: `model_avg = model * (average_period / batch_idx_train) +
+            model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+        """,
+    )
+
+    parser.add_argument(
+        "--use-fp16",
+        type=str2bool,
+        default=False,
+        help="Whether to use half precision training.",
+    )
+
+    parser.add_argument(
+        "--delay-penalty",
+        type=float,
+        default=0.0,
+        help="""A constant value used to penalize symbol delay,
+        to encourage streaming models to emit symbols earlier.
+        See https://github.com/k2-fsa/k2/issues/955 and
+        https://arxiv.org/pdf/2211.00490.pdf for more details.""",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    """Return a dict containing training parameters.
+
+    All training related parameters that are not passed from the commandline
+    are saved in the variable `params`.
+
+    Commandline options are merged into `params` after they are parsed, so
+    you can also access them via `params`.
+
+    Explanation of options saved in `params`:
+
+        - best_train_loss: Best training loss so far. It is used to select
+                           the model that has the lowest training loss. It is
+                           updated during the training.
+
+        - best_valid_loss: Best validation loss so far. It is used to select
+                           the model that has the lowest validation loss. It is
+                           updated during the training.
+
+        - best_train_epoch: It is the epoch that has the best training loss.
+
+        - best_valid_epoch: It is the epoch that has the best validation loss.
+
+        - batch_idx_train: Used to writing statistics to tensorboard. It
+                           contains number of batches trained so far across
+                           epochs.
+
+        - log_interval:  Print training loss if batch_idx % log_interval` is 0
+
+        - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+        - valid_interval:  Run validation if batch_idx % valid_interval is 0
+
+        - feature_dim: The model input dim. It has to match the one used
+                       in computing features.
+
+        - subsampling_factor:  The subsampling factor for the model.
+
+        - encoder_dim: Hidden dim for multi-head attention model.
+
+        - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+        - warm_step: The warm_step for Noam optimizer.
+    """
+    params = AttributeDict(
+        {
+            "best_train_loss": float("inf"),
+            "best_valid_loss": float("inf"),
+            "best_train_epoch": -1,
+            "best_valid_epoch": -1,
+            "batch_idx_train": 0,
+            "log_interval": 50,
+            "reset_interval": 200,
+            "valid_interval": 3000,  # For the 100h subset, use 800
+            # parameters for conformer
+            "feature_dim": 80,
+            "subsampling_factor": 4,
+            # parameters for Noam
+            "model_warm_step": 3000,  # arg given to model, not for lrate
+            "env_info": get_env_info(),
+        }
+    )
+
+    return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+    # TODO: We can add an option to switch between Conformer and Transformer
+    encoder = Conformer(
+        num_features=params.feature_dim,
+        subsampling_factor=params.subsampling_factor,
+        d_model=params.encoder_dim,
+        nhead=params.nhead,
+        dim_feedforward=params.dim_feedforward,
+        num_encoder_layers=params.num_encoder_layers,
+        dynamic_chunk_training=params.dynamic_chunk_training,
+        short_chunk_size=params.short_chunk_size,
+        num_left_chunks=params.num_left_chunks,
+        causal=params.causal_convolution,
+    )
+    return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+    decoder = Decoder(
+        vocab_size=params.vocab_size,
+        decoder_dim=params.decoder_dim,
+        blank_id=params.blank_id,
+        context_size=params.context_size,
+    )
+    return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+    joiner = Joiner(
+        encoder_dim=params.encoder_dim,
+        decoder_dim=params.decoder_dim,
+        joiner_dim=params.joiner_dim,
+        vocab_size=params.vocab_size,
+    )
+    return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+    encoder = get_encoder_model(params)
+    decoder = get_decoder_model(params)
+    joiner = get_joiner_model(params)
+
+    model = Transducer(
+        encoder=encoder,
+        decoder=decoder,
+        joiner=joiner,
+        encoder_dim=params.encoder_dim,
+        decoder_dim=params.decoder_dim,
+        joiner_dim=params.joiner_dim,
+        vocab_size=params.vocab_size,
+    )
+    return model
+
+
+def load_checkpoint_if_available(
+    params: AttributeDict,
+    model: nn.Module,
+    model_avg: nn.Module = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+    """Load checkpoint from file.
+
+    If params.start_batch is positive, it will load the checkpoint from
+    `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+    params.start_epoch is larger than 1, it will load the checkpoint from
+    `params.start_epoch - 1`.
+
+    Apart from loading state dict for `model` and `optimizer` it also updates
+    `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+    and `best_valid_loss` in `params`.
+
+    Args:
+      params:
+        The return value of :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer that we are using.
+      scheduler:
+        The scheduler that we are using.
+    Returns:
+      Return a dict containing previously saved training info.
+    """
+    if params.start_batch > 0:
+        filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+    elif params.start_epoch > 1:
+        filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+    else:
+        return None
+
+    assert filename.is_file(), f"{filename} does not exist!"
+
+    saved_params = load_checkpoint(
+        filename,
+        model=model,
+        model_avg=model_avg,
+        optimizer=optimizer,
+        scheduler=scheduler,
+    )
+
+    keys = [
+        "best_train_epoch",
+        "best_valid_epoch",
+        "batch_idx_train",
+        "best_train_loss",
+        "best_valid_loss",
+    ]
+    for k in keys:
+        params[k] = saved_params[k]
+
+    if params.start_batch > 0:
+        if "cur_epoch" in saved_params:
+            params["start_epoch"] = saved_params["cur_epoch"]
+
+    return saved_params
+
+
+def save_checkpoint(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    model_avg: Optional[nn.Module] = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+    sampler: Optional[CutSampler] = None,
+    scaler: Optional[GradScaler] = None,
+    rank: int = 0,
+) -> None:
+    """Save model, optimizer, scheduler and training stats to file.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer used in the training.
+      sampler:
+       The sampler for the training dataset.
+      scaler:
+        The scaler used for mix precision training.
+    """
+    if rank != 0:
+        return
+    filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+    save_checkpoint_impl(
+        filename=filename,
+        model=model,
+        model_avg=model_avg,
+        params=params,
+        optimizer=optimizer,
+        scheduler=scheduler,
+        sampler=sampler,
+        scaler=scaler,
+        rank=rank,
+    )
+
+    if params.best_train_epoch == params.cur_epoch:
+        best_train_filename = params.exp_dir / "best-train-loss.pt"
+        copyfile(src=filename, dst=best_train_filename)
+
+    if params.best_valid_epoch == params.cur_epoch:
+        best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+        copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    sp: spm.SentencePieceProcessor,
+    batch: dict,
+    is_training: bool,
+    warmup: float = 1.0,
+) -> Tuple[Tensor, MetricsTracker]:
+    """
+    Compute RNN-T loss given the model and its inputs.
+
+    Args:
+      params:
+        Parameters for training. See :func:`get_params`.
+      model:
+        The model for training. It is an instance of Conformer in our case.
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      is_training:
+        True for training. False for validation. When it is True, this
+        function enables autograd during computation; when it is False, it
+        disables autograd.
+     warmup: a floating point value which increases throughout training;
+        values >= 1.0 are fully warmed up and have all modules present.
+    """
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+    feature = batch["inputs"]
+    # at entry, feature is (N, T, C)
+    assert feature.ndim == 3
+    feature = feature.to(device)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    texts = batch["supervisions"]["text"]
+    y = sp.encode(texts, out_type=int)
+    y = k2.RaggedTensor(y).to(device)
+
+    with torch.set_grad_enabled(is_training):
+        simple_loss, pruned_loss = model(
+            x=feature,
+            x_lens=feature_lens,
+            y=y,
+            prune_range=params.prune_range,
+            am_scale=params.am_scale,
+            lm_scale=params.lm_scale,
+            warmup=warmup,
+            reduction="none",
+            delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
+        )
+        simple_loss_is_finite = torch.isfinite(simple_loss)
+        pruned_loss_is_finite = torch.isfinite(pruned_loss)
+        is_finite = simple_loss_is_finite & pruned_loss_is_finite
+        if not torch.all(is_finite):
+            logging.info(
+                "Not all losses are finite!\n"
+                f"simple_loss: {simple_loss}\n"
+                f"pruned_loss: {pruned_loss}"
+            )
+            display_and_save_batch(batch, params=params, sp=sp)
+            simple_loss = simple_loss[simple_loss_is_finite]
+            pruned_loss = pruned_loss[pruned_loss_is_finite]
+
+            # If the batch contains more than 10 utterances AND
+            # if either all simple_loss or pruned_loss is inf or nan,
+            # we stop the training process by raising an exception
+            if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite):
+                raise ValueError(
+                    "There are too many utterances in this batch "
+                    "leading to inf or nan losses."
+                )
+
+        simple_loss = simple_loss.sum()
+        pruned_loss = pruned_loss.sum()
+        # after the main warmup step, we keep pruned_loss_scale small
+        # for the same amount of time (model_warm_step), to avoid
+        # overwhelming the simple_loss and causing it to diverge,
+        # in case it had not fully learned the alignment yet.
+        pruned_loss_scale = (
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+        )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+    assert loss.requires_grad == is_training
+
+    info = MetricsTracker()
+    with warnings.catch_warnings():
+        warnings.simplefilter("ignore")
+        # info["frames"] is an approximate number for two reasons:
+        # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2
+        # (2) If some utterances in the batch lead to inf/nan loss, they
+        #     are filtered out.
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+    # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances`  # noqa
+    info["utterances"] = feature.size(0)
+    # averaged input duration in frames over utterances
+    info["utt_duration"] = feature_lens.sum().item()
+    # averaged padding proportion over utterances
+    info["utt_pad_proportion"] = (
+        ((feature.size(1) - feature_lens) / feature.size(1)).sum().item()
+    )
+
+    # Note: We use reduction=sum while computing the loss.
+    info["loss"] = loss.detach().cpu().item()
+    info["simple_loss"] = simple_loss.detach().cpu().item()
+    info["pruned_loss"] = pruned_loss.detach().cpu().item()
+
+    return loss, info
+
+
+def compute_validation_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    sp: spm.SentencePieceProcessor,
+    valid_dl: torch.utils.data.DataLoader,
+    world_size: int = 1,
+) -> MetricsTracker:
+    """Run the validation process."""
+    model.eval()
+
+    tot_loss = MetricsTracker()
+
+    for batch_idx, batch in enumerate(valid_dl):
+        loss, loss_info = compute_loss(
+            params=params,
+            model=model,
+            sp=sp,
+            batch=batch,
+            is_training=False,
+        )
+        assert loss.requires_grad is False
+        tot_loss = tot_loss + loss_info
+
+    if world_size > 1:
+        tot_loss.reduce(loss.device)
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    if loss_value < params.best_valid_loss:
+        params.best_valid_epoch = params.cur_epoch
+        params.best_valid_loss = loss_value
+
+    return tot_loss
+
+
+def train_one_epoch(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    optimizer: torch.optim.Optimizer,
+    scheduler: LRSchedulerType,
+    sp: spm.SentencePieceProcessor,
+    train_dl: torch.utils.data.DataLoader,
+    valid_dl: torch.utils.data.DataLoader,
+    scaler: GradScaler,
+    model_avg: Optional[nn.Module] = None,
+    tb_writer: Optional[SummaryWriter] = None,
+    world_size: int = 1,
+    rank: int = 0,
+) -> None:
+    """Train the model for one epoch.
+
+    The training loss from the mean of all frames is saved in
+    `params.train_loss`. It runs the validation process every
+    `params.valid_interval` batches.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The model for training.
+      optimizer:
+        The optimizer we are using.
+      scheduler:
+        The learning rate scheduler, we call step() every step.
+      train_dl:
+        Dataloader for the training dataset.
+      valid_dl:
+        Dataloader for the validation dataset.
+      scaler:
+        The scaler used for mix precision training.
+      model_avg:
+        The stored model averaged from the start of training.
+      tb_writer:
+        Writer to write log messages to tensorboard.
+      world_size:
+        Number of nodes in DDP training. If it is 1, DDP is disabled.
+      rank:
+        The rank of the node in DDP training. If no DDP is used, it should
+        be set to 0.
+    """
+    model.train()
+
+    tot_loss = MetricsTracker()
+
+    for batch_idx, batch in enumerate(train_dl):
+        params.batch_idx_train += 1
+        batch_size = len(batch["supervisions"]["text"])
+
+        try:
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, loss_info = compute_loss(
+                    params=params,
+                    model=model,
+                    sp=sp,
+                    batch=batch,
+                    is_training=True,
+                    warmup=(params.batch_idx_train / params.model_warm_step),
+                )
+            # summary stats
+            tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+            # NOTE: We use reduction==sum and loss is computed over utterances
+            # in the batch and there is no normalization to it so far.
+            scaler.scale(loss).backward()
+            scheduler.step_batch(params.batch_idx_train)
+            scaler.step(optimizer)
+            scaler.update()
+            optimizer.zero_grad()
+        except:  # noqa
+            display_and_save_batch(batch, params=params, sp=sp)
+            raise
+
+        if params.print_diagnostics and batch_idx == 5:
+            return
+
+        if (
+            rank == 0
+            and params.batch_idx_train > 0
+            and params.batch_idx_train % params.average_period == 0
+        ):
+            update_averaged_model(
+                params=params,
+                model_cur=model,
+                model_avg=model_avg,
+            )
+
+        if (
+            params.batch_idx_train > 0
+            and params.batch_idx_train % params.save_every_n == 0
+        ):
+            save_checkpoint_with_global_batch_idx(
+                out_dir=params.exp_dir,
+                global_batch_idx=params.batch_idx_train,
+                model=model,
+                model_avg=model_avg,
+                params=params,
+                optimizer=optimizer,
+                scheduler=scheduler,
+                sampler=train_dl.sampler,
+                scaler=scaler,
+                rank=rank,
+            )
+            remove_checkpoints(
+                out_dir=params.exp_dir,
+                topk=params.keep_last_k,
+                rank=rank,
+            )
+
+        if batch_idx % params.log_interval == 0:
+            cur_lr = scheduler.get_last_lr()[0]
+            logging.info(
+                f"Epoch {params.cur_epoch}, "
+                f"batch {batch_idx}, loss[{loss_info}], "
+                f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+                f"lr: {cur_lr:.2e}"
+            )
+
+            if tb_writer is not None:
+                tb_writer.add_scalar(
+                    "train/learning_rate", cur_lr, params.batch_idx_train
+                )
+
+                loss_info.write_summary(
+                    tb_writer, "train/current_", params.batch_idx_train
+                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+
+        if batch_idx > 0 and batch_idx % params.valid_interval == 0:
+            logging.info("Computing validation loss")
+            valid_info = compute_validation_loss(
+                params=params,
+                model=model,
+                sp=sp,
+                valid_dl=valid_dl,
+                world_size=world_size,
+            )
+            model.train()
+            logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+            if tb_writer is not None:
+                valid_info.write_summary(
+                    tb_writer, "train/valid_", params.batch_idx_train
+                )
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    params.train_loss = loss_value
+    if params.train_loss < params.best_train_loss:
+        params.best_train_epoch = params.cur_epoch
+        params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+    """
+    Args:
+      rank:
+        It is a value between 0 and `world_size-1`, which is
+        passed automatically by `mp.spawn()` in :func:`main`.
+        The node with rank 0 is responsible for saving checkpoint.
+      world_size:
+        Number of GPUs for DDP training.
+      args:
+        The return value of get_parser().parse_args()
+    """
+    params = get_params()
+    params.update(vars(args))
+
+    fix_random_seed(params.seed)
+    if world_size > 1:
+        setup_dist(rank, world_size, params.master_port)
+
+    setup_logger(f"{params.exp_dir}/log/log-train")
+    logging.info("Training started")
+
+    if args.tensorboard and rank == 0:
+        tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+    else:
+        tb_writer = None
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", rank)
+    logging.info(f"Device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    if params.dynamic_chunk_training:
+        assert (
+            params.causal_convolution
+        ), "dynamic_chunk_training requires causal convolution"
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    assert params.save_every_n >= params.average_period
+    model_avg: Optional[nn.Module] = None
+    if rank == 0:
+        # model_avg is only used with rank 0
+        model_avg = copy.deepcopy(model)
+
+    assert params.start_epoch > 0, params.start_epoch
+    checkpoints = load_checkpoint_if_available(
+        params=params, model=model, model_avg=model_avg
+    )
+
+    model.to(device)
+    if world_size > 1:
+        logging.info("Using DDP")
+        model = DDP(model, device_ids=[rank])
+
+    optimizer = Eve(model.parameters(), lr=params.initial_lr)
+
+    scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+    if checkpoints and "optimizer" in checkpoints:
+        logging.info("Loading optimizer state dict")
+        optimizer.load_state_dict(checkpoints["optimizer"])
+
+    if (
+        checkpoints
+        and "scheduler" in checkpoints
+        and checkpoints["scheduler"] is not None
+    ):
+        logging.info("Loading scheduler state dict")
+        scheduler.load_state_dict(checkpoints["scheduler"])
+
+    if params.print_diagnostics:
+        opts = diagnostics.TensorDiagnosticOptions(
+            2**22
+        )  # allow 4 megabytes per sub-module
+        diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+    xbmu_amdo = Xbmu_AmdoAsrDataModule(args)
+
+    train_cuts = xbmu_amdo.train_cuts()
+
+    def remove_short_and_long_utt(c: Cut):
+        # Keep only utterances with duration between 1 second and 20 seconds
+        #
+        # Caution: There is a reason to select 20.0 here. Please see
+        # ../local/display_manifest_statistics.py
+        #
+        # You should use ../local/display_manifest_statistics.py to get
+        # an utterance duration distribution for your dataset to select
+        # the threshold
+        if c.duration < 1.0 or c.duration > 20.0:
+            logging.warning(
+                f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+            )
+            return False
+
+        # In pruned RNN-T, we require that T >= S
+        # where T is the number of feature frames after subsampling
+        # and S is the number of tokens in the utterance
+
+        # In ./conformer.py, the conv module uses the following expression
+        # for subsampling
+        T = ((c.num_frames - 1) // 2 - 1) // 2
+        tokens = sp.encode(c.supervisions[0].text, out_type=str)
+
+        if T < len(tokens):
+            logging.warning(
+                f"Exclude cut with ID {c.id} from training. "
+                f"Number of frames (before subsampling): {c.num_frames}. "
+                f"Number of frames (after subsampling): {T}. "
+                f"Text: {c.supervisions[0].text}. "
+                f"Tokens: {tokens}. "
+                f"Number of tokens: {len(tokens)}"
+            )
+            return False
+
+        return True
+
+    train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+    if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+        # We only load the sampler's state dict when it loads a checkpoint
+        # saved in the middle of an epoch
+        sampler_state_dict = checkpoints["sampler"]
+    else:
+        sampler_state_dict = None
+
+    train_dl = xbmu_amdo.train_dataloaders(
+        train_cuts, sampler_state_dict=sampler_state_dict
+    )
+
+    valid_cuts = xbmu_amdo.valid_cuts()
+    valid_dl = xbmu_amdo.valid_dataloaders(valid_cuts)
+
+    if params.start_batch <= 0 and not params.print_diagnostics:
+        scan_pessimistic_batches_for_oom(
+            model=model,
+            train_dl=train_dl,
+            optimizer=optimizer,
+            sp=sp,
+            params=params,
+            warmup=0.0 if params.start_epoch == 1 else 1.0,
+        )
+
+    scaler = GradScaler(enabled=params.use_fp16)
+    if checkpoints and "grad_scaler" in checkpoints:
+        logging.info("Loading grad scaler state dict")
+        scaler.load_state_dict(checkpoints["grad_scaler"])
+
+    for epoch in range(params.start_epoch, params.num_epochs + 1):
+        scheduler.step_epoch(epoch - 1)
+        fix_random_seed(params.seed + epoch - 1)
+        train_dl.sampler.set_epoch(epoch - 1)
+
+        if tb_writer is not None:
+            tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+        params.cur_epoch = epoch
+
+        train_one_epoch(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sp=sp,
+            train_dl=train_dl,
+            valid_dl=valid_dl,
+            scaler=scaler,
+            tb_writer=tb_writer,
+            world_size=world_size,
+            rank=rank,
+        )
+
+        if params.print_diagnostics:
+            diagnostic.print_diagnostics()
+            break
+
+        save_checkpoint(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sampler=train_dl.sampler,
+            scaler=scaler,
+            rank=rank,
+        )
+
+    logging.info("Done!")
+
+    if world_size > 1:
+        torch.distributed.barrier()
+        cleanup_dist()
+
+
+def scan_pessimistic_batches_for_oom(
+    model: Union[nn.Module, DDP],
+    train_dl: torch.utils.data.DataLoader,
+    optimizer: torch.optim.Optimizer,
+    sp: spm.SentencePieceProcessor,
+    params: AttributeDict,
+    warmup: float,
+):
+    from lhotse.dataset import find_pessimistic_batches
+
+    logging.info(
+        "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+    )
+    batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+    for criterion, cuts in batches.items():
+        batch = train_dl.dataset[cuts]
+        try:
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, _ = compute_loss(
+                    params=params,
+                    model=model,
+                    sp=sp,
+                    batch=batch,
+                    is_training=True,
+                    warmup=warmup,
+                )
+            loss.backward()
+            optimizer.step()
+            optimizer.zero_grad()
+        except Exception as e:
+            if "CUDA out of memory" in str(e):
+                logging.error(
+                    "Your GPU ran out of memory with the current "
+                    "max_duration setting. We recommend decreasing "
+                    "max_duration and trying again.\n"
+                    f"Failing criterion: {criterion} "
+                    f"(={crit_values[criterion]}) ..."
+                )
+            display_and_save_batch(batch, params=params, sp=sp)
+            raise
+
+
+def main():
+    parser = get_parser()
+    Xbmu_AmdoAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    world_size = args.world_size
+    assert world_size >= 1
+    if world_size > 1:
+        mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+    else:
+        run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/__init__.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/asr_datamodule.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/asr_datamodule.py
new file mode 120000
index 000000000..c473a600a
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/asr_datamodule.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless5/asr_datamodule.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/beam_search.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/beam_search.py
new file mode 120000
index 000000000..e24eca39f
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py
new file mode 100755
index 000000000..ace792e13
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py
@@ -0,0 +1,843 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
+#                                                 Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./pruned_transducer_stateless7/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7/exp \
+    --max-duration 600 \
+    --decoding-method greedy_search
+
+(2) beam search (not recommended)
+./pruned_transducer_stateless7/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7/exp \
+    --max-duration 600 \
+    --decoding-method beam_search \
+    --beam-size 4
+
+(3) modified beam search
+./pruned_transducer_stateless7/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7/exp \
+    --max-duration 600 \
+    --decoding-method modified_beam_search \
+    --beam-size 4
+
+(4) fast beam search (one best)
+./pruned_transducer_stateless7/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7/exp \
+    --max-duration 600 \
+    --decoding-method fast_beam_search \
+    --beam 20.0 \
+    --max-contexts 8 \
+    --max-states 64
+
+(5) fast beam search (nbest)
+./pruned_transducer_stateless7/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7/exp \
+    --max-duration 600 \
+    --decoding-method fast_beam_search_nbest \
+    --beam 20.0 \
+    --max-contexts 8 \
+    --max-states 64 \
+    --num-paths 200 \
+    --nbest-scale 0.5
+
+(6) fast beam search (nbest oracle WER)
+./pruned_transducer_stateless7/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7/exp \
+    --max-duration 600 \
+    --decoding-method fast_beam_search_nbest_oracle \
+    --beam 20.0 \
+    --max-contexts 8 \
+    --max-states 64 \
+    --num-paths 200 \
+    --nbest-scale 0.5
+
+(7) fast beam search (with LG)
+./pruned_transducer_stateless7/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7/exp \
+    --max-duration 600 \
+    --decoding-method fast_beam_search_nbest_LG \
+    --beam 20.0 \
+    --max-contexts 8 \
+    --max-states 64
+"""
+
+
+import argparse
+import logging
+import math
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import Xbmu_AmdoAsrDataModule
+from beam_search import (
+    beam_search,
+    fast_beam_search_nbest,
+    fast_beam_search_nbest_LG,
+    fast_beam_search_nbest_oracle,
+    fast_beam_search_one_best,
+    greedy_search,
+    greedy_search_batch,
+    modified_beam_search,
+)
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+    AttributeDict,
+    setup_logger,
+    store_transcripts,
+    str2bool,
+    write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=30,
+        help="""It specifies the checkpoint to use for decoding.
+        Note: Epoch counts from 1.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=9,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=True,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless7/exp",
+        help="The experiment dir",
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_500/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=Path,
+        default="data/lang_bpe_500",
+        help="The lang dir containing word table and LG graph",
+    )
+
+    parser.add_argument(
+        "--decoding-method",
+        type=str,
+        default="greedy_search",
+        help="""Possible values are:
+          - greedy_search
+          - beam_search
+          - modified_beam_search
+          - fast_beam_search
+          - fast_beam_search_nbest
+          - fast_beam_search_nbest_oracle
+          - fast_beam_search_nbest_LG
+        If you use fast_beam_search_nbest_LG, you have to specify
+        `--lang-dir`, which should contain `LG.pt`.
+        """,
+    )
+
+    parser.add_argument(
+        "--beam-size",
+        type=int,
+        default=4,
+        help="""An integer indicating how many candidates we will keep for each
+        frame. Used only when --decoding-method is beam_search or
+        modified_beam_search.""",
+    )
+
+    parser.add_argument(
+        "--beam",
+        type=float,
+        default=20.0,
+        help="""A floating point value to calculate the cutoff score during beam
+        search (i.e., `cutoff = max-score - beam`), which is the same as the
+        `beam` in Kaldi.
+        Used only when --decoding-method is fast_beam_search,
+        fast_beam_search_nbest, fast_beam_search_nbest_LG,
+        and fast_beam_search_nbest_oracle
+        """,
+    )
+
+    parser.add_argument(
+        "--ngram-lm-scale",
+        type=float,
+        default=0.01,
+        help="""
+        Used only when --decoding_method is fast_beam_search_nbest_LG.
+        It specifies the scale for n-gram LM scores.
+        """,
+    )
+
+    parser.add_argument(
+        "--max-contexts",
+        type=int,
+        default=8,
+        help="""Used only when --decoding-method is
+        fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+        and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--max-states",
+        type=int,
+        default=64,
+        help="""Used only when --decoding-method is
+        fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+        and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+    parser.add_argument(
+        "--max-sym-per-frame",
+        type=int,
+        default=1,
+        help="""Maximum number of symbols per frame.
+        Used only when --decoding_method is greedy_search""",
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=200,
+        help="""Number of paths for nbest decoding.
+        Used only when the decoding method is fast_beam_search_nbest,
+        fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=0.5,
+        help="""Scale applied to lattice scores when computing nbest paths.
+        Used only when the decoding method is fast_beam_search_nbest,
+        fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--simulate-streaming",
+        type=str2bool,
+        default=False,
+        help="""Whether to simulate streaming in decoding, this is a good way to
+        test a streaming model.
+        """,
+    )
+
+    parser.add_argument(
+        "--decode-chunk-size",
+        type=int,
+        default=16,
+        help="The chunk size for decoding (in frames after subsampling)",
+    )
+
+    parser.add_argument(
+        "--left-context",
+        type=int,
+        default=64,
+        help="left context can be seen during decoding (in frames after subsampling)",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def decode_one_batch(
+    params: AttributeDict,
+    model: nn.Module,
+    sp: spm.SentencePieceProcessor,
+    batch: dict,
+    word_table: Optional[k2.SymbolTable] = None,
+    decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+    """Decode one batch and return the result in a dict. The dict has the
+    following format:
+
+        - key: It indicates the setting used for decoding. For example,
+               if greedy_search is used, it would be "greedy_search"
+               If beam search with a beam size of 7 is used, it would be
+               "beam_7"
+        - value: It contains the decoding result. `len(value)` equals to
+                 batch size. `value[i]` is the decoding result for the i-th
+                 utterance in the given batch.
+    Args:
+      params:
+        It's the return value of :func:`get_params`.
+      model:
+        The neural model.
+      sp:
+        The BPE model.
+      batch:
+        It is the return value from iterating
+        `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+        for the format of the `batch`.
+      word_table:
+        The word symbol table.
+      decoding_graph:
+        The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+        only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
+        fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+    Returns:
+      Return the decoding result. See above description for the format of
+      the returned dict.
+    """
+    device = next(model.parameters()).device
+    feature = batch["inputs"]
+    assert feature.ndim == 3
+
+    feature = feature.to(device)
+    # at entry, feature is (N, T, C)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    if params.simulate_streaming:
+        feature_lens += params.left_context
+        feature = torch.nn.functional.pad(
+            feature,
+            pad=(0, 0, 0, params.left_context),
+            value=LOG_EPS,
+        )
+        encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
+            x=feature,
+            x_lens=feature_lens,
+            chunk_size=params.decode_chunk_size,
+            left_context=params.left_context,
+            simulate_streaming=True,
+        )
+    else:
+        encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+
+    hyps = []
+
+    if params.decoding_method == "fast_beam_search":
+        hyp_tokens = fast_beam_search_one_best(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "fast_beam_search_nbest_LG":
+        hyp_tokens = fast_beam_search_nbest_LG(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+            num_paths=params.num_paths,
+            nbest_scale=params.nbest_scale,
+        )
+        for hyp in hyp_tokens:
+            hyps.append([word_table[i] for i in hyp])
+    elif params.decoding_method == "fast_beam_search_nbest":
+        hyp_tokens = fast_beam_search_nbest(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+            num_paths=params.num_paths,
+            nbest_scale=params.nbest_scale,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "fast_beam_search_nbest_oracle":
+        hyp_tokens = fast_beam_search_nbest_oracle(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+            num_paths=params.num_paths,
+            ref_texts=sp.encode(supervisions["text"]),
+            nbest_scale=params.nbest_scale,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+        hyp_tokens = greedy_search_batch(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "modified_beam_search":
+        hyp_tokens = modified_beam_search(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam_size,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    else:
+        batch_size = encoder_out.size(0)
+
+        for i in range(batch_size):
+            # fmt: off
+            encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+            # fmt: on
+            if params.decoding_method == "greedy_search":
+                hyp = greedy_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    max_sym_per_frame=params.max_sym_per_frame,
+                )
+            elif params.decoding_method == "beam_search":
+                hyp = beam_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    beam=params.beam_size,
+                )
+            else:
+                raise ValueError(
+                    f"Unsupported decoding method: {params.decoding_method}"
+                )
+            hyps.append(sp.decode(hyp).split())
+
+    if params.decoding_method == "greedy_search":
+        return {"greedy_search": hyps}
+    elif "fast_beam_search" in params.decoding_method:
+        key = f"beam_{params.beam}_"
+        key += f"max_contexts_{params.max_contexts}_"
+        key += f"max_states_{params.max_states}"
+        if "nbest" in params.decoding_method:
+            key += f"_num_paths_{params.num_paths}_"
+            key += f"nbest_scale_{params.nbest_scale}"
+            if "LG" in params.decoding_method:
+                key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+
+        return {key: hyps}
+    else:
+        return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+    dl: torch.utils.data.DataLoader,
+    params: AttributeDict,
+    model: nn.Module,
+    sp: spm.SentencePieceProcessor,
+    word_table: Optional[k2.SymbolTable] = None,
+    decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+    """Decode dataset.
+
+    Args:
+      dl:
+        PyTorch's dataloader containing the dataset to decode.
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The neural model.
+      sp:
+        The BPE model.
+      word_table:
+        The word symbol table.
+      decoding_graph:
+        The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+        only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
+        fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+    Returns:
+      Return a dict, whose key may be "greedy_search" if greedy search
+      is used, or it may be "beam_7" if beam size of 7 is used.
+      Its value is a list of tuples. Each tuple contains two elements:
+      The first is the reference transcript, and the second is the
+      predicted result.
+    """
+    num_cuts = 0
+
+    try:
+        num_batches = len(dl)
+    except TypeError:
+        num_batches = "?"
+
+    if params.decoding_method == "greedy_search":
+        log_interval = 50
+    else:
+        log_interval = 20
+
+    results = defaultdict(list)
+    for batch_idx, batch in enumerate(dl):
+        texts = batch["supervisions"]["text"]
+        cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+        hyps_dict = decode_one_batch(
+            params=params,
+            model=model,
+            sp=sp,
+            decoding_graph=decoding_graph,
+            word_table=word_table,
+            batch=batch,
+        )
+
+        for name, hyps in hyps_dict.items():
+            this_batch = []
+            assert len(hyps) == len(texts)
+            for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+                ref_words = ref_text.split()
+                this_batch.append((cut_id, ref_words, hyp_words))
+
+            results[name].extend(this_batch)
+
+        num_cuts += len(texts)
+
+        if batch_idx % log_interval == 0:
+            batch_str = f"{batch_idx}/{num_batches}"
+
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+    return results
+
+
+def save_results(
+    params: AttributeDict,
+    test_set_name: str,
+    results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+    test_set_wers = dict()
+    for key, results in results_dict.items():
+        recog_path = (
+            params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+        )
+        results = sorted(results)
+        store_transcripts(filename=recog_path, texts=results)
+        logging.info(f"The transcripts are stored in {recog_path}")
+
+        # The following prints out WERs, per-word error statistics and aligned
+        # ref/hyp pairs.
+        errs_filename = (
+            params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+        )
+        with open(errs_filename, "w") as f:
+            wer = write_error_stats(
+                f, f"{test_set_name}-{key}", results, enable_log=True
+            )
+            test_set_wers[key] = wer
+
+        logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+    test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+    errs_info = (
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+    )
+    with open(errs_info, "w") as f:
+        print("settings\tWER", file=f)
+        for key, val in test_set_wers:
+            print("{}\t{}".format(key, val), file=f)
+
+    s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+    note = "\tbest for {}".format(test_set_name)
+    for key, val in test_set_wers:
+        s += "{}\t{}{}\n".format(key, val, note)
+        note = ""
+    logging.info(s)
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    Xbmu_AmdoAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    params = get_params()
+    params.update(vars(args))
+
+    assert params.decoding_method in (
+        "greedy_search",
+        "beam_search",
+        "fast_beam_search",
+        "fast_beam_search_nbest",
+        "fast_beam_search_nbest_LG",
+        "fast_beam_search_nbest_oracle",
+        "modified_beam_search",
+    )
+    params.res_dir = params.exp_dir / params.decoding_method
+
+    if params.iter > 0:
+        params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+    else:
+        params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+    if params.simulate_streaming:
+        params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
+        params.suffix += f"-left-context-{params.left_context}"
+
+    if "fast_beam_search" in params.decoding_method:
+        params.suffix += f"-beam-{params.beam}"
+        params.suffix += f"-max-contexts-{params.max_contexts}"
+        params.suffix += f"-max-states-{params.max_states}"
+        if "nbest" in params.decoding_method:
+            params.suffix += f"-nbest-scale-{params.nbest_scale}"
+            params.suffix += f"-num-paths-{params.num_paths}"
+            if "LG" in params.decoding_method:
+                params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+    elif "beam_search" in params.decoding_method:
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+    else:
+        params.suffix += f"-context-{params.context_size}"
+        params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+    if params.use_averaged_model:
+        params.suffix += "-use-averaged-model"
+
+    setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+    logging.info("Decoding started")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"Device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  and  are defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.unk_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    if params.simulate_streaming:
+        assert (
+            params.causal_convolution
+        ), "Decoding in streaming requires causal convolution"
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to(device)
+    model.eval()
+
+    if "fast_beam_search" in params.decoding_method:
+        if params.decoding_method == "fast_beam_search_nbest_LG":
+            lexicon = Lexicon(params.lang_dir)
+            word_table = lexicon.word_table
+            lg_filename = params.lang_dir / "LG.pt"
+            logging.info(f"Loading {lg_filename}")
+            decoding_graph = k2.Fsa.from_dict(
+                torch.load(lg_filename, map_location=device)
+            )
+            decoding_graph.scores *= params.ngram_lm_scale
+        else:
+            word_table = None
+            decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+    else:
+        decoding_graph = None
+        word_table = None
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    # we need cut ids to display recognition results.
+    args.return_cuts = True
+    xbmu_amdo = Xbmu_AmdoAsrDataModule(args)
+
+    test_cuts = xbmu_amdo.test_cuts()
+
+    test_dl = xbmu_amdo.test_dataloaders(test_cuts)
+
+    test_sets = [
+        "test",
+    ]
+    test_dl = [
+        test_dl,
+    ]
+
+    for test_set, test_dl in zip(test_sets, test_dl):
+        results_dict = decode_dataset(
+            dl=test_dl,
+            params=params,
+            model=model,
+            sp=sp,
+            word_table=word_table,
+            decoding_graph=decoding_graph,
+        )
+
+        save_results(
+            params=params,
+            test_set_name=test_set,
+            results_dict=results_dict,
+        )
+
+    logging.info("Done!")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decoder.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decoder.py
new file mode 120000
index 000000000..8283d8c5a
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decoder.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/encoder_interface.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/encoder_interface.py
new file mode 120000
index 000000000..f58253127
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/encoder_interface.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/export.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/export.py
new file mode 120000
index 000000000..2713792e6
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/export.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/export.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/jit_pretrained.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/jit_pretrained.py
new file mode 120000
index 000000000..a44034e34
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/jit_pretrained.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/joiner.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/joiner.py
new file mode 120000
index 000000000..0f0c3c90a
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/joiner.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/model.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/model.py
new file mode 120000
index 000000000..0d8bc665b
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/model.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/optim.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/optim.py
new file mode 120000
index 000000000..8a05abb5f
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/optim.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/pretrained.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/pretrained.py
new file mode 100755
index 000000000..d05bafcfb
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/pretrained.py
@@ -0,0 +1,355 @@
+#!/usr/bin/env python3
+# Copyright      2021  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script loads a checkpoint and uses it to decode waves.
+You can generate the checkpoint with the following command:
+
+./pruned_transducer_stateless7/export.py \
+  --exp-dir ./pruned_transducer_stateless7/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 20 \
+  --avg 10
+
+Usage of this script:
+
+(1) greedy search
+./pruned_transducer_stateless7/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method greedy_search \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+(2) beam search
+./pruned_transducer_stateless7/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method beam_search \
+    --beam-size 4 \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+(3) modified beam search
+./pruned_transducer_stateless7/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method modified_beam_search \
+    --beam-size 4 \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+(4) fast beam search
+./pruned_transducer_stateless7/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method fast_beam_search \
+    --beam-size 4 \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+You can also use `./pruned_transducer_stateless7/exp/epoch-xx.pt`.
+
+Note: ./pruned_transducer_stateless7/exp/pretrained.pt is generated by
+./pruned_transducer_stateless7/export.py
+"""
+
+
+import argparse
+import logging
+import math
+from typing import List
+
+import k2
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from beam_search import (
+    beam_search,
+    fast_beam_search_one_best,
+    greedy_search,
+    greedy_search_batch,
+    modified_beam_search,
+)
+from torch.nn.utils.rnn import pad_sequence
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.utils import str2bool
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--checkpoint",
+        type=str,
+        required=True,
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        help="""Path to bpe.model.""",
+    )
+
+    parser.add_argument(
+        "--method",
+        type=str,
+        default="greedy_search",
+        help="""Possible values are:
+          - greedy_search
+          - beam_search
+          - modified_beam_search
+          - fast_beam_search
+        """,
+    )
+
+    parser.add_argument(
+        "sound_files",
+        type=str,
+        nargs="+",
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
+    )
+
+    parser.add_argument(
+        "--sample-rate",
+        type=int,
+        default=16000,
+        help="The sample rate of the input sound file",
+    )
+
+    parser.add_argument(
+        "--beam-size",
+        type=int,
+        default=4,
+        help="""An integer indicating how many candidates we will keep for each
+        frame. Used only when --method is beam_search or
+        modified_beam_search.""",
+    )
+
+    parser.add_argument(
+        "--beam",
+        type=float,
+        default=4,
+        help="""A floating point value to calculate the cutoff score during beam
+        search (i.e., `cutoff = max-score - beam`), which is the same as the
+        `beam` in Kaldi.
+        Used only when --method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--max-contexts",
+        type=int,
+        default=4,
+        help="""Used only when --method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--max-states",
+        type=int,
+        default=8,
+        help="""Used only when --method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+    parser.add_argument(
+        "--max-sym-per-frame",
+        type=int,
+        default=1,
+        help="""Maximum number of symbols per frame. Used only when
+        --method is greedy_search.
+        """,
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def read_sound_files(
+    filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+    """Read a list of sound files into a list 1-D float32 torch tensors.
+    Args:
+      filenames:
+        A list of sound filenames.
+      expected_sample_rate:
+        The expected sample rate of the sound files.
+    Returns:
+      Return a list of 1-D float32 torch tensors.
+    """
+    ans = []
+    for f in filenames:
+        wave, sample_rate = torchaudio.load(f)
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        # We use only the first channel
+        ans.append(wave[0])
+    return ans
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+
+    params = get_params()
+
+    params.update(vars(args))
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.unk_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(f"{params}")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    logging.info("Creating model")
+    model = get_transducer_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    checkpoint = torch.load(args.checkpoint, map_location="cpu")
+    model.load_state_dict(checkpoint["model"], strict=False)
+    model.to(device)
+    model.eval()
+    model.device = device
+
+    logging.info("Constructing Fbank computer")
+    opts = kaldifeat.FbankOptions()
+    opts.device = device
+    opts.frame_opts.dither = 0
+    opts.frame_opts.snip_edges = False
+    opts.frame_opts.samp_freq = params.sample_rate
+    opts.mel_opts.num_bins = params.feature_dim
+
+    fbank = kaldifeat.Fbank(opts)
+
+    logging.info(f"Reading sound files: {params.sound_files}")
+    waves = read_sound_files(
+        filenames=params.sound_files, expected_sample_rate=params.sample_rate
+    )
+    waves = [w.to(device) for w in waves]
+
+    logging.info("Decoding started")
+    features = fbank(waves)
+    feature_lengths = [f.size(0) for f in features]
+
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+
+    feature_lengths = torch.tensor(feature_lengths, device=device)
+
+    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
+
+    num_waves = encoder_out.size(0)
+    hyps = []
+    msg = f"Using {params.method}"
+    if params.method == "beam_search":
+        msg += f" with beam size {params.beam_size}"
+    logging.info(msg)
+
+    if params.method == "fast_beam_search":
+        decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+        hyp_tokens = fast_beam_search_one_best(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.method == "modified_beam_search":
+        hyp_tokens = modified_beam_search(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam_size,
+        )
+
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
+        hyp_tokens = greedy_search_batch(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    else:
+        for i in range(num_waves):
+            # fmt: off
+            encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+            # fmt: on
+            if params.method == "greedy_search":
+                hyp = greedy_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    max_sym_per_frame=params.max_sym_per_frame,
+                )
+            elif params.method == "beam_search":
+                hyp = beam_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    beam=params.beam_size,
+                )
+            else:
+                raise ValueError(f"Unsupported method: {params.method}")
+
+            hyps.append(sp.decode(hyp).split())
+
+    s = "\n"
+    for filename, hyp in zip(params.sound_files, hyps):
+        words = " ".join(hyp)
+        s += f"{filename}:\n{words}\n\n"
+    logging.info(s)
+
+    logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/scaling.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/scaling.py
new file mode 120000
index 000000000..5f9be9fe0
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/scaling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/scaling_converter.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/scaling_converter.py
new file mode 120000
index 000000000..f9960e5c6
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/scaling_converter.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/test_model.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/test_model.py
new file mode 120000
index 000000000..7ceac5d10
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/test_model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/test_model.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py
new file mode 100755
index 000000000..1332bafd8
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py
@@ -0,0 +1,1224 @@
+#!/usr/bin/env python3
+# Copyright    2021-2022  Xiaomi Corp.        (authors: Fangjun Kuang,
+#                                                       Wei Kang,
+#                                                       Mingshuang Luo,)
+#                                                       Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless7/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --exp-dir pruned_transducer_stateless7/exp \
+  --full-libri 1 \
+  --max-duration 300
+
+# For mix precision training:
+
+./pruned_transducer_stateless7/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --use-fp16 1 \
+  --exp-dir pruned_transducer_stateless7/exp \
+  --full-libri 1 \
+  --max-duration 550
+
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import Xbmu_AmdoAsrDataModule
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import Transducer
+from optim import Eden, ScaledAdam
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+    save_checkpoint_with_global_batch_idx,
+    update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+    if isinstance(model, DDP):
+        # get underlying nn.Module
+        model = model.module
+    for module in model.modules():
+        if hasattr(module, "batch_count"):
+            module.batch_count = batch_count
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+    parser.add_argument(
+        "--num-encoder-layers",
+        type=str,
+        default="2,4,3,2,4",
+        help="Number of zipformer encoder layers, comma separated.",
+    )
+
+    parser.add_argument(
+        "--feedforward-dims",
+        type=str,
+        default="1024,1024,2048,2048,1024",
+        help="Feedforward dimension of the zipformer encoder layers, comma separated.",
+    )
+
+    parser.add_argument(
+        "--nhead",
+        type=str,
+        default="8,8,8,8,8",
+        help="Number of attention heads in the zipformer encoder layers.",
+    )
+
+    parser.add_argument(
+        "--encoder-dims",
+        type=str,
+        default="384,384,384,384,384",
+        help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated",
+    )
+
+    parser.add_argument(
+        "--attention-dims",
+        type=str,
+        default="192,192,192,192,192",
+        help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
+        not the same as embedding dimension.""",
+    )
+
+    parser.add_argument(
+        "--encoder-unmasked-dims",
+        type=str,
+        default="256,256,256,256,256",
+        help="Unmasked dimensions in the encoders, relates to augmentation during training.  "
+        "Must be <= each of encoder_dims.  Empirically, less than 256 seems to make performance "
+        " worse.",
+    )
+
+    parser.add_argument(
+        "--zipformer-downsampling-factors",
+        type=str,
+        default="1,2,4,8,2",
+        help="Downsampling factor for each stack of encoder layers.",
+    )
+
+    parser.add_argument(
+        "--cnn-module-kernels",
+        type=str,
+        default="31,31,31,31,31",
+        help="Sizes of kernels in convolution modules",
+    )
+
+    parser.add_argument(
+        "--decoder-dim",
+        type=int,
+        default=512,
+        help="Embedding dimension in the decoder model.",
+    )
+
+    parser.add_argument(
+        "--joiner-dim",
+        type=int,
+        default=512,
+        help="""Dimension used in the joiner model.
+        Outputs from the encoder and decoder model are projected
+        to this dimension before adding.
+        """,
+    )
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--world-size",
+        type=int,
+        default=1,
+        help="Number of GPUs for DDP training.",
+    )
+
+    parser.add_argument(
+        "--master-port",
+        type=int,
+        default=12354,
+        help="Master port to use for DDP training.",
+    )
+
+    parser.add_argument(
+        "--tensorboard",
+        type=str2bool,
+        default=True,
+        help="Should various information be logged in tensorboard.",
+    )
+
+    parser.add_argument(
+        "--num-epochs",
+        type=int,
+        default=30,
+        help="Number of epochs to train.",
+    )
+
+    parser.add_argument(
+        "--start-epoch",
+        type=int,
+        default=1,
+        help="""Resume training from this epoch. It should be positive.
+        If larger than 1, it will load checkpoint from
+        exp-dir/epoch-{start_epoch-1}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--start-batch",
+        type=int,
+        default=0,
+        help="""If positive, --start-epoch is ignored and
+        it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless7/exp",
+        help="""The experiment dir.
+        It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_500/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--base-lr", type=float, default=0.05, help="The base learning rate."
+    )
+
+    parser.add_argument(
+        "--lr-batches",
+        type=float,
+        default=5000,
+        help="""Number of steps that affects how rapidly the learning rate
+        decreases. We suggest not to change this.""",
+    )
+
+    parser.add_argument(
+        "--lr-epochs",
+        type=float,
+        default=3.5,
+        help="""Number of epochs that affects how rapidly the learning rate decreases.
+        """,
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+
+    parser.add_argument(
+        "--prune-range",
+        type=int,
+        default=5,
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
+    )
+
+    parser.add_argument(
+        "--lm-scale",
+        type=float,
+        default=0.25,
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
+    )
+
+    parser.add_argument(
+        "--am-scale",
+        type=float,
+        default=0.0,
+        help="The scale to smooth the loss with am (output of encoder network) part.",
+    )
+
+    parser.add_argument(
+        "--simple-loss-scale",
+        type=float,
+        default=0.5,
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
+    )
+
+    parser.add_argument(
+        "--seed",
+        type=int,
+        default=42,
+        help="The seed for random generators intended for reproducibility",
+    )
+
+    parser.add_argument(
+        "--print-diagnostics",
+        type=str2bool,
+        default=False,
+        help="Accumulate stats on activations, print them and exit.",
+    )
+
+    parser.add_argument(
+        "--inf-check",
+        type=str2bool,
+        default=False,
+        help="Add hooks to check for infinite module outputs and gradients.",
+    )
+
+    parser.add_argument(
+        "--save-every-n",
+        type=int,
+        default=2000,
+        help="""Save checkpoint after processing this number of batches"
+        periodically. We save checkpoint to exp-dir/ whenever
+        params.batch_idx_train % save_every_n == 0. The checkpoint filename
+        has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+        Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+        end of each epoch where `xxx` is the epoch number counting from 0.
+        """,
+    )
+
+    parser.add_argument(
+        "--keep-last-k",
+        type=int,
+        default=30,
+        help="""Only keep this number of checkpoints on disk.
+        For instance, if it is 3, there are only 3 checkpoints
+        in the exp-dir with filenames `checkpoint-xxx.pt`.
+        It does not affect checkpoints with name `epoch-xxx.pt`.
+        """,
+    )
+
+    parser.add_argument(
+        "--average-period",
+        type=int,
+        default=200,
+        help="""Update the averaged model, namely `model_avg`, after processing
+        this number of batches. `model_avg` is a separate version of model,
+        in which each floating-point parameter is the average of all the
+        parameters from the start of training. Each time we take the average,
+        we do: `model_avg = model * (average_period / batch_idx_train) +
+            model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+        """,
+    )
+
+    parser.add_argument(
+        "--use-fp16",
+        type=str2bool,
+        default=False,
+        help="Whether to use half precision training.",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    """Return a dict containing training parameters.
+
+    All training related parameters that are not passed from the commandline
+    are saved in the variable `params`.
+
+    Commandline options are merged into `params` after they are parsed, so
+    you can also access them via `params`.
+
+    Explanation of options saved in `params`:
+
+        - best_train_loss: Best training loss so far. It is used to select
+                           the model that has the lowest training loss. It is
+                           updated during the training.
+
+        - best_valid_loss: Best validation loss so far. It is used to select
+                           the model that has the lowest validation loss. It is
+                           updated during the training.
+
+        - best_train_epoch: It is the epoch that has the best training loss.
+
+        - best_valid_epoch: It is the epoch that has the best validation loss.
+
+        - batch_idx_train: Used to writing statistics to tensorboard. It
+                           contains number of batches trained so far across
+                           epochs.
+
+        - log_interval:  Print training loss if batch_idx % log_interval` is 0
+
+        - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+        - valid_interval:  Run validation if batch_idx % valid_interval is 0
+
+        - feature_dim: The model input dim. It has to match the one used
+                       in computing features.
+
+        - subsampling_factor:  The subsampling factor for the model.
+
+        - encoder_dim: Hidden dim for multi-head attention model.
+
+        - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+        - warm_step: The warmup period that dictates the decay of the
+              scale on "simple" (un-pruned) loss.
+    """
+    params = AttributeDict(
+        {
+            "best_train_loss": float("inf"),
+            "best_valid_loss": float("inf"),
+            "best_train_epoch": -1,
+            "best_valid_epoch": -1,
+            "batch_idx_train": 0,
+            "log_interval": 50,
+            "reset_interval": 200,
+            "valid_interval": 3000,  # For the 100h subset, use 800
+            # parameters for zipformer
+            "feature_dim": 80,
+            "subsampling_factor": 4,  # not passed in, this is fixed.
+            "warm_step": 2000,
+            "env_info": get_env_info(),
+        }
+    )
+
+    return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+    # TODO: We can add an option to switch between Zipformer and Transformer
+    def to_int_tuple(s: str):
+        return tuple(map(int, s.split(",")))
+
+    encoder = Zipformer(
+        num_features=params.feature_dim,
+        output_downsampling_factor=2,
+        zipformer_downsampling_factors=to_int_tuple(
+            params.zipformer_downsampling_factors
+        ),
+        encoder_dims=to_int_tuple(params.encoder_dims),
+        attention_dim=to_int_tuple(params.attention_dims),
+        encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
+        nhead=to_int_tuple(params.nhead),
+        feedforward_dim=to_int_tuple(params.feedforward_dims),
+        cnn_module_kernels=to_int_tuple(params.cnn_module_kernels),
+        num_encoder_layers=to_int_tuple(params.num_encoder_layers),
+    )
+    return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+    decoder = Decoder(
+        vocab_size=params.vocab_size,
+        decoder_dim=params.decoder_dim,
+        blank_id=params.blank_id,
+        context_size=params.context_size,
+    )
+    return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+    joiner = Joiner(
+        encoder_dim=int(params.encoder_dims.split(",")[-1]),
+        decoder_dim=params.decoder_dim,
+        joiner_dim=params.joiner_dim,
+        vocab_size=params.vocab_size,
+    )
+    return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+    encoder = get_encoder_model(params)
+    decoder = get_decoder_model(params)
+    joiner = get_joiner_model(params)
+
+    model = Transducer(
+        encoder=encoder,
+        decoder=decoder,
+        joiner=joiner,
+        encoder_dim=int(params.encoder_dims.split(",")[-1]),
+        decoder_dim=params.decoder_dim,
+        joiner_dim=params.joiner_dim,
+        vocab_size=params.vocab_size,
+    )
+    return model
+
+
+def load_checkpoint_if_available(
+    params: AttributeDict,
+    model: nn.Module,
+    model_avg: nn.Module = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+    """Load checkpoint from file.
+
+    If params.start_batch is positive, it will load the checkpoint from
+    `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+    params.start_epoch is larger than 1, it will load the checkpoint from
+    `params.start_epoch - 1`.
+
+    Apart from loading state dict for `model` and `optimizer` it also updates
+    `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+    and `best_valid_loss` in `params`.
+
+    Args:
+      params:
+        The return value of :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer that we are using.
+      scheduler:
+        The scheduler that we are using.
+    Returns:
+      Return a dict containing previously saved training info.
+    """
+    if params.start_batch > 0:
+        filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+    elif params.start_epoch > 1:
+        filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+    else:
+        return None
+
+    assert filename.is_file(), f"{filename} does not exist!"
+
+    saved_params = load_checkpoint(
+        filename,
+        model=model,
+        model_avg=model_avg,
+        optimizer=optimizer,
+        scheduler=scheduler,
+    )
+
+    keys = [
+        "best_train_epoch",
+        "best_valid_epoch",
+        "batch_idx_train",
+        "best_train_loss",
+        "best_valid_loss",
+    ]
+    for k in keys:
+        params[k] = saved_params[k]
+
+    if params.start_batch > 0:
+        if "cur_epoch" in saved_params:
+            params["start_epoch"] = saved_params["cur_epoch"]
+
+        if "cur_batch_idx" in saved_params:
+            params["cur_batch_idx"] = saved_params["cur_batch_idx"]
+
+    return saved_params
+
+
+def save_checkpoint(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    model_avg: Optional[nn.Module] = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+    sampler: Optional[CutSampler] = None,
+    scaler: Optional[GradScaler] = None,
+    rank: int = 0,
+) -> None:
+    """Save model, optimizer, scheduler and training stats to file.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer used in the training.
+      sampler:
+       The sampler for the training dataset.
+      scaler:
+        The scaler used for mix precision training.
+    """
+    if rank != 0:
+        return
+    filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+    save_checkpoint_impl(
+        filename=filename,
+        model=model,
+        model_avg=model_avg,
+        params=params,
+        optimizer=optimizer,
+        scheduler=scheduler,
+        sampler=sampler,
+        scaler=scaler,
+        rank=rank,
+    )
+
+    if params.best_train_epoch == params.cur_epoch:
+        best_train_filename = params.exp_dir / "best-train-loss.pt"
+        copyfile(src=filename, dst=best_train_filename)
+
+    if params.best_valid_epoch == params.cur_epoch:
+        best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+        copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    sp: spm.SentencePieceProcessor,
+    batch: dict,
+    is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+    """
+    Compute transducer loss given the model and its inputs.
+
+    Args:
+      params:
+        Parameters for training. See :func:`get_params`.
+      model:
+        The model for training. It is an instance of Zipformer in our case.
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      is_training:
+        True for training. False for validation. When it is True, this
+        function enables autograd during computation; when it is False, it
+        disables autograd.
+     warmup: a floating point value which increases throughout training;
+        values >= 1.0 are fully warmed up and have all modules present.
+    """
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+    feature = batch["inputs"]
+    # at entry, feature is (N, T, C)
+    assert feature.ndim == 3
+    feature = feature.to(device)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    batch_idx_train = params.batch_idx_train
+    warm_step = params.warm_step
+
+    texts = batch["supervisions"]["text"]
+    y = sp.encode(texts, out_type=int)
+    y = k2.RaggedTensor(y).to(device)
+
+    with torch.set_grad_enabled(is_training):
+        simple_loss, pruned_loss = model(
+            x=feature,
+            x_lens=feature_lens,
+            y=y,
+            prune_range=params.prune_range,
+            am_scale=params.am_scale,
+            lm_scale=params.lm_scale,
+        )
+
+        s = params.simple_loss_scale
+        # take down the scale on the simple loss from 1.0 at the start
+        # to params.simple_loss scale by warm_step.
+        simple_loss_scale = (
+            s
+            if batch_idx_train >= warm_step
+            else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+        )
+        pruned_loss_scale = (
+            1.0
+            if batch_idx_train >= warm_step
+            else 0.1 + 0.9 * (batch_idx_train / warm_step)
+        )
+
+        loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+    assert loss.requires_grad == is_training
+
+    info = MetricsTracker()
+    with warnings.catch_warnings():
+        warnings.simplefilter("ignore")
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+    # Note: We use reduction=sum while computing the loss.
+    info["loss"] = loss.detach().cpu().item()
+    info["simple_loss"] = simple_loss.detach().cpu().item()
+    info["pruned_loss"] = pruned_loss.detach().cpu().item()
+
+    return loss, info
+
+
+def compute_validation_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    sp: spm.SentencePieceProcessor,
+    valid_dl: torch.utils.data.DataLoader,
+    world_size: int = 1,
+) -> MetricsTracker:
+    """Run the validation process."""
+    model.eval()
+
+    tot_loss = MetricsTracker()
+
+    for batch_idx, batch in enumerate(valid_dl):
+        loss, loss_info = compute_loss(
+            params=params,
+            model=model,
+            sp=sp,
+            batch=batch,
+            is_training=False,
+        )
+        assert loss.requires_grad is False
+        tot_loss = tot_loss + loss_info
+
+    if world_size > 1:
+        tot_loss.reduce(loss.device)
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    if loss_value < params.best_valid_loss:
+        params.best_valid_epoch = params.cur_epoch
+        params.best_valid_loss = loss_value
+
+    return tot_loss
+
+
+def train_one_epoch(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    optimizer: torch.optim.Optimizer,
+    scheduler: LRSchedulerType,
+    sp: spm.SentencePieceProcessor,
+    train_dl: torch.utils.data.DataLoader,
+    valid_dl: torch.utils.data.DataLoader,
+    scaler: GradScaler,
+    model_avg: Optional[nn.Module] = None,
+    tb_writer: Optional[SummaryWriter] = None,
+    world_size: int = 1,
+    rank: int = 0,
+) -> None:
+    """Train the model for one epoch.
+
+    The training loss from the mean of all frames is saved in
+    `params.train_loss`. It runs the validation process every
+    `params.valid_interval` batches.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The model for training.
+      optimizer:
+        The optimizer we are using.
+      scheduler:
+        The learning rate scheduler, we call step() every step.
+      train_dl:
+        Dataloader for the training dataset.
+      valid_dl:
+        Dataloader for the validation dataset.
+      scaler:
+        The scaler used for mix precision training.
+      model_avg:
+        The stored model averaged from the start of training.
+      tb_writer:
+        Writer to write log messages to tensorboard.
+      world_size:
+        Number of nodes in DDP training. If it is 1, DDP is disabled.
+      rank:
+        The rank of the node in DDP training. If no DDP is used, it should
+        be set to 0.
+    """
+    model.train()
+
+    tot_loss = MetricsTracker()
+
+    cur_batch_idx = params.get("cur_batch_idx", 0)
+
+    for batch_idx, batch in enumerate(train_dl):
+        if batch_idx < cur_batch_idx:
+            continue
+        cur_batch_idx = batch_idx
+
+        params.batch_idx_train += 1
+        batch_size = len(batch["supervisions"]["text"])
+
+        try:
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, loss_info = compute_loss(
+                    params=params,
+                    model=model,
+                    sp=sp,
+                    batch=batch,
+                    is_training=True,
+                )
+            # summary stats
+            tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+            # NOTE: We use reduction==sum and loss is computed over utterances
+            # in the batch and there is no normalization to it so far.
+            scaler.scale(loss).backward()
+            set_batch_count(model, params.batch_idx_train)
+            scheduler.step_batch(params.batch_idx_train)
+
+            scaler.step(optimizer)
+            scaler.update()
+            optimizer.zero_grad()
+        except:  # noqa
+            display_and_save_batch(batch, params=params, sp=sp)
+            raise
+
+        if params.print_diagnostics and batch_idx == 5:
+            return
+
+        if (
+            rank == 0
+            and params.batch_idx_train > 0
+            and params.batch_idx_train % params.average_period == 0
+        ):
+            update_averaged_model(
+                params=params,
+                model_cur=model,
+                model_avg=model_avg,
+            )
+
+        if (
+            params.batch_idx_train > 0
+            and params.batch_idx_train % params.save_every_n == 0
+        ):
+            params.cur_batch_idx = batch_idx
+            save_checkpoint_with_global_batch_idx(
+                out_dir=params.exp_dir,
+                global_batch_idx=params.batch_idx_train,
+                model=model,
+                model_avg=model_avg,
+                params=params,
+                optimizer=optimizer,
+                scheduler=scheduler,
+                sampler=train_dl.sampler,
+                scaler=scaler,
+                rank=rank,
+            )
+            del params.cur_batch_idx
+            remove_checkpoints(
+                out_dir=params.exp_dir,
+                topk=params.keep_last_k,
+                rank=rank,
+            )
+
+        if batch_idx % 100 == 0 and params.use_fp16:
+            # If the grad scale was less than 1, try increasing it.    The _growth_interval
+            # of the grad scaler is configurable, but we can't configure it to have different
+            # behavior depending on the current grad scale.
+            cur_grad_scale = scaler._scale.item()
+            if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
+                scaler.update(cur_grad_scale * 2.0)
+            if cur_grad_scale < 0.01:
+                logging.warning(f"Grad scale is small: {cur_grad_scale}")
+            if cur_grad_scale < 1.0e-05:
+                raise RuntimeError(
+                    f"grad_scale is too small, exiting: {cur_grad_scale}"
+                )
+
+        if batch_idx % params.log_interval == 0:
+            cur_lr = scheduler.get_last_lr()[0]
+            cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+            logging.info(
+                f"Epoch {params.cur_epoch}, "
+                f"batch {batch_idx}, loss[{loss_info}], "
+                f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+                f"lr: {cur_lr:.2e}, "
+                + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+            )
+
+            if tb_writer is not None:
+                tb_writer.add_scalar(
+                    "train/learning_rate", cur_lr, params.batch_idx_train
+                )
+
+                loss_info.write_summary(
+                    tb_writer, "train/current_", params.batch_idx_train
+                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                if params.use_fp16:
+                    tb_writer.add_scalar(
+                        "train/grad_scale",
+                        cur_grad_scale,
+                        params.batch_idx_train,
+                    )
+
+        if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+            logging.info("Computing validation loss")
+            valid_info = compute_validation_loss(
+                params=params,
+                model=model,
+                sp=sp,
+                valid_dl=valid_dl,
+                world_size=world_size,
+            )
+            model.train()
+            logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+            logging.info(
+                f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+            )
+            if tb_writer is not None:
+                valid_info.write_summary(
+                    tb_writer, "train/valid_", params.batch_idx_train
+                )
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    params.train_loss = loss_value
+    if params.train_loss < params.best_train_loss:
+        params.best_train_epoch = params.cur_epoch
+        params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+    """
+    Args:
+      rank:
+        It is a value between 0 and `world_size-1`, which is
+        passed automatically by `mp.spawn()` in :func:`main`.
+        The node with rank 0 is responsible for saving checkpoint.
+      world_size:
+        Number of GPUs for DDP training.
+      args:
+        The return value of get_parser().parse_args()
+    """
+    params = get_params()
+    params.update(vars(args))
+
+    fix_random_seed(params.seed)
+    if world_size > 1:
+        setup_dist(rank, world_size, params.master_port)
+
+    setup_logger(f"{params.exp_dir}/log/log-train")
+    logging.info("Training started")
+
+    if args.tensorboard and rank == 0:
+        tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+    else:
+        tb_writer = None
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", rank)
+    logging.info(f"Device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    assert params.save_every_n >= params.average_period
+    model_avg: Optional[nn.Module] = None
+    if rank == 0:
+        # model_avg is only used with rank 0
+        model_avg = copy.deepcopy(model).to(torch.float64)
+
+    assert params.start_epoch > 0, params.start_epoch
+    checkpoints = load_checkpoint_if_available(
+        params=params, model=model, model_avg=model_avg
+    )
+
+    model.to(device)
+    if world_size > 1:
+        logging.info("Using DDP")
+        model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+    optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0)
+
+    scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+    if checkpoints and "optimizer" in checkpoints:
+        logging.info("Loading optimizer state dict")
+        optimizer.load_state_dict(checkpoints["optimizer"])
+
+    if (
+        checkpoints
+        and "scheduler" in checkpoints
+        and checkpoints["scheduler"] is not None
+    ):
+        logging.info("Loading scheduler state dict")
+        scheduler.load_state_dict(checkpoints["scheduler"])
+
+    if params.print_diagnostics:
+        opts = diagnostics.TensorDiagnosticOptions(
+            2**22
+        )  # allow 4 megabytes per sub-module
+        diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+    if params.inf_check:
+        register_inf_check_hooks(model)
+
+    xbmu_amdo = Xbmu_AmdoAsrDataModule(args)
+
+    train_cuts = xbmu_amdo.train_cuts()
+
+    def remove_short_and_long_utt(c: Cut):
+        # Keep only utterances with duration between 1 second and 20 seconds
+        #
+        # Caution: There is a reason to select 20.0 here. Please see
+        # ../local/display_manifest_statistics.py
+        #
+        # You should use ../local/display_manifest_statistics.py to get
+        # an utterance duration distribution for your dataset to select
+        # the threshold
+        if c.duration < 1.0 or c.duration > 20.0:
+            logging.warning(
+                f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+            )
+            return False
+
+        # In pruned RNN-T, we require that T >= S
+        # where T is the number of feature frames after subsampling
+        # and S is the number of tokens in the utterance
+
+        # In ./zipformer.py, the conv module uses the following expression
+        # for subsampling
+        T = ((c.num_frames - 7) // 2 + 1) // 2
+        tokens = sp.encode(c.supervisions[0].text, out_type=str)
+
+        if T < len(tokens):
+            logging.warning(
+                f"Exclude cut with ID {c.id} from training. "
+                f"Number of frames (before subsampling): {c.num_frames}. "
+                f"Number of frames (after subsampling): {T}. "
+                f"Text: {c.supervisions[0].text}. "
+                f"Tokens: {tokens}. "
+                f"Number of tokens: {len(tokens)}"
+            )
+            return False
+
+        return True
+
+    train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+    if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+        # We only load the sampler's state dict when it loads a checkpoint
+        # saved in the middle of an epoch
+        sampler_state_dict = checkpoints["sampler"]
+    else:
+        sampler_state_dict = None
+
+    train_dl = xbmu_amdo.train_dataloaders(
+        train_cuts, sampler_state_dict=sampler_state_dict
+    )
+
+    valid_cuts = xbmu_amdo.valid_cuts()
+    valid_dl = xbmu_amdo.valid_dataloaders(valid_cuts)
+
+    if not params.print_diagnostics:
+        scan_pessimistic_batches_for_oom(
+            model=model,
+            train_dl=train_dl,
+            optimizer=optimizer,
+            sp=sp,
+            params=params,
+        )
+
+    scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+    if checkpoints and "grad_scaler" in checkpoints:
+        logging.info("Loading grad scaler state dict")
+        scaler.load_state_dict(checkpoints["grad_scaler"])
+
+    for epoch in range(params.start_epoch, params.num_epochs + 1):
+        scheduler.step_epoch(epoch - 1)
+        fix_random_seed(params.seed + epoch - 1)
+        train_dl.sampler.set_epoch(epoch - 1)
+
+        if tb_writer is not None:
+            tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+        params.cur_epoch = epoch
+
+        train_one_epoch(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sp=sp,
+            train_dl=train_dl,
+            valid_dl=valid_dl,
+            scaler=scaler,
+            tb_writer=tb_writer,
+            world_size=world_size,
+            rank=rank,
+        )
+
+        if params.print_diagnostics:
+            diagnostic.print_diagnostics()
+            break
+
+        save_checkpoint(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sampler=train_dl.sampler,
+            scaler=scaler,
+            rank=rank,
+        )
+
+    logging.info("Done!")
+
+    if world_size > 1:
+        torch.distributed.barrier()
+        cleanup_dist()
+
+
+def display_and_save_batch(
+    batch: dict,
+    params: AttributeDict,
+    sp: spm.SentencePieceProcessor,
+) -> None:
+    """Display the batch statistics and save the batch into disk.
+
+    Args:
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      params:
+        Parameters for training. See :func:`get_params`.
+      sp:
+        The BPE model.
+    """
+    from lhotse.utils import uuid4
+
+    filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+    logging.info(f"Saving batch to {filename}")
+    torch.save(batch, filename)
+
+    supervisions = batch["supervisions"]
+    features = batch["inputs"]
+
+    logging.info(f"features shape: {features.shape}")
+
+    y = sp.encode(supervisions["text"], out_type=int)
+    num_tokens = sum(len(i) for i in y)
+    logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+    model: Union[nn.Module, DDP],
+    train_dl: torch.utils.data.DataLoader,
+    optimizer: torch.optim.Optimizer,
+    sp: spm.SentencePieceProcessor,
+    params: AttributeDict,
+):
+    from lhotse.dataset import find_pessimistic_batches
+
+    logging.info(
+        "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+    )
+    batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+    for criterion, cuts in batches.items():
+        batch = train_dl.dataset[cuts]
+        try:
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, _ = compute_loss(
+                    params=params,
+                    model=model,
+                    sp=sp,
+                    batch=batch,
+                    is_training=True,
+                )
+            loss.backward()
+            optimizer.zero_grad()
+        except Exception as e:
+            if "CUDA out of memory" in str(e):
+                logging.error(
+                    "Your GPU ran out of memory with the current "
+                    "max_duration setting. We recommend decreasing "
+                    "max_duration and trying again.\n"
+                    f"Failing criterion: {criterion} "
+                    f"(={crit_values[criterion]}) ..."
+                )
+            display_and_save_batch(batch, params=params, sp=sp)
+            raise
+        logging.info(
+            f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+        )
+
+
+def main():
+    parser = get_parser()
+    Xbmu_AmdoAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    world_size = args.world_size
+    assert world_size >= 1
+    if world_size > 1:
+        mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+    else:
+        run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/zipformer.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/zipformer.py
new file mode 120000
index 000000000..f2f66041e
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/zipformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/zipformer.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/shared b/egs/xbmu_amdo31/ASR/shared
new file mode 120000
index 000000000..4c5e91438
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/shared
@@ -0,0 +1 @@
+../../../icefall/shared/
\ No newline at end of file

From c25c8c6ad18b8a3d5de2f093947f7b2293eec35a Mon Sep 17 00:00:00 2001
From: Wei Kang 
Date: Sun, 4 Dec 2022 17:20:17 +0800
Subject: [PATCH 061/120] Add need_repeat_flag in phone based ctc graph
 compiler (#727)

* Fix is_repeat_token in icefall

* Fix phone based recipe

* Update egs/librispeech/ASR/conformer_ctc3/train.py

Co-authored-by: Fangjun Kuang 

* Fix black

Co-authored-by: Fangjun Kuang 
---
 egs/librispeech/ASR/conformer_ctc3/train.py |  1 +
 icefall/graph_compiler.py                   | 18 ++++++++++++++----
 2 files changed, 15 insertions(+), 4 deletions(-)

diff --git a/egs/librispeech/ASR/conformer_ctc3/train.py b/egs/librispeech/ASR/conformer_ctc3/train.py
index fb3b740c1..ac489af9e 100755
--- a/egs/librispeech/ASR/conformer_ctc3/train.py
+++ b/egs/librispeech/ASR/conformer_ctc3/train.py
@@ -890,6 +890,7 @@ def run(rank, world_size, args):
         graph_compiler = CtcTrainingGraphCompiler(
             lexicon,
             device=device,
+            need_repeat_flag=params.delay_penalty > 0,
         )
         # Manually add the sos/eos ID with their default values
         # from the BPE recipe which we're adapting here.
diff --git a/icefall/graph_compiler.py b/icefall/graph_compiler.py
index 0dcd777ad..d26ddbbd1 100644
--- a/icefall/graph_compiler.py
+++ b/icefall/graph_compiler.py
@@ -29,6 +29,7 @@ class CtcTrainingGraphCompiler(object):
         lexicon: Lexicon,
         device: torch.device,
         oov: str = "",
+        need_repeat_flag: bool = False,
     ):
         """
         Args:
@@ -39,6 +40,13 @@ class CtcTrainingGraphCompiler(object):
           oov:
             Out of vocabulary word. When a word in the transcript
             does not exist in the lexicon, it is replaced with `oov`.
+          need_repeat_flag:
+            If True, will add an attribute named `_is_repeat_token_` to ctc_topo
+            indicating whether this token is a repeat token in ctc graph.
+            This attribute is needed to implement delay-penalty for phone-based
+            ctc loss. See https://github.com/k2-fsa/k2/pull/1086 for more
+            details. Note: The above change MUST be included in k2 to open this
+            flag.
         """
         L_inv = lexicon.L_inv.to(device)
         assert L_inv.requires_grad is False
@@ -53,6 +61,12 @@ class CtcTrainingGraphCompiler(object):
         ctc_topo = k2.ctc_topo(max_token_id, modified=False)
 
         self.ctc_topo = ctc_topo.to(device)
+
+        if need_repeat_flag:
+            self.ctc_topo._is_repeat_token_ = (
+                self.ctc_topo.labels != self.ctc_topo.aux_labels
+            )
+
         self.device = device
 
     def compile(self, texts: List[str]) -> k2.Fsa:
@@ -79,10 +93,6 @@ class CtcTrainingGraphCompiler(object):
 
         fsa_with_self_loops = k2.arc_sort(fsa_with_self_loops)
 
-        self.ctc_topo._is_repeat_token_ = (
-            self.ctc_topo.labels != self.ctc_topo.aux_labels
-        ).int()
-
         decoding_graph = k2.compose(
             self.ctc_topo, fsa_with_self_loops, treat_epsilons_specially=False
         )

From bd7fa2253dab9f627edc914b3289fb2f6c0e5bb6 Mon Sep 17 00:00:00 2001
From: Fangjun Kuang 
Date: Sun, 4 Dec 2022 20:27:45 +0800
Subject: [PATCH 062/120] Update the manifest statistics of the L subset of
 wenetspeech (#731)

---
 .../ASR/local/display_manifest_statistics.py  | 19 +++++++++++++++++++
 1 file changed, 19 insertions(+)

diff --git a/egs/wenetspeech/ASR/local/display_manifest_statistics.py b/egs/wenetspeech/ASR/local/display_manifest_statistics.py
index c41445b8d..36e4ac5c3 100644
--- a/egs/wenetspeech/ASR/local/display_manifest_statistics.py
+++ b/egs/wenetspeech/ASR/local/display_manifest_statistics.py
@@ -33,6 +33,7 @@ def main():
     paths = [
         "./data/fbank/cuts_S.jsonl.gz",
         "./data/fbank/cuts_M.jsonl.gz",
+        "./data/fbank/cuts_L.jsonl.gz",
         "./data/fbank/cuts_DEV.jsonl.gz",
         "./data/fbank/cuts_TEST_NET.jsonl.gz",
         "./data/fbank/cuts_TEST_MEETING.jsonl.gz",
@@ -48,6 +49,24 @@ if __name__ == "__main__":
     main()
 
 """
+Starting display the statistics for ./data/fbank/cuts_L.jsonl.gz
+
+Cuts count: 43874235
+Total duration (hours): 30217.3
+Speech duration (hours): 30217.3 (100.0%)
+***
+Duration statistics (seconds):
+mean    2.5
+std     1.7
+min     0.2
+25%     1.4
+50%     2.0
+75%     3.0
+99%     8.4
+99.5%   9.1
+99.9%   15.4
+max     405.1
+
 Starting display the statistics for ./data/fbank/cuts_S.jsonl.gz
 Duration statistics (seconds):
 mean    2.4

From be6e08f69a9384de27c28115a299d4fe64bb5de1 Mon Sep 17 00:00:00 2001
From: Cesc 
Date: Mon, 5 Dec 2022 23:35:10 +0800
Subject: [PATCH 063/120] fix wenet stateless5 jit export error (#735)

---
 egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py      | 2 ++
 egs/wenetspeech/ASR/pruned_transducer_stateless5/lstmp.py       | 1 +
 .../ASR/pruned_transducer_stateless5/scaling_converter.py       | 1 +
 3 files changed, 4 insertions(+)
 mode change 100644 => 100755 egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
 create mode 120000 egs/wenetspeech/ASR/pruned_transducer_stateless5/lstmp.py
 create mode 120000 egs/wenetspeech/ASR/pruned_transducer_stateless5/scaling_converter.py

diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
old mode 100644
new mode 100755
index 35577c327..cb541070e
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
@@ -74,6 +74,7 @@ import logging
 from pathlib import Path
 
 import torch
+from scaling_converter import convert_scaled_to_non_scaled
 from train import add_model_arguments, get_params, get_transducer_model
 
 from icefall.checkpoint import average_checkpoints, load_checkpoint
@@ -184,6 +185,7 @@ def main():
         # it here.
         # Otherwise, one of its arguments is a ragged tensor and is not
         # torch scriptabe.
+        convert_scaled_to_non_scaled(model, inplace=True)
         model.__class__.forward = torch.jit.ignore(model.__class__.forward)
         logging.info("Using torch.jit.script")
         model = torch.jit.script(model)
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/lstmp.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/lstmp.py
new file mode 120000
index 000000000..d13a1e063
--- /dev/null
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/lstmp.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/lstmp.py
\ No newline at end of file
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/scaling_converter.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/scaling_converter.py
new file mode 120000
index 000000000..e58473a04
--- /dev/null
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/scaling_converter.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/scaling_converter.py
\ No newline at end of file

From f13cf61b05432a989e6a42c95b843a56639bcbde Mon Sep 17 00:00:00 2001
From: Fangjun Kuang 
Date: Tue, 6 Dec 2022 16:34:27 +0800
Subject: [PATCH 064/120] Convert conv-emformer to ncnn (#717)

* Export conv-emformer via torch.jit.trace()
---
 ...former-transducer-stateless2-2022-12-05.sh |   79 +
 ...-lstm-transducer-stateless2-2022-09-03.sh} |    0
 ...ormer-transducer-stateless2-2022-12-05.yml |   77 +
 ...-lstm-transducer-stateless2-2022-09-03.yml |    2 +-
 .../emformer2.py                              | 1798 +++++++++++++++++
 .../export-for-ncnn.py                        |  335 +++
 .../jit_pretrained.py                         |  292 +++
 .../lstmp.py                                  |    1 +
 .../scaling_converter.py                      |    1 +
 .../streaming-ncnn-decode.py                  |  387 ++++
 .../train2.py                                 | 1128 +++++++++++
 11 files changed, 4099 insertions(+), 1 deletion(-)
 create mode 100755 .github/scripts/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.sh
 rename .github/scripts/{run-librispeech-lstm-transducer-stateless2-2022-09-03.yml => run-librispeech-lstm-transducer-stateless2-2022-09-03.sh} (100%)
 create mode 100644 .github/workflows/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.yml
 create mode 100644 egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py
 create mode 100755 egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py
 create mode 100755 egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py
 create mode 120000 egs/librispeech/ASR/conv_emformer_transducer_stateless2/lstmp.py
 create mode 120000 egs/librispeech/ASR/conv_emformer_transducer_stateless2/scaling_converter.py
 create mode 100755 egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py
 create mode 100755 egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py

diff --git a/.github/scripts/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.sh b/.github/scripts/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.sh
new file mode 100755
index 000000000..32c939206
--- /dev/null
+++ b/.github/scripts/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.sh
@@ -0,0 +1,79 @@
+#!/usr/bin/env bash
+#
+set -e
+
+log() {
+  # This function is from espnet
+  local fname=${BASH_SOURCE[1]##*/}
+  echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+cd egs/librispeech/ASR
+
+repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05
+
+log "Downloading pre-trained model from $repo_url"
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+repo=$(basename $repo_url)
+pushd $repo
+git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt"
+git lfs pull --include "data/lang_bpe_500/bpe.model"
+cd exp
+ln -s pretrained-epoch-30-avg-10-averaged.pt epoch-99.pt
+popd
+
+log "Display test files"
+tree $repo/
+soxi $repo/test_wavs/*.wav
+ls -lh $repo/test_wavs/*.wav
+
+log  "Install ncnn and pnnx"
+
+# We are using a modified ncnn here. Will try to merge it to the official repo
+# of ncnn
+git clone https://github.com/csukuangfj/ncnn
+pushd ncnn
+git submodule init
+git submodule update python/pybind11
+python3 setup.py bdist_wheel
+ls -lh dist/
+pip install dist/*.whl
+cd tools/pnnx
+mkdir build
+cd build
+cmake -D Python3_EXECUTABLE=/opt/hostedtoolcache/Python/3.8.14/x64/bin/python3 ..
+make -j4 pnnx
+
+./src/pnnx || echo "pass"
+
+popd
+
+log "Test exporting to pnnx format"
+
+./conv_emformer_transducer_stateless2/export-for-ncnn.py \
+  --exp-dir $repo/exp \
+  --bpe-model $repo/data/lang_bpe_500/bpe.model \
+  --epoch 99 \
+  --avg 1 \
+  --use-averaged-model 0 \
+  \
+  --num-encoder-layers 12 \
+  --chunk-length 32 \
+  --cnn-module-kernel 31 \
+  --left-context-length 32 \
+  --right-context-length 8 \
+  --memory-size 32
+
+./ncnn/tools/pnnx/build/src/pnnx $repo/exp/encoder_jit_trace-pnnx.pt
+./ncnn/tools/pnnx/build/src/pnnx $repo/exp/decoder_jit_trace-pnnx.pt
+./ncnn/tools/pnnx/build/src/pnnx $repo/exp/joiner_jit_trace-pnnx.pt
+
+./conv_emformer_transducer_stateless2/streaming-ncnn-decode.py \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \
+ --encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \
+ --decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \
+ --decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \
+ --joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \
+ --joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \
+ $repo/test_wavs/1089-134686-0001.wav
diff --git a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh
similarity index 100%
rename from .github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
rename to .github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh
diff --git a/.github/workflows/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.yml b/.github/workflows/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.yml
new file mode 100644
index 000000000..b9a1582c4
--- /dev/null
+++ b/.github/workflows/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.yml
@@ -0,0 +1,77 @@
+name: run-librispeech-conv-emformer-transducer-stateless2-2022-12-05
+
+on:
+  push:
+    branches:
+      - master
+  pull_request:
+    types: [labeled]
+
+  schedule:
+    # minute (0-59)
+    # hour (0-23)
+    # day of the month (1-31)
+    # month (1-12)
+    # day of the week (0-6)
+    # nightly build at 15:50 UTC time every day
+    - cron: "50 15 * * *"
+
+jobs:
+  run_librispeech_conv_emformer_transducer_stateless2_2022_12_05:
+    if: github.event.label.name == 'ready' || github.event.label.name == 'ncnn' || github.event_name == 'push' || github.event_name == 'schedule'
+    runs-on: ${{ matrix.os }}
+    strategy:
+      matrix:
+        os: [ubuntu-latest]
+        python-version: [3.8]
+
+      fail-fast: false
+
+    steps:
+      - uses: actions/checkout@v2
+        with:
+          fetch-depth: 0
+
+      - name: Setup Python ${{ matrix.python-version }}
+        uses: actions/setup-python@v2
+        with:
+          python-version: ${{ matrix.python-version }}
+          cache: 'pip'
+          cache-dependency-path: '**/requirements-ci.txt'
+
+      - name: Install Python dependencies
+        run: |
+          grep -v '^#' ./requirements-ci.txt  | grep -v kaldifst | xargs -n 1 -L 1 pip install
+          pip uninstall -y protobuf
+          pip install --no-binary protobuf protobuf
+
+      - name: Cache kaldifeat
+        id: my-cache
+        uses: actions/cache@v2
+        with:
+          path: |
+            ~/tmp/kaldifeat
+          key: cache-tmp-${{ matrix.python-version }}-2022-09-25
+
+      - name: Install kaldifeat
+        if: steps.my-cache.outputs.cache-hit != 'true'
+        shell: bash
+        run: |
+          .github/scripts/install-kaldifeat.sh
+
+      - name: Inference with pre-trained model
+        shell: bash
+        env:
+          GITHUB_EVENT_NAME: ${{ github.event_name }}
+          GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
+        run: |
+          mkdir -p egs/librispeech/ASR/data
+          ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
+          ls -lh egs/librispeech/ASR/data/*
+
+          sudo apt-get -qq install git-lfs tree sox
+          export PYTHONPATH=$PWD:$PYTHONPATH
+          export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
+          export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
+
+          .github/scripts/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.sh
diff --git a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
index 59f116fde..f5ee09e16 100644
--- a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
+++ b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
@@ -111,7 +111,7 @@ jobs:
           export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
           export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
 
-          .github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
+          .github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh
 
       - name: Display decoding results for lstm_transducer_stateless2
         if: github.event_name == 'schedule'
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py
new file mode 100644
index 000000000..65a7efa77
--- /dev/null
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py
@@ -0,0 +1,1798 @@
+# Copyright      2022  Xiaomi Corporation     (Author: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# It is modified based on
+# 1) https://github.com/pytorch/audio/blob/main/torchaudio/models/emformer.py  # noqa
+# 2) https://github.com/pytorch/audio/blob/main/torchaudio/prototype/models/conv_emformer.py  # noqa
+
+import math
+from typing import List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+from encoder_interface import EncoderInterface
+from scaling import (
+    ActivationBalancer,
+    BasicNorm,
+    DoubleSwish,
+    ScaledConv1d,
+    ScaledConv2d,
+    ScaledLinear,
+)
+
+from icefall.utils import make_pad_mask
+
+LOG_EPSILON = math.log(1e-10)
+
+
+def unstack_states(
+    states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]]
+) -> List[Tuple[List[List[torch.Tensor]], List[torch.Tensor]]]:
+    """Unstack the emformer state corresponding to a batch of utterances
+    into a list of states, where the i-th entry is the state from the i-th
+    utterance in the batch.
+
+    Args:
+      states:
+        A tuple of 2 elements.
+        ``states[0]`` is the attention caches of a batch of utterance.
+        ``states[1]`` is the convolution caches of a batch of utterance.
+        ``len(states[0])`` and ``len(states[1])`` both eqaul to number of layers.  # noqa
+
+    Returns:
+      A list of states.
+      ``states[i]`` is a tuple of 2 elements of i-th utterance.
+      ``states[i][0]`` is the attention caches of i-th utterance.
+      ``states[i][1]`` is the convolution caches of i-th utterance.
+      ``len(states[i][0])`` and ``len(states[i][1])`` both eqaul to number of layers.  # noqa
+    """
+
+    attn_caches, conv_caches = states
+    batch_size = conv_caches[0].size(0)
+    num_layers = len(attn_caches)
+
+    list_attn_caches = [None] * batch_size
+    for i in range(batch_size):
+        list_attn_caches[i] = [[] for _ in range(num_layers)]
+    for li, layer in enumerate(attn_caches):
+        for s in layer:
+            s_list = s.unbind(dim=1)
+            for bi, b in enumerate(list_attn_caches):
+                b[li].append(s_list[bi])
+
+    list_conv_caches = [None] * batch_size
+    for i in range(batch_size):
+        list_conv_caches[i] = [None] * num_layers
+    for li, layer in enumerate(conv_caches):
+        c_list = layer.unbind(dim=0)
+        for bi, b in enumerate(list_conv_caches):
+            b[li] = c_list[bi]
+
+    ans = [None] * batch_size
+    for i in range(batch_size):
+        ans[i] = [list_attn_caches[i], list_conv_caches[i]]
+
+    return ans
+
+
+def stack_states(
+    state_list: List[Tuple[List[List[torch.Tensor]], List[torch.Tensor]]]
+) -> Tuple[List[List[torch.Tensor]], List[torch.Tensor]]:
+    """Stack list of emformer states that correspond to separate utterances
+    into a single emformer state so that it can be used as an input for
+    emformer when those utterances are formed into a batch.
+
+    Note:
+      It is the inverse of :func:`unstack_states`.
+
+    Args:
+      state_list:
+        Each element in state_list corresponding to the internal state
+        of the emformer model for a single utterance.
+        ``states[i]`` is a tuple of 2 elements of i-th utterance.
+        ``states[i][0]`` is the attention caches of i-th utterance.
+        ``states[i][1]`` is the convolution caches of i-th utterance.
+        ``len(states[i][0])`` and ``len(states[i][1])`` both eqaul to number of layers.  # noqa
+
+    Returns:
+      A new state corresponding to a batch of utterances.
+      See the input argument of :func:`unstack_states` for the meaning
+      of the returned tensor.
+    """
+    batch_size = len(state_list)
+
+    attn_caches = []
+    for layer in state_list[0][0]:
+        if batch_size > 1:
+            # Note: We will stack attn_caches[layer][s][] later to get attn_caches[layer][s]  # noqa
+            attn_caches.append([[s] for s in layer])
+        else:
+            attn_caches.append([s.unsqueeze(1) for s in layer])
+    for b, states in enumerate(state_list[1:], 1):
+        for li, layer in enumerate(states[0]):
+            for si, s in enumerate(layer):
+                attn_caches[li][si].append(s)
+                if b == batch_size - 1:
+                    attn_caches[li][si] = torch.stack(attn_caches[li][si], dim=1)
+
+    conv_caches = []
+    for layer in state_list[0][1]:
+        if batch_size > 1:
+            # Note: We will stack conv_caches[layer][] later to get conv_caches[layer]  # noqa
+            conv_caches.append([layer])
+        else:
+            conv_caches.append(layer.unsqueeze(0))
+    for b, states in enumerate(state_list[1:], 1):
+        for li, layer in enumerate(states[1]):
+            conv_caches[li].append(layer)
+            if b == batch_size - 1:
+                conv_caches[li] = torch.stack(conv_caches[li], dim=0)
+
+    return [attn_caches, conv_caches]
+
+
+class ConvolutionModule(nn.Module):
+    """ConvolutionModule.
+
+    Modified from https://github.com/pytorch/audio/blob/main/torchaudio/prototype/models/conv_emformer.py # noqa
+
+    Args:
+      chunk_length (int):
+        Length of each chunk.
+      right_context_length (int):
+        Length of right context.
+      channels (int):
+        The number of input channels and output channels of conv layers.
+      kernel_size (int):
+        Kernerl size of conv layers.
+      bias (bool):
+        Whether to use bias in conv layers (default=True).
+    """
+
+    def __init__(
+        self,
+        chunk_length: int,
+        right_context_length: int,
+        channels: int,
+        kernel_size: int,
+        bias: bool = True,
+    ) -> None:
+        """Construct an ConvolutionModule object."""
+        super().__init__()
+        # kernerl_size should be an odd number for 'SAME' padding
+        assert (kernel_size - 1) % 2 == 0, kernel_size
+
+        self.chunk_length = chunk_length
+        self.right_context_length = right_context_length
+        self.channels = channels
+
+        self.pointwise_conv1 = ScaledConv1d(
+            channels,
+            2 * channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            bias=bias,
+        )
+        # After pointwise_conv1 we put x through a gated linear unit
+        # (nn.functional.glu).
+        # For most layers the normal rms value of channels of x seems to be in
+        # the range 1 to 4, but sometimes, for some reason, for layer 0 the rms
+        # ends up being very large, between 50 and 100 for different channels.
+        # This will cause very peaky and sparse derivatives for the sigmoid
+        # gating function, which will tend to make the loss function not learn
+        # effectively.  (for most layers the average absolute values are in the
+        # range 0.5..9.0, and the average p(x>0), i.e. positive proportion,
+        # at the output of pointwise_conv1.output is around 0.35 to 0.45 for
+        # different layers, which likely breaks down as 0.5 for the "linear"
+        # half and 0.2 to 0.3 for the part that goes into the sigmoid.
+        # The idea is that if we constrain the rms values to a reasonable range
+        # via a constraint of max_abs=10.0, it will be in a better position to
+        # start learning something, i.e. to latch onto the correct range.
+        self.deriv_balancer1 = ActivationBalancer(
+            channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
+        )
+
+        # make it causal by padding cached (kernel_size - 1) frames on the left
+        self.cache_size = kernel_size - 1
+        self.depthwise_conv = ScaledConv1d(
+            channels,
+            channels,
+            kernel_size,
+            stride=1,
+            padding=0,
+            groups=channels,
+            bias=bias,
+        )
+
+        self.deriv_balancer2 = ActivationBalancer(
+            channel_dim=1, min_positive=0.05, max_positive=1.0
+        )
+
+        self.activation = DoubleSwish()
+
+        self.pointwise_conv2 = ScaledConv1d(
+            channels,
+            channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            bias=bias,
+            initial_scale=0.25,
+        )
+
+    def _split_right_context(
+        self,
+        pad_utterance: torch.Tensor,
+        right_context: torch.Tensor,
+    ) -> torch.Tensor:
+        """
+        Args:
+          pad_utterance:
+            Its shape is (cache_size + U, B, D).
+          right_context:
+            Its shape is (R, B, D).
+
+        Returns:
+          Right context segments padding with corresponding context.
+          Its shape is (num_segs * B, D, cache_size + right_context_length).
+        """
+        U_, B, D = pad_utterance.size()
+        R = right_context.size(0)
+        assert self.right_context_length != 0
+        assert R % self.right_context_length == 0
+        num_chunks = R // self.right_context_length
+        right_context = right_context.reshape(
+            num_chunks, self.right_context_length, B, D
+        )
+        right_context = right_context.permute(0, 2, 1, 3).reshape(
+            num_chunks * B, self.right_context_length, D
+        )
+
+        intervals = torch.arange(
+            0, self.chunk_length * (num_chunks - 1), self.chunk_length
+        )
+        first = torch.arange(self.chunk_length, self.chunk_length + self.cache_size)
+        indexes = intervals.unsqueeze(1) + first.unsqueeze(0)
+        indexes = torch.cat(
+            [indexes, torch.arange(U_ - self.cache_size, U_).unsqueeze(0)]
+        )
+        padding = pad_utterance[indexes]  # (num_chunks, cache_size, B, D)
+        padding = padding.permute(0, 2, 1, 3).reshape(
+            num_chunks * B, self.cache_size, D
+        )
+
+        pad_right_context = torch.cat([padding, right_context], dim=1)
+        # (num_chunks * B, cache_size + right_context_length, D)
+        return pad_right_context.permute(0, 2, 1)
+
+    def _merge_right_context(self, right_context: torch.Tensor, B: int) -> torch.Tensor:
+        """
+        Args:
+          right_context:
+            Right context segments.
+            It shape is (num_segs * B, D, right_context_length).
+          B:
+            Batch size.
+
+        Returns:
+          A tensor of shape (B, D, R), where
+          R = num_segs * right_context_length.
+        """
+        right_context = right_context.reshape(
+            -1, B, self.channels, self.right_context_length
+        )
+        right_context = right_context.permute(1, 2, 0, 3)
+        right_context = right_context.reshape(B, self.channels, -1)
+        return right_context
+
+    def forward(
+        self,
+        utterance: torch.Tensor,
+        right_context: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Causal convolution module.
+
+        Args:
+          utterance (torch.Tensor):
+            Utterance tensor of shape (U, B, D).
+          right_context (torch.Tensor):
+            Right context tensor of shape (R, B, D).
+
+        Returns:
+          A tuple of 2 tensors:
+          - output utterance of shape (U, B, D).
+          - output right_context of shape (R, B, D).
+        """
+        U, B, D = utterance.size()
+        R, _, _ = right_context.size()
+
+        # point-wise conv and GLU mechanism
+        x = torch.cat([right_context, utterance], dim=0)  # (R + U, B, D)
+        x = x.permute(1, 2, 0)  # (B, D, R + U)
+        x = self.pointwise_conv1(x)  # (B, 2 * D, R + U)
+        x = self.deriv_balancer1(x)
+        x = nn.functional.glu(x, dim=1)  # (B, D, R + U)
+        utterance = x[:, :, R:]  # (B, D, U)
+        right_context = x[:, :, :R]  # (B, D, R)
+
+        # make causal convolution
+        cache = torch.zeros(B, D, self.cache_size, device=x.device, dtype=x.dtype)
+        pad_utterance = torch.cat([cache, utterance], dim=2)  # (B, D, cache + U)
+
+        # depth-wise conv on utterance
+        utterance = self.depthwise_conv(pad_utterance)  # (B, D, U)
+
+        if self.right_context_length > 0:
+            # depth-wise conv on right_context
+            pad_right_context = self._split_right_context(
+                pad_utterance.permute(2, 0, 1), right_context.permute(2, 0, 1)
+            )  # (num_segs * B, D, cache_size + right_context_length)
+            right_context = self.depthwise_conv(
+                pad_right_context
+            )  # (num_segs * B, D, right_context_length)
+            right_context = self._merge_right_context(right_context, B)  # (B, D, R)
+
+        x = torch.cat([right_context, utterance], dim=2)  # (B, D, R + U)
+        x = self.deriv_balancer2(x)
+        x = self.activation(x)
+
+        # point-wise conv
+        x = self.pointwise_conv2(x)  # (B, D, R + U)
+
+        right_context = x[:, :, :R]  # (B, D, R)
+        utterance = x[:, :, R:]  # (B, D, U)
+        return (
+            utterance.permute(2, 0, 1),
+            right_context.permute(2, 0, 1),
+        )
+
+    @torch.jit.export
+    def infer(
+        self,
+        utterance: torch.Tensor,
+        right_context: torch.Tensor,
+        cache: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        """Causal convolution module applied on both utterance and right_context.
+
+        Args:
+          utterance (torch.Tensor):
+            Utterance tensor of shape (U, B, D).
+          right_context (torch.Tensor):
+            Right context tensor of shape (R, B, D).
+          cache (torch.Tensor, optional):
+            Cached tensor for left padding of shape (B, D, cache_size).
+
+        Returns:
+          A tuple of 3 tensors:
+            - output utterance of shape (U, B, D).
+            - output right_context of shape (R, B, D).
+            - updated cache tensor of shape (B, D, cache_size).
+        """
+        #  U, B, D = utterance.size()
+        #  R, _, _ = right_context.size()
+        U = self.chunk_length
+        B = 1
+        D = self.channels
+        R = self.right_context_length
+
+        # point-wise conv
+        x = torch.cat([utterance, right_context], dim=0)  # (U + R, B, D)
+        x = x.permute(1, 2, 0)  # (B, D, U + R)
+        x = self.pointwise_conv1(x)  # (B, 2 * D, U + R)
+        x = self.deriv_balancer1(x)
+        x = nn.functional.glu(x, dim=1)  # (B, D, U + R)
+
+        # make causal convolution
+        assert cache.shape == (B, D, self.cache_size), cache.shape
+        x = torch.cat([cache, x], dim=2)  # (B, D, cache_size + U + R)
+        # update cache
+        new_cache = x[:, :, -R - self.cache_size : -R]
+
+        # 1-D depth-wise conv
+        x = self.depthwise_conv(x)  # (B, D, U + R)
+
+        x = self.deriv_balancer2(x)
+        x = self.activation(x)
+
+        # point-wise conv
+        x = self.pointwise_conv2(x)  # (B, D, U + R)
+
+        utterance = x[:, :, :U]  # (B, D, U)
+        right_context = x[:, :, U:]  # (B, D, R)
+        return (
+            utterance.permute(2, 0, 1),
+            right_context.permute(2, 0, 1),
+            new_cache,
+        )
+
+
+class EmformerAttention(nn.Module):
+    r"""Emformer layer attention module.
+
+    Args:
+      embed_dim (int):
+        Embedding dimension.
+      nhead (int):
+        Number of attention heads in each Emformer layer.
+      dropout (float, optional):
+        Dropout probability. (Default: 0.0)
+      tanh_on_mem (bool, optional):
+        If ``True``, applies tanh to memory elements. (Default: ``False``)
+      negative_inf (float, optional):
+        Value to use for negative infinity in attention weights. (Default: -1e8)
+    """
+
+    def __init__(
+        self,
+        embed_dim: int,
+        nhead: int,
+        left_context_length: int,
+        chunk_length: int,
+        right_context_length: int,
+        memory_size: int,
+        dropout: float = 0.0,
+        tanh_on_mem: bool = False,
+        negative_inf: float = -1e8,
+    ):
+        super().__init__()
+
+        if embed_dim % nhead != 0:
+            raise ValueError(
+                f"embed_dim ({embed_dim}) is not a multiple of nhead ({nhead})."
+            )
+
+        self.embed_dim = embed_dim
+        self.nhead = nhead
+        self.tanh_on_mem = tanh_on_mem
+        self.negative_inf = negative_inf
+        self.head_dim = embed_dim // nhead
+        self.dropout = dropout
+
+        self.left_context_length = left_context_length
+        self.right_context_length = right_context_length
+        self.chunk_length = chunk_length
+        self.memory_size = memory_size
+
+        self.emb_to_key_value = ScaledLinear(embed_dim, 2 * embed_dim, bias=True)
+        self.emb_to_query = ScaledLinear(embed_dim, embed_dim, bias=True)
+        self.out_proj = ScaledLinear(
+            embed_dim, embed_dim, bias=True, initial_scale=0.25
+        )
+
+    def _gen_attention_probs(
+        self,
+        attention_weights: torch.Tensor,
+        attention_mask: torch.Tensor,
+        padding_mask: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        """Given the entire attention weights, mask out unecessary connections
+        and optionally with padding positions, to obtain underlying chunk-wise
+        attention probabilities.
+
+        B: batch size;
+        Q: length of query;
+        KV: length of key and value.
+
+        Args:
+          attention_weights (torch.Tensor):
+            Attention weights computed on the entire concatenated tensor
+            with shape (B * nhead, Q, KV).
+          attention_mask (torch.Tensor):
+            Mask tensor where chunk-wise connections are filled with `False`,
+            and other unnecessary connections are filled with `True`,
+            with shape (Q, KV).
+          padding_mask (torch.Tensor, optional):
+            Mask tensor where the padding positions are fill with `True`,
+            and other positions are filled with `False`, with shapa `(B, KV)`.
+
+        Returns:
+          A tensor of shape (B * nhead, Q, KV).
+        """
+        attention_weights_float = attention_weights.float()
+        attention_weights_float = attention_weights_float.masked_fill(
+            attention_mask.unsqueeze(0), self.negative_inf
+        )
+        if padding_mask is not None:
+            Q = attention_weights.size(1)
+            B = attention_weights.size(0) // self.nhead
+            attention_weights_float = attention_weights_float.view(B, self.nhead, Q, -1)
+            attention_weights_float = attention_weights_float.masked_fill(
+                padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
+                self.negative_inf,
+            )
+            attention_weights_float = attention_weights_float.view(
+                B * self.nhead, Q, -1
+            )
+
+        attention_probs = nn.functional.softmax(
+            attention_weights_float, dim=-1
+        ).type_as(attention_weights)
+
+        attention_probs = nn.functional.dropout(
+            attention_probs, p=self.dropout, training=self.training
+        )
+        return attention_probs
+
+    def _forward_impl(
+        self,
+        utterance: torch.Tensor,
+        right_context: torch.Tensor,
+        memory: torch.Tensor,
+        attention_mask: torch.Tensor,
+        padding_mask: Optional[torch.Tensor] = None,
+        left_context_key: Optional[torch.Tensor] = None,
+        left_context_val: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        """Underlying chunk-wise attention implementation."""
+        #  U, B, _ = utterance.size()
+        #  R = right_context.size(0)
+        #  M = memory.size(0)
+
+        U = self.chunk_length
+        B = 1
+        R = self.right_context_length
+        M = self.memory_size
+        L = self.left_context_length
+
+        scaling = float(self.head_dim) ** -0.5
+
+        # compute query with [right_context, utterance].
+        query = self.emb_to_query(torch.cat([right_context, utterance]))
+        # compute key and value with [memory, right_context, utterance].
+        key, value = self.emb_to_key_value(
+            torch.cat([memory, right_context, utterance])
+        ).chunk(chunks=2, dim=2)
+
+        if left_context_key is not None and left_context_val is not None:
+            # now compute key and value with
+            #   [memory, right context, left context, uttrance]
+            # this is used in inference mode
+            key = torch.cat([key[: M + R], left_context_key, key[M + R :]])
+            value = torch.cat([value[: M + R], left_context_val, value[M + R :]])
+
+        #  Q = query.size(0)
+        Q = U + R
+
+        # KV = key.size(0)
+
+        reshaped_query = query.view(Q, self.nhead, self.head_dim).permute(1, 0, 2)
+        reshaped_key = key.view(M + R + U + L, self.nhead, self.head_dim).permute(
+            1, 0, 2
+        )
+        reshaped_value = value.view(M + R + U + L, self.nhead, self.head_dim).permute(
+            1, 0, 2
+        )
+
+        #  reshaped_query, reshaped_key, reshaped_value = [
+        #      tensor.contiguous().view(-1, B * self.nhead, self.head_dim).transpose(0, 1)
+        #      for tensor in [query, key, value]
+        #  ]  # (B * nhead, Q or KV, head_dim)
+        attention_weights = torch.bmm(
+            reshaped_query * scaling, reshaped_key.permute(0, 2, 1)
+        )  # (B * nhead, Q, KV)
+
+        # compute attention probabilities
+        if False:
+            attention_probs = self._gen_attention_probs(
+                attention_weights, attention_mask, padding_mask
+            )
+        else:
+            attention_probs = nn.functional.softmax(attention_weights, dim=-1)
+
+        # compute attention outputs
+        attention = torch.bmm(attention_probs, reshaped_value)
+        assert attention.shape == (B * self.nhead, Q, self.head_dim)
+        attention = attention.permute(1, 0, 2).reshape(-1, self.embed_dim)
+        # TODO(fangjun): ncnn does not support reshape(-1, 1, self.embed_dim)
+        # We have to change InnerProduct in ncnn to ignore the extra dim below
+        attention = attention.unsqueeze(1)
+
+        # apply output projection
+        output_right_context_utterance = self.out_proj(attention)
+        # The return shape of output_right_context_utterance is (10, 1, 512)
+
+        return output_right_context_utterance, key, value
+
+    def forward(
+        self,
+        utterance: torch.Tensor,
+        right_context: torch.Tensor,
+        memory: torch.Tensor,
+        attention_mask: torch.Tensor,
+        padding_mask: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        # TODO: Modify docs.
+        """Forward pass for training and validation mode.
+
+        B: batch size;
+        D: embedding dimension;
+        R: length of the hard-copied right contexts;
+        U: length of full utterance;
+        M: length of memory vectors.
+
+        It computes a `big` attention matrix on full utterance and
+        then utilizes a pre-computed mask to simulate chunk-wise attention.
+
+        It concatenates three blocks: hard-copied right contexts,
+        and full utterance, as a `big` block,
+        to compute the query tensor:
+        query = [right_context, utterance],
+        with length Q = R + U.
+        It concatenates the three blocks: memory vectors,
+        hard-copied right contexts, and full utterance as another `big` block,
+        to compute the key and value tensors:
+        key & value = [memory, right_context, utterance],
+        with length KV = M + R + U.
+        Attention scores is computed with above `big` query and key.
+
+        Then the underlying chunk-wise attention is obtained by applying
+        the attention mask. Suppose
+        c_i: chunk at index i;
+        r_i: right context that c_i can use;
+        l_i: left context that c_i can use;
+        m_i: past memory vectors from previous layer that c_i can use;
+        The target chunk-wise attention is:
+        c_i, r_i (in query) -> l_i, c_i, r_i, m_i (in key)
+
+        Args:
+          utterance (torch.Tensor):
+            Full utterance frames, with shape (U, B, D).
+          right_context (torch.Tensor):
+            Hard-copied right context frames, with shape (R, B, D),
+            where R = num_chunks * right_context_length
+          memory (torch.Tensor):
+            Memory elements, with shape (M, B, D), where M = num_chunks - 1.
+            It is an empty tensor without using memory.
+          attention_mask (torch.Tensor):
+            Pre-computed attention mask to simulate underlying chunk-wise
+            attention, with shape (Q, KV).
+          padding_mask (torch.Tensor):
+            Padding mask of key tensor, with shape (B, KV).
+
+        Returns:
+          Output of right context and utterance, with shape (R + U, B, D).
+        """
+        output_right_context_utterance, _, _ = self._forward_impl(
+            utterance,
+            right_context,
+            memory,
+            attention_mask,
+            padding_mask=padding_mask,
+        )
+        return output_right_context_utterance
+
+    @torch.jit.export
+    def infer(
+        self,
+        utterance: torch.Tensor,
+        right_context: torch.Tensor,
+        memory: torch.Tensor,
+        left_context_key: torch.Tensor,
+        left_context_val: torch.Tensor,
+        padding_mask: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        """Forward pass for inference.
+
+        B: batch size;
+        D: embedding dimension;
+        R: length of right context;
+        U: length of utterance, i.e., current chunk;
+        L: length of cached left context;
+        M: length of cached memory vectors.
+
+        It concatenates the right context and utterance (i.e., current chunk)
+        of current chunk, to compute the query tensor:
+        query = [right_context, utterance],
+        with length Q = R + U.
+        It concatenates the memory vectors, right context, left context, and
+        current chunk, to compute the key and value tensors:
+        key & value = [memory, right_context, left_context, utterance],
+        with length KV = M + R + L + U.
+
+        The chunk-wise attention is:
+        chunk, right context (in query) ->
+          left context, chunk, right context, memory vectors (in key).
+
+        Args:
+          utterance (torch.Tensor):
+            Current chunk frames, with shape (U, B, D), where U = chunk_length.
+          right_context (torch.Tensor):
+            Right context frames, with shape (R, B, D),
+            where R = right_context_length.
+          memory (torch.Tensor):
+            Memory vectors, with shape (M, B, D), or empty tensor.
+          left_context_key (torch,Tensor):
+            Cached attention key of left context from preceding computation,
+            with shape (L, B, D).
+          left_context_val (torch.Tensor):
+            Cached attention value of left context from preceding computation,
+            with shape (L, B, D).
+          padding_mask (torch.Tensor):
+            Padding mask of key tensor, with shape (B, KV).
+
+        Returns:
+          A tuple containing 4 tensors:
+            - output of right context and utterance, with shape (R + U, B, D).
+            - attention key of left context and utterance, which would be cached
+              for next computation, with shape (L + U, B, D).
+            - attention value of left context and utterance, which would be
+              cached for next computation, with shape (L + U, B, D).
+        """
+        #  U = utterance.size(0)
+        #  R = right_context.size(0)
+        #  L = left_context_key.size(0)
+        #  M = memory.size(0)
+
+        U = self.chunk_length
+        R = self.right_context_length
+        L = self.left_context_length
+        M = self.memory_size
+
+        # query = [right context, utterance]
+        Q = R + U
+        # key, value = [memory, right context, left context, utterance]
+        KV = M + R + L + U
+        attention_mask = torch.zeros(Q, KV).to(
+            dtype=torch.bool, device=utterance.device
+        )
+
+        output_right_context_utterance, key, value = self._forward_impl(
+            utterance,
+            right_context,
+            memory,
+            attention_mask,
+            padding_mask=padding_mask,
+            left_context_key=left_context_key,
+            left_context_val=left_context_val,
+        )
+        return (
+            output_right_context_utterance,
+            key[M + R :],
+            value[M + R :],
+        )
+
+
+class EmformerEncoderLayer(nn.Module):
+    """Emformer layer that constitutes Emformer.
+
+    Args:
+      d_model (int):
+        Input dimension.
+      nhead (int):
+        Number of attention heads.
+      dim_feedforward (int):
+        Hidden layer dimension of feedforward network.
+      chunk_length (int):
+        Length of each input segment.
+      dropout (float, optional):
+        Dropout probability. (Default: 0.0)
+      layer_dropout (float, optional):
+        Layer dropout probability. (Default: 0.0)
+      cnn_module_kernel (int):
+        Kernel size of convolution module.
+      left_context_length (int, optional):
+        Length of left context. (Default: 0)
+      right_context_length (int, optional):
+        Length of right context. (Default: 0)
+      memory_size (int, optional):
+        Number of memory elements to use. (Default: 0)
+      tanh_on_mem (bool, optional):
+        If ``True``, applies tanh to memory elements. (Default: ``False``)
+      negative_inf (float, optional):
+        Value to use for negative infinity in attention weights. (Default: -1e8)
+    """
+
+    def __init__(
+        self,
+        d_model: int,
+        nhead: int,
+        dim_feedforward: int,
+        chunk_length: int,
+        dropout: float = 0.1,
+        layer_dropout: float = 0.075,
+        cnn_module_kernel: int = 31,
+        left_context_length: int = 0,
+        right_context_length: int = 0,
+        memory_size: int = 0,
+        tanh_on_mem: bool = False,
+        negative_inf: float = -1e8,
+    ):
+        super().__init__()
+
+        self.attention = EmformerAttention(
+            embed_dim=d_model,
+            nhead=nhead,
+            left_context_length=left_context_length,
+            chunk_length=chunk_length,
+            memory_size=memory_size,
+            right_context_length=right_context_length,
+            dropout=dropout,
+            tanh_on_mem=tanh_on_mem,
+            negative_inf=negative_inf,
+        )
+        self.summary_op = nn.AvgPool1d(
+            kernel_size=chunk_length, stride=chunk_length, ceil_mode=True
+        )
+
+        self.feed_forward_macaron = nn.Sequential(
+            ScaledLinear(d_model, dim_feedforward),
+            ActivationBalancer(channel_dim=-1),
+            DoubleSwish(),
+            nn.Dropout(dropout),
+            ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
+        )
+
+        self.feed_forward = nn.Sequential(
+            ScaledLinear(d_model, dim_feedforward),
+            ActivationBalancer(channel_dim=-1),
+            DoubleSwish(),
+            nn.Dropout(dropout),
+            ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
+        )
+
+        self.conv_module = ConvolutionModule(
+            chunk_length,
+            right_context_length,
+            d_model,
+            cnn_module_kernel,
+        )
+
+        self.norm_final = BasicNorm(d_model)
+
+        # try to ensure the output is close to zero-mean
+        # (or at least, zero-median).
+        self.balancer = ActivationBalancer(
+            channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
+        )
+
+        self.dropout = nn.Dropout(dropout)
+
+        self.layer_dropout = layer_dropout
+        self.left_context_length = left_context_length
+        self.right_context_length = right_context_length
+        self.chunk_length = chunk_length
+        self.memory_size = memory_size
+        self.d_model = d_model
+        self.use_memory = memory_size > 0
+
+    def _update_attn_cache(
+        self,
+        next_key: torch.Tensor,
+        next_val: torch.Tensor,
+        memory: torch.Tensor,
+        attn_cache: List[torch.Tensor],
+    ) -> List[torch.Tensor]:
+        """Update cached attention state:
+        1) output memory of current chunk in the lower layer;
+        2) attention key and value in current chunk's computation, which would
+        be reused in next chunk's computation.
+        """
+        # attn_cache[0].shape (self.memory_size, 1, 512)
+        # memory.shape (1, 1, 512)
+        # attn_cache[1].shape (self.left_context_length, 1, 512)
+        # attn_cache[2].shape (self.left_context_length, 1, 512)
+        # next_key.shape (self.left_context_length + self.right_context_utterance, 1, 512)
+        # next_value.shape (self.left_context_length + self.right_context_utterance, 1, 512)
+        new_memory = torch.cat([attn_cache[0], memory])
+        # TODO(fangjun): Remove torch.cat
+        #  new_key = torch.cat([attn_cache[1], next_key])
+        #  new_val = torch.cat([attn_cache[2], next_val])
+        attn_cache[0] = new_memory[1:]
+        attn_cache[1] = next_key[-self.left_context_length :]
+        attn_cache[2] = next_val[-self.left_context_length :]
+        return attn_cache
+
+    def _apply_conv_module_forward(
+        self,
+        right_context_utterance: torch.Tensor,
+        R: int,
+    ) -> torch.Tensor:
+        """Apply convolution module in training and validation mode."""
+        utterance = right_context_utterance[R:]
+        right_context = right_context_utterance[:R]
+        utterance, right_context = self.conv_module(utterance, right_context)
+        right_context_utterance = torch.cat([right_context, utterance])
+        return right_context_utterance
+
+    def _apply_conv_module_infer(
+        self,
+        right_context_utterance: torch.Tensor,
+        R: int,
+        conv_cache: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Apply convolution module on utterance in inference mode."""
+        utterance = right_context_utterance[R:]
+        right_context = right_context_utterance[:R]
+        utterance, right_context, conv_cache = self.conv_module.infer(
+            utterance, right_context, conv_cache
+        )
+        right_context_utterance = torch.cat([right_context, utterance])
+        return right_context_utterance, conv_cache
+
+    def _apply_attention_module_forward(
+        self,
+        right_context_utterance: torch.Tensor,
+        R: int,
+        attention_mask: torch.Tensor,
+        padding_mask: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        """Apply attention module in training and validation mode."""
+        utterance = right_context_utterance[R:]
+        right_context = right_context_utterance[:R]
+
+        if self.use_memory:
+            memory = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[
+                :-1, :, :
+            ]
+        else:
+            memory = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device)
+        output_right_context_utterance = self.attention(
+            utterance=utterance,
+            right_context=right_context,
+            memory=memory,
+            attention_mask=attention_mask,
+            padding_mask=padding_mask,
+        )
+
+        return output_right_context_utterance
+
+    def _apply_attention_module_infer(
+        self,
+        right_context_utterance: torch.Tensor,
+        R: int,
+        attn_cache: List[torch.Tensor],
+        padding_mask: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+        """Apply attention module in inference mode.
+        1) Unpack cached states including:
+           - memory from previous chunks;
+           - attention key and value of left context from preceding
+             chunk's compuation;
+        2) Apply attention computation;
+        3) Update cached attention states including:
+           - memory of current chunk;
+           - attention key and value in current chunk's computation, which would
+             be resued in next chunk's computation.
+        """
+        utterance = right_context_utterance[R:]
+        right_context = right_context_utterance[:R]
+
+        pre_memory = attn_cache[0]
+        left_context_key = attn_cache[1]
+        left_context_val = attn_cache[2]
+
+        if self.use_memory:
+            memory = torch.mean(utterance, dim=0, keepdim=True)
+
+            #  memory = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[
+            #          :1, :, :
+            #  ]
+        else:
+            memory = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device)
+        (output_right_context_utterance, next_key, next_val) = self.attention.infer(
+            utterance=utterance,
+            right_context=right_context,
+            memory=pre_memory,
+            left_context_key=left_context_key,
+            left_context_val=left_context_val,
+            padding_mask=padding_mask,
+        )
+        attn_cache = self._update_attn_cache(next_key, next_val, memory, attn_cache)
+        return output_right_context_utterance, attn_cache
+
+    def forward(
+        self,
+        utterance: torch.Tensor,
+        right_context: torch.Tensor,
+        attention_mask: torch.Tensor,
+        padding_mask: Optional[torch.Tensor] = None,
+        warmup: float = 1.0,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        r"""Forward pass for training and validation mode.
+
+        B: batch size;
+        D: embedding dimension;
+        R: length of hard-copied right contexts;
+        U: length of full utterance;
+        M: length of memory vectors.
+
+        Args:
+          utterance (torch.Tensor):
+            Utterance frames, with shape (U, B, D).
+          right_context (torch.Tensor):
+            Right context frames, with shape (R, B, D).
+          attention_mask (torch.Tensor):
+            Attention mask for underlying attention module,
+            with shape (Q, KV), where Q = R + U, KV = M + R + U.
+          padding_mask (torch.Tensor):
+            Padding mask of ker tensor, with shape (B, KV).
+
+        Returns:
+          A tuple containing 2 tensors:
+            - output utterance, with shape (U, B, D).
+            - output right context, with shape (R, B, D).
+        """
+        R = right_context.size(0)
+        src = torch.cat([right_context, utterance])
+        src_orig = src
+
+        warmup_scale = min(0.1 + warmup, 1.0)
+        # alpha = 1.0 means fully use this encoder layer, 0.0 would mean
+        # completely bypass it.
+        if self.training:
+            alpha = (
+                warmup_scale
+                if torch.rand(()).item() <= (1.0 - self.layer_dropout)
+                else 0.1
+            )
+        else:
+            alpha = 1.0
+
+        # macaron style feed forward module
+        src = src + self.dropout(self.feed_forward_macaron(src))
+
+        # emformer attention module
+        src_att = self._apply_attention_module_forward(
+            src, R, attention_mask, padding_mask=padding_mask
+        )
+        src = src + self.dropout(src_att)
+
+        # convolution module
+        src_conv = self._apply_conv_module_forward(src, R)
+        src = src + self.dropout(src_conv)
+
+        # feed forward module
+        src = src + self.dropout(self.feed_forward(src))
+
+        src = self.norm_final(self.balancer(src))
+
+        if alpha != 1.0:
+            src = alpha * src + (1 - alpha) * src_orig
+
+        output_utterance = src[R:]
+        output_right_context = src[:R]
+        return output_utterance, output_right_context
+
+    @torch.jit.export
+    def infer(
+        self,
+        utterance: torch.Tensor,
+        right_context: torch.Tensor,
+        cache: List[torch.Tensor],
+        padding_mask: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
+        """Forward pass for inference.
+
+         B: batch size;
+         D: embedding dimension;
+         R: length of right_context;
+         U: length of utterance;
+         M: length of memory.
+
+        Args:
+           utterance (torch.Tensor):
+             Utterance frames, with shape (U, B, D).
+           right_context (torch.Tensor):
+             Right context frames, with shape (R, B, D).
+           attn_cache (List[torch.Tensor]):
+             Cached attention tensors generated in preceding computation,
+             including memory, key and value of left context.
+           conv_cache (torch.Tensor, optional):
+             Cache tensor of left context for causal convolution.
+           padding_mask (torch.Tensor):
+             Padding mask of ker tensor.
+
+         Returns:
+           (Tensor, Tensor, List[torch.Tensor], Tensor):
+             - output utterance, with shape (U, B, D);
+             - output right_context, with shape (R, B, D);
+             - output attention cache;
+             - output convolution cache.
+        """
+        R = self.right_context_length
+        src = torch.cat([right_context, utterance])
+        attn_cache = cache[:3]
+        conv_cache = cache[3]
+
+        # macaron style feed forward module
+        src = src + self.dropout(self.feed_forward_macaron(src))
+
+        # emformer attention module
+        src_att, attn_cache = self._apply_attention_module_infer(
+            src, R, attn_cache, padding_mask=padding_mask
+        )
+        src = src + self.dropout(src_att)
+
+        # convolution module
+        src_conv, conv_cache = self._apply_conv_module_infer(src, R, conv_cache)
+        src = src + self.dropout(src_conv)
+
+        # feed forward module
+        src = src + self.dropout(self.feed_forward(src))
+
+        src = self.norm_final(self.balancer(src))
+
+        output_utterance = src[R:]
+        output_right_context = src[:R]
+        return (output_utterance, output_right_context, attn_cache + [conv_cache])
+
+
+def _gen_attention_mask_block(
+    col_widths: List[int],
+    col_mask: List[bool],
+    num_rows: int,
+    device: torch.device,
+) -> torch.Tensor:
+    assert len(col_widths) == len(
+        col_mask
+    ), "Length of col_widths must match that of col_mask"
+
+    mask_block = [
+        torch.ones(num_rows, col_width, device=device)
+        if is_ones_col
+        else torch.zeros(num_rows, col_width, device=device)
+        for col_width, is_ones_col in zip(col_widths, col_mask)
+    ]
+    return torch.cat(mask_block, dim=1)
+
+
+class EmformerEncoder(nn.Module):
+    """Implements the Emformer architecture introduced in
+    *Emformer: Efficient Memory Transformer Based Acoustic Model for Low Latency
+    Streaming Speech Recognition*
+    [:footcite:`shi2021emformer`].
+
+    In this model, the memory bank computation is simplifed, using the averaged
+    value of each chunk as its memory vector.
+
+    Args:
+      d_model (int):
+        Input dimension.
+      nhead (int):
+        Number of attention heads in each emformer layer.
+      dim_feedforward (int):
+        Hidden layer dimension of each emformer layer's feedforward network.
+      num_encoder_layers (int):
+        Number of emformer layers to instantiate.
+      chunk_length (int):
+        Length of each input segment.
+      dropout (float, optional):
+        Dropout probability. (default: 0.0)
+      layer_dropout (float, optional):
+        Layer dropout probability. (default: 0.0)
+      cnn_module_kernel (int):
+        Kernel size of convolution module.
+      left_context_length (int, optional):
+        Length of left context. (default: 0)
+      right_context_length (int, optional):
+        Length of right context. (default: 0)
+      memory_size (int, optional):
+        Number of memory elements to use. (default: 0)
+      tanh_on_mem (bool, optional):
+        If ``true``, applies tanh to memory elements. (default: ``false``)
+      negative_inf (float, optional):
+        Value to use for negative infinity in attention weights. (default: -1e8)
+    """
+
+    def __init__(
+        self,
+        chunk_length: int,
+        d_model: int = 256,
+        nhead: int = 4,
+        dim_feedforward: int = 2048,
+        num_encoder_layers: int = 12,
+        dropout: float = 0.1,
+        layer_dropout: float = 0.075,
+        cnn_module_kernel: int = 31,
+        left_context_length: int = 0,
+        right_context_length: int = 0,
+        memory_size: int = 0,
+        tanh_on_mem: bool = False,
+        negative_inf: float = -1e8,
+    ):
+        super().__init__()
+
+        assert (
+            chunk_length - 1
+        ) & chunk_length == 0, "chunk_length should be a power of 2."
+        self.shift = int(math.log(chunk_length, 2))
+
+        self.use_memory = memory_size > 0
+
+        self.emformer_layers = nn.ModuleList(
+            [
+                EmformerEncoderLayer(
+                    d_model=d_model,
+                    nhead=nhead,
+                    dim_feedforward=dim_feedforward,
+                    chunk_length=chunk_length,
+                    dropout=dropout,
+                    layer_dropout=layer_dropout,
+                    cnn_module_kernel=cnn_module_kernel,
+                    left_context_length=left_context_length,
+                    right_context_length=right_context_length,
+                    memory_size=memory_size,
+                    tanh_on_mem=tanh_on_mem,
+                    negative_inf=negative_inf,
+                )
+                for layer_idx in range(num_encoder_layers)
+            ]
+        )
+
+        self.num_encoder_layers = num_encoder_layers
+        self.d_model = d_model
+        self.left_context_length = left_context_length
+        self.right_context_length = right_context_length
+        self.chunk_length = chunk_length
+        self.memory_size = memory_size
+        self.cnn_module_kernel = cnn_module_kernel
+
+    def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor:
+        """Hard copy each chunk's right context and concat them."""
+        T = x.shape[0]
+        num_chunks = math.ceil((T - self.right_context_length) / self.chunk_length)
+        # first (num_chunks - 1) right context block
+        intervals = torch.arange(
+            0, self.chunk_length * (num_chunks - 1), self.chunk_length
+        )
+        first = torch.arange(
+            self.chunk_length, self.chunk_length + self.right_context_length
+        )
+        indexes = intervals.unsqueeze(1) + first.unsqueeze(0)
+        # cat last right context block
+        indexes = torch.cat(
+            [
+                indexes,
+                torch.arange(T - self.right_context_length, T).unsqueeze(0),
+            ]
+        )
+        right_context_blocks = x[indexes.reshape(-1)]
+        return right_context_blocks
+
+    def _gen_attention_mask_col_widths(self, chunk_idx: int, U: int) -> List[int]:
+        """Calculate column widths (key, value) in attention mask for the
+        chunk_idx chunk."""
+        num_chunks = math.ceil(U / self.chunk_length)
+        rc = self.right_context_length
+        lc = self.left_context_length
+        rc_start = chunk_idx * rc
+        rc_end = rc_start + rc
+        chunk_start = max(chunk_idx * self.chunk_length - lc, 0)
+        chunk_end = min((chunk_idx + 1) * self.chunk_length, U)
+        R = rc * num_chunks
+
+        if self.use_memory:
+            m_start = max(chunk_idx - self.memory_size, 0)
+            M = num_chunks - 1
+            col_widths = [
+                m_start,  # before memory
+                chunk_idx - m_start,  # memory
+                M - chunk_idx,  # after memory
+                rc_start,  # before right context
+                rc,  # right context
+                R - rc_end,  # after right context
+                chunk_start,  # before chunk
+                chunk_end - chunk_start,  # chunk
+                U - chunk_end,  # after chunk
+            ]
+        else:
+            col_widths = [
+                rc_start,  # before right context
+                rc,  # right context
+                R - rc_end,  # after right context
+                chunk_start,  # before chunk
+                chunk_end - chunk_start,  # chunk
+                U - chunk_end,  # after chunk
+            ]
+
+        return col_widths
+
+    def _gen_attention_mask(self, utterance: torch.Tensor) -> torch.Tensor:
+        """Generate attention mask to simulate underlying chunk-wise attention
+        computation, where chunk-wise connections are filled with `False`,
+        and other unnecessary connections beyond chunk are filled with `True`.
+
+        R: length of hard-copied right contexts;
+        U: length of full utterance;
+        M: length of memory vectors;
+        Q: length of attention query;
+        KV: length of attention key and value.
+
+        The shape of attention mask is (Q, KV).
+        If self.use_memory is `True`:
+          query = [right_context, utterance];
+          key, value = [memory, right_context, utterance];
+          Q = R + U, KV = M + R + U.
+        Otherwise:
+          query = [right_context, utterance]
+          key, value = [right_context, utterance]
+          Q = R + U, KV = R + U.
+
+        Suppose:
+          c_i: chunk at index i;
+          r_i: right context that c_i can use;
+          l_i: left context that c_i can use;
+          m_i: past memory vectors from previous layer that c_i can use;
+        The target chunk-wise attention is:
+          c_i, r_i (in query) -> l_i, c_i, r_i, m_i (in key).
+        """
+        U = utterance.size(0)
+        num_chunks = math.ceil(U / self.chunk_length)
+
+        right_context_mask = []
+        utterance_mask = []
+
+        if self.use_memory:
+            num_cols = 9
+            # right context and utterance both attend to memory, right context,
+            # utterance
+            right_context_utterance_cols_mask = [
+                idx in [1, 4, 7] for idx in range(num_cols)
+            ]
+        else:
+            num_cols = 6
+            # right context and utterance both attend to right context and
+            # utterance
+            right_context_utterance_cols_mask = [
+                idx in [1, 4] for idx in range(num_cols)
+            ]
+        masks_to_concat = [right_context_mask, utterance_mask]
+
+        for chunk_idx in range(num_chunks):
+            col_widths = self._gen_attention_mask_col_widths(chunk_idx, U)
+
+            right_context_mask_block = _gen_attention_mask_block(
+                col_widths,
+                right_context_utterance_cols_mask,
+                self.right_context_length,
+                utterance.device,
+            )
+            right_context_mask.append(right_context_mask_block)
+
+            utterance_mask_block = _gen_attention_mask_block(
+                col_widths,
+                right_context_utterance_cols_mask,
+                min(
+                    self.chunk_length,
+                    U - chunk_idx * self.chunk_length,
+                ),
+                utterance.device,
+            )
+            utterance_mask.append(utterance_mask_block)
+
+        attention_mask = (
+            1 - torch.cat([torch.cat(mask) for mask in masks_to_concat])
+        ).to(torch.bool)
+        return attention_mask
+
+    def _forward(
+        self, x: torch.Tensor, lengths: torch.Tensor, warmup: float = 1.0
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Forward pass for training and validation mode.
+
+        B: batch size;
+        D: input dimension;
+        U: length of utterance.
+
+        Args:
+          x (torch.Tensor):
+            Utterance frames right-padded with right context frames,
+            with shape (U + right_context_length, B, D).
+          lengths (torch.Tensor):
+            With shape (B,) and i-th element representing number of valid
+            utterance frames for i-th batch element in x, which contains the
+            right_context at the end.
+
+        Returns:
+          A tuple of 2 tensors:
+            - output utterance frames, with shape (U, B, D).
+            - output_lengths, with shape (B,), without containing the
+              right_context at the end.
+        """
+        U = x.size(0) - self.right_context_length
+
+        right_context = self._gen_right_context(x)
+        utterance = x[:U]
+        output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
+        attention_mask = self._gen_attention_mask(utterance)
+
+        M = (
+            right_context.size(0) // self.right_context_length - 1
+            if self.use_memory
+            else 0
+        )
+        padding_mask = make_pad_mask(M + right_context.size(0) + output_lengths)
+
+        output = utterance
+        for layer in self.emformer_layers:
+            output, right_context = layer(
+                output,
+                right_context,
+                attention_mask,
+                padding_mask=padding_mask,
+                warmup=warmup,
+            )
+
+        return output, output_lengths
+
+    @torch.jit.export
+    def infer(
+        self,
+        x: torch.Tensor,
+        states: List[torch.Tensor],
+    ) -> Tuple[torch.Tensor, List[torch.Tensor],]:
+        """Forward pass for streaming inference.
+
+        B: batch size;
+        D: input dimension;
+        U: length of utterance.
+
+        Args:
+          x (torch.Tensor):
+            Utterance frames right-padded with right context frames,
+            with shape (U + right_context_length, B, D).
+          lengths (torch.Tensor):
+            With shape (B,) and i-th element representing number of valid
+            utterance frames for i-th batch element in x, which contains the
+            right_context at the end.
+          states (List[torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor]]: # noqa
+            Cached states containing:
+            - attn_caches: attention states from preceding chunk's computation,
+              where each element corresponds to each emformer layer
+            - conv_caches: left context for causal convolution, where each
+              element corresponds to each layer.
+
+        Returns:
+          (Tensor, Tensor, List[List[torch.Tensor]], List[torch.Tensor]):
+            - output utterance frames, with shape (U, B, D).
+            - output lengths, with shape (B,), without containing the
+              right_context at the end.
+            - updated states from current chunk's computation.
+        """
+        # lengths = chunk_length + right_context_length
+        utterance = x[: self.chunk_length]
+        right_context = x[self.chunk_length :]
+        #  right_context_utterance = torch.cat([right_context, utterance])
+
+        output = utterance
+        output_states: List[torch.Tensor] = []
+        for layer_idx, layer in enumerate(self.emformer_layers):
+            start = layer_idx * 4
+            end = start + 4
+            cache = states[start:end]
+
+            (output, right_context, output_cache,) = layer.infer(
+                output,
+                right_context,
+                padding_mask=None,
+                cache=cache,
+            )
+            output_states.extend(output_cache)
+
+        return output, output_states
+
+    @torch.jit.export
+    def init_states(
+        self, device: torch.device = torch.device("cpu")
+    ) -> List[torch.Tensor]:
+        """Create initial states."""
+        #
+        states = []
+        # layer0: attn cache, conv cache, 3 tensors + 1 tensor
+        # layer1: attn cache, conv cache, 3 tensors +  1 tensor
+        # layer2: attn cache, conv cache, 3 tensors + 1 tensor
+        # ...
+        # last layer: attn cache, conv cache, 3 tensors + 1 tensor
+        for i in range(self.num_encoder_layers):
+            states.append(torch.zeros(self.memory_size, 1, self.d_model, device=device))
+            states.append(
+                torch.zeros(self.left_context_length, 1, self.d_model, device=device)
+            )
+            states.append(
+                torch.zeros(self.left_context_length, 1, self.d_model, device=device)
+            )
+
+            states.append(
+                torch.zeros(1, self.d_model, self.cnn_module_kernel - 1, device=device)
+            )
+        return states
+
+        attn_caches = [
+            [
+                torch.zeros(self.memory_size, self.d_model, device=device),
+                torch.zeros(self.left_context_length, self.d_model, device=device),
+                torch.zeros(self.left_context_length, self.d_model, device=device),
+            ]
+            for _ in range(self.num_encoder_layers)
+        ]
+        conv_caches = [
+            torch.zeros(self.d_model, self.cnn_module_kernel - 1, device=device)
+            for _ in range(self.num_encoder_layers)
+        ]
+        states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]] = (
+            attn_caches,
+            conv_caches,
+        )
+        return states
+
+
+class Emformer(EncoderInterface):
+    def __init__(
+        self,
+        num_features: int,
+        chunk_length: int,
+        subsampling_factor: int = 4,
+        d_model: int = 256,
+        nhead: int = 4,
+        dim_feedforward: int = 2048,
+        num_encoder_layers: int = 12,
+        dropout: float = 0.1,
+        layer_dropout: float = 0.075,
+        cnn_module_kernel: int = 3,
+        left_context_length: int = 0,
+        right_context_length: int = 0,
+        memory_size: int = 0,
+        tanh_on_mem: bool = False,
+        negative_inf: float = -1e8,
+        is_pnnx: bool = True,
+    ):
+        super().__init__()
+
+        self.subsampling_factor = subsampling_factor
+        self.right_context_length = right_context_length
+        self.chunk_length = chunk_length
+        if subsampling_factor != 4:
+            raise NotImplementedError("Support only 'subsampling_factor=4'.")
+        if chunk_length % subsampling_factor != 0:
+            raise NotImplementedError(
+                "chunk_length must be a mutiple of subsampling_factor."
+            )
+        if left_context_length != 0 and left_context_length % subsampling_factor != 0:
+            raise NotImplementedError(
+                "left_context_length must be 0 or a mutiple of subsampling_factor."  # noqa
+            )
+        if right_context_length != 0 and right_context_length % subsampling_factor != 0:
+            raise NotImplementedError(
+                "right_context_length must be 0 or a mutiple of subsampling_factor."  # noqa
+            )
+
+        # self.encoder_embed converts the input of shape (N, T, num_features)
+        # to the shape (N, T//subsampling_factor, d_model).
+        # That is, it does two things simultaneously:
+        #   (1) subsampling: T -> T//subsampling_factor
+        #   (2) embedding: num_features -> d_model
+        self.encoder_embed = Conv2dSubsampling(num_features, d_model, is_pnnx=is_pnnx)
+        self.is_pnnx = is_pnnx
+
+        self.encoder = EmformerEncoder(
+            chunk_length=chunk_length // subsampling_factor,
+            d_model=d_model,
+            nhead=nhead,
+            dim_feedforward=dim_feedforward,
+            num_encoder_layers=num_encoder_layers,
+            dropout=dropout,
+            layer_dropout=layer_dropout,
+            cnn_module_kernel=cnn_module_kernel,
+            left_context_length=left_context_length // subsampling_factor,
+            right_context_length=right_context_length // subsampling_factor,
+            memory_size=memory_size,
+            tanh_on_mem=tanh_on_mem,
+            negative_inf=negative_inf,
+        )
+
+    def _forward(
+        self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Forward pass for training and non-streaming inference.
+
+        B: batch size;
+        D: feature dimension;
+        T: length of utterance.
+
+        Args:
+          x (torch.Tensor):
+            Utterance frames right-padded with right context frames,
+            with shape (B, T, D).
+          x_lens (torch.Tensor):
+            With shape (B,) and i-th element representing number of valid
+            utterance frames for i-th batch element in x, containing the
+            right_context at the end.
+          warmup:
+            A floating point value that gradually increases from 0 throughout
+            training; when it is >= 1.0 we are "fully warmed up".  It is used
+            to turn modules on sequentially.
+
+        Returns:
+          (Tensor, Tensor):
+            - output embedding, with shape (B, T', D), where
+              T' = ((T - 1) // 2 - 1) // 2 - self.right_context_length // 4.
+            - output lengths, with shape (B,), without containing the
+              right_context at the end.
+        """
+        x = self.encoder_embed(x)
+        x = x.permute(1, 0, 2)  # (N, T, C) -> (T, N, C)
+
+        x_lens = (((x_lens - 1) >> 1) - 1) >> 1
+        assert x.size(0) == x_lens.max().item()
+
+        output, output_lengths = self.encoder(x, x_lens, warmup=warmup)  # (T, N, C)
+
+        output = output.permute(1, 0, 2)  # (T, N, C) -> (N, T, C)
+
+        return output, output_lengths
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        states: List[torch.Tensor],
+    ) -> Tuple[torch.Tensor, List[torch.Tensor],]:
+        """Forward pass for streaming inference.
+
+        B: batch size;
+        D: feature dimension;
+        T: length of utterance.
+
+        Args:
+          x (torch.Tensor):
+            Utterance frames right-padded with right context frames,
+            with shape (B, T, D).
+          lengths (torch.Tensor):
+            With shape (B,) and i-th element representing number of valid
+            utterance frames for i-th batch element in x, containing the
+            right_context at the end.
+          states (List[torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor]]: # noqa
+            Cached states containing:
+            - past_lens: number of past frames for each sample in batch
+            - attn_caches: attention states from preceding chunk's computation,
+              where each element corresponds to each emformer layer
+            - conv_caches: left context for causal convolution, where each
+              element corresponds to each layer.
+        Returns:
+          (Tensor, Tensor):
+            - output embedding, with shape (B, T', D), where
+              T' = ((T - 1) // 2 - 1) // 2 - self.right_context_length // 4.
+            - output lengths, with shape (B,), without containing the
+              right_context at the end.
+            - updated states from current chunk's computation.
+        """
+        x = self.encoder_embed(x)
+        # drop the first and last frames
+        x = x[:, 1:-1, :]
+        x = x.permute(1, 0, 2)  # (N, T, C) -> (T, N, C)
+
+        # Caution: We assume the subsampling factor is 4!
+
+        output, output_states = self.encoder.infer(x, states)
+
+        output = output.permute(1, 0, 2)  # (T, N, C) -> (N, T, C)
+
+        return output, output_states
+
+    @torch.jit.export
+    def init_states(
+        self, device: torch.device = torch.device("cpu")
+    ) -> List[torch.Tensor]:
+        """Create initial states."""
+        return self.encoder.init_states(device)
+
+
+class Conv2dSubsampling(nn.Module):
+    """Convolutional 2D subsampling (to 1/4 length).
+
+    Convert an input of shape (N, T, idim) to an output
+    with shape (N, T', odim), where
+    T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
+
+    It is based on
+    https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py  # noqa
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        layer1_channels: int = 8,
+        layer2_channels: int = 32,
+        layer3_channels: int = 128,
+        is_pnnx: bool = False,
+    ) -> None:
+        """
+        Args:
+          in_channels:
+            Number of channels in. The input shape is (N, T, in_channels).
+            Caution: It requires: T >=7, in_channels >=7
+          out_channels
+            Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels)
+          layer1_channels:
+            Number of channels in layer1
+          layer1_channels:
+            Number of channels in layer2
+          is_pnnx:
+            True if we are converting the model to PNNX format.
+            False otherwise.
+        """
+        assert in_channels >= 7
+        super().__init__()
+
+        self.conv = nn.Sequential(
+            ScaledConv2d(
+                in_channels=1,
+                out_channels=layer1_channels,
+                kernel_size=3,
+                padding=1,
+            ),
+            ActivationBalancer(channel_dim=1),
+            DoubleSwish(),
+            ScaledConv2d(
+                in_channels=layer1_channels,
+                out_channels=layer2_channels,
+                kernel_size=3,
+                stride=2,
+            ),
+            ActivationBalancer(channel_dim=1),
+            DoubleSwish(),
+            ScaledConv2d(
+                in_channels=layer2_channels,
+                out_channels=layer3_channels,
+                kernel_size=3,
+                stride=2,
+            ),
+            ActivationBalancer(channel_dim=1),
+            DoubleSwish(),
+        )
+        self.out = ScaledLinear(
+            layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels
+        )
+        # set learn_eps=False because out_norm is preceded by `out`, and `out`
+        # itself has learned scale, so the extra degree of freedom is not
+        # needed.
+        self.out_norm = BasicNorm(out_channels, learn_eps=False)
+        # constrain median of output to be close to zero.
+        self.out_balancer = ActivationBalancer(
+            channel_dim=-1, min_positive=0.45, max_positive=0.55
+        )
+
+        # ncnn supports only batch size == 1
+        self.is_pnnx = is_pnnx
+        self.conv_out_dim = self.out.weight.shape[1]
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """Subsample x.
+
+        Args:
+          x:
+            Its shape is (N, T, idim).
+
+        Returns:
+          Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
+        """
+        # On entry, x is (N, T, idim)
+        x = x.unsqueeze(1)  # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
+        x = self.conv(x)
+
+        if torch.jit.is_tracing() and self.is_pnnx:
+            x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim)
+            x = self.out(x)
+        else:
+            # Now x is of shape (N, odim, ((T-1)//2-1)//2, ((idim-1)//2-1)//2)
+            b, c, t, f = x.size()
+            x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
+        # Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
+        x = self.out_norm(x)
+        x = self.out_balancer(x)
+        return x
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py
new file mode 100755
index 000000000..716de5734
--- /dev/null
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py
@@ -0,0 +1,335 @@
+#!/usr/bin/env python3
+
+"""
+Usage:
+./conv_emformer_transducer_stateless2/export-for-ncnn.py \
+  --exp-dir ./conv_emformer_transducer_stateless2/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 30 \
+  --avg 10 \
+  --use-averaged-model=True \
+  --num-encoder-layers 12 \
+  --chunk-length 32 \
+  --cnn-module-kernel 31 \
+  --left-context-length 32 \
+  --right-context-length 8 \
+  --memory-size 32 \
+
+cd ./conv_emformer_transducer_stateless2/exp
+pnnx encoder_jit_trace-pnnx.pt
+pnnx decoder_jit_trace-pnnx.pt
+pnnx joiner_jit_trace-pnnx.pt
+
+You can find converted models at
+https://huggingface.co/csukuangfj/sherpa-ncnn-conv-emformer-transducer-2022-12-04
+
+See ./streaming-ncnn-decode.py
+and
+https://github.com/k2-fsa/sherpa-ncnn
+for usage.
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import sentencepiece as spm
+import torch
+from scaling_converter import convert_scaled_to_non_scaled
+from train2 import add_model_arguments, get_params, get_transducer_model
+
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.utils import str2bool
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=28,
+        help="""It specifies the checkpoint to use for averaging.
+        Note: Epoch counts from 0.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=15,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless2/exp",
+        help="""It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_500/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--jit",
+        type=str2bool,
+        default=False,
+        help="""True to save a model after applying torch.jit.script.
+        """,
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=True,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def export_encoder_model_jit_trace(
+    encoder_model: torch.nn.Module,
+    encoder_filename: str,
+) -> None:
+    """Export the given encoder model with torch.jit.trace()
+
+    Note: The warmup argument is fixed to 1.
+
+    Args:
+      encoder_model:
+        The input encoder model
+      encoder_filename:
+        The filename to save the exported model.
+    """
+    chunk_length = encoder_model.chunk_length  # before subsampling
+    right_context_length = encoder_model.right_context_length  # before subsampling
+    pad_length = right_context_length + 2 * 4 + 3
+    s = f"chunk_length: {chunk_length}, "
+    s += f"right_context_length: {right_context_length}\n"
+    logging.info(s)
+
+    T = chunk_length + pad_length
+
+    x = torch.zeros(1, T, 80, dtype=torch.float32)
+    states = encoder_model.init_states()
+    states = encoder_model.init_states()
+
+    traced_model = torch.jit.trace(encoder_model, (x, states))
+    traced_model.save(encoder_filename)
+    logging.info(f"Saved to {encoder_filename}")
+
+
+def export_decoder_model_jit_trace(
+    decoder_model: torch.nn.Module,
+    decoder_filename: str,
+) -> None:
+    """Export the given decoder model with torch.jit.trace()
+
+    Note: The argument need_pad is fixed to False.
+
+    Args:
+      decoder_model:
+        The input decoder model
+      decoder_filename:
+        The filename to save the exported model.
+    """
+    y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
+    need_pad = torch.tensor([False])
+
+    traced_model = torch.jit.trace(decoder_model, (y, need_pad))
+    traced_model.save(decoder_filename)
+    logging.info(f"Saved to {decoder_filename}")
+
+
+def export_joiner_model_jit_trace(
+    joiner_model: torch.nn.Module,
+    joiner_filename: str,
+) -> None:
+    """Export the given joiner model with torch.jit.trace()
+
+    Note: The argument project_input is fixed to True. A user should not
+    project the encoder_out/decoder_out by himself/herself. The exported joiner
+    will do that for the user.
+
+    Args:
+      joiner_model:
+        The input joiner model
+      joiner_filename:
+        The filename to save the exported model.
+
+    """
+    encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
+    decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
+    encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
+    decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
+
+    traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out))
+    traced_model.save(joiner_filename)
+    logging.info(f"Saved to {joiner_filename}")
+
+
+@torch.no_grad()
+def main():
+    args = get_parser().parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    params = get_params()
+    params.update(vars(args))
+
+    device = torch.device("cpu")
+
+    logging.info(f"device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to("cpu")
+    model.eval()
+
+    convert_scaled_to_non_scaled(model, inplace=True)
+    logging.info("Using torch.jit.trace()")
+
+    logging.info("Exporting encoder")
+    encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt"
+    export_encoder_model_jit_trace(model.encoder, encoder_filename)
+
+    logging.info("Exporting decoder")
+    decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt"
+    export_decoder_model_jit_trace(model.decoder, decoder_filename)
+
+    logging.info("Exporting joiner")
+    joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt"
+    export_joiner_model_jit_trace(model.joiner, joiner_filename)
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py
new file mode 100755
index 000000000..1fe358c79
--- /dev/null
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py
@@ -0,0 +1,292 @@
+#!/usr/bin/env python3
+# flake8: noqa
+# Copyright      2022  Xiaomi Corp.        (authors: Fangjun Kuang, Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script loads torchscript models exported by `torch.jit.trace()`
+and uses them to decode waves.
+You can use the following command to get the exported models:
+
+./conv_emformer_transducer_stateless2/export-for-ncnn.py \
+  --exp-dir ./conv_emformer_transducer_stateless2/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 20 \
+  --avg 10
+
+Usage of this script:
+
+./conv_emformer_transducer_stateless2/jit_pretrained.py \
+  --encoder-model-filename ./conv_emformer_transducer_stateless2/exp/encoder_jit_trace-pnnx.pt \
+  --decoder-model-filename ./conv_emformer_transducer_stateless2/exp/decoder_jit_trace-pnnx.pt \
+  --joiner-model-filename ./conv_emformer_transducer_stateless2/exp/joiner_jit_trace-pnnx.pt \
+  --bpe-model ./data/lang_bpe_500/bpe.model \
+  /path/to/foo.wav \
+"""
+
+import argparse
+import logging
+import math
+from typing import List
+
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
+from torch.nn.utils.rnn import pad_sequence
+from typing import Optional, List
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--encoder-model-filename",
+        type=str,
+        required=True,
+        help="Path to the encoder torchscript model. ",
+    )
+
+    parser.add_argument(
+        "--decoder-model-filename",
+        type=str,
+        required=True,
+        help="Path to the decoder torchscript model. ",
+    )
+
+    parser.add_argument(
+        "--joiner-model-filename",
+        type=str,
+        required=True,
+        help="Path to the joiner torchscript model. ",
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        help="""Path to bpe.model.""",
+    )
+
+    parser.add_argument(
+        "sound_file",
+        type=str,
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
+    )
+
+    parser.add_argument(
+        "--sample-rate",
+        type=int,
+        default=16000,
+        help="The sample rate of the input sound file",
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="Context size of the decoder model",
+    )
+
+    return parser
+
+
+def read_sound_files(
+    filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+    """Read a list of sound files into a list 1-D float32 torch tensors.
+    Args:
+      filenames:
+        A list of sound filenames.
+      expected_sample_rate:
+        The expected sample rate of the sound files.
+    Returns:
+      Return a list of 1-D float32 torch tensors.
+    """
+    ans = []
+    for f in filenames:
+        wave, sample_rate = torchaudio.load(f)
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        # We use only the first channel
+        ans.append(wave[0])
+    return ans
+
+
+def greedy_search(
+    decoder: torch.jit.ScriptModule,
+    joiner: torch.jit.ScriptModule,
+    encoder_out: torch.Tensor,
+    decoder_out: Optional[torch.Tensor] = None,
+    hyp: Optional[List[int]] = None,
+):
+    assert encoder_out.ndim == 2
+    context_size = 2
+    blank_id = 0
+
+    if decoder_out is None:
+        assert hyp is None, hyp
+        hyp = [blank_id] * context_size
+        decoder_input = torch.tensor(hyp, dtype=torch.int32).unsqueeze(0)
+        # decoder_input.shape (1,, 1 context_size)
+        decoder_out = decoder(decoder_input, torch.tensor([0])).squeeze(1)
+    else:
+        assert decoder_out.ndim == 2
+        assert hyp is not None, hyp
+
+    T = encoder_out.size(0)
+    for i in range(T):
+        cur_encoder_out = encoder_out[i : i + 1]
+        joiner_out = joiner(cur_encoder_out, decoder_out).squeeze(0)
+        y = joiner_out.argmax(dim=0).item()
+
+        if y != blank_id:
+            hyp.append(y)
+            decoder_input = hyp[-context_size:]
+
+            decoder_input = torch.tensor(decoder_input, dtype=torch.int32).unsqueeze(0)
+            decoder_out = decoder(decoder_input, torch.tensor([0])).squeeze(1)
+
+    return hyp, decoder_out
+
+
+def create_streaming_feature_extractor(sample_rate) -> OnlineFeature:
+    """Create a CPU streaming feature extractor.
+
+    At present, we assume it returns a fbank feature extractor with
+    fixed options. In the future, we will support passing in the options
+    from outside.
+
+    Returns:
+      Return a CPU streaming feature extractor.
+    """
+    opts = FbankOptions()
+    opts.device = "cpu"
+    opts.frame_opts.dither = 0
+    opts.frame_opts.snip_edges = False
+    opts.frame_opts.samp_freq = sample_rate
+    opts.mel_opts.num_bins = 80
+    return OnlineFbank(opts)
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+    logging.info(vars(args))
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    encoder = torch.jit.load(args.encoder_model_filename)
+    decoder = torch.jit.load(args.decoder_model_filename)
+    joiner = torch.jit.load(args.joiner_model_filename)
+
+    encoder.eval()
+    decoder.eval()
+    joiner.eval()
+
+    encoder.to(device)
+    decoder.to(device)
+    joiner.to(device)
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(args.bpe_model)
+
+    logging.info("Constructing Fbank computer")
+    online_fbank = create_streaming_feature_extractor(args.sample_rate)
+
+    logging.info(f"Reading sound files: {args.sound_file}")
+    wave_samples = read_sound_files(
+        filenames=[args.sound_file],
+        expected_sample_rate=args.sample_rate,
+    )[0]
+    logging.info(wave_samples.shape)
+
+    logging.info("Decoding started")
+    chunk_length = encoder.chunk_length
+    right_context_length = encoder.right_context_length
+
+    # Assume the subsampling factor is 4
+    pad_length = right_context_length + 2 * 4 + 3
+    T = chunk_length + pad_length
+
+    logging.info(f"chunk_length: {chunk_length}")
+    logging.info(f"right_context_length: {right_context_length}")
+
+    states = encoder.init_states(device)
+    logging.info(f"num layers: {len(states)//4}")
+
+    tail_padding = torch.zeros(int(0.3 * args.sample_rate), dtype=torch.float32)
+
+    wave_samples = torch.cat([wave_samples, tail_padding])
+
+    chunk = int(0.25 * args.sample_rate)  # 0.2 second
+    num_processed_frames = 0
+
+    hyp = None
+    decoder_out = None
+
+    start = 0
+    while start < wave_samples.numel():
+        logging.info(f"{start}/{wave_samples.numel()}")
+        end = min(start + chunk, wave_samples.numel())
+        samples = wave_samples[start:end]
+        start += chunk
+        online_fbank.accept_waveform(
+            sampling_rate=args.sample_rate,
+            waveform=samples,
+        )
+        while online_fbank.num_frames_ready - num_processed_frames >= T:
+            frames = []
+            for i in range(T):
+                frames.append(online_fbank.get_frame(num_processed_frames + i))
+            num_processed_frames += chunk_length
+            frames = torch.cat(frames, dim=0).unsqueeze(0)
+            # TODO(fangjun): remove x_lens
+            x_lens = torch.tensor([T])
+            encoder_out, _, states = encoder(frames, x_lens, states)
+
+            hyp, decoder_out = greedy_search(
+                decoder, joiner, encoder_out.squeeze(0), decoder_out, hyp
+            )
+
+    context_size = 2
+
+    logging.info(args.sound_file)
+    logging.info(sp.decode(hyp[context_size:]))
+
+    logging.info("Decoding Done")
+
+
+torch.set_num_threads(4)
+torch.set_num_interop_threads(1)
+torch._C._jit_set_profiling_executor(False)
+torch._C._jit_set_profiling_mode(False)
+torch._C._set_graph_executor_optimize(False)
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/lstmp.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/lstmp.py
new file mode 120000
index 000000000..4f377cd01
--- /dev/null
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/lstmp.py
@@ -0,0 +1 @@
+../lstm_transducer_stateless2/lstmp.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/scaling_converter.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/scaling_converter.py
new file mode 120000
index 000000000..3b667058d
--- /dev/null
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/scaling_converter.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless3/scaling_converter.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py
new file mode 100755
index 000000000..b21fe5c7e
--- /dev/null
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py
@@ -0,0 +1,387 @@
+#!/usr/bin/env python3
+#
+# Copyright      2022  Xiaomi Corp.        (authors: Fangjun Kuang, Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+./conv_emformer_transducer_stateless2/streaming-ncnn-decode.py \
+  --tokens ./sherpa-ncnn-conv-emformer-transducer-2022-12-04/tokens.txt \
+  --encoder-param-filename ./sherpa-ncnn-conv-emformer-transducer-2022-12-04/encoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.param \
+  --encoder-bin-filename ./sherpa-ncnn-conv-emformer-transducer-2022-12-04/encoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.bin \
+  --decoder-param-filename ./sherpa-ncnn-conv-emformer-transducer-2022-12-04/decoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.param \
+  --decoder-bin-filename ./sherpa-ncnn-conv-emformer-transducer-2022-12-04/decoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.bin \
+  --joiner-param-filename ./sherpa-ncnn-conv-emformer-transducer-2022-12-04/joiner_jit_trace-epoch-30-avg-10-pnnx.ncnn.param \
+  --joiner-bin-filename ./sherpa-ncnn-conv-emformer-transducer-2022-12-04/joiner_jit_trace-epoch-30-avg-10-pnnx.ncnn.bin \
+  ./sherpa-ncnn-conv-emformer-transducer-2022-12-04/test_wavs/1089-134686-0001.wav
+
+You can find pretrained models at
+https://huggingface.co/csukuangfj/sherpa-ncnn-conv-emformer-transducer-2022-12-04
+"""
+
+import argparse
+import logging
+from typing import List, Optional, Tuple
+
+import k2
+import ncnn
+import torch
+import torchaudio
+from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
+
+
+def get_args():
+    parser = argparse.ArgumentParser()
+
+    parser.add_argument(
+        "--tokens",
+        type=str,
+        help="Path to tokens.txt",
+    )
+
+    parser.add_argument(
+        "--encoder-param-filename",
+        type=str,
+        help="Path to encoder.ncnn.param",
+    )
+
+    parser.add_argument(
+        "--encoder-bin-filename",
+        type=str,
+        help="Path to encoder.ncnn.bin",
+    )
+
+    parser.add_argument(
+        "--decoder-param-filename",
+        type=str,
+        help="Path to decoder.ncnn.param",
+    )
+
+    parser.add_argument(
+        "--decoder-bin-filename",
+        type=str,
+        help="Path to decoder.ncnn.bin",
+    )
+
+    parser.add_argument(
+        "--joiner-param-filename",
+        type=str,
+        help="Path to joiner.ncnn.param",
+    )
+
+    parser.add_argument(
+        "--joiner-bin-filename",
+        type=str,
+        help="Path to joiner.ncnn.bin",
+    )
+
+    parser.add_argument(
+        "sound_filename",
+        type=str,
+        help="Path to foo.wav",
+    )
+
+    return parser.parse_args()
+
+
+class Model:
+    def __init__(self, args):
+        self.init_encoder(args)
+        self.init_decoder(args)
+        self.init_joiner(args)
+
+        self.num_layers = 12
+        self.memory_size = 32
+        self.d_model = 512
+        self.cnn_module_kernel = 31
+
+        self.left_context_length = 32 // 4  # after subsampling
+        self.chunk_length = 32  # before subsampling
+        right_context_length = 8  # before subsampling
+        pad_length = right_context_length + 2 * 4 + 3
+        self.T = self.chunk_length + pad_length
+        print("T", self.T, self.chunk_length)
+
+    def get_init_states(self) -> List[torch.Tensor]:
+        states = []
+
+        for i in range(self.num_layers):
+            s0 = torch.zeros(self.memory_size, self.d_model)
+            s1 = torch.zeros(self.left_context_length, self.d_model)
+            s2 = torch.zeros(self.left_context_length, self.d_model)
+            s3 = torch.zeros(self.d_model, self.cnn_module_kernel - 1)
+            states.extend([s0, s1, s2, s3])
+
+        return states
+
+    def init_encoder(self, args):
+        encoder_net = ncnn.Net()
+        encoder_net.opt.use_packing_layout = False
+        encoder_net.opt.use_fp16_storage = False
+        encoder_param = args.encoder_param_filename
+        encoder_model = args.encoder_bin_filename
+
+        encoder_net.load_param(encoder_param)
+        encoder_net.load_model(encoder_model)
+
+        self.encoder_net = encoder_net
+
+    def init_decoder(self, args):
+        decoder_param = args.decoder_param_filename
+        decoder_model = args.decoder_bin_filename
+
+        decoder_net = ncnn.Net()
+
+        decoder_net.load_param(decoder_param)
+        decoder_net.load_model(decoder_model)
+
+        self.decoder_net = decoder_net
+
+    def init_joiner(self, args):
+        joiner_param = args.joiner_param_filename
+        joiner_model = args.joiner_bin_filename
+        joiner_net = ncnn.Net()
+        joiner_net.load_param(joiner_param)
+        joiner_net.load_model(joiner_model)
+
+        self.joiner_net = joiner_net
+
+    def run_encoder(
+        self,
+        x: torch.Tensor,
+        states: List[torch.Tensor],
+    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+        """
+        Args:
+          x:
+            A tensor of shape (T, C)
+          states:
+            A list of tensors. len(states) == self.num_layers * 4
+        Returns:
+          Return a tuple containing:
+           - encoder_out, a tensor of shape (T, encoder_dim).
+           - next_states, a list of tensors containing the next states
+        """
+        with self.encoder_net.create_extractor() as ex:
+            ex.set_num_threads(4)
+            ex.input("in0", ncnn.Mat(x.numpy()).clone())
+
+            # layer0 in2-in5
+            # layer1 in6-in9
+            for i in range(self.num_layers):
+                offset = 1 + i * 4
+                name = f"in{offset}"
+                # (32, 1, 512) -> (32, 512)
+                ex.input(name, ncnn.Mat(states[i * 4 + 0].numpy()).clone())
+
+                name = f"in{offset+1}"
+                #  (8, 1, 512) -> (8, 512)
+                ex.input(name, ncnn.Mat(states[i * 4 + 1].numpy()).clone())
+
+                name = f"in{offset+2}"
+                #  (8, 1, 512) -> (8, 512)
+                ex.input(name, ncnn.Mat(states[i * 4 + 2].numpy()).clone())
+
+                name = f"in{offset+3}"
+                #  (1, 512, 2) -> (512, 2)
+                ex.input(name, ncnn.Mat(states[i * 4 + 3].numpy()).clone())
+
+            import pdb
+
+            #  pdb.set_trace()
+            ret, ncnn_out0 = ex.extract("out0")
+            #  assert ret == 0, ret
+            encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone()
+
+            out_states: List[torch.Tensor] = []
+            for i in range(4 * self.num_layers):
+                name = f"out{i+1}"
+                ret, ncnn_out_state = ex.extract(name)
+                assert ret == 0, ret
+                ncnn_out_state = torch.from_numpy(ncnn_out_state.numpy())
+                out_states.append(ncnn_out_state)
+
+            return encoder_out, out_states
+
+    def run_decoder(self, decoder_input):
+        assert decoder_input.dtype == torch.int32
+
+        with self.decoder_net.create_extractor() as ex:
+            ex.set_num_threads(4)
+            ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone())
+            ret, ncnn_out0 = ex.extract("out0")
+            assert ret == 0, ret
+            decoder_out = torch.from_numpy(ncnn_out0.numpy()).clone()
+            return decoder_out
+
+    def run_joiner(self, encoder_out, decoder_out):
+        with self.joiner_net.create_extractor() as ex:
+            ex.set_num_threads(4)
+            ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone())
+            ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone())
+            ret, ncnn_out0 = ex.extract("out0")
+            assert ret == 0, ret
+            joiner_out = torch.from_numpy(ncnn_out0.numpy()).clone()
+            return joiner_out
+
+
+def read_sound_files(
+    filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+    """Read a list of sound files into a list 1-D float32 torch tensors.
+    Args:
+      filenames:
+        A list of sound filenames.
+      expected_sample_rate:
+        The expected sample rate of the sound files.
+    Returns:
+      Return a list of 1-D float32 torch tensors.
+    """
+    ans = []
+    for f in filenames:
+        wave, sample_rate = torchaudio.load(f)
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        # We use only the first channel
+        ans.append(wave[0])
+    return ans
+
+
+def create_streaming_feature_extractor() -> OnlineFeature:
+    """Create a CPU streaming feature extractor.
+
+    At present, we assume it returns a fbank feature extractor with
+    fixed options. In the future, we will support passing in the options
+    from outside.
+
+    Returns:
+      Return a CPU streaming feature extractor.
+    """
+    opts = FbankOptions()
+    opts.device = "cpu"
+    opts.frame_opts.dither = 0
+    opts.frame_opts.snip_edges = False
+    opts.frame_opts.samp_freq = 16000
+    opts.mel_opts.num_bins = 80
+    return OnlineFbank(opts)
+
+
+def greedy_search(
+    model: Model,
+    encoder_out: torch.Tensor,
+    decoder_out: Optional[torch.Tensor] = None,
+    hyp: Optional[List[int]] = None,
+):
+    context_size = 2
+    blank_id = 0
+
+    if decoder_out is None:
+        assert hyp is None, hyp
+        hyp = [blank_id] * context_size
+        decoder_input = torch.tensor(hyp, dtype=torch.int32)  # (1, context_size)
+        decoder_out = model.run_decoder(decoder_input).squeeze(0)
+    else:
+        assert decoder_out.ndim == 1
+        assert hyp is not None, hyp
+
+    T = encoder_out.size(0)
+    for t in range(T):
+        cur_encoder_out = encoder_out[t]
+
+        joiner_out = model.run_joiner(cur_encoder_out, decoder_out)
+        y = joiner_out.argmax(dim=0).item()
+        if y != blank_id:
+            hyp.append(y)
+            decoder_input = hyp[-context_size:]
+            decoder_input = torch.tensor(decoder_input, dtype=torch.int32)
+            decoder_out = model.run_decoder(decoder_input).squeeze(0)
+
+    return hyp, decoder_out
+
+
+def main():
+    args = get_args()
+    logging.info(vars(args))
+
+    model = Model(args)
+
+    sound_file = args.sound_filename
+
+    sample_rate = 16000
+
+    logging.info("Constructing Fbank computer")
+    online_fbank = create_streaming_feature_extractor()
+
+    logging.info(f"Reading sound files: {sound_file}")
+    wave_samples = read_sound_files(
+        filenames=[sound_file],
+        expected_sample_rate=sample_rate,
+    )[0]
+    logging.info(wave_samples.shape)
+
+    tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32)
+
+    wave_samples = torch.cat([wave_samples, tail_padding])
+
+    states = model.get_init_states()
+
+    hyp = None
+    decoder_out = None
+
+    num_processed_frames = 0
+    segment = model.T
+    offset = model.chunk_length
+
+    chunk = int(1 * sample_rate)  # 0.2 second
+
+    start = 0
+    while start < wave_samples.numel():
+        end = min(start + chunk, wave_samples.numel())
+        samples = wave_samples[start:end]
+        start += chunk
+
+        online_fbank.accept_waveform(
+            sampling_rate=sample_rate,
+            waveform=samples,
+        )
+        while online_fbank.num_frames_ready - num_processed_frames >= segment:
+            frames = []
+            for i in range(segment):
+                frames.append(online_fbank.get_frame(num_processed_frames + i))
+            num_processed_frames += offset
+            frames = torch.cat(frames, dim=0)
+            encoder_out, states = model.run_encoder(frames, states)
+            hyp, decoder_out = greedy_search(model, encoder_out, decoder_out, hyp)
+
+    symbol_table = k2.SymbolTable.from_file(args.tokens)
+
+    context_size = 2
+    text = ""
+    for i in hyp[context_size:]:
+        text += symbol_table[i]
+    text = text.replace("▁", " ").strip()
+
+    logging.info(sound_file)
+    logging.info(text)
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+
+    main()
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py
new file mode 100755
index 000000000..c91f94876
--- /dev/null
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py
@@ -0,0 +1,1128 @@
+#!/usr/bin/env python3
+# Copyright    2021  Xiaomi Corp.        (authors: Fangjun Kuang,
+#                                                  Wei Kang,
+#                                                  Mingshuang Luo,)
+#                                                  Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./conv_emformer_transducer_stateless2/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --exp-dir conv_emformer_transducer_stateless2/exp \
+  --full-libri 1 \
+  --max-duration 280 \
+  --master-port 12321 \
+  --num-encoder-layers 12 \
+  --chunk-length 32 \
+  --cnn-module-kernel 31 \
+  --left-context-length 32 \
+  --right-context-length 8 \
+  --memory-size 32
+
+# For mix precision training:
+./conv_emformer_transducer_stateless2/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --use-fp16 1 \
+  --exp-dir conv_emformer_transducer_stateless2/exp \
+  --full-libri 1 \
+  --max-duration 300 \
+  --master-port 12321 \
+  --num-encoder-layers 12 \
+  --chunk-length 32 \
+  --cnn-module-kernel 31 \
+  --left-context-length 32 \
+  --right-context-length 8 \
+  --memory-size 32
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from decoder import Decoder
+from emformer2 import Emformer
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import Transducer
+from optim import Eden, Eve
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+    save_checkpoint_with_global_batch_idx,
+    update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+    parser.add_argument(
+        "--encoder-dim",
+        type=int,
+        default=512,
+        help="Attention dim for the Emformer",
+    )
+
+    parser.add_argument(
+        "--nhead",
+        type=int,
+        default=8,
+        help="Number of attention heads for the Emformer",
+    )
+
+    parser.add_argument(
+        "--dim-feedforward",
+        type=int,
+        default=2048,
+        help="Feed-forward dimension for the Emformer",
+    )
+
+    parser.add_argument(
+        "--num-encoder-layers",
+        type=int,
+        default=12,
+        help="Number of encoder layers for the Emformer",
+    )
+
+    parser.add_argument(
+        "--cnn-module-kernel",
+        type=int,
+        default=31,
+        help="Kernel size for the convolution module.",
+    )
+
+    parser.add_argument(
+        "--left-context-length",
+        type=int,
+        default=32,
+        help="""Number of frames before subsampling for left context
+        in the Emformer.""",
+    )
+
+    parser.add_argument(
+        "--chunk-length",
+        type=int,
+        default=32,
+        help="""Number of frames before subsampling for each chunk
+        in the Emformer.""",
+    )
+
+    parser.add_argument(
+        "--right-context-length",
+        type=int,
+        default=8,
+        help="""Number of frames before subsampling for right context
+        in the Emformer.""",
+    )
+
+    parser.add_argument(
+        "--memory-size",
+        type=int,
+        default=0,
+        help="Number of entries in the memory for the Emformer",
+    )
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--world-size",
+        type=int,
+        default=1,
+        help="Number of GPUs for DDP training.",
+    )
+
+    parser.add_argument(
+        "--master-port",
+        type=int,
+        default=12354,
+        help="Master port to use for DDP training.",
+    )
+
+    parser.add_argument(
+        "--tensorboard",
+        type=str2bool,
+        default=True,
+        help="Should various information be logged in tensorboard.",
+    )
+
+    parser.add_argument(
+        "--num-epochs",
+        type=int,
+        default=30,
+        help="Number of epochs to train.",
+    )
+
+    parser.add_argument(
+        "--start-epoch",
+        type=int,
+        default=1,
+        help="""Resume training from this epoch. It should be positive.
+        If larger than 1, it will load checkpoint from
+        exp-dir/epoch-{start_epoch-1}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--start-batch",
+        type=int,
+        default=0,
+        help="""If positive, --start-epoch is ignored and
+        it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless2/exp",
+        help="""The experiment dir.
+        It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_500/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--initial-lr",
+        type=float,
+        default=0.003,
+        help="""The initial learning rate. This value should not need to be
+        changed.""",
+    )
+
+    parser.add_argument(
+        "--lr-batches",
+        type=float,
+        default=5000,
+        help="""Number of steps that affects how rapidly the learning rate decreases.
+        We suggest not to change this.""",
+    )
+
+    parser.add_argument(
+        "--lr-epochs",
+        type=float,
+        default=6,
+        help="""Number of epochs that affects how rapidly the learning rate decreases.
+        """,
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+
+    parser.add_argument(
+        "--prune-range",
+        type=int,
+        default=5,
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
+    )
+
+    parser.add_argument(
+        "--lm-scale",
+        type=float,
+        default=0.25,
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
+    )
+
+    parser.add_argument(
+        "--am-scale",
+        type=float,
+        default=0.0,
+        help="The scale to smooth the loss with am (output of encoder network) part.",
+    )
+
+    parser.add_argument(
+        "--simple-loss-scale",
+        type=float,
+        default=0.5,
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
+    )
+
+    parser.add_argument(
+        "--seed",
+        type=int,
+        default=42,
+        help="The seed for random generators intended for reproducibility",
+    )
+
+    parser.add_argument(
+        "--print-diagnostics",
+        type=str2bool,
+        default=False,
+        help="Accumulate stats on activations, print them and exit.",
+    )
+
+    parser.add_argument(
+        "--save-every-n",
+        type=int,
+        default=8000,
+        help="""Save checkpoint after processing this number of batches"
+        periodically. We save checkpoint to exp-dir/ whenever
+        params.batch_idx_train % save_every_n == 0. The checkpoint filename
+        has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+        Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+        end of each epoch where `xxx` is the epoch number counting from 0.
+        """,
+    )
+
+    parser.add_argument(
+        "--keep-last-k",
+        type=int,
+        default=20,
+        help="""Only keep this number of checkpoints on disk.
+        For instance, if it is 3, there are only 3 checkpoints
+        in the exp-dir with filenames `checkpoint-xxx.pt`.
+        It does not affect checkpoints with name `epoch-xxx.pt`.
+        """,
+    )
+
+    parser.add_argument(
+        "--average-period",
+        type=int,
+        default=100,
+        help="""Update the averaged model, namely `model_avg`, after processing
+        this number of batches. `model_avg` is a separate version of model,
+        in which each floating-point parameter is the average of all the
+        parameters from the start of training. Each time we take the average,
+        we do: `model_avg = model * (average_period / batch_idx_train) +
+            model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+        """,
+    )
+
+    parser.add_argument(
+        "--use-fp16",
+        type=str2bool,
+        default=False,
+        help="Whether to use half precision training.",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    """Return a dict containing training parameters.
+
+    All training related parameters that are not passed from the commandline
+    are saved in the variable `params`.
+
+    Commandline options are merged into `params` after they are parsed, so
+    you can also access them via `params`.
+
+    Explanation of options saved in `params`:
+
+        - best_train_loss: Best training loss so far. It is used to select
+                           the model that has the lowest training loss. It is
+                           updated during the training.
+
+        - best_valid_loss: Best validation loss so far. It is used to select
+                           the model that has the lowest validation loss. It is
+                           updated during the training.
+
+        - best_train_epoch: It is the epoch that has the best training loss.
+
+        - best_valid_epoch: It is the epoch that has the best validation loss.
+
+        - batch_idx_train: Used to writing statistics to tensorboard. It
+                           contains number of batches trained so far across
+                           epochs.
+
+        - log_interval:  Print training loss if batch_idx % log_interval` is 0
+
+        - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+        - valid_interval:  Run validation if batch_idx % valid_interval is 0
+
+        - feature_dim: The model input dim. It has to match the one used
+                       in computing features.
+
+        - subsampling_factor:  The subsampling factor for the model.
+
+        - encoder_dim: Hidden dim for multi-head attention model.
+
+        - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+        - warm_step: The warm_step for Noam optimizer.
+    """
+    params = AttributeDict(
+        {
+            "best_train_loss": float("inf"),
+            "best_valid_loss": float("inf"),
+            "best_train_epoch": -1,
+            "best_valid_epoch": -1,
+            "batch_idx_train": 0,
+            "log_interval": 50,
+            "reset_interval": 200,
+            "valid_interval": 3000,  # For the 100h subset, use 800
+            # parameters for Emformer
+            "feature_dim": 80,
+            "subsampling_factor": 4,
+            # parameters for decoder
+            "decoder_dim": 512,
+            # parameters for joiner
+            "joiner_dim": 512,
+            # parameters for Noam
+            "model_warm_step": 3000,  # arg given to model, not for lrate
+            "env_info": get_env_info(),
+        }
+    )
+
+    return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+    # TODO: We can add an option to switch between Conformer and Transformer
+    encoder = Emformer(
+        num_features=params.feature_dim,
+        chunk_length=params.chunk_length,
+        subsampling_factor=params.subsampling_factor,
+        d_model=params.encoder_dim,
+        nhead=params.nhead,
+        dim_feedforward=params.dim_feedforward,
+        num_encoder_layers=params.num_encoder_layers,
+        cnn_module_kernel=params.cnn_module_kernel,
+        left_context_length=params.left_context_length,
+        right_context_length=params.right_context_length,
+        memory_size=params.memory_size,
+    )
+    return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+    decoder = Decoder(
+        vocab_size=params.vocab_size,
+        decoder_dim=params.decoder_dim,
+        blank_id=params.blank_id,
+        context_size=params.context_size,
+    )
+    return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+    joiner = Joiner(
+        encoder_dim=params.encoder_dim,
+        decoder_dim=params.decoder_dim,
+        joiner_dim=params.joiner_dim,
+        vocab_size=params.vocab_size,
+    )
+    return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+    encoder = get_encoder_model(params)
+    decoder = get_decoder_model(params)
+    joiner = get_joiner_model(params)
+
+    model = Transducer(
+        encoder=encoder,
+        decoder=decoder,
+        joiner=joiner,
+        encoder_dim=params.encoder_dim,
+        decoder_dim=params.decoder_dim,
+        joiner_dim=params.joiner_dim,
+        vocab_size=params.vocab_size,
+    )
+    return model
+
+
+def load_checkpoint_if_available(
+    params: AttributeDict,
+    model: nn.Module,
+    model_avg: nn.Module = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+    """Load checkpoint from file.
+
+    If params.start_batch is positive, it will load the checkpoint from
+    `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+    params.start_epoch is larger than 1, it will load the checkpoint from
+    `params.start_epoch - 1`.
+
+    Apart from loading state dict for `model` and `optimizer` it also updates
+    `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+    and `best_valid_loss` in `params`.
+
+    Args:
+      params:
+        The return value of :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer that we are using.
+      scheduler:
+        The scheduler that we are using.
+    Returns:
+      Return a dict containing previously saved training info.
+    """
+    if params.start_batch > 0:
+        filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+    elif params.start_epoch > 1:
+        filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+    else:
+        return None
+
+    assert filename.is_file(), f"{filename} does not exist!"
+
+    saved_params = load_checkpoint(
+        filename,
+        model=model,
+        model_avg=model_avg,
+        optimizer=optimizer,
+        scheduler=scheduler,
+    )
+
+    keys = [
+        "best_train_epoch",
+        "best_valid_epoch",
+        "batch_idx_train",
+        "best_train_loss",
+        "best_valid_loss",
+    ]
+    for k in keys:
+        params[k] = saved_params[k]
+
+    if params.start_batch > 0:
+        if "cur_epoch" in saved_params:
+            params["start_epoch"] = saved_params["cur_epoch"]
+
+        if "cur_batch_idx" in saved_params:
+            params["cur_batch_idx"] = saved_params["cur_batch_idx"]
+
+    return saved_params
+
+
+def save_checkpoint(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    model_avg: Optional[nn.Module] = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+    sampler: Optional[CutSampler] = None,
+    scaler: Optional[GradScaler] = None,
+    rank: int = 0,
+) -> None:
+    """Save model, optimizer, scheduler and training stats to file.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer used in the training.
+      sampler:
+       The sampler for the training dataset.
+      scaler:
+        The scaler used for mix precision training.
+    """
+    if rank != 0:
+        return
+    filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+    save_checkpoint_impl(
+        filename=filename,
+        model=model,
+        model_avg=model_avg,
+        params=params,
+        optimizer=optimizer,
+        scheduler=scheduler,
+        sampler=sampler,
+        scaler=scaler,
+        rank=rank,
+    )
+
+    if params.best_train_epoch == params.cur_epoch:
+        best_train_filename = params.exp_dir / "best-train-loss.pt"
+        copyfile(src=filename, dst=best_train_filename)
+
+    if params.best_valid_epoch == params.cur_epoch:
+        best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+        copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    sp: spm.SentencePieceProcessor,
+    batch: dict,
+    is_training: bool,
+    warmup: float = 1.0,
+) -> Tuple[Tensor, MetricsTracker]:
+    """
+    Compute RNN-T loss given the model and its inputs.
+
+    Args:
+      params:
+        Parameters for training. See :func:`get_params`.
+      model:
+        The model for training. It is an instance of Conformer in our case.
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      is_training:
+        True for training. False for validation. When it is True, this
+        function enables autograd during computation; when it is False, it
+        disables autograd.
+     warmup: a floating point value which increases throughout training;
+        values >= 1.0 are fully warmed up and have all modules present.
+    """
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+    feature = batch["inputs"]
+    # at entry, feature is (N, T, C)
+    assert feature.ndim == 3
+    feature = feature.to(device)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    texts = batch["supervisions"]["text"]
+    y = sp.encode(texts, out_type=int)
+    y = k2.RaggedTensor(y).to(device)
+
+    with torch.set_grad_enabled(is_training):
+        simple_loss, pruned_loss = model(
+            x=feature,
+            x_lens=feature_lens,
+            y=y,
+            prune_range=params.prune_range,
+            am_scale=params.am_scale,
+            lm_scale=params.lm_scale,
+            warmup=warmup,
+        )
+        # after the main warmup step, we keep pruned_loss_scale small
+        # for the same amount of time (model_warm_step), to avoid
+        # overwhelming the simple_loss and causing it to diverge,
+        # in case it had not fully learned the alignment yet.
+        pruned_loss_scale = (
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+        )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+    assert loss.requires_grad == is_training
+
+    info = MetricsTracker()
+    with warnings.catch_warnings():
+        warnings.simplefilter("ignore")
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+    # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances`  # noqa
+    info["utterances"] = feature.size(0)
+    # averaged input duration in frames over utterances
+    info["utt_duration"] = feature_lens.sum().item()
+    # averaged padding proportion over utterances
+    info["utt_pad_proportion"] = (
+        ((feature.size(1) - feature_lens) / feature.size(1)).sum().item()
+    )
+
+    # Note: We use reduction=sum while computing the loss.
+    info["loss"] = loss.detach().cpu().item()
+    info["simple_loss"] = simple_loss.detach().cpu().item()
+    info["pruned_loss"] = pruned_loss.detach().cpu().item()
+
+    return loss, info
+
+
+def compute_validation_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    sp: spm.SentencePieceProcessor,
+    valid_dl: torch.utils.data.DataLoader,
+    world_size: int = 1,
+) -> MetricsTracker:
+    """Run the validation process."""
+    model.eval()
+
+    tot_loss = MetricsTracker()
+
+    for batch_idx, batch in enumerate(valid_dl):
+        loss, loss_info = compute_loss(
+            params=params,
+            model=model,
+            sp=sp,
+            batch=batch,
+            is_training=False,
+        )
+        assert loss.requires_grad is False
+        tot_loss = tot_loss + loss_info
+
+    if world_size > 1:
+        tot_loss.reduce(loss.device)
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    if loss_value < params.best_valid_loss:
+        params.best_valid_epoch = params.cur_epoch
+        params.best_valid_loss = loss_value
+
+    return tot_loss
+
+
+def train_one_epoch(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    optimizer: torch.optim.Optimizer,
+    scheduler: LRSchedulerType,
+    sp: spm.SentencePieceProcessor,
+    train_dl: torch.utils.data.DataLoader,
+    valid_dl: torch.utils.data.DataLoader,
+    scaler: GradScaler,
+    model_avg: Optional[nn.Module] = None,
+    tb_writer: Optional[SummaryWriter] = None,
+    world_size: int = 1,
+    rank: int = 0,
+) -> None:
+    """Train the model for one epoch.
+
+    The training loss from the mean of all frames is saved in
+    `params.train_loss`. It runs the validation process every
+    `params.valid_interval` batches.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The model for training.
+      optimizer:
+        The optimizer we are using.
+      scheduler:
+        The learning rate scheduler, we call step() every step.
+      train_dl:
+        Dataloader for the training dataset.
+      valid_dl:
+        Dataloader for the validation dataset.
+      scaler:
+        The scaler used for mix precision training.
+      model_avg:
+        The stored model averaged from the start of training.
+      tb_writer:
+        Writer to write log messages to tensorboard.
+      world_size:
+        Number of nodes in DDP training. If it is 1, DDP is disabled.
+      rank:
+        The rank of the node in DDP training. If no DDP is used, it should
+        be set to 0.
+    """
+    model.train()
+
+    tot_loss = MetricsTracker()
+
+    cur_batch_idx = params.get("cur_batch_idx", 0)
+
+    for batch_idx, batch in enumerate(train_dl):
+        if batch_idx < cur_batch_idx:
+            continue
+        cur_batch_idx = batch_idx
+
+        params.batch_idx_train += 1
+        batch_size = len(batch["supervisions"]["text"])
+
+        with torch.cuda.amp.autocast(enabled=params.use_fp16):
+            loss, loss_info = compute_loss(
+                params=params,
+                model=model,
+                sp=sp,
+                batch=batch,
+                is_training=True,
+                warmup=(params.batch_idx_train / params.model_warm_step),
+            )
+        # summary stats
+        tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+        # NOTE: We use reduction==sum and loss is computed over utterances
+        # in the batch and there is no normalization to it so far.
+        scaler.scale(loss).backward()
+        scheduler.step_batch(params.batch_idx_train)
+        scaler.step(optimizer)
+        scaler.update()
+        optimizer.zero_grad()
+
+        if params.print_diagnostics and batch_idx == 5:
+            return
+
+        if (
+            rank == 0
+            and params.batch_idx_train > 0
+            and params.batch_idx_train % params.average_period == 0
+        ):
+            update_averaged_model(
+                params=params,
+                model_cur=model,
+                model_avg=model_avg,
+            )
+
+        if (
+            params.batch_idx_train > 0
+            and params.batch_idx_train % params.save_every_n == 0
+        ):
+            params.cur_batch_idx = batch_idx
+            save_checkpoint_with_global_batch_idx(
+                out_dir=params.exp_dir,
+                global_batch_idx=params.batch_idx_train,
+                model=model,
+                model_avg=model_avg,
+                params=params,
+                optimizer=optimizer,
+                scheduler=scheduler,
+                sampler=train_dl.sampler,
+                scaler=scaler,
+                rank=rank,
+            )
+            del params.cur_batch_idx
+            remove_checkpoints(
+                out_dir=params.exp_dir,
+                topk=params.keep_last_k,
+                rank=rank,
+            )
+
+        if batch_idx % params.log_interval == 0:
+            cur_lr = scheduler.get_last_lr()[0]
+            logging.info(
+                f"Epoch {params.cur_epoch}, "
+                f"batch {batch_idx}, loss[{loss_info}], "
+                f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+                f"lr: {cur_lr:.2e}"
+            )
+
+            if tb_writer is not None:
+                tb_writer.add_scalar(
+                    "train/learning_rate", cur_lr, params.batch_idx_train
+                )
+
+                loss_info.write_summary(
+                    tb_writer, "train/current_", params.batch_idx_train
+                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+
+        if batch_idx > 0 and batch_idx % params.valid_interval == 0:
+            logging.info("Computing validation loss")
+            valid_info = compute_validation_loss(
+                params=params,
+                model=model,
+                sp=sp,
+                valid_dl=valid_dl,
+                world_size=world_size,
+            )
+            model.train()
+            logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+            if tb_writer is not None:
+                valid_info.write_summary(
+                    tb_writer, "train/valid_", params.batch_idx_train
+                )
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    params.train_loss = loss_value
+    if params.train_loss < params.best_train_loss:
+        params.best_train_epoch = params.cur_epoch
+        params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+    """
+    Args:
+      rank:
+        It is a value between 0 and `world_size-1`, which is
+        passed automatically by `mp.spawn()` in :func:`main`.
+        The node with rank 0 is responsible for saving checkpoint.
+      world_size:
+        Number of GPUs for DDP training.
+      args:
+        The return value of get_parser().parse_args()
+    """
+    params = get_params()
+    params.update(vars(args))
+    if params.full_libri is False:
+        params.valid_interval = 1600
+
+    fix_random_seed(params.seed)
+    if world_size > 1:
+        setup_dist(rank, world_size, params.master_port)
+
+    setup_logger(f"{params.exp_dir}/log/log-train")
+    logging.info("Training started")
+
+    if args.tensorboard and rank == 0:
+        tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+    else:
+        tb_writer = None
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", rank)
+    logging.info(f"Device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    assert params.save_every_n >= params.average_period
+    model_avg: Optional[nn.Module] = None
+    if rank == 0:
+        # model_avg is only used with rank 0
+        model_avg = copy.deepcopy(model)
+
+    assert params.start_epoch > 0, params.start_epoch
+    checkpoints = load_checkpoint_if_available(
+        params=params, model=model, model_avg=model_avg
+    )
+
+    model.to(device)
+    if world_size > 1:
+        logging.info("Using DDP")
+        model = DDP(model, device_ids=[rank])
+
+    optimizer = Eve(model.parameters(), lr=params.initial_lr)
+
+    scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+    if checkpoints and "optimizer" in checkpoints:
+        logging.info("Loading optimizer state dict")
+        optimizer.load_state_dict(checkpoints["optimizer"])
+
+    if (
+        checkpoints
+        and "scheduler" in checkpoints
+        and checkpoints["scheduler"] is not None
+    ):
+        logging.info("Loading scheduler state dict")
+        scheduler.load_state_dict(checkpoints["scheduler"])
+
+    if params.print_diagnostics:
+        opts = diagnostics.TensorDiagnosticOptions(
+            2**22
+        )  # allow 4 megabytes per sub-module
+        diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+    librispeech = LibriSpeechAsrDataModule(args)
+
+    if params.full_libri:
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
+
+    def remove_short_and_long_utt(c: Cut):
+        # Keep only utterances with duration between 1 second and 20 seconds
+        #
+        # Caution: There is a reason to select 20.0 here. Please see
+        # ../local/display_manifest_statistics.py
+        #
+        # You should use ../local/display_manifest_statistics.py to get
+        # an utterance duration distribution for your dataset to select
+        # the threshold
+        return 1.0 <= c.duration <= 20.0
+
+    train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+    if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+        # We only load the sampler's state dict when it loads a checkpoint
+        # saved in the middle of an epoch
+        sampler_state_dict = checkpoints["sampler"]
+    else:
+        sampler_state_dict = None
+
+    train_dl = librispeech.train_dataloaders(
+        train_cuts, sampler_state_dict=sampler_state_dict
+    )
+
+    valid_cuts = librispeech.dev_clean_cuts()
+    valid_cuts += librispeech.dev_other_cuts()
+    valid_dl = librispeech.valid_dataloaders(valid_cuts)
+
+    if not params.print_diagnostics:
+        scan_pessimistic_batches_for_oom(
+            model=model,
+            train_dl=train_dl,
+            optimizer=optimizer,
+            sp=sp,
+            params=params,
+        )
+
+    scaler = GradScaler(enabled=params.use_fp16)
+    if checkpoints and "grad_scaler" in checkpoints:
+        logging.info("Loading grad scaler state dict")
+        scaler.load_state_dict(checkpoints["grad_scaler"])
+
+    for epoch in range(params.start_epoch, params.num_epochs + 1):
+        scheduler.step_epoch(epoch - 1)
+        fix_random_seed(params.seed + epoch - 1)
+        train_dl.sampler.set_epoch(epoch - 1)
+
+        if tb_writer is not None:
+            tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+        params.cur_epoch = epoch
+
+        train_one_epoch(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sp=sp,
+            train_dl=train_dl,
+            valid_dl=valid_dl,
+            scaler=scaler,
+            tb_writer=tb_writer,
+            world_size=world_size,
+            rank=rank,
+        )
+
+        if params.print_diagnostics:
+            diagnostic.print_diagnostics()
+            break
+
+        save_checkpoint(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sampler=train_dl.sampler,
+            scaler=scaler,
+            rank=rank,
+        )
+
+    logging.info("Done!")
+
+    if world_size > 1:
+        torch.distributed.barrier()
+        cleanup_dist()
+
+
+def scan_pessimistic_batches_for_oom(
+    model: Union[nn.Module, DDP],
+    train_dl: torch.utils.data.DataLoader,
+    optimizer: torch.optim.Optimizer,
+    sp: spm.SentencePieceProcessor,
+    params: AttributeDict,
+):
+    from lhotse.dataset import find_pessimistic_batches
+
+    logging.info(
+        "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+    )
+    batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+    for criterion, cuts in batches.items():
+        batch = train_dl.dataset[cuts]
+        try:
+            # warmup = 0.0 is so that the derivs for the pruned loss stay zero
+            # (i.e. are not remembered by the decaying-average in adam), because
+            # we want to avoid these params being subject to shrinkage in adam.
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, _ = compute_loss(
+                    params=params,
+                    model=model,
+                    sp=sp,
+                    batch=batch,
+                    is_training=True,
+                    warmup=0.0,
+                )
+            loss.backward()
+            optimizer.step()
+            optimizer.zero_grad()
+        except RuntimeError as e:
+            if "CUDA out of memory" in str(e):
+                logging.error(
+                    "Your GPU ran out of memory with the current "
+                    "max_duration setting. We recommend decreasing "
+                    "max_duration and trying again.\n"
+                    f"Failing criterion: {criterion} "
+                    f"(={crit_values[criterion]}) ..."
+                )
+            raise
+
+
+def main():
+    parser = get_parser()
+    LibriSpeechAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    world_size = args.world_size
+    assert world_size >= 1
+    if world_size > 1:
+        mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+    else:
+        run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+    main()

From 10472e7ffc8bd3f8a096eb7cc62c86a4b861a9a1 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ali=20Haznedaro=C4=9Flu?=
 <53865510+ahazned@users.noreply.github.com>
Date: Wed, 7 Dec 2022 03:22:50 +0300
Subject: [PATCH 065/120] Update prepare.sh (#737)

---
 egs/spgispeech/ASR/prepare.sh | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/egs/spgispeech/ASR/prepare.sh b/egs/spgispeech/ASR/prepare.sh
index 4842f52d0..8331f94d5 100755
--- a/egs/spgispeech/ASR/prepare.sh
+++ b/egs/spgispeech/ASR/prepare.sh
@@ -108,7 +108,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
     pieces=$(find data/manifests -name "cuts_train_[0-9]*.jsonl.gz")
     lhotse combine $pieces data/manifests/cuts_train.jsonl.gz
   fi
-  gunzip -c data/manifests/train_cuts.jsonl.gz | shuf | gzip -c > data/manifests/train_cuts_shuf.jsonl.gz
+  gunzip -c data/manifests/cuts_train.jsonl.gz | shuf | gzip -c > data/manifests/cuts_train_shuf.jsonl.gz
 fi
 
 if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
@@ -136,7 +136,7 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
     # Add special words to words.txt
     echo " 0" > $lang_dir/words.txt
     echo "!SIL 1" >> $lang_dir/words.txt
-    echo "[UNK] 2" >> $lang_dir/words.txt
+    echo " 2" >> $lang_dir/words.txt
 
     # Add regular words to words.txt
     gunzip -c data/manifests/cuts_train_raw.jsonl.gz \

From 0e325c8782c8b9178cf0f2b030e49ae64f2b091d Mon Sep 17 00:00:00 2001
From: huangruizhe 
Date: Wed, 7 Dec 2022 02:43:26 -0500
Subject: [PATCH 066/120] Fixed rnn_lm model.py (#738)

---
 icefall/rnn_lm/model.py | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/icefall/rnn_lm/model.py b/icefall/rnn_lm/model.py
index 9eef88840..3598a4857 100644
--- a/icefall/rnn_lm/model.py
+++ b/icefall/rnn_lm/model.py
@@ -159,10 +159,10 @@ class RnnLmModel(torch.nn.Module):
         if state:
             h, c = state
         else:
-            h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(
+            h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size).to(
                 device
             )
-            c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(
+            c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size).to(
                 device
             )
 
@@ -179,8 +179,8 @@ class RnnLmModel(torch.nn.Module):
         if state:
             h, c = state
         else:
-            h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size)
-            c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size)
+            h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size)
+            c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size)
 
         device = next(self.parameters()).device
 

From d65fe17d2766e34adbb4080f9691ea829ac0ae05 Mon Sep 17 00:00:00 2001
From: armusc <46787089+armusc@users.noreply.github.com>
Date: Thu, 8 Dec 2022 13:21:51 +0100
Subject: [PATCH 067/120] Update train.py with parameters_names as required by
 optimizer initialization (#742)

* Update train.py
---
 egs/ami/ASR/pruned_transducer_stateless7/train.py     | 11 ++++++++++-
 .../ASR/pruned_transducer_stateless7_ctc/train.py     | 11 ++++++++++-
 2 files changed, 20 insertions(+), 2 deletions(-)

diff --git a/egs/ami/ASR/pruned_transducer_stateless7/train.py b/egs/ami/ASR/pruned_transducer_stateless7/train.py
index b5efb3405..81823ced2 100755
--- a/egs/ami/ASR/pruned_transducer_stateless7/train.py
+++ b/egs/ami/ASR/pruned_transducer_stateless7/train.py
@@ -972,7 +972,16 @@ def run(rank, world_size, args):
         logging.info("Using DDP")
         model = DDP(model, device_ids=[rank], find_unused_parameters=True)
 
-    optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0)
+    parameters_names = []
+    parameters_names.append(
+        [name_param_pair[0] for name_param_pair in model.named_parameters()]
+    )
+    optimizer = ScaledAdam(
+        model.parameters(),
+        lr=params.base_lr,
+        clipping_scale=2.0,
+        parameters_names=parameters_names,
+    )
 
     scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
 
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py
index abfd56e5a..162ad8412 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py
@@ -1036,7 +1036,16 @@ def run(rank, world_size, args):
         logging.info("Using DDP")
         model = DDP(model, device_ids=[rank], find_unused_parameters=True)
 
-    optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0)
+    parameters_names = []
+    parameters_names.append(
+        [name_param_pair[0] for name_param_pair in model.named_parameters()]
+    )
+    optimizer = ScaledAdam(
+        model.parameters(),
+        lr=params.base_lr,
+        clipping_scale=2.0,
+        parameters_names=parameters_names,
+    )
 
     scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
 

From 4501821fd98821a6cf3a238c6dc5c01422643fdb Mon Sep 17 00:00:00 2001
From: Fangjun Kuang 
Date: Fri, 9 Dec 2022 16:46:44 +0800
Subject: [PATCH 068/120] Support using OpenFst to compile HLG. (#606)

* Support using OpenFst to compile HLG.

* Fix style issues
---
 .../ASR/local/compile_hlg_using_openfst.py    | 184 ++++++++++++++++++
 egs/librispeech/ASR/prepare.sh                |  41 +++-
 icefall/shared/convert-k2-to-openfst.py       | 102 ++++++++++
 requirements.txt                              |   1 +
 4 files changed, 325 insertions(+), 3 deletions(-)
 create mode 100755 egs/librispeech/ASR/local/compile_hlg_using_openfst.py
 create mode 100755 icefall/shared/convert-k2-to-openfst.py

diff --git a/egs/librispeech/ASR/local/compile_hlg_using_openfst.py b/egs/librispeech/ASR/local/compile_hlg_using_openfst.py
new file mode 100755
index 000000000..9e5e3df69
--- /dev/null
+++ b/egs/librispeech/ASR/local/compile_hlg_using_openfst.py
@@ -0,0 +1,184 @@
+#!/usr/bin/env python3
+# Copyright    2022  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This script takes as input lang_dir and generates HLG from
+
+    - H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt
+    - L, the lexicon, built from lang_dir/L_disambig.fst
+
+        Caution: We use a lexicon that contains disambiguation symbols
+
+    - G, the LM, built from data/lm/G_3_gram.fst.txt
+
+The generated HLG is saved in $lang_dir/HLG_fst.pt
+
+So when to use this script instead of ./local/compile_hlg.py ?
+If you have a very large G, ./local/compile_hlg.py may throw OOM for
+determinization. In that case, you can use this script to compile HLG.
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import k2
+import kaldifst
+import torch
+
+from icefall.lexicon import Lexicon
+
+
+def get_args():
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        help="""Input and output directory.
+        """,
+    )
+
+    return parser.parse_args()
+
+
+def compile_HLG(lang_dir: str) -> kaldifst.StdVectorFst:
+    """
+    Args:
+      lang_dir:
+        The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
+
+    Return:
+      An FST representing HLG.
+    """
+
+    L = kaldifst.StdVectorFst.read(f"{lang_dir}/L_disambig.fst")
+    logging.info("Arc sort L")
+    kaldifst.arcsort(L, sort_type="olabel")
+    logging.info(f"L: #states {L.num_states}")
+
+    G_filename_txt = "data/lm/G_3_gram.fst.txt"
+    G_filename_binary = "data/lm/G_3_gram.fst"
+    if Path(G_filename_binary).is_file():
+        logging.info(f"Loading {G_filename_binary}")
+        G = kaldifst.StdVectorFst.read(G_filename_binary)
+    else:
+        logging.info(f"Loading {G_filename_txt}")
+        with open(G_filename_txt) as f:
+            G = kaldifst.compile(s=f.read(), acceptor=False)
+            logging.info(f"Saving G to {G_filename_binary}")
+            G.write(G_filename_binary)
+
+    logging.info("Arc sort G")
+    kaldifst.arcsort(G, sort_type="ilabel")
+
+    logging.info(f"G: #states {G.num_states}")
+
+    logging.info("Compose L and G and connect LG")
+    LG = kaldifst.compose(L, G, connect=True)
+    logging.info(f"LG: #states {LG.num_states}")
+
+    logging.info("Determinizestar LG")
+    kaldifst.determinize_star(LG)
+    logging.info(f"LG after determinize_star: #states {LG.num_states}")
+
+    logging.info("Minimize encoded LG")
+    kaldifst.minimize_encoded(LG)
+    logging.info(f"LG after minimize_encoded: #states {LG.num_states}")
+
+    logging.info("Converting LG to k2 format")
+    LG = k2.Fsa.from_openfst(LG.to_str(is_acceptor=False), acceptor=False)
+    logging.info(f"LG in k2: #states: {LG.shape[0]}, #arcs: {LG.num_arcs}")
+
+    lexicon = Lexicon(lang_dir)
+
+    first_token_disambig_id = lexicon.token_table["#0"]
+    first_word_disambig_id = lexicon.word_table["#0"]
+    logging.info(f"token id for #0: {first_token_disambig_id}")
+    logging.info(f"word id for #0: {first_word_disambig_id}")
+
+    max_token_id = max(lexicon.tokens)
+    modified = False
+    logging.info(
+        f"Building ctc_topo. modified: {modified}, max_token_id: {max_token_id}"
+    )
+
+    H = k2.ctc_topo(max_token_id, modified=modified)
+    logging.info(f"H: #states: {H.shape[0]}, #arcs: {H.num_arcs}")
+
+    logging.info("Removing disambiguation symbols on LG")
+    LG.labels[LG.labels >= first_token_disambig_id] = 0
+    LG.aux_labels[LG.aux_labels >= first_word_disambig_id] = 0
+
+    # See https://github.com/k2-fsa/k2/issues/874
+    # for why we need to set LG.properties to None
+    LG.__dict__["_properties"] = None
+
+    logging.info("Removing epsilons from LG")
+    LG = k2.remove_epsilon(LG)
+    logging.info(
+        f"LG after k2.remove_epsilon: #states: {LG.shape[0]}, #arcs: {LG.num_arcs}"
+    )
+
+    logging.info("Connecting LG after removing epsilons")
+    LG = k2.connect(LG)
+    LG.aux_labels = LG.aux_labels.remove_values_eq(0)
+    logging.info(f"LG after k2.connect: #states: {LG.shape[0]}, #arcs: {LG.num_arcs}")
+
+    logging.info("Arc sorting LG")
+    LG = k2.arc_sort(LG)
+
+    logging.info("Composing H and LG")
+
+    HLG = k2.compose(H, LG, inner_labels="tokens")
+    logging.info(
+        f"HLG after k2.compose: #states: {HLG.shape[0]}, #arcs: {HLG.num_arcs}"
+    )
+
+    logging.info("Connecting HLG")
+    HLG = k2.connect(HLG)
+    logging.info(
+        f"HLG after k2.connect: #states: {HLG.shape[0]}, #arcs: {HLG.num_arcs}"
+    )
+
+    logging.info("Arc sorting LG")
+    HLG = k2.arc_sort(HLG)
+
+    return HLG
+
+
+def main():
+    args = get_args()
+    lang_dir = Path(args.lang_dir)
+
+    filename = lang_dir / "HLG_fst.pt"
+
+    if filename.is_file():
+        logging.info(f"{filename} already exists - skipping")
+        return
+
+    HLG = compile_HLG(lang_dir)
+    logging.info(f"Saving HLG to {filename}")
+    torch.save(HLG.as_dict(), filename)
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+
+    main()
diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh
index 542bbcdd8..11c8e1066 100755
--- a/egs/librispeech/ASR/prepare.sh
+++ b/egs/librispeech/ASR/prepare.sh
@@ -44,9 +44,9 @@ dl_dir=$PWD/download
 # It will generate data/lang_bpe_xxx,
 # data/lang_bpe_yyy if the array contains xxx, yyy
 vocab_sizes=(
-  5000
-  2000
-  1000
+  # 5000
+  # 2000
+  # 1000
   500
 )
 
@@ -168,6 +168,22 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
   if [ ! -f $lang_dir/L_disambig.pt ]; then
     ./local/prepare_lang.py --lang-dir $lang_dir
   fi
+
+  if [ ! -f $lang_dir/L.fst ]; then
+    log "Converting L.pt to L.fst"
+    ./shared/convert-k2-to-openfst.py \
+      --olabels aux_labels \
+      $lang_dir/L.pt \
+      $lang_dir/L.fst
+  fi
+
+  if [ ! -f $lang_dir/L_disambig.fst ]; then
+    log "Converting L_disambig.pt to L_disambig.fst"
+    ./shared/convert-k2-to-openfst.py \
+      --olabels aux_labels \
+      $lang_dir/L_disambig.pt \
+      $lang_dir/disambig_L.fst
+  fi
 fi
 
 
@@ -208,6 +224,22 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
         --lexicon $lang_dir/lexicon.txt \
         --bpe-model $lang_dir/bpe.model
     fi
+
+    if [ ! -f $lang_dir/L.fst ]; then
+      log "Converting L.pt to L.fst"
+      ./shared/convert-k2-to-openfst.py \
+        --olabels aux_labels \
+        $lang_dir/L.pt \
+        $lang_dir/L.fst
+    fi
+
+    if [ ! -f $lang_dir/L_disambig.fst ]; then
+      log "Converting L_disambig.pt to L_disambig.fst"
+      ./shared/convert-k2-to-openfst.py \
+        --olabels aux_labels \
+        $lang_dir/L_disambig.pt \
+        $lang_dir/L_disambig.fst
+    fi
   done
 fi
 
@@ -270,10 +302,13 @@ fi
 if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
   log "Stage 9: Compile HLG"
   ./local/compile_hlg.py --lang-dir data/lang_phone
+  ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone
 
   for vocab_size in ${vocab_sizes[@]}; do
     lang_dir=data/lang_bpe_${vocab_size}
     ./local/compile_hlg.py --lang-dir $lang_dir
+
+    ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir
   done
 fi
 
diff --git a/icefall/shared/convert-k2-to-openfst.py b/icefall/shared/convert-k2-to-openfst.py
new file mode 100755
index 000000000..29a2cd7f7
--- /dev/null
+++ b/icefall/shared/convert-k2-to-openfst.py
@@ -0,0 +1,102 @@
+#!/usr/bin/env python3
+# Copyright    2022  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script takes as input an FST in k2 format and convert it
+to an FST in OpenFST format.
+
+The generated FST is saved into a binary file and its type is
+StdVectorFst.
+
+Usage examples:
+(1) Convert an acceptor
+
+  ./convert-k2-to-openfst.py in.pt binary.fst
+
+(2) Convert a transducer
+
+  ./convert-k2-to-openfst.py --olabels aux_labels in.pt binary.fst
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import k2
+import kaldifst.utils
+import torch
+
+
+def get_args():
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--olabels",
+        type=str,
+        default=None,
+        help="""If not empty, the input FST is assumed to be a transducer
+        and we use its attribute specified by "olabels" as the output labels.
+        """,
+    )
+    parser.add_argument(
+        "input_filename",
+        type=str,
+        help="Path to the input FST in k2 format",
+    )
+
+    parser.add_argument(
+        "output_filename",
+        type=str,
+        help="Path to the output FST in OpenFst format",
+    )
+
+    return parser.parse_args()
+
+
+def main():
+    args = get_args()
+    logging.info(f"{vars(args)}")
+
+    input_filename = args.input_filename
+    output_filename = args.output_filename
+    olabels = args.olabels
+
+    if Path(output_filename).is_file():
+        logging.info(f"{output_filename} already exists - skipping")
+        return
+
+    assert Path(input_filename).is_file(), f"{input_filename} does not exist"
+    logging.info(f"Loading {input_filename}")
+    k2_fst = k2.Fsa.from_dict(torch.load(input_filename))
+    if olabels:
+        assert hasattr(k2_fst, olabels), f"No such attribute: {olabels}"
+
+    p = Path(output_filename).parent
+    if not p.is_dir():
+        logging.info(f"Creating {p}")
+        p.mkdir(parents=True)
+
+    logging.info("Converting (May take some time if the input FST is large)")
+    fst = kaldifst.utils.k2_to_openfst(k2_fst, olabels=olabels)
+    logging.info(f"Saving to {output_filename}")
+    fst.write(output_filename)
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/requirements.txt b/requirements.txt
index 5e32af853..a07f6b7c7 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,3 +1,4 @@
+kaldifst
 kaldilm
 kaldialign
 sentencepiece>=0.1.96

From a0cf85343dad31a678ddaac7652f0bb2bbb4cac2 Mon Sep 17 00:00:00 2001
From: Yifan Yang <64255737+yfyeung@users.noreply.github.com>
Date: Fri, 9 Dec 2022 19:23:11 +0800
Subject: [PATCH 069/120] fix for memory usage in
 pruned_transducer_stateless7/scaling.py (#752)

Co-authored-by: yifanyang 
---
 egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py
index 6f63e0629..042c9c3e4 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py
@@ -562,7 +562,7 @@ class ActivationBalancer(torch.nn.Module):
                 sign_factor = None
 
             scale_factor = _compute_scale_factor(
-                x,
+                x.detach(),
                 self.channel_dim,
                 min_abs=self.min_abs,
                 max_abs=self.max_abs,

From c4aaf3ea3bfcebdad79f4e9d10080ed514113830 Mon Sep 17 00:00:00 2001
From: Desh Raj 
Date: Sat, 10 Dec 2022 15:45:23 +0530
Subject: [PATCH 070/120] Add AliMeeting multi-condition training recipe (#751)

* add AliMeeting multi-domain recipe

* convert scripts to symbolic links
---
 egs/alimeeting/ASR_v2/README.md               |   38 +
 egs/alimeeting/ASR_v2/RESULTS.md              |   90 ++
 egs/alimeeting/ASR_v2/local/__init__.py       |    0
 .../ASR_v2/local/compute_fbank_alimeeting.py  |  193 +++
 .../ASR_v2/local/compute_fbank_musan.py       |    1 +
 .../local/prepare_alimeeting_enhanced.py      |  158 +++
 .../ASR_v2/local/prepare_alimeeting_gss.sh    |   98 ++
 egs/alimeeting/ASR_v2/local/prepare_char.py   |    1 +
 egs/alimeeting/ASR_v2/local/prepare_words.py  |    1 +
 egs/alimeeting/ASR_v2/local/text2segments.py  |    1 +
 egs/alimeeting/ASR_v2/local/text2token.py     |    1 +
 egs/alimeeting/ASR_v2/prepare.sh              |  125 ++
 .../pruned_transducer_stateless7/__init__.py  |    0
 .../asr_datamodule.py                         |  419 ++++++
 .../beam_search.py                            |    1 +
 .../pruned_transducer_stateless7/decode.py    |  698 ++++++++++
 .../pruned_transducer_stateless7/decoder.py   |    1 +
 .../encoder_interface.py                      |    1 +
 .../pruned_transducer_stateless7/export.py    |  320 +++++
 .../jit_pretrained.py                         |    1 +
 .../pruned_transducer_stateless7/joiner.py    |    1 +
 .../pruned_transducer_stateless7/model.py     |    1 +
 .../pruned_transducer_stateless7/optim.py     |    1 +
 .../pretrained.py                             |    1 +
 .../pruned_transducer_stateless7/scaling.py   |    1 +
 .../scaling_converter.py                      |    1 +
 .../test_model.py                             |    1 +
 .../pruned_transducer_stateless7/train.py     | 1186 +++++++++++++++++
 .../pruned_transducer_stateless7/zipformer.py |    1 +
 egs/alimeeting/ASR_v2/shared                  |    1 +
 30 files changed, 3343 insertions(+)
 create mode 100644 egs/alimeeting/ASR_v2/README.md
 create mode 100644 egs/alimeeting/ASR_v2/RESULTS.md
 create mode 100644 egs/alimeeting/ASR_v2/local/__init__.py
 create mode 100755 egs/alimeeting/ASR_v2/local/compute_fbank_alimeeting.py
 create mode 120000 egs/alimeeting/ASR_v2/local/compute_fbank_musan.py
 create mode 100644 egs/alimeeting/ASR_v2/local/prepare_alimeeting_enhanced.py
 create mode 100755 egs/alimeeting/ASR_v2/local/prepare_alimeeting_gss.sh
 create mode 120000 egs/alimeeting/ASR_v2/local/prepare_char.py
 create mode 120000 egs/alimeeting/ASR_v2/local/prepare_words.py
 create mode 120000 egs/alimeeting/ASR_v2/local/text2segments.py
 create mode 120000 egs/alimeeting/ASR_v2/local/text2token.py
 create mode 100755 egs/alimeeting/ASR_v2/prepare.sh
 create mode 100644 egs/alimeeting/ASR_v2/pruned_transducer_stateless7/__init__.py
 create mode 100644 egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py
 create mode 120000 egs/alimeeting/ASR_v2/pruned_transducer_stateless7/beam_search.py
 create mode 100755 egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decode.py
 create mode 120000 egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decoder.py
 create mode 120000 egs/alimeeting/ASR_v2/pruned_transducer_stateless7/encoder_interface.py
 create mode 100755 egs/alimeeting/ASR_v2/pruned_transducer_stateless7/export.py
 create mode 120000 egs/alimeeting/ASR_v2/pruned_transducer_stateless7/jit_pretrained.py
 create mode 120000 egs/alimeeting/ASR_v2/pruned_transducer_stateless7/joiner.py
 create mode 120000 egs/alimeeting/ASR_v2/pruned_transducer_stateless7/model.py
 create mode 120000 egs/alimeeting/ASR_v2/pruned_transducer_stateless7/optim.py
 create mode 120000 egs/alimeeting/ASR_v2/pruned_transducer_stateless7/pretrained.py
 create mode 120000 egs/alimeeting/ASR_v2/pruned_transducer_stateless7/scaling.py
 create mode 120000 egs/alimeeting/ASR_v2/pruned_transducer_stateless7/scaling_converter.py
 create mode 120000 egs/alimeeting/ASR_v2/pruned_transducer_stateless7/test_model.py
 create mode 100755 egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py
 create mode 120000 egs/alimeeting/ASR_v2/pruned_transducer_stateless7/zipformer.py
 create mode 120000 egs/alimeeting/ASR_v2/shared

diff --git a/egs/alimeeting/ASR_v2/README.md b/egs/alimeeting/ASR_v2/README.md
new file mode 100644
index 000000000..f70327501
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/README.md
@@ -0,0 +1,38 @@
+
+# Introduction
+
+This recipe trains multi-domain ASR models for AliMeeting. By multi-domain, we mean that
+we train a single model on close-talk and far-field conditions. This recipe optionally
+uses [GSS]-based enhancement for far-field array microphone.
+We pool data in the following 4 ways and train a single model on the pooled data:
+
+(i) individual headset microphone (IHM)
+(ii) IHM with simulated reverb
+(iii) Single distant microphone (SDM)
+(iv) GSS-enhanced array microphones
+
+This is different from `alimeeting/ASR` since that recipe trains a model only on the
+far-field audio. Additionally, we use text normalization here similar to the original
+M2MeT challenge, so the results should be more comparable to those from Table 4 of
+the [paper](https://arxiv.org/abs/2110.07393).
+
+The following additional packages need to be installed to run this recipe:
+* `pip install jieba`
+* `pip install paddlepaddle`
+* `pip install git+https://github.com/desh2608/gss.git`
+
+[./RESULTS.md](./RESULTS.md) contains the latest results.
+
+## Performance Record
+
+### pruned_transducer_stateless7
+
+The following are decoded using `modified_beam_search`:
+
+| Evaluation set           | eval WER    | test WER |
+|--------------------------|------------|---------|
+| IHM                      |  9.58  | 11.53 |
+| SDM                      |  23.37  | 25.85 |
+| MDM (GSS-enhanced)       |  11.82  | 14.22 |
+
+See [RESULTS](/egs/alimeeting/ASR_v2/RESULTS.md) for details.
diff --git a/egs/alimeeting/ASR_v2/RESULTS.md b/egs/alimeeting/ASR_v2/RESULTS.md
new file mode 100644
index 000000000..15b24250d
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/RESULTS.md
@@ -0,0 +1,90 @@
+## Results (CER)
+
+#### 2022-12-09
+
+#### Zipformer (pruned_transducer_stateless7)
+
+Zipformer encoder + non-current decoder. The decoder
+contains only an embedding layer, a Conv1d (with kernel size 2) and a linear
+layer (to transform tensor dim).
+
+All the results below are using a single model that is trained by combining the following
+data: IHM, IHM+reverb, SDM, and GSS-enhanced MDM. Speed perturbation and MUSAN noise
+augmentation are applied on top of the pooled data.
+
+**WERs for IHM:**
+
+|                           | eval | test | comment                                  |
+|---------------------------|------------|------------|------------------------------------------|
+| greedy search             |  10.13  |  12.21  | --epoch 15 --avg 8 --max-duration 500 |
+| modified beam search      |  9.58  |  11.53  | --epoch 15 --avg 8 --max-duration 500 --beam-size 4 |
+| fast beam search          |  9.92  |  12.07  | --epoch 15 --avg 8 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 |
+
+**WERs for SDM:**
+
+|                           | eval | test | comment                                  |
+|---------------------------|------------|------------|------------------------------------------|
+| greedy search             |  23.70  |  26.41  | --epoch 15 --avg 8 --max-duration 500 |
+| modified beam search      |  23.37  |  25.85  | --epoch 15 --avg 8 --max-duration 500 --beam-size 4 |
+| fast beam search          |  23.60  |  26.38  | --epoch 15 --avg 8 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 |
+
+**WERs for GSS-enhanced MDM:**
+
+|                           | eval | test | comment                                  |
+|---------------------------|------------|------------|------------------------------------------|
+| greedy search             |  12.24  |  14.99  | --epoch 15 --avg 8 --max-duration 500 |
+| modified beam search      |  11.82  |  14.22  | --epoch 15 --avg 8 --max-duration 500 --beam-size 4 |
+| fast beam search          |  12.30  |  14.98  | --epoch 15 --avg 8 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 |
+
+The training command for reproducing is given below:
+
+```
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless7/train.py \
+  --world-size 4 \
+  --num-epochs 15 \
+  --exp-dir pruned_transducer_stateless7/exp \
+  --max-duration 300 \
+  --max-cuts 100 \
+  --prune-range 5 \
+  --lr-factor 5 \
+  --lm-scale 0.25 \
+  --use-fp16 True
+```
+
+The decoding command is:
+```
+# greedy search
+./pruned_transducer_stateless7/decode.py \
+        --epoch 15 \
+        --avg 8 \
+        --exp-dir ./pruned_transducer_stateless7/exp \
+        --max-duration 500 \
+        --decoding-method greedy_search
+
+# modified beam search
+./pruned_transducer_stateless7/decode.py \
+        --epoch 15 \
+        --avg 8 \
+        --exp-dir ./pruned_transducer_stateless7/exp \
+        --max-duration 500 \
+        --decoding-method modified_beam_search \
+        --beam-size 4
+
+# fast beam search
+./pruned_transducer_stateless7/decode.py \
+        --epoch 15 \
+        --avg 8 \
+        --exp-dir ./pruned_transducer_stateless5/exp \
+        --max-duration 500 \
+        --decoding-method fast_beam_search \
+        --beam 4 \
+        --max-contexts 4 \
+        --max-states 8
+```
+
+Pretrained model is available at 
+
+The tensorboard training log can be found at
+
diff --git a/egs/alimeeting/ASR_v2/local/__init__.py b/egs/alimeeting/ASR_v2/local/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/alimeeting/ASR_v2/local/compute_fbank_alimeeting.py b/egs/alimeeting/ASR_v2/local/compute_fbank_alimeeting.py
new file mode 100755
index 000000000..c6aa2ab36
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/local/compute_fbank_alimeeting.py
@@ -0,0 +1,193 @@
+#!/usr/bin/env python3
+# Copyright    2022  Johns Hopkins University        (authors: Desh Raj)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of the AliMeeting dataset.
+For the training data, we prepare IHM, reverberated IHM, SDM, and GSS-enhanced
+audios. For the test data, we separately prepare IHM, SDM, and GSS-enhanced
+parts (which are the 3 evaluation settings).
+It looks for manifests in the directory data/manifests.
+
+The generated fbank features are saved in data/fbank.
+"""
+import logging
+from pathlib import Path
+
+import torch
+import torch.multiprocessing
+from lhotse import CutSet, LilcomChunkyWriter
+from lhotse.features.kaldifeat import (
+    KaldifeatFbank,
+    KaldifeatFbankConfig,
+    KaldifeatFrameOptions,
+    KaldifeatMelOptions,
+)
+from lhotse.recipes.utils import read_manifests_if_cached
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+torch.multiprocessing.set_sharing_strategy("file_system")
+
+
+def compute_fbank_ami():
+    src_dir = Path("data/manifests")
+    output_dir = Path("data/fbank")
+
+    sampling_rate = 16000
+    num_mel_bins = 80
+
+    extractor = KaldifeatFbank(
+        KaldifeatFbankConfig(
+            frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
+            mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
+            device="cuda",
+        )
+    )
+
+    logging.info("Reading manifests")
+    manifests_ihm = read_manifests_if_cached(
+        dataset_parts=["train", "eval", "test"],
+        output_dir=src_dir,
+        prefix="alimeeting-ihm",
+        suffix="jsonl.gz",
+    )
+    manifests_sdm = read_manifests_if_cached(
+        dataset_parts=["train", "eval", "test"],
+        output_dir=src_dir,
+        prefix="alimeeting-sdm",
+        suffix="jsonl.gz",
+    )
+    # For GSS we already have cuts so we read them directly.
+    manifests_gss = read_manifests_if_cached(
+        dataset_parts=["train", "eval", "test"],
+        output_dir=src_dir,
+        prefix="alimeeting-gss",
+        suffix="jsonl.gz",
+    )
+
+    def _extract_feats(cuts: CutSet, storage_path: Path, manifest_path: Path) -> None:
+        cuts = cuts + cuts.perturb_speed(0.9) + cuts.perturb_speed(1.1)
+        _ = cuts.compute_and_store_features_batch(
+            extractor=extractor,
+            storage_path=storage_path,
+            manifest_path=manifest_path,
+            batch_duration=5000,
+            num_workers=8,
+            storage_type=LilcomChunkyWriter,
+        )
+
+    logging.info(
+        "Preparing training cuts: IHM + reverberated IHM + SDM + GSS (optional)"
+    )
+
+    logging.info("Processing train split IHM")
+    cuts_ihm = (
+        CutSet.from_manifests(**manifests_ihm["train"])
+        .trim_to_supervisions(keep_overlapping=False, keep_all_channels=False)
+        .modify_ids(lambda x: x + "-ihm")
+    )
+    _extract_feats(
+        cuts_ihm,
+        output_dir / "feats_train_ihm",
+        src_dir / "cuts_train_ihm.jsonl.gz",
+    )
+
+    logging.info("Processing train split IHM + reverberated IHM")
+    cuts_ihm_rvb = cuts_ihm.reverb_rir()
+    _extract_feats(
+        cuts_ihm_rvb,
+        output_dir / "feats_train_ihm_rvb",
+        src_dir / "cuts_train_ihm_rvb.jsonl.gz",
+    )
+
+    logging.info("Processing train split SDM")
+    cuts_sdm = (
+        CutSet.from_manifests(**manifests_sdm["train"])
+        .trim_to_supervisions(keep_overlapping=False)
+        .modify_ids(lambda x: x + "-sdm")
+    )
+    _extract_feats(
+        cuts_sdm,
+        output_dir / "feats_train_sdm",
+        src_dir / "cuts_train_sdm.jsonl.gz",
+    )
+
+    logging.info("Processing train split GSS")
+    cuts_gss = (
+        CutSet.from_manifests(**manifests_gss["train"])
+        .trim_to_supervisions(keep_overlapping=False)
+        .modify_ids(lambda x: x + "-gss")
+    )
+    _extract_feats(
+        cuts_gss,
+        output_dir / "feats_train_gss",
+        src_dir / "cuts_train_gss.jsonl.gz",
+    )
+
+    logging.info("Preparing test cuts: IHM, SDM, GSS (optional)")
+    for split in ["eval", "test"]:
+        logging.info(f"Processing {split} IHM")
+        cuts_ihm = (
+            CutSet.from_manifests(**manifests_ihm[split])
+            .trim_to_supervisions(keep_overlapping=False, keep_all_channels=False)
+            .compute_and_store_features_batch(
+                extractor=extractor,
+                storage_path=output_dir / f"feats_{split}_ihm",
+                manifest_path=src_dir / f"cuts_{split}_ihm.jsonl.gz",
+                batch_duration=500,
+                num_workers=4,
+                storage_type=LilcomChunkyWriter,
+            )
+        )
+        logging.info(f"Processing {split} SDM")
+        cuts_sdm = (
+            CutSet.from_manifests(**manifests_sdm[split])
+            .trim_to_supervisions(keep_overlapping=False)
+            .compute_and_store_features_batch(
+                extractor=extractor,
+                storage_path=output_dir / f"feats_{split}_sdm",
+                manifest_path=src_dir / f"cuts_{split}_sdm.jsonl.gz",
+                batch_duration=500,
+                num_workers=4,
+                storage_type=LilcomChunkyWriter,
+            )
+        )
+        logging.info(f"Processing {split} GSS")
+        cuts_gss = (
+            CutSet.from_manifests(**manifests_gss[split])
+            .trim_to_supervisions(keep_overlapping=False)
+            .compute_and_store_features_batch(
+                extractor=extractor,
+                storage_path=output_dir / f"feats_{split}_gss",
+                manifest_path=src_dir / f"cuts_{split}_gss.jsonl.gz",
+                batch_duration=500,
+                num_workers=4,
+                storage_type=LilcomChunkyWriter,
+            )
+        )
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    logging.basicConfig(format=formatter, level=logging.INFO)
+
+    compute_fbank_ami()
diff --git a/egs/alimeeting/ASR_v2/local/compute_fbank_musan.py b/egs/alimeeting/ASR_v2/local/compute_fbank_musan.py
new file mode 120000
index 000000000..5833f2484
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/local/compute_fbank_musan.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compute_fbank_musan.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/local/prepare_alimeeting_enhanced.py b/egs/alimeeting/ASR_v2/local/prepare_alimeeting_enhanced.py
new file mode 100644
index 000000000..f1512efa5
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/local/prepare_alimeeting_enhanced.py
@@ -0,0 +1,158 @@
+#!/usr/local/bin/python
+# -*- coding: utf-8 -*-
+# Data preparation for AliMeeting GSS-enhanced dataset.
+
+import logging
+from concurrent.futures import ThreadPoolExecutor
+from pathlib import Path
+
+from lhotse import Recording, RecordingSet, SupervisionSet
+from lhotse.qa import fix_manifests
+from lhotse.recipes.utils import read_manifests_if_cached
+from lhotse.utils import fastcopy
+from tqdm import tqdm
+
+logging.basicConfig(
+    format="%(asctime)s %(levelname)-8s %(message)s",
+    level=logging.INFO,
+    datefmt="%Y-%m-%d %H:%M:%S",
+)
+
+
+def get_args():
+    import argparse
+
+    parser = argparse.ArgumentParser(description="AMI enhanced dataset preparation.")
+    parser.add_argument(
+        "manifests_dir",
+        type=Path,
+        help="Path to directory containing AliMeeting manifests.",
+    )
+    parser.add_argument(
+        "enhanced_dir",
+        type=Path,
+        help="Path to enhanced data directory.",
+    )
+    parser.add_argument(
+        "--num-jobs",
+        "-j",
+        type=int,
+        default=1,
+        help="Number of parallel jobs to run.",
+    )
+    parser.add_argument(
+        "--min-segment-duration",
+        "-d",
+        type=float,
+        default=0.0,
+        help="Minimum duration of a segment in seconds.",
+    )
+    return parser.parse_args()
+
+
+def find_recording_and_create_new_supervision(enhanced_dir, supervision):
+    """
+    Given a supervision (corresponding to original AMI recording), this function finds the
+    enhanced recording correspoding to the supervision, and returns this recording and
+    a new supervision whose start and end times are adjusted to match the enhanced recording.
+    """
+    file_name = Path(
+        f"{supervision.recording_id}-{supervision.speaker}-{int(100*supervision.start):06d}_{int(100*supervision.end):06d}.flac"
+    )
+    save_path = enhanced_dir / f"{supervision.recording_id}" / file_name
+    if save_path.exists():
+        recording = Recording.from_file(save_path)
+        if recording.duration == 0:
+            logging.warning(f"Skipping {save_path} which has duration 0 seconds.")
+            return None
+
+        # Old supervision is wrt to the original recording, we create new supervision
+        # wrt to the enhanced segment
+        new_supervision = fastcopy(
+            supervision,
+            recording_id=recording.id,
+            start=0,
+            duration=recording.duration,
+        )
+        return recording, new_supervision
+    else:
+        logging.warning(f"{save_path} does not exist.")
+        return None
+
+
+def main(args):
+    # Get arguments
+    manifests_dir = args.manifests_dir
+    enhanced_dir = args.enhanced_dir
+
+    # Load manifests from cache if they exist (saves time)
+    manifests = read_manifests_if_cached(
+        dataset_parts=["train", "eval", "test"],
+        output_dir=manifests_dir,
+        prefix="alimeeting-sdm",
+        suffix="jsonl.gz",
+    )
+    if not manifests:
+        raise ValueError(
+            "AliMeeting SDM manifests not found in {}".format(manifests_dir)
+        )
+
+    with ThreadPoolExecutor(args.num_jobs) as ex:
+        for part in ["train", "eval", "test"]:
+            logging.info(f"Processing {part}...")
+            supervisions_orig = manifests[part]["supervisions"].filter(
+                lambda s: s.duration >= args.min_segment_duration
+            )
+            futures = []
+
+            for supervision in tqdm(
+                supervisions_orig,
+                desc="Distributing tasks",
+            ):
+                futures.append(
+                    ex.submit(
+                        find_recording_and_create_new_supervision,
+                        enhanced_dir,
+                        supervision,
+                    )
+                )
+
+            recordings = []
+            supervisions = []
+            for future in tqdm(
+                futures,
+                total=len(futures),
+                desc="Processing tasks",
+            ):
+                result = future.result()
+                if result is not None:
+                    recording, new_supervision = result
+                    recordings.append(recording)
+                    supervisions.append(new_supervision)
+
+            # Remove duplicates from the recordings
+            recordings_nodup = {}
+            for recording in recordings:
+                if recording.id not in recordings_nodup:
+                    recordings_nodup[recording.id] = recording
+                else:
+                    logging.warning("Recording {} is duplicated.".format(recording.id))
+            recordings = RecordingSet.from_recordings(recordings_nodup.values())
+            supervisions = SupervisionSet.from_segments(supervisions)
+
+            recordings, supervisions = fix_manifests(
+                recordings=recordings, supervisions=supervisions
+            )
+
+            logging.info(f"Writing {part} enhanced manifests")
+            recordings.to_file(
+                manifests_dir / f"alimeeting-gss_recordings_{part}.jsonl.gz"
+            )
+            supervisions.to_file(
+                manifests_dir / f"alimeeting-gss_supervisions_{part}.jsonl.gz"
+            )
+
+
+if __name__ == "__main__":
+    args = get_args()
+    main(args)
diff --git a/egs/alimeeting/ASR_v2/local/prepare_alimeeting_gss.sh b/egs/alimeeting/ASR_v2/local/prepare_alimeeting_gss.sh
new file mode 100755
index 000000000..76db19832
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/local/prepare_alimeeting_gss.sh
@@ -0,0 +1,98 @@
+#!/bin/bash
+# This script is used to run GSS-based enhancement on AMI data.
+set -euo pipefail
+nj=4
+stage=0
+
+. shared/parse_options.sh || exit 1
+
+if [ $# != 2 ]; then
+   echo "Wrong #arguments ($#, expected 2)"
+   echo "Usage: local/prepare_alimeeting_gss.sh [options]  "
+   echo "e.g. local/prepare_alimeeting_gss.sh data/manifests exp/ami_gss"
+   echo "main options (for others, see top of script file)"
+   echo "  --nj                                 # number of parallel jobs"
+   echo "  --stage                           # stage to start running from"
+   exit 1;
+fi
+
+DATA_DIR=$1
+EXP_DIR=$2
+
+mkdir -p $EXP_DIR
+
+log() {
+  # This function is from espnet
+  local fname=${BASH_SOURCE[1]##*/}
+  echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+if [ $stage -le 1 ]; then
+  log "Stage 1: Prepare cut sets"
+  for part in train eval test; do
+    lhotse cut simple \
+      -r $DATA_DIR/alimeeting-mdm_recordings_${part}.jsonl.gz \
+      -s $DATA_DIR/alimeeting-mdm_supervisions_${part}.jsonl.gz \
+      $EXP_DIR/cuts_${part}.jsonl.gz
+  done
+fi
+
+if [ $stage -le 2 ]; then
+  log "Stage 2: Trim cuts to supervisions (1 cut per supervision segment)"
+  for part in train eval test; do
+    lhotse cut trim-to-supervisions --discard-overlapping \
+        $EXP_DIR/cuts_${part}.jsonl.gz $EXP_DIR/cuts_per_segment_${part}.jsonl.gz
+  done
+fi
+
+if [ $stage -le 3 ]; then
+  log "Stage 3: Split manifests for multi-GPU processing (optional)"
+  for part in train eval test; do
+    gss utils split $nj $EXP_DIR/cuts_per_segment_${part}.jsonl.gz \
+      $EXP_DIR/cuts_per_segment_${part}_split$nj
+  done
+fi
+
+if [ $stage -le 4 ]; then
+  log "Stage 4: Enhance train segments using GSS (requires GPU)"
+  # for train, we use smaller context and larger batches to speed-up processing
+  for JOB in $(seq $nj); do
+    gss enhance cuts $EXP_DIR/cuts_train.jsonl.gz \
+      $EXP_DIR/cuts_per_segment_train_split$nj/cuts_per_segment_train.JOB.jsonl.gz $EXP_DIR/enhanced \
+      --bss-iterations 10 \
+      --context-duration 5.0 \
+      --use-garbage-class \
+      --channels 0,1,2,3,4,5,6,7 \
+      --min-segment-length 0.05 \
+      --max-segment-length 25.0 \
+      --max-batch-duration 60.0 \
+      --num-buckets 4 \
+      --num-workers 4
+  done
+fi
+
+if [ $stage -le 5 ]; then
+  log "Stage 5: Enhance eval/test segments using GSS (using GPU)"
+  # for eval/test, we use larger context and smaller batches to get better quality
+  for part in eval test; do
+    for JOB in $(seq $nj); do
+      gss enhance cuts $EXP_DIR/cuts_${part}.jsonl.gz \
+      $EXP_DIR/cuts_per_segment_${part}_split$nj/cuts_per_segment_${part}.JOB.jsonl.gz \
+      $EXP_DIR/enhanced \
+      --bss-iterations 10 \
+      --context-duration 15.0 \
+      --use-garbage-class \
+      --channels 0,1,2,3,4,5,6,7 \
+      --min-segment-length 0.05 \
+      --max-segment-length 16.0 \
+      --max-batch-duration 45.0 \
+      --num-buckets 4 \
+      --num-workers 4
+    done
+  done
+fi
+
+if [ $stage -le 6 ]; then
+  log "Stage 6: Prepare manifests for GSS-enhanced data"
+  python local/prepare_alimeeting_enhanced.py $DATA_DIR $EXP_DIR/enhanced -j $nj --min-segment-duration 0.05
+fi
diff --git a/egs/alimeeting/ASR_v2/local/prepare_char.py b/egs/alimeeting/ASR_v2/local/prepare_char.py
new file mode 120000
index 000000000..ee5dd34f1
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/local/prepare_char.py
@@ -0,0 +1 @@
+../../ASR/local/prepare_char.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/local/prepare_words.py b/egs/alimeeting/ASR_v2/local/prepare_words.py
new file mode 120000
index 000000000..970bfd60c
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/local/prepare_words.py
@@ -0,0 +1 @@
+../../ASR/local/prepare_words.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/local/text2segments.py b/egs/alimeeting/ASR_v2/local/text2segments.py
new file mode 120000
index 000000000..bf4547794
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/local/text2segments.py
@@ -0,0 +1 @@
+../../ASR/local/text2segments.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/local/text2token.py b/egs/alimeeting/ASR_v2/local/text2token.py
new file mode 120000
index 000000000..f6b8531b6
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/local/text2token.py
@@ -0,0 +1 @@
+../../ASR/local/text2token.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/prepare.sh b/egs/alimeeting/ASR_v2/prepare.sh
new file mode 100755
index 000000000..76a108771
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/prepare.sh
@@ -0,0 +1,125 @@
+#!/usr/bin/env bash
+
+set -eou pipefail
+
+stage=-1
+stop_stage=100
+use_gss=true  # Use GSS-based enhancement with MDM setting
+
+# We assume dl_dir (download dir) contains the following
+# directories and files. If not, they will be downloaded
+# by this script automatically.
+#
+#  - $dl_dir/alimeeting
+#     This directory contains the following files downloaded from
+#       https://openslr.org/62/
+#
+#     - Train_Ali_far.tar.gz
+#     - Train_Ali_near.tar.gz
+#     - Test_Ali.tar.gz
+#     - Eval_Ali.tar.gz
+#
+#  - $dl_dir/musan
+#      This directory contains the following directories downloaded from
+#       http://www.openslr.org/17/
+#
+#     - music
+#     - noise
+#     - speech
+
+dl_dir=$PWD/download
+
+. shared/parse_options.sh || exit 1
+
+# All files generated by this script are saved in "data".
+# You can safely remove "data" and rerun this script to regenerate it.
+mkdir -p data
+
+log() {
+  # This function is from espnet
+  local fname=${BASH_SOURCE[1]##*/}
+  echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "dl_dir: $dl_dir"
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+  log "Stage 0: Download data"
+
+  if [ ! -f $dl_dir/alimeeting/Train_Ali_far.tar.gz ]; then
+    lhotse download ali-meeting $dl_dir/alimeeting
+  fi
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+  log "Stage 1: Prepare alimeeting manifest"
+  # We assume that you have downloaded the alimeeting corpus
+  # to $dl_dir/alimeeting
+  for part in ihm sdm mdm; do
+    mkdir -p data/manifests/alimeeting
+    lhotse prepare ali-meeting --mic $part --save-mono --normalize-text m2met \
+      $dl_dir/alimeeting data/manifests
+  done
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+  log "Stage 2: Prepare musan manifest"
+  # We assume that you have downloaded the musan corpus
+  # to data/musan
+  mkdir -p data/manifests
+  lhotse prepare musan $dl_dir/musan data/manifests
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ] && [ $use_gss = true ]; then
+  log "Stage 3: Apply GSS enhancement on MDM data (this stage requires a GPU)"
+  # We assume that you have installed the GSS package: https://github.com/desh2608/gss
+  local/prepare_alimeeting_gss.sh data/manifests exp/alimeeting_gss
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+  log "Stage 4: Compute fbank for musan"
+  mkdir -p data/fbank
+  python local/compute_fbank_musan.py
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+  log "Stage 5: Compute fbank for alimeeting"
+  mkdir -p data/fbank
+  python local/compute_fbank_alimeeting.py
+  log "Combine features from train splits"
+  lhotse combine data/manifests/cuts_train_{ihm,ihm_rvb,sdm,gss}.jsonl.gz - | shuf |\
+    gzip -c > data/manifests/cuts_train_all.jsonl.gz
+fi
+
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+  log "Stage 6: Prepare char based lang"
+  lang_char_dir=data/lang_char
+  mkdir -p $lang_char_dir
+
+  # Prepare text.
+  # Note: in Linux, you can install jq with the  following command:
+  # wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
+  gunzip -c data/manifests/alimeeting-sdm_supervisions_train.jsonl.gz \
+    | jq ".text" | sed 's/"//g' \
+    | ./local/text2token.py -t "char" > $lang_char_dir/text
+
+  # Prepare words segments
+  python ./local/text2segments.py \
+    --input $lang_char_dir/text \
+    --output $lang_char_dir/text_words_segmentation
+
+  cat $lang_char_dir/text_words_segmentation | sed "s/ /\n/g" \
+    | sort -u | sed "/^$/d" \
+    | uniq > $lang_char_dir/words_no_ids.txt
+
+  # Prepare words.txt
+  if [ ! -f $lang_char_dir/words.txt ]; then
+    ./local/prepare_words.py \
+      --input-file $lang_char_dir/words_no_ids.txt \
+      --output-file $lang_char_dir/words.txt
+  fi
+
+  if [ ! -f $lang_char_dir/L_disambig.pt ]; then
+    ./local/prepare_char.py
+  fi
+fi
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/__init__.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py
new file mode 100644
index 000000000..1cfd053c7
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py
@@ -0,0 +1,419 @@
+# Copyright      2021  Piotr Żelasko
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import logging
+import re
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+import torch
+from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
+from lhotse.cut import Cut
+from lhotse.dataset import (
+    CutConcatenate,
+    CutMix,
+    DynamicBucketingSampler,
+    K2SpeechRecognitionDataset,
+    PrecomputedFeatures,
+    SpecAugment,
+)
+from lhotse.dataset.input_strategies import OnTheFlyFeatures
+from lhotse.utils import fix_random_seed
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+
+from icefall.utils import str2bool
+
+
+class _SeedWorkers:
+    def __init__(self, seed: int):
+        self.seed = seed
+
+    def __call__(self, worker_id: int):
+        fix_random_seed(self.seed + worker_id)
+
+
+class AlimeetingAsrDataModule:
+    """
+    DataModule for k2 ASR experiments.
+    It assumes there is always one train and valid dataloader,
+    but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
+    and test-other).
+    It contains all the common data pipeline modules used in ASR
+    experiments, e.g.:
+    - dynamic batch size,
+    - bucketing samplers,
+    - cut concatenation,
+    - augmentation,
+    - on-the-fly feature extraction
+    This class should be derived for specific corpora used in ASR tasks.
+    """
+
+    def __init__(self, args: argparse.Namespace):
+        self.args = args
+
+    @classmethod
+    def add_arguments(cls, parser: argparse.ArgumentParser):
+        group = parser.add_argument_group(
+            title="ASR data related options",
+            description=(
+                "These options are used for the preparation of "
+                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+                "effective batch sizes, sampling strategies, applied data "
+                "augmentations, etc."
+            ),
+        )
+        group.add_argument(
+            "--manifest-dir",
+            type=Path,
+            default=Path("data/manifests"),
+            help="Path to directory with train/valid/test cuts.",
+        )
+        group.add_argument(
+            "--enable-musan",
+            type=str2bool,
+            default=True,
+            help=(
+                "When enabled, select noise from MUSAN and mix it "
+                "with training dataset. "
+            ),
+        )
+        group.add_argument(
+            "--concatenate-cuts",
+            type=str2bool,
+            default=False,
+            help=(
+                "When enabled, utterances (cuts) will be concatenated "
+                "to minimize the amount of padding."
+            ),
+        )
+        group.add_argument(
+            "--duration-factor",
+            type=float,
+            default=1.0,
+            help=(
+                "Determines the maximum duration of a concatenated cut "
+                "relative to the duration of the longest cut in a batch."
+            ),
+        )
+        group.add_argument(
+            "--gap",
+            type=float,
+            default=1.0,
+            help=(
+                "The amount of padding (in seconds) inserted between "
+                "concatenated cuts. This padding is filled with noise when "
+                "noise augmentation is used."
+            ),
+        )
+        group.add_argument(
+            "--max-duration",
+            type=int,
+            default=100.0,
+            help=(
+                "Maximum pooled recordings duration (seconds) in a "
+                "single batch. You can reduce it if it causes CUDA OOM."
+            ),
+        )
+        group.add_argument(
+            "--max-cuts", type=int, default=None, help="Maximum cuts in a single batch."
+        )
+        group.add_argument(
+            "--num-buckets",
+            type=int,
+            default=50,
+            help=(
+                "The number of buckets for the BucketingSampler"
+                "(you might want to increase it for larger datasets)."
+            ),
+        )
+        group.add_argument(
+            "--on-the-fly-feats",
+            type=str2bool,
+            default=False,
+            help=(
+                "When enabled, use on-the-fly cut mixing and feature "
+                "extraction. Will drop existing precomputed feature manifests "
+                "if available."
+            ),
+        )
+        group.add_argument(
+            "--shuffle",
+            type=str2bool,
+            default=True,
+            help=(
+                "When enabled (=default), the examples will be "
+                "shuffled for each epoch."
+            ),
+        )
+
+        group.add_argument(
+            "--num-workers",
+            type=int,
+            default=8,
+            help=(
+                "The number of training dataloader workers that " "collect the batches."
+            ),
+        )
+        group.add_argument(
+            "--enable-spec-aug",
+            type=str2bool,
+            default=True,
+            help="When enabled, use SpecAugment for training dataset.",
+        )
+        group.add_argument(
+            "--spec-aug-time-warp-factor",
+            type=int,
+            default=80,
+            help=(
+                "Used only when --enable-spec-aug is True. "
+                "It specifies the factor for time warping in SpecAugment. "
+                "Larger values mean more warping. "
+                "A value less than 1 means to disable time warp."
+            ),
+        )
+
+    def train_dataloaders(
+        self,
+        cuts_train: CutSet,
+        sampler_state_dict: Optional[Dict[str, Any]] = None,
+    ) -> DataLoader:
+        """
+        Args:
+          cuts_train:
+            CutSet for training.
+          sampler_state_dict:
+            The state dict for the training sampler.
+        """
+        logging.info("About to get Musan cuts")
+
+        transforms = []
+        if self.args.enable_musan:
+            logging.info("Enable MUSAN")
+            cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+            transforms.append(
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+            )
+        else:
+            logging.info("Disable MUSAN")
+
+        if self.args.concatenate_cuts:
+            logging.info(
+                "Using cut concatenation with duration factor "
+                f"{self.args.duration_factor} and gap {self.args.gap}."
+            )
+            # Cut concatenation should be the first transform in the list,
+            # so that if we e.g. mix noise in, it will fill the gaps between
+            # different utterances.
+            transforms = [
+                CutConcatenate(
+                    duration_factor=self.args.duration_factor, gap=self.args.gap
+                )
+            ] + transforms
+
+        input_transforms = []
+        if self.args.enable_spec_aug:
+            logging.info("Enable SpecAugment")
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+            input_transforms.append(
+                SpecAugment(
+                    time_warp_factor=self.args.spec_aug_time_warp_factor,
+                    num_frame_masks=2,
+                    features_mask_size=27,
+                    num_feature_masks=2,
+                    frames_mask_size=100,
+                )
+            )
+        else:
+            logging.info("Disable SpecAugment")
+
+        logging.info("About to create train dataset")
+        if self.args.on_the_fly_feats:
+            train = K2SpeechRecognitionDataset(
+                cut_transforms=transforms,
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_transforms=input_transforms,
+            )
+        else:
+            train = K2SpeechRecognitionDataset(
+                cut_transforms=transforms,
+                input_transforms=input_transforms,
+            )
+
+        logging.info("Using DynamicBucketingSampler.")
+        train_sampler = DynamicBucketingSampler(
+            cuts_train,
+            max_duration=self.args.max_duration,
+            max_cuts=self.args.max_cuts,
+            shuffle=False,
+            num_buckets=self.args.num_buckets,
+            drop_last=True,
+        )
+        logging.info("About to create train dataloader")
+
+        if sampler_state_dict is not None:
+            logging.info("Loading sampler state dict")
+            train_sampler.load_state_dict(sampler_state_dict)
+
+        # 'seed' is derived from the current random state, which will have
+        # previously been set in the main process.
+        seed = torch.randint(0, 100000, ()).item()
+        worker_init_fn = _SeedWorkers(seed)
+
+        train_dl = DataLoader(
+            train,
+            sampler=train_sampler,
+            batch_size=None,
+            num_workers=self.args.num_workers,
+            persistent_workers=False,
+            worker_init_fn=worker_init_fn,
+        )
+
+        return train_dl
+
+    def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+
+        transforms = []
+        if self.args.concatenate_cuts:
+            transforms = [
+                CutConcatenate(
+                    duration_factor=self.args.duration_factor, gap=self.args.gap
+                )
+            ] + transforms
+
+        logging.info("About to create dev dataset")
+        if self.args.on_the_fly_feats:
+            validate = K2SpeechRecognitionDataset(
+                cut_transforms=transforms,
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+            )
+        else:
+            validate = K2SpeechRecognitionDataset(
+                cut_transforms=transforms,
+            )
+        valid_sampler = DynamicBucketingSampler(
+            cuts_valid,
+            max_duration=self.args.max_duration,
+            shuffle=False,
+        )
+        logging.info("About to create dev dataloader")
+        valid_dl = DataLoader(
+            validate,
+            sampler=valid_sampler,
+            batch_size=None,
+            num_workers=2,
+            persistent_workers=False,
+        )
+
+        return valid_dl
+
+    def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+        logging.debug("About to create test dataset")
+        test = K2SpeechRecognitionDataset(
+            input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
+            if self.args.on_the_fly_feats
+            else PrecomputedFeatures(),
+            return_cuts=True,
+        )
+        sampler = DynamicBucketingSampler(
+            cuts, max_duration=self.args.max_duration, shuffle=False
+        )
+        logging.debug("About to create test dataloader")
+        test_dl = DataLoader(
+            test,
+            batch_size=None,
+            sampler=sampler,
+            num_workers=self.args.num_workers,
+        )
+        return test_dl
+
+    def remove_short_cuts(self, cut: Cut) -> bool:
+        """
+        See: https://github.com/k2-fsa/icefall/issues/500
+        Basically, the zipformer model subsamples the input using the following formula:
+        num_out_frames = ((num_in_frames - 7)//2 + 1)//2
+        For num_out_frames to be at least 1, num_in_frames must be at least 9.
+        """
+        return cut.duration >= 0.09
+
+    @lru_cache()
+    def train_cuts(self, sp: Optional[Any] = None) -> CutSet:
+        logging.info("About to get AMI train cuts")
+
+        def _remove_short_and_long_utt(c: Cut):
+            if c.duration < 0.1 or c.duration > 25.0:
+                return False
+
+            # In pruned RNN-T, we require that T >= S
+            # where T is the number of feature frames after subsampling
+            # and S is the number of tokens in the utterance
+
+            # In ./zipformer.py, the conv module uses the following expression
+            # for subsampling
+            T = ((c.num_frames - 7) // 2 + 1) // 2
+            tokens = c.supervisions[0].text
+            return T >= len(tokens)
+
+        cuts_train = load_manifest_lazy(
+            self.args.manifest_dir / "cuts_train_all.jsonl.gz"
+        )
+
+        return cuts_train.filter(_remove_short_and_long_utt)
+
+    @lru_cache()
+    def eval_ihm_cuts(self) -> CutSet:
+        logging.info("About to get AliMeeting IHM eval cuts")
+        cs = load_manifest_lazy(self.args.manifest_dir / "cuts_eval_ihm.jsonl.gz")
+        return cs.filter(self.remove_short_cuts)
+
+    @lru_cache()
+    def eval_sdm_cuts(self) -> CutSet:
+        logging.info("About to get AliMeeting SDM eval cuts")
+        cs = load_manifest_lazy(self.args.manifest_dir / "cuts_eval_sdm.jsonl.gz")
+        return cs.filter(self.remove_short_cuts)
+
+    @lru_cache()
+    def eval_gss_cuts(self) -> CutSet:
+        if not (self.args.manifest_dir / "cuts_eval_gss.jsonl.gz").exists():
+            logging.info("No GSS dev cuts found")
+            return None
+        logging.info("About to get AliMeeting GSS-enhanced eval cuts")
+        cs = load_manifest_lazy(self.args.manifest_dir / "cuts_eval_gss.jsonl.gz")
+        return cs.filter(self.remove_short_cuts)
+
+    @lru_cache()
+    def test_ihm_cuts(self) -> CutSet:
+        logging.info("About to get AliMeeting IHM test cuts")
+        cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_ihm.jsonl.gz")
+        return cs.filter(self.remove_short_cuts)
+
+    @lru_cache()
+    def test_sdm_cuts(self) -> CutSet:
+        logging.info("About to get AliMeeting SDM test cuts")
+        cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_sdm.jsonl.gz")
+        return cs.filter(self.remove_short_cuts)
+
+    @lru_cache()
+    def test_gss_cuts(self) -> CutSet:
+        if not (self.args.manifest_dir / "cuts_test_gss.jsonl.gz").exists():
+            logging.info("No GSS test cuts found")
+            return None
+        logging.info("About to get AliMeeting GSS-enhanced test cuts")
+        cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_gss.jsonl.gz")
+        return cs.filter(self.remove_short_cuts)
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/beam_search.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/beam_search.py
new file mode 120000
index 000000000..37516affc
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/beam_search.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decode.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decode.py
new file mode 100755
index 000000000..53381c1f4
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decode.py
@@ -0,0 +1,698 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./pruned_transducer_stateless7/decode.py \
+        --epoch 15 \
+        --avg 8 \
+        --exp-dir ./pruned_transducer_stateless7/exp \
+        --max-duration 500 \
+        --decoding-method greedy_search
+
+(2) modified beam search
+./pruned_transducer_stateless7/decode.py \
+        --epoch 15 \
+        --avg 8 \
+        --exp-dir ./pruned_transducer_stateless7/exp \
+        --max-duration 500 \
+        --decoding-method modified_beam_search \
+        --beam-size 4
+
+(3) fast beam search
+./pruned_transducer_stateless7/decode.py \
+        --epoch 15 \
+        --avg 8 \
+        --exp-dir ./pruned_transducer_stateless7/exp \
+        --max-duration 500 \
+        --decoding-method fast_beam_search \
+        --beam 4 \
+        --max-contexts 4 \
+        --max-states 8
+"""
+
+
+import argparse
+import logging
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import AlimeetingAsrDataModule
+from beam_search import (
+    beam_search,
+    fast_beam_search_nbest_LG,
+    fast_beam_search_one_best,
+    greedy_search,
+    greedy_search_batch,
+    modified_beam_search,
+)
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall import NgramLm
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+    AttributeDict,
+    setup_logger,
+    store_transcripts,
+    str2bool,
+    write_error_stats,
+)
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=30,
+        help="""It specifies the checkpoint to use for decoding.
+        Note: Epoch counts from 0.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=10,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=True,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless2/exp",
+        help="The experiment dir",
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        default="data/lang_char",
+        help="""The lang dir
+        It contains language related input files such as
+        "lexicon.txt"
+        """,
+    )
+
+    parser.add_argument(
+        "--decoding-method",
+        type=str,
+        default="greedy_search",
+        help="""Possible values are:
+          - greedy_search
+          - beam_search
+          - modified_beam_search
+          - fast_beam_search
+          - fast_beam_search_nbest
+          - fast_beam_search_nbest_oracle
+          - fast_beam_search_nbest_LG
+        If you use fast_beam_search_nbest_LG, you have to specify
+        `--lang-dir`, which should contain `LG.pt`.
+        """,
+    )
+
+    parser.add_argument(
+        "--beam-size",
+        type=int,
+        default=4,
+        help="""An interger indicating how many candidates we will keep for each
+        frame. Used only when --decoding-method is beam_search or
+        modified_beam_search.""",
+    )
+
+    parser.add_argument(
+        "--beam",
+        type=float,
+        default=4,
+        help="""A floating point value to calculate the cutoff score during beam
+        search (i.e., `cutoff = max-score - beam`), which is the same as the
+        `beam` in Kaldi.
+        Used only when --decoding-method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--ngram-lm-scale",
+        type=float,
+        default=0.01,
+        help="""
+        Used only when --decoding_method is fast_beam_search_nbest_LG.
+        It specifies the scale for n-gram LM scores.
+        """,
+    )
+
+    parser.add_argument(
+        "--max-contexts",
+        type=int,
+        default=8,
+        help="""Used only when --decoding-method is
+        fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+        and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--max-states",
+        type=int,
+        default=64,
+        help="""Used only when --decoding-method is
+        fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+        and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+    )
+    parser.add_argument(
+        "--max-sym-per-frame",
+        type=int,
+        default=1,
+        help="""Maximum number of symbols per frame.
+        Used only when --decoding_method is greedy_search""",
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=200,
+        help="""Number of paths for nbest decoding.
+        Used only when the decoding method is fast_beam_search_nbest,
+        fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=0.5,
+        help="""Scale applied to lattice scores when computing nbest paths.
+        Used only when the decoding method is fast_beam_search_nbest,
+        fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def decode_one_batch(
+    params: AttributeDict,
+    model: nn.Module,
+    lexicon: Lexicon,
+    batch: dict,
+    decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+    """Decode one batch and return the result in a dict. The dict has the
+    following format:
+
+        - key: It indicates the setting used for decoding. For example,
+               if greedy_search is used, it would be "greedy_search"
+               If beam search with a beam size of 7 is used, it would be
+               "beam_7"
+        - value: It contains the decoding result. `len(value)` equals to
+                 batch size. `value[i]` is the decoding result for the i-th
+                 utterance in the given batch.
+    Args:
+      params:
+        It's the return value of :func:`get_params`.
+      model:
+        The neural model.
+      batch:
+        It is the return value from iterating
+        `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+        for the format of the `batch`.
+      decoding_graph:
+        The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+        only when --decoding_method is fast_beam_search.
+    Returns:
+      Return the decoding result. See above description for the format of
+      the returned dict.
+    """
+    device = model.device
+    feature = batch["inputs"]
+    assert feature.ndim == 3
+
+    feature = feature.to(device)
+    # at entry, feature is (N, T, C)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    hyps = []
+
+    if params.decoding_method == "fast_beam_search":
+        hyp_tokens = fast_beam_search_one_best(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+        )
+        for i in range(encoder_out.size(0)):
+            hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+    elif params.decoding_method == "fast_beam_search_nbest_LG":
+        hyp_tokens = fast_beam_search_nbest_LG(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+            num_paths=params.num_paths,
+            nbest_scale=params.nbest_scale,
+        )
+        for i in range(encoder_out.size(0)):
+            hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+        hyp_tokens = greedy_search_batch(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+        )
+        for i in range(encoder_out.size(0)):
+            hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+    elif params.decoding_method == "modified_beam_search":
+        hyp_tokens = modified_beam_search(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam_size,
+        )
+        for i in range(encoder_out.size(0)):
+            hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+    else:
+        batch_size = encoder_out.size(0)
+
+        for i in range(batch_size):
+            # fmt: off
+            encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+            # fmt: on
+            if params.decoding_method == "greedy_search":
+                hyp = greedy_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    max_sym_per_frame=params.max_sym_per_frame,
+                )
+            elif params.decoding_method == "beam_search":
+                hyp = beam_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    beam=params.beam_size,
+                )
+            else:
+                raise ValueError(
+                    f"Unsupported decoding method: {params.decoding_method}"
+                )
+            hyps.append([lexicon.token_table[idx] for idx in hyp])
+
+    if params.decoding_method == "greedy_search":
+        return {"greedy_search": hyps}
+    elif params.decoding_method == "fast_beam_search":
+        return {
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
+        }
+    elif "fast_beam_search" in params.decoding_method:
+        key = f"beam_{params.beam}_"
+        key += f"max_contexts_{params.max_contexts}_"
+        key += f"max_states_{params.max_states}"
+        if "nbest" in params.decoding_method:
+            key += f"_num_paths_{params.num_paths}_"
+            key += f"nbest_scale_{params.nbest_scale}"
+            if "LG" in params.decoding_method:
+                key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+
+        return {key: hyps}
+    else:
+        return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+    dl: torch.utils.data.DataLoader,
+    params: AttributeDict,
+    model: nn.Module,
+    lexicon: Lexicon,
+    decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[List[str], List[str]]]]:
+    """Decode dataset.
+
+    Args:
+      dl:
+        PyTorch's dataloader containing the dataset to decode.
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The neural model.
+      decoding_graph:
+        The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+        only when --decoding_method is fast_beam_search.
+    Returns:
+      Return a dict, whose key may be "greedy_search" if greedy search
+      is used, or it may be "beam_7" if beam size of 7 is used.
+      Its value is a list of tuples. Each tuple contains two elements:
+      The first is the reference transcript, and the second is the
+      predicted result.
+    """
+    num_cuts = 0
+
+    try:
+        num_batches = len(dl)
+    except TypeError:
+        num_batches = "?"
+
+    if params.decoding_method == "greedy_search":
+        log_interval = 100
+    else:
+        log_interval = 2
+
+    results = defaultdict(list)
+    for batch_idx, batch in enumerate(dl):
+        texts = batch["supervisions"]["text"]
+        texts = [list(str(text).replace(" ", "")) for text in texts]
+        cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+        hyps_dict = decode_one_batch(
+            params=params,
+            model=model,
+            lexicon=lexicon,
+            decoding_graph=decoding_graph,
+            batch=batch,
+        )
+
+        for name, hyps in hyps_dict.items():
+            this_batch = []
+            assert len(hyps) == len(texts)
+            for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+                this_batch.append((cut_id, ref_text, hyp_words))
+
+            results[name].extend(this_batch)
+
+        num_cuts += len(texts)
+
+        if batch_idx % log_interval == 0:
+            batch_str = f"{batch_idx}/{num_batches}"
+
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+    return results
+
+
+def save_results(
+    params: AttributeDict,
+    test_set_name: str,
+    results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+    test_set_wers = dict()
+    for key, results in results_dict.items():
+        recog_path = (
+            params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+        )
+        results = sorted(results)
+        store_transcripts(filename=recog_path, texts=results)
+        logging.info(f"The transcripts are stored in {recog_path}")
+
+        # The following prints out WERs, per-word error statistics and aligned
+        # ref/hyp pairs.
+        errs_filename = (
+            params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+        )
+        with open(errs_filename, "w") as f:
+            wer = write_error_stats(
+                f, f"{test_set_name}-{key}", results, enable_log=True
+            )
+            test_set_wers[key] = wer
+
+        logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+    test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+    errs_info = (
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+    )
+    with open(errs_info, "w") as f:
+        print("settings\tWER", file=f)
+        for key, val in test_set_wers:
+            print("{}\t{}".format(key, val), file=f)
+
+    s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+    note = "\tbest for {}".format(test_set_name)
+    for key, val in test_set_wers:
+        s += "{}\t{}{}\n".format(key, val, note)
+        note = ""
+    logging.info(s)
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    AlimeetingAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    params = get_params()
+    params.update(vars(args))
+
+    assert params.decoding_method in (
+        "greedy_search",
+        "beam_search",
+        "fast_beam_search",
+        "fast_beam_search_nbest_LG",
+        "modified_beam_search",
+    )
+    params.res_dir = params.exp_dir / params.decoding_method
+
+    if params.iter > 0:
+        params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+    else:
+        params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+    if "fast_beam_search" in params.decoding_method:
+        params.suffix += f"-beam-{params.beam}"
+        params.suffix += f"-max-contexts-{params.max_contexts}"
+        params.suffix += f"-max-states-{params.max_states}"
+        if "nbest" in params.decoding_method:
+            params.suffix += f"-nbest-scale-{params.nbest_scale}"
+            params.suffix += f"-num-paths-{params.num_paths}"
+            if "LG" in params.decoding_method:
+                params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+    elif "beam_search" in params.decoding_method:
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+    else:
+        params.suffix += f"-context-{params.context_size}"
+        params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+    setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+    logging.info("Decoding started")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"Device: {device}")
+
+    lexicon = Lexicon(params.lang_dir)
+    params.blank_id = lexicon.token_table[""]
+    params.vocab_size = max(lexicon.tokens) + 1
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to(device)
+    model.eval()
+    model.device = device
+
+    if "fast_beam_search" in params.decoding_method:
+        decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+    else:
+        decoding_graph = None
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    alimeeting = AlimeetingAsrDataModule(args)
+
+    eval_ihm_cuts = alimeeting.eval_ihm_cuts()
+    test_ihm_cuts = alimeeting.test_ihm_cuts()
+    eval_sdm_cuts = alimeeting.eval_sdm_cuts()
+    test_sdm_cuts = alimeeting.test_sdm_cuts()
+    eval_gss_cuts = alimeeting.eval_gss_cuts()
+    test_gss_cuts = alimeeting.test_gss_cuts()
+
+    eval_ihm_dl = alimeeting.test_dataloaders(eval_ihm_cuts)
+    test_ihm_dl = alimeeting.test_dataloaders(test_ihm_cuts)
+    eval_sdm_dl = alimeeting.test_dataloaders(eval_sdm_cuts)
+    test_sdm_dl = alimeeting.test_dataloaders(test_sdm_cuts)
+    if eval_gss_cuts is not None:
+        eval_gss_dl = alimeeting.test_dataloaders(eval_gss_cuts)
+    if test_gss_cuts is not None:
+        test_gss_dl = alimeeting.test_dataloaders(test_gss_cuts)
+
+    test_sets = {
+        "eval_ihm": (eval_ihm_dl, eval_ihm_cuts),
+        "test_ihm": (test_ihm_dl, test_ihm_cuts),
+        "eval_sdm": (eval_sdm_dl, eval_sdm_cuts),
+        "test_sdm": (test_sdm_dl, test_sdm_cuts),
+    }
+    if eval_gss_cuts is not None:
+        test_sets["eval_gss"] = (eval_gss_dl, eval_gss_cuts)
+    if test_gss_cuts is not None:
+        test_sets["test_gss"] = (test_gss_dl, test_gss_cuts)
+
+    for test_set in test_sets:
+        logging.info(f"Decoding {test_set}")
+        dl, cuts = test_sets[test_set]
+        results_dict = decode_dataset(
+            dl=dl,
+            params=params,
+            model=model,
+            lexicon=lexicon,
+            decoding_graph=decoding_graph,
+        )
+
+        save_results(
+            params=params,
+            test_set_name=test_set,
+            results_dict=results_dict,
+        )
+
+    logging.info("Done!")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decoder.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decoder.py
new file mode 120000
index 000000000..8283d8c5a
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decoder.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/encoder_interface.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/encoder_interface.py
new file mode 120000
index 000000000..0c2673d46
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/encoder_interface.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/encoder_interface.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/export.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/export.py
new file mode 100755
index 000000000..23a88dd29
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/export.py
@@ -0,0 +1,320 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script converts several saved checkpoints
+# to a single one using model averaging.
+"""
+
+Usage:
+
+(1) Export to torchscript model using torch.jit.script()
+
+./pruned_transducer_stateless7/export.py \
+  --exp-dir ./pruned_transducer_stateless7/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 30 \
+  --avg 9 \
+  --jit 1
+
+It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
+load it by `torch.jit.load("cpu_jit.pt")`.
+
+Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
+are on CPU. You can use `to("cuda")` to move them to a CUDA device.
+
+Check
+https://github.com/k2-fsa/sherpa
+for how to use the exported models outside of icefall.
+
+(2) Export `model.state_dict()`
+
+./pruned_transducer_stateless7/export.py \
+  --exp-dir ./pruned_transducer_stateless7/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 20 \
+  --avg 10
+
+It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
+load it by `icefall.checkpoint.load_checkpoint()`.
+
+To use the generated file with `pruned_transducer_stateless7/decode.py`,
+you can do:
+
+    cd /path/to/exp_dir
+    ln -s pretrained.pt epoch-9999.pt
+
+    cd /path/to/egs/librispeech/ASR
+    ./pruned_transducer_stateless7/decode.py \
+        --exp-dir ./pruned_transducer_stateless7/exp \
+        --epoch 9999 \
+        --avg 1 \
+        --max-duration 600 \
+        --decoding-method greedy_search \
+        --bpe-model data/lang_bpe_500/bpe.model
+
+Check ./pretrained.py for its usage.
+
+Note: If you don't want to train a model from scratch, we have
+provided one for you. You can get it at
+
+https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
+
+with the following commands:
+
+    sudo apt-get install git-lfs
+    git lfs install
+    git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
+    # You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from scaling_converter import convert_scaled_to_non_scaled
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import str2bool
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=15,
+        help="""It specifies the checkpoint to use for decoding.
+        Note: Epoch counts from 1.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=8,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=True,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless7/exp",
+        help="""It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        default="data/lang_char",
+        help="The lang dir",
+    )
+
+    parser.add_argument(
+        "--jit",
+        type=str2bool,
+        default=False,
+        help="""True to save a model after applying torch.jit.script.
+        It will generate a file named cpu_jit.pt
+
+        Check ./jit_pretrained.py for how to use it.
+        """,
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+@torch.no_grad()
+def main():
+    args = get_parser().parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    params = get_params()
+    params.update(vars(args))
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    lexicon = Lexicon(params.lang_dir)
+
+    params.blank_id = 0
+    params.vocab_size = max(lexicon.tokens) + 1
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    model.to(device)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to("cpu")
+    model.eval()
+
+    if params.jit is True:
+        convert_scaled_to_non_scaled(model, inplace=True)
+        logging.info("Using torch.jit.script()")
+        # We won't use the forward() method of the model in C++, so just ignore
+        # it here.
+        # Otherwise, one of its arguments is a ragged tensor and is not
+        # torch scriptabe.
+        model.__class__.forward = torch.jit.ignore(model.__class__.forward)
+        logging.info("Using torch.jit.script")
+        model = torch.jit.script(model)
+        filename = params.exp_dir / "cpu_jit.pt"
+        model.save(str(filename))
+        logging.info(f"Saved to {filename}")
+    else:
+        logging.info("Not using torchscript. Export model.state_dict()")
+        # Save it using a format so that it can be loaded
+        # by :func:`load_checkpoint`
+        filename = params.exp_dir / "pretrained.pt"
+        torch.save({"model": model.state_dict()}, str(filename))
+        logging.info(f"Saved to {filename}")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/jit_pretrained.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/jit_pretrained.py
new file mode 120000
index 000000000..a44034e34
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/jit_pretrained.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/joiner.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/joiner.py
new file mode 120000
index 000000000..0f0c3c90a
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/joiner.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/model.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/model.py
new file mode 120000
index 000000000..0d8bc665b
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/model.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/optim.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/optim.py
new file mode 120000
index 000000000..8a05abb5f
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/optim.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/pretrained.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/pretrained.py
new file mode 120000
index 000000000..068f0f57f
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/pretrained.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/pretrained.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/scaling.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/scaling.py
new file mode 120000
index 000000000..5f9be9fe0
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/scaling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/scaling_converter.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/scaling_converter.py
new file mode 120000
index 000000000..f9960e5c6
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/scaling_converter.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/test_model.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/test_model.py
new file mode 120000
index 000000000..7ceac5d10
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/test_model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/test_model.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py
new file mode 100755
index 000000000..757d6535e
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py
@@ -0,0 +1,1186 @@
+#!/usr/bin/env python3
+# Copyright    2021-2022  Xiaomi Corp.        (authors: Fangjun Kuang,
+#                                                       Wei Kang,
+#                                                       Mingshuang Luo,)
+#                                                       Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless7/train.py \
+  --world-size 4 \
+  --num-epochs 15 \
+  --start-epoch 1 \
+  --exp-dir pruned_transducer_stateless7/exp \
+  --max-duration 150 \
+    --use-fp16 True
+
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import AlimeetingAsrDataModule
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import Transducer
+from optim import Eden, ScaledAdam
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer
+
+from icefall import diagnostics
+from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+    save_checkpoint_with_global_batch_idx,
+    update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.lexicon import Lexicon
+from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+    if isinstance(model, DDP):
+        # get underlying nn.Module
+        model = model.module
+    for module in model.modules():
+        if hasattr(module, "batch_count"):
+            module.batch_count = batch_count
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+    parser.add_argument(
+        "--num-encoder-layers",
+        type=str,
+        default="2,4,3,2,4",
+        help="Number of zipformer encoder layers, comma separated.",
+    )
+
+    parser.add_argument(
+        "--feedforward-dims",
+        type=str,
+        default="1024,1024,2048,2048,1024",
+        help="Feedforward dimension of the zipformer encoder layers, comma separated.",
+    )
+
+    parser.add_argument(
+        "--nhead",
+        type=str,
+        default="8,8,8,8,8",
+        help="Number of attention heads in the zipformer encoder layers.",
+    )
+
+    parser.add_argument(
+        "--encoder-dims",
+        type=str,
+        default="384,384,384,384,384",
+        help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated",
+    )
+
+    parser.add_argument(
+        "--attention-dims",
+        type=str,
+        default="192,192,192,192,192",
+        help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
+        not the same as embedding dimension.""",
+    )
+
+    parser.add_argument(
+        "--encoder-unmasked-dims",
+        type=str,
+        default="256,256,256,256,256",
+        help="Unmasked dimensions in the encoders, relates to augmentation during training.  "
+        "Must be <= each of encoder_dims.  Empirically, less than 256 seems to make performance "
+        " worse.",
+    )
+
+    parser.add_argument(
+        "--zipformer-downsampling-factors",
+        type=str,
+        default="1,2,4,8,2",
+        help="Downsampling factor for each stack of encoder layers.",
+    )
+
+    parser.add_argument(
+        "--cnn-module-kernels",
+        type=str,
+        default="31,31,31,31,31",
+        help="Sizes of kernels in convolution modules",
+    )
+
+    parser.add_argument(
+        "--decoder-dim",
+        type=int,
+        default=512,
+        help="Embedding dimension in the decoder model.",
+    )
+
+    parser.add_argument(
+        "--joiner-dim",
+        type=int,
+        default=512,
+        help="""Dimension used in the joiner model.
+        Outputs from the encoder and decoder model are projected
+        to this dimension before adding.
+        """,
+    )
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--world-size",
+        type=int,
+        default=1,
+        help="Number of GPUs for DDP training.",
+    )
+
+    parser.add_argument(
+        "--master-port",
+        type=int,
+        default=12354,
+        help="Master port to use for DDP training.",
+    )
+
+    parser.add_argument(
+        "--tensorboard",
+        type=str2bool,
+        default=True,
+        help="Should various information be logged in tensorboard.",
+    )
+
+    parser.add_argument(
+        "--num-epochs",
+        type=int,
+        default=15,
+        help="Number of epochs to train.",
+    )
+
+    parser.add_argument(
+        "--start-epoch",
+        type=int,
+        default=1,
+        help="""Resume training from this epoch. It should be positive.
+        If larger than 1, it will load checkpoint from
+        exp-dir/epoch-{start_epoch-1}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--start-batch",
+        type=int,
+        default=0,
+        help="""If positive, --start-epoch is ignored and
+        it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless7/exp",
+        help="""The experiment dir.
+        It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        default="data/lang_char",
+        help="""The lang dir
+        It contains language related input files such as
+        "lexicon.txt"
+        """,
+    )
+
+    parser.add_argument(
+        "--base-lr", type=float, default=0.05, help="The base learning rate."
+    )
+
+    parser.add_argument(
+        "--lr-batches",
+        type=float,
+        default=5000,
+        help="""Number of steps that affects how rapidly the learning rate
+        decreases. We suggest not to change this.""",
+    )
+
+    parser.add_argument(
+        "--lr-epochs",
+        type=float,
+        default=3.5,
+        help="""Number of epochs that affects how rapidly the learning rate decreases.
+        """,
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+    )
+
+    parser.add_argument(
+        "--prune-range",
+        type=int,
+        default=5,
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
+    )
+
+    parser.add_argument(
+        "--lm-scale",
+        type=float,
+        default=0.25,
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
+    )
+
+    parser.add_argument(
+        "--am-scale",
+        type=float,
+        default=0.0,
+        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+    )
+
+    parser.add_argument(
+        "--simple-loss-scale",
+        type=float,
+        default=0.5,
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
+    )
+
+    parser.add_argument(
+        "--seed",
+        type=int,
+        default=42,
+        help="The seed for random generators intended for reproducibility",
+    )
+
+    parser.add_argument(
+        "--print-diagnostics",
+        type=str2bool,
+        default=False,
+        help="Accumulate stats on activations, print them and exit.",
+    )
+
+    parser.add_argument(
+        "--inf-check",
+        type=str2bool,
+        default=False,
+        help="Add hooks to check for infinite module outputs and gradients.",
+    )
+
+    parser.add_argument(
+        "--save-every-n",
+        type=int,
+        default=5000,
+        help="""Save checkpoint after processing this number of batches"
+        periodically. We save checkpoint to exp-dir/ whenever
+        params.batch_idx_train % save_every_n == 0. The checkpoint filename
+        has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+        Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+        end of each epoch where `xxx` is the epoch number counting from 0.
+        """,
+    )
+
+    parser.add_argument(
+        "--keep-last-k",
+        type=int,
+        default=10,
+        help="""Only keep this number of checkpoints on disk.
+        For instance, if it is 3, there are only 3 checkpoints
+        in the exp-dir with filenames `checkpoint-xxx.pt`.
+        It does not affect checkpoints with name `epoch-xxx.pt`.
+        """,
+    )
+
+    parser.add_argument(
+        "--average-period",
+        type=int,
+        default=200,
+        help="""Update the averaged model, namely `model_avg`, after processing
+        this number of batches. `model_avg` is a separate version of model,
+        in which each floating-point parameter is the average of all the
+        parameters from the start of training. Each time we take the average,
+        we do: `model_avg = model * (average_period / batch_idx_train) +
+            model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+        """,
+    )
+
+    parser.add_argument(
+        "--use-fp16",
+        type=str2bool,
+        default=False,
+        help="Whether to use half precision training.",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    """Return a dict containing training parameters.
+
+    All training related parameters that are not passed from the commandline
+    are saved in the variable `params`.
+
+    Commandline options are merged into `params` after they are parsed, so
+    you can also access them via `params`.
+
+    Explanation of options saved in `params`:
+
+        - best_train_loss: Best training loss so far. It is used to select
+                           the model that has the lowest training loss. It is
+                           updated during the training.
+
+        - best_valid_loss: Best validation loss so far. It is used to select
+                           the model that has the lowest validation loss. It is
+                           updated during the training.
+
+        - best_train_epoch: It is the epoch that has the best training loss.
+
+        - best_valid_epoch: It is the epoch that has the best validation loss.
+
+        - batch_idx_train: Used to writing statistics to tensorboard. It
+                           contains number of batches trained so far across
+                           epochs.
+
+        - log_interval:  Print training loss if batch_idx % log_interval` is 0
+
+        - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+        - valid_interval:  Run validation if batch_idx % valid_interval is 0
+
+        - feature_dim: The model input dim. It has to match the one used
+                       in computing features.
+
+        - subsampling_factor:  The subsampling factor for the model.
+
+        - encoder_dim: Hidden dim for multi-head attention model.
+
+        - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+        - warm_step: The warmup period that dictates the decay of the
+              scale on "simple" (un-pruned) loss.
+    """
+    params = AttributeDict(
+        {
+            "best_train_loss": float("inf"),
+            "best_valid_loss": float("inf"),
+            "best_train_epoch": -1,
+            "best_valid_epoch": -1,
+            "batch_idx_train": 0,
+            "log_interval": 100,
+            "reset_interval": 200,
+            "valid_interval": 3000,  # For the 100h subset, use 800
+            # parameters for zipformer
+            "feature_dim": 80,
+            "subsampling_factor": 4,  # not passed in, this is fixed.
+            "warm_step": 2000,
+            "env_info": get_env_info(),
+        }
+    )
+
+    return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+    # TODO: We can add an option to switch between Zipformer and Transformer
+    def to_int_tuple(s: str):
+        return tuple(map(int, s.split(",")))
+
+    encoder = Zipformer(
+        num_features=params.feature_dim,
+        output_downsampling_factor=2,
+        zipformer_downsampling_factors=to_int_tuple(
+            params.zipformer_downsampling_factors
+        ),
+        encoder_dims=to_int_tuple(params.encoder_dims),
+        attention_dim=to_int_tuple(params.attention_dims),
+        encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
+        nhead=to_int_tuple(params.nhead),
+        feedforward_dim=to_int_tuple(params.feedforward_dims),
+        cnn_module_kernels=to_int_tuple(params.cnn_module_kernels),
+        num_encoder_layers=to_int_tuple(params.num_encoder_layers),
+    )
+    return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+    decoder = Decoder(
+        vocab_size=params.vocab_size,
+        decoder_dim=params.decoder_dim,
+        blank_id=params.blank_id,
+        context_size=params.context_size,
+    )
+    return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+    joiner = Joiner(
+        encoder_dim=int(params.encoder_dims.split(",")[-1]),
+        decoder_dim=params.decoder_dim,
+        joiner_dim=params.joiner_dim,
+        vocab_size=params.vocab_size,
+    )
+    return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+    encoder = get_encoder_model(params)
+    decoder = get_decoder_model(params)
+    joiner = get_joiner_model(params)
+
+    model = Transducer(
+        encoder=encoder,
+        decoder=decoder,
+        joiner=joiner,
+        encoder_dim=int(params.encoder_dims.split(",")[-1]),
+        decoder_dim=params.decoder_dim,
+        joiner_dim=params.joiner_dim,
+        vocab_size=params.vocab_size,
+    )
+    return model
+
+
+def load_checkpoint_if_available(
+    params: AttributeDict,
+    model: nn.Module,
+    model_avg: nn.Module = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+    """Load checkpoint from file.
+
+    If params.start_batch is positive, it will load the checkpoint from
+    `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+    params.start_epoch is larger than 1, it will load the checkpoint from
+    `params.start_epoch - 1`.
+
+    Apart from loading state dict for `model` and `optimizer` it also updates
+    `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+    and `best_valid_loss` in `params`.
+
+    Args:
+      params:
+        The return value of :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer that we are using.
+      scheduler:
+        The scheduler that we are using.
+    Returns:
+      Return a dict containing previously saved training info.
+    """
+    if params.start_batch > 0:
+        filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+    elif params.start_epoch > 1:
+        filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+    else:
+        return None
+
+    assert filename.is_file(), f"{filename} does not exist!"
+
+    saved_params = load_checkpoint(
+        filename,
+        model=model,
+        model_avg=model_avg,
+        optimizer=optimizer,
+        scheduler=scheduler,
+    )
+
+    keys = [
+        "best_train_epoch",
+        "best_valid_epoch",
+        "batch_idx_train",
+        "best_train_loss",
+        "best_valid_loss",
+    ]
+    for k in keys:
+        params[k] = saved_params[k]
+
+    if params.start_batch > 0:
+        if "cur_epoch" in saved_params:
+            params["start_epoch"] = saved_params["cur_epoch"]
+
+        if "cur_batch_idx" in saved_params:
+            params["cur_batch_idx"] = saved_params["cur_batch_idx"]
+
+    return saved_params
+
+
+def save_checkpoint(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    model_avg: Optional[nn.Module] = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+    sampler: Optional[CutSampler] = None,
+    scaler: Optional[GradScaler] = None,
+    rank: int = 0,
+) -> None:
+    """Save model, optimizer, scheduler and training stats to file.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer used in the training.
+      sampler:
+       The sampler for the training dataset.
+      scaler:
+        The scaler used for mix precision training.
+    """
+    if rank != 0:
+        return
+    filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+    save_checkpoint_impl(
+        filename=filename,
+        model=model,
+        model_avg=model_avg,
+        params=params,
+        optimizer=optimizer,
+        scheduler=scheduler,
+        sampler=sampler,
+        scaler=scaler,
+        rank=rank,
+    )
+
+    if params.best_train_epoch == params.cur_epoch:
+        best_train_filename = params.exp_dir / "best-train-loss.pt"
+        copyfile(src=filename, dst=best_train_filename)
+
+    if params.best_valid_epoch == params.cur_epoch:
+        best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+        copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    graph_compiler: CharCtcTrainingGraphCompiler,
+    batch: dict,
+    is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+    """
+    Compute transducer loss given the model and its inputs.
+
+    Args:
+      params:
+        Parameters for training. See :func:`get_params`.
+      model:
+        The model for training. It is an instance of Zipformer in our case.
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      is_training:
+        True for training. False for validation. When it is True, this
+        function enables autograd during computation; when it is False, it
+        disables autograd.
+     warmup: a floating point value which increases throughout training;
+        values >= 1.0 are fully warmed up and have all modules present.
+    """
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+    feature = batch["inputs"]
+    # at entry, feature is (N, T, C)
+    assert feature.ndim == 3
+    feature = feature.to(device)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    batch_idx_train = params.batch_idx_train
+    warm_step = params.warm_step
+
+    texts = batch["supervisions"]["text"]
+
+    y = graph_compiler.texts_to_ids(texts)
+    if type(y) == list:
+        y = k2.RaggedTensor(y).to(device)
+    else:
+        y = y.to(device)
+
+    with torch.set_grad_enabled(is_training):
+        simple_loss, pruned_loss = model(
+            x=feature,
+            x_lens=feature_lens,
+            y=y,
+            prune_range=params.prune_range,
+            am_scale=params.am_scale,
+            lm_scale=params.lm_scale,
+        )
+
+        s = params.simple_loss_scale
+        # take down the scale on the simple loss from 1.0 at the start
+        # to params.simple_loss scale by warm_step.
+        simple_loss_scale = (
+            s
+            if batch_idx_train >= warm_step
+            else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+        )
+        pruned_loss_scale = (
+            1.0
+            if batch_idx_train >= warm_step
+            else 0.1 + 0.9 * (batch_idx_train / warm_step)
+        )
+
+        loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+    assert loss.requires_grad == is_training
+
+    info = MetricsTracker()
+    with warnings.catch_warnings():
+        warnings.simplefilter("ignore")
+        info["frames"] = ((feature_lens - 7) // 2).sum().item()
+
+    # Note: We use reduction=sum while computing the loss.
+    info["loss"] = loss.detach().cpu().item()
+    info["simple_loss"] = simple_loss.detach().cpu().item()
+    info["pruned_loss"] = pruned_loss.detach().cpu().item()
+
+    return loss, info
+
+
+def compute_validation_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    graph_compiler: CharCtcTrainingGraphCompiler,
+    valid_dl: torch.utils.data.DataLoader,
+    world_size: int = 1,
+) -> MetricsTracker:
+    """Run the validation process."""
+    model.eval()
+
+    tot_loss = MetricsTracker()
+
+    for batch_idx, batch in enumerate(valid_dl):
+        loss, loss_info = compute_loss(
+            params=params,
+            model=model,
+            graph_compiler=graph_compiler,
+            batch=batch,
+            is_training=False,
+        )
+        assert loss.requires_grad is False
+        tot_loss = tot_loss + loss_info
+
+    if world_size > 1:
+        tot_loss.reduce(loss.device)
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    if loss_value < params.best_valid_loss:
+        params.best_valid_epoch = params.cur_epoch
+        params.best_valid_loss = loss_value
+
+    return tot_loss
+
+
+def train_one_epoch(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    optimizer: torch.optim.Optimizer,
+    scheduler: LRSchedulerType,
+    graph_compiler: CharCtcTrainingGraphCompiler,
+    train_dl: torch.utils.data.DataLoader,
+    valid_dl: torch.utils.data.DataLoader,
+    scaler: GradScaler,
+    model_avg: Optional[nn.Module] = None,
+    tb_writer: Optional[SummaryWriter] = None,
+    world_size: int = 1,
+    rank: int = 0,
+) -> None:
+    """Train the model for one epoch.
+
+    The training loss from the mean of all frames is saved in
+    `params.train_loss`. It runs the validation process every
+    `params.valid_interval` batches.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The model for training.
+      optimizer:
+        The optimizer we are using.
+      scheduler:
+        The learning rate scheduler, we call step() every step.
+      train_dl:
+        Dataloader for the training dataset.
+      valid_dl:
+        Dataloader for the validation dataset.
+      scaler:
+        The scaler used for mix precision training.
+      model_avg:
+        The stored model averaged from the start of training.
+      tb_writer:
+        Writer to write log messages to tensorboard.
+      world_size:
+        Number of nodes in DDP training. If it is 1, DDP is disabled.
+      rank:
+        The rank of the node in DDP training. If no DDP is used, it should
+        be set to 0.
+    """
+    model.train()
+
+    tot_loss = MetricsTracker()
+
+    cur_batch_idx = params.get("cur_batch_idx", 0)
+
+    for batch_idx, batch in enumerate(train_dl):
+        if batch_idx < cur_batch_idx:
+            continue
+        cur_batch_idx = batch_idx
+
+        params.batch_idx_train += 1
+        batch_size = len(batch["supervisions"]["text"])
+
+        try:
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, loss_info = compute_loss(
+                    params=params,
+                    model=model,
+                    graph_compiler=graph_compiler,
+                    batch=batch,
+                    is_training=True,
+                )
+            # summary stats
+            tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+            # NOTE: We use reduction==sum and loss is computed over utterances
+            # in the batch and there is no normalization to it so far.
+            scaler.scale(loss).backward()
+            set_batch_count(model, params.batch_idx_train)
+            scheduler.step_batch(params.batch_idx_train)
+
+            scaler.step(optimizer)
+            scaler.update()
+            optimizer.zero_grad()
+        except:  # noqa
+            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
+            raise
+
+        if params.print_diagnostics and batch_idx == 5:
+            return
+
+        if (
+            rank == 0
+            and params.batch_idx_train > 0
+            and params.batch_idx_train % params.average_period == 0
+        ):
+            update_averaged_model(
+                params=params,
+                model_cur=model,
+                model_avg=model_avg,
+            )
+
+        if (
+            params.batch_idx_train > 0
+            and params.batch_idx_train % params.save_every_n == 0
+        ):
+            params.cur_batch_idx = batch_idx
+            save_checkpoint_with_global_batch_idx(
+                out_dir=params.exp_dir,
+                global_batch_idx=params.batch_idx_train,
+                model=model,
+                model_avg=model_avg,
+                params=params,
+                optimizer=optimizer,
+                scheduler=scheduler,
+                sampler=train_dl.sampler,
+                scaler=scaler,
+                rank=rank,
+            )
+            del params.cur_batch_idx
+            remove_checkpoints(
+                out_dir=params.exp_dir,
+                topk=params.keep_last_k,
+                rank=rank,
+            )
+
+        if batch_idx % 100 == 0 and params.use_fp16:
+            # If the grad scale was less than 1, try increasing it.    The _growth_interval
+            # of the grad scaler is configurable, but we can't configure it to have different
+            # behavior depending on the current grad scale.
+            cur_grad_scale = scaler._scale.item()
+            if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
+                scaler.update(cur_grad_scale * 2.0)
+            if cur_grad_scale < 0.01:
+                logging.warning(f"Grad scale is small: {cur_grad_scale}")
+            if cur_grad_scale < 1.0e-05:
+                raise RuntimeError(
+                    f"grad_scale is too small, exiting: {cur_grad_scale}"
+                )
+
+        if batch_idx % params.log_interval == 0:
+            cur_lr = scheduler.get_last_lr()[0]
+            cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+            logging.info(
+                f"Epoch {params.cur_epoch}, "
+                f"batch {batch_idx}, loss[{loss_info}], "
+                f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+                f"lr: {cur_lr:.2e}, "
+                + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+            )
+
+            if tb_writer is not None:
+                tb_writer.add_scalar(
+                    "train/learning_rate", cur_lr, params.batch_idx_train
+                )
+
+                loss_info.write_summary(
+                    tb_writer, "train/current_", params.batch_idx_train
+                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                if params.use_fp16:
+                    tb_writer.add_scalar(
+                        "train/grad_scale", cur_grad_scale, params.batch_idx_train
+                    )
+
+        if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+            logging.info("Computing validation loss")
+            valid_info = compute_validation_loss(
+                params=params,
+                model=model,
+                graph_compiler=graph_compiler,
+                valid_dl=valid_dl,
+                world_size=world_size,
+            )
+            model.train()
+            logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+            logging.info(
+                f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+            )
+            if tb_writer is not None:
+                valid_info.write_summary(
+                    tb_writer, "train/valid_", params.batch_idx_train
+                )
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    params.train_loss = loss_value
+    if params.train_loss < params.best_train_loss:
+        params.best_train_epoch = params.cur_epoch
+        params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+    """
+    Args:
+      rank:
+        It is a value between 0 and `world_size-1`, which is
+        passed automatically by `mp.spawn()` in :func:`main`.
+        The node with rank 0 is responsible for saving checkpoint.
+      world_size:
+        Number of GPUs for DDP training.
+      args:
+        The return value of get_parser().parse_args()
+    """
+    params = get_params()
+    params.update(vars(args))
+
+    fix_random_seed(params.seed)
+    if world_size > 1:
+        setup_dist(rank, world_size, params.master_port)
+
+    setup_logger(f"{params.exp_dir}/log/log-train")
+    logging.info("Training started")
+
+    if args.tensorboard and rank == 0:
+        tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+    else:
+        tb_writer = None
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", rank)
+    logging.info(f"Device: {device}")
+
+    lexicon = Lexicon(params.lang_dir)
+    graph_compiler = CharCtcTrainingGraphCompiler(
+        lexicon=lexicon,
+        device=device,
+    )
+
+    params.blank_id = lexicon.token_table[""]
+    params.vocab_size = max(lexicon.tokens) + 1
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    assert params.save_every_n >= params.average_period
+    model_avg: Optional[nn.Module] = None
+    if rank == 0:
+        # model_avg is only used with rank 0
+        model_avg = copy.deepcopy(model).to(torch.float64)
+
+    assert params.start_epoch > 0, params.start_epoch
+    checkpoints = load_checkpoint_if_available(
+        params=params, model=model, model_avg=model_avg
+    )
+
+    model.to(device)
+    if world_size > 1:
+        logging.info("Using DDP")
+        model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+    parameters_names = []
+    parameters_names.append(
+        [name_param_pair[0] for name_param_pair in model.named_parameters()]
+    )
+    optimizer = ScaledAdam(
+        model.parameters(),
+        lr=params.base_lr,
+        clipping_scale=2.0,
+        parameters_names=parameters_names,
+    )
+
+    scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+    if checkpoints and "optimizer" in checkpoints:
+        logging.info("Loading optimizer state dict")
+        optimizer.load_state_dict(checkpoints["optimizer"])
+
+    if (
+        checkpoints
+        and "scheduler" in checkpoints
+        and checkpoints["scheduler"] is not None
+    ):
+        logging.info("Loading scheduler state dict")
+        scheduler.load_state_dict(checkpoints["scheduler"])
+
+    if params.print_diagnostics:
+        opts = diagnostics.TensorDiagnosticOptions(
+            2**22
+        )  # allow 4 megabytes per sub-module
+        diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+    if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+        # We only load the sampler's state dict when it loads a checkpoint
+        # saved in the middle of an epoch
+        sampler_state_dict = checkpoints["sampler"]
+    else:
+        sampler_state_dict = None
+
+    if params.inf_check:
+        register_inf_check_hooks(model)
+
+    alimeeting = AlimeetingAsrDataModule(args)
+
+    train_cuts = alimeeting.train_cuts()
+    train_dl = alimeeting.train_dataloaders(
+        train_cuts, sampler_state_dict=sampler_state_dict
+    )
+
+    valid_cuts = alimeeting.eval_ihm_cuts()
+    valid_dl = alimeeting.valid_dataloaders(valid_cuts)
+
+    # if not params.print_diagnostics:
+    #     scan_pessimistic_batches_for_oom(
+    #         model=model,
+    #         train_dl=train_dl,
+    #         optimizer=optimizer,
+    #         graph_compiler=graph_compiler,
+    #         params=params,
+    #     )
+
+    scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+    if checkpoints and "grad_scaler" in checkpoints:
+        logging.info("Loading grad scaler state dict")
+        scaler.load_state_dict(checkpoints["grad_scaler"])
+
+    for epoch in range(params.start_epoch, params.num_epochs + 1):
+        scheduler.step_epoch(epoch - 1)
+        fix_random_seed(params.seed + epoch - 1)
+        train_dl.sampler.set_epoch(epoch - 1)
+
+        if tb_writer is not None:
+            tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+        params.cur_epoch = epoch
+
+        train_one_epoch(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            graph_compiler=graph_compiler,
+            train_dl=train_dl,
+            valid_dl=valid_dl,
+            scaler=scaler,
+            tb_writer=tb_writer,
+            world_size=world_size,
+            rank=rank,
+        )
+
+        if params.print_diagnostics:
+            diagnostic.print_diagnostics()
+            break
+
+        save_checkpoint(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sampler=train_dl.sampler,
+            scaler=scaler,
+            rank=rank,
+        )
+
+    logging.info("Done!")
+
+    if world_size > 1:
+        torch.distributed.barrier()
+        cleanup_dist()
+
+
+def display_and_save_batch(
+    batch: dict,
+    params: AttributeDict,
+    graph_compiler: CharCtcTrainingGraphCompiler,
+) -> None:
+    """Display the batch statistics and save the batch into disk.
+
+    Args:
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      params:
+        Parameters for training. See :func:`get_params`.
+      sp:
+        The BPE model.
+    """
+    from lhotse.utils import uuid4
+
+    filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+    logging.info(f"Saving batch to {filename}")
+    torch.save(batch, filename)
+
+    supervisions = batch["supervisions"]
+    features = batch["inputs"]
+
+    logging.info(f"features shape: {features.shape}")
+
+
+def scan_pessimistic_batches_for_oom(
+    model: Union[nn.Module, DDP],
+    train_dl: torch.utils.data.DataLoader,
+    optimizer: torch.optim.Optimizer,
+    graph_compiler: CharCtcTrainingGraphCompiler,
+    params: AttributeDict,
+):
+    from lhotse.dataset import find_pessimistic_batches
+
+    logging.info(
+        "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+    )
+    batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+    for criterion, cuts in batches.items():
+        batch = train_dl.dataset[cuts]
+        try:
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, _ = compute_loss(
+                    params=params,
+                    model=model,
+                    graph_compiler=graph_compiler,
+                    batch=batch,
+                    is_training=True,
+                )
+            loss.backward()
+            optimizer.zero_grad()
+        except Exception as e:
+            if "CUDA out of memory" in str(e):
+                logging.error(
+                    "Your GPU ran out of memory with the current "
+                    "max_duration setting. We recommend decreasing "
+                    "max_duration and trying again.\n"
+                    f"Failing criterion: {criterion} "
+                    f"(={crit_values[criterion]}) ..."
+                )
+            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
+            raise
+        logging.info(
+            f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+        )
+
+
+def main():
+    parser = get_parser()
+    AlimeetingAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    world_size = args.world_size
+    assert world_size >= 1
+    if world_size > 1:
+        mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+    else:
+        run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/zipformer.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/zipformer.py
new file mode 120000
index 000000000..f2f66041e
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/zipformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/zipformer.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/shared b/egs/alimeeting/ASR_v2/shared
new file mode 120000
index 000000000..3a3b28f96
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/shared
@@ -0,0 +1 @@
+../../../egs/aishell/ASR/shared
\ No newline at end of file

From 02c18ba4b25a805db8e8dbb6b7fc4766ad1e006a Mon Sep 17 00:00:00 2001
From: Yifan Yang <64255737+yfyeung@users.noreply.github.com>
Date: Sat, 10 Dec 2022 19:34:19 +0800
Subject: [PATCH 071/120] rm the dup line of Zipformer.py (#755)

Co-authored-by: yifanyang 
---
 egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py | 1 -
 1 file changed, 1 deletion(-)

diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py
index b007a7308..e8fd89abd 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py
@@ -81,7 +81,6 @@ class Zipformer(EncoderInterface):
         super(Zipformer, self).__init__()
 
         self.num_features = num_features
-        self.encoder_unmasked_dims = encoder_unmasked_dims
         assert 0 < encoder_dims[0] <= encoder_dims[1]
         self.encoder_dims = encoder_dims
         self.encoder_unmasked_dims = encoder_unmasked_dims

From e83409cbe536cc031728f394fc3eb1132aac01e1 Mon Sep 17 00:00:00 2001
From: wzy <38179632+v-yunbin@users.noreply.github.com>
Date: Sun, 11 Dec 2022 20:16:10 +0800
Subject: [PATCH 072/120]  Filter the training data of T <  S  for Wenet train
 recipe (#753)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

* filter the case of T <  S  for training data

* fix style issues

* fix style issues

* fix style issues

Co-authored-by: 张云斌 
---
 .../ASR/pruned_transducer_stateless2/train.py | 32 +++++++++++++++++--
 .../ASR/pruned_transducer_stateless5/train.py | 32 +++++++++++++++++--
 2 files changed, 58 insertions(+), 6 deletions(-)

diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
index 43fa0d01b..48b347b64 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
@@ -861,15 +861,41 @@ def run(rank, world_size, args):
     valid_cuts = wenetspeech.valid_cuts()
 
     def remove_short_and_long_utt(c: Cut):
-        # Keep only utterances with duration between 1 second and 15.0 seconds
+        # Keep only utterances with duration between 1 second and 10 seconds
         #
-        # Caution: There is a reason to select 15.0 here. Please see
+        # Caution: There is a reason to select 10.0 here. Please see
         # ../local/display_manifest_statistics.py
         #
         # You should use ../local/display_manifest_statistics.py to get
         # an utterance duration distribution for your dataset to select
         # the threshold
-        return 1.0 <= c.duration <= 15.0
+        if c.duration < 1.0 or c.duration > 10.0:
+            logging.warning(
+                f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+            )
+            return False
+
+        # In pruned RNN-T, we require that T >= S
+        # where T is the number of feature frames after subsampling
+        # and S is the number of tokens in the utterance
+
+        # In ./conformer.py, the conv module uses the following expression
+        # for subsampling
+        T = ((c.num_frames - 1) // 2 - 1) // 2
+        tokens = c.supervisions[0].text.replace(" ", "")
+
+        if T < len(tokens):
+            logging.warning(
+                f"Exclude cut with ID {c.id} from training. "
+                f"Number of frames (before subsampling): {c.num_frames}. "
+                f"Number of frames (after subsampling): {T}. "
+                f"Text: {c.supervisions[0].text}. "
+                f"Tokens: {tokens}. "
+                f"Number of tokens: {len(tokens)}"
+            )
+            return False
+
+        return True
 
     train_cuts = train_cuts.filter(remove_short_and_long_utt)
 
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py
index 440b65f32..34a72be8f 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py
@@ -1006,15 +1006,41 @@ def run(rank, world_size, args):
     valid_cuts = wenetspeech.valid_cuts()
 
     def remove_short_and_long_utt(c: Cut):
-        # Keep only utterances with duration between 1 second and 15.0 seconds
+        # Keep only utterances with duration between 1 second and 10 seconds
         #
-        # Caution: There is a reason to select 15.0 here. Please see
+        # Caution: There is a reason to select 10.0 here. Please see
         # ../local/display_manifest_statistics.py
         #
         # You should use ../local/display_manifest_statistics.py to get
         # an utterance duration distribution for your dataset to select
         # the threshold
-        return 1.0 <= c.duration <= 15.0
+        if c.duration < 1.0 or c.duration > 10.0:
+            logging.warning(
+                f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+            )
+            return False
+
+        # In pruned RNN-T, we require that T >= S
+        # where T is the number of feature frames after subsampling
+        # and S is the number of tokens in the utterance
+
+        # In ./conformer.py, the conv module uses the following expression
+        # for subsampling
+        T = ((c.num_frames - 1) // 2 - 1) // 2
+        tokens = c.supervisions[0].text.replace(" ", "")
+
+        if T < len(tokens):
+            logging.warning(
+                f"Exclude cut with ID {c.id} from training. "
+                f"Number of frames (before subsampling): {c.num_frames}. "
+                f"Number of frames (after subsampling): {T}. "
+                f"Text: {c.supervisions[0].text}. "
+                f"Tokens: {tokens}. "
+                f"Number of tokens: {len(tokens)}"
+            )
+            return False
+
+        return True
 
     train_cuts = train_cuts.filter(remove_short_and_long_utt)
 

From b25c234c51426d61552cdca819ab57fe712214c9 Mon Sep 17 00:00:00 2001
From: Zengwei Yao 
Date: Sun, 11 Dec 2022 21:30:39 +0800
Subject: [PATCH 073/120] Add Zipformer-MMI (#746)

* Minor fix to conformer-mmi

* Minor fixes

* Fix decode.py

* add training files

* train with ctc warmup

* add pruned_transducer_stateless7_mmi

* add zipformer_mmi/mmi_decode.py, using HP as decoding graph

* add mmi_decode.py

* remove pruned_transducer_stateless7_mmi

* rename zipformer_mmi/train_with_ctc.py as zipformer_mmi/train.py

* remove unused method

* rename mmi_decode.py

* add export.py pretrained.py jit_pretrained.py ...

* add RESULTS.md

* add CI test

* add docs

* add README.md

Co-authored-by: pkufool 
---
 .flake8                                       |    3 +-
 ...n-librispeech-conformer-ctc3-2022-11-28.sh |    8 +-
 ...ed-transducer-stateless7-ctc-2022-12-01.sh |    8 +-
 ...un-librispeech-zipformer-mmi-2022-12-08.sh |  103 ++
 ...n-librispeech-2022-12-08-zipformer-mmi.yml |  167 +++
 docs/source/recipes/librispeech/index.rst     |    1 +
 .../recipes/librispeech/zipformer_mmi.rst     |  422 ++++++
 egs/librispeech/ASR/RESULTS.md                |   57 +
 .../ASR/conformer_ctc3/jit_pretrained.py      |    5 +-
 .../ASR/conformer_ctc3/pretrained.py          |    5 +-
 egs/librispeech/ASR/conformer_mmi/decode.py   |   12 +-
 .../ASR/conformer_mmi/train-with-attention.py |   76 +-
 egs/librispeech/ASR/conformer_mmi/train.py    |   67 +-
 egs/librispeech/ASR/generate-lm.sh            |    2 +-
 .../export.py                                 |    6 +-
 .../jit_pretrained_ctc.py                     |    5 +-
 .../pretrained_ctc.py                         |    5 +-
 egs/librispeech/ASR/zipformer_mmi/README.md   |   26 +
 egs/librispeech/ASR/zipformer_mmi/__init__.py |    0
 .../ASR/zipformer_mmi/asr_datamodule.py       |    1 +
 egs/librispeech/ASR/zipformer_mmi/decode.py   |  736 ++++++++++
 .../ASR/zipformer_mmi/encoder_interface.py    |    1 +
 egs/librispeech/ASR/zipformer_mmi/export.py   |  307 +++++
 .../ASR/zipformer_mmi/jit_pretrained.py       |  391 ++++++
 egs/librispeech/ASR/zipformer_mmi/model.py    |   75 ++
 egs/librispeech/ASR/zipformer_mmi/optim.py    |    1 +
 .../ASR/zipformer_mmi/pretrained.py           |  410 ++++++
 egs/librispeech/ASR/zipformer_mmi/scaling.py  |    1 +
 .../ASR/zipformer_mmi/scaling_converter.py    |    1 +
 .../ASR/zipformer_mmi/test_model.py           |   57 +
 egs/librispeech/ASR/zipformer_mmi/train.py    | 1198 +++++++++++++++++
 .../ASR/zipformer_mmi/zipformer.py            |    1 +
 icefall/decode.py                             |  101 ++
 icefall/mmi.py                                |   10 +-
 34 files changed, 4224 insertions(+), 45 deletions(-)
 create mode 100755 .github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh
 create mode 100644 .github/workflows/run-librispeech-2022-12-08-zipformer-mmi.yml
 create mode 100644 docs/source/recipes/librispeech/zipformer_mmi.rst
 create mode 100644 egs/librispeech/ASR/zipformer_mmi/README.md
 create mode 100644 egs/librispeech/ASR/zipformer_mmi/__init__.py
 create mode 120000 egs/librispeech/ASR/zipformer_mmi/asr_datamodule.py
 create mode 100755 egs/librispeech/ASR/zipformer_mmi/decode.py
 create mode 120000 egs/librispeech/ASR/zipformer_mmi/encoder_interface.py
 create mode 100755 egs/librispeech/ASR/zipformer_mmi/export.py
 create mode 100755 egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py
 create mode 100644 egs/librispeech/ASR/zipformer_mmi/model.py
 create mode 120000 egs/librispeech/ASR/zipformer_mmi/optim.py
 create mode 100755 egs/librispeech/ASR/zipformer_mmi/pretrained.py
 create mode 120000 egs/librispeech/ASR/zipformer_mmi/scaling.py
 create mode 120000 egs/librispeech/ASR/zipformer_mmi/scaling_converter.py
 create mode 100755 egs/librispeech/ASR/zipformer_mmi/test_model.py
 create mode 100755 egs/librispeech/ASR/zipformer_mmi/train.py
 create mode 120000 egs/librispeech/ASR/zipformer_mmi/zipformer.py

diff --git a/.flake8 b/.flake8
index a0f44263c..41d8799c8 100644
--- a/.flake8
+++ b/.flake8
@@ -1,7 +1,7 @@
 [flake8]
 show-source=true
 statistics=true
-max-line-length = 80
+max-line-length = 88
 per-file-ignores =
     # line too long
     icefall/diagnostics.py: E501,
@@ -12,6 +12,7 @@ per-file-ignores =
     egs/librispeech/ASR/lstm_transducer_stateless*/*.py: E501, E203
     egs/librispeech/ASR/conv_emformer_transducer_stateless*/*.py: E501, E203
     egs/librispeech/ASR/conformer_ctc*/*py: E501,
+    egs/librispeech/ASR/zipformer_mmi/*.py: E501, E203
     egs/librispeech/ASR/RESULTS.md: E999,
 
     # invalid escape sequence (cause by tex formular), W605
diff --git a/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh b/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh
index 27944807f..df29f188e 100755
--- a/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh
+++ b/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh
@@ -13,7 +13,6 @@ cd egs/librispeech/ASR
 repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conformer-ctc3-2022-11-27
 
 log "Downloading pre-trained model from $repo_url"
-git lfs install
 GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
 repo=$(basename $repo_url)
 
@@ -23,7 +22,12 @@ soxi $repo/test_wavs/*.wav
 ls -lh $repo/test_wavs/*.wav
 
 pushd $repo/exp
-git lfs pull --include "data/*"
+git lfs pull --include "data/lang_bpe_500/HLG.pt"
+git lfs pull --include "data/lang_bpe_500/L.pt"
+git lfs pull --include "data/lang_bpe_500/LG.pt"
+git lfs pull --include "data/lang_bpe_500/Linv.pt"
+git lfs pull --include "data/lang_bpe_500/bpe.model"
+git lfs pull --include "data/lm/G_4_gram.pt"
 git lfs pull --include "exp/jit_trace.pt"
 git lfs pull --include "exp/pretrained.pt"
 ln -s pretrained.pt epoch-99.pt
diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh
index 6642d5f67..e081c9374 100755
--- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh
+++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh
@@ -13,7 +13,6 @@ cd egs/librispeech/ASR
 repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-ctc-2022-12-01
 
 log "Downloading pre-trained model from $repo_url"
-git lfs install
 GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
 repo=$(basename $repo_url)
 
@@ -23,7 +22,12 @@ soxi $repo/test_wavs/*.wav
 ls -lh $repo/test_wavs/*.wav
 
 pushd $repo/exp
-git lfs pull --include "data/*"
+git lfs pull --include "data/lang_bpe_500/HLG.pt"
+git lfs pull --include "data/lang_bpe_500/L.pt"
+git lfs pull --include "data/lang_bpe_500/LG.pt"
+git lfs pull --include "data/lang_bpe_500/Linv.pt"
+git lfs pull --include "data/lang_bpe_500/bpe.model"
+git lfs pull --include "data/lm/G_4_gram.pt"
 git lfs pull --include "exp/cpu_jit.pt"
 git lfs pull --include "exp/pretrained.pt"
 ln -s pretrained.pt epoch-99.pt
diff --git a/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh b/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh
new file mode 100755
index 000000000..77f28b054
--- /dev/null
+++ b/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh
@@ -0,0 +1,103 @@
+#!/usr/bin/env bash
+
+set -e
+
+log() {
+  # This function is from espnet
+  local fname=${BASH_SOURCE[1]##*/}
+  echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+cd egs/librispeech/ASR
+
+repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-mmi-2022-12-08
+
+log "Downloading pre-trained model from $repo_url"
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+repo=$(basename $repo_url)
+
+log "Display test files"
+tree $repo/
+soxi $repo/test_wavs/*.wav
+ls -lh $repo/test_wavs/*.wav
+
+pushd $repo/exp
+git lfs pull --include "data/lang_bpe_500/3gram.pt"
+git lfs pull --include "data/lang_bpe_500/4gram.pt"
+git lfs pull --include "data/lang_bpe_500/L.pt"
+git lfs pull --include "data/lang_bpe_500/LG.pt"
+git lfs pull --include "data/lang_bpe_500/Linv.pt"
+git lfs pull --include "data/lang_bpe_500/bpe.model"
+git lfs pull --include "exp/cpu_jit.pt"
+git lfs pull --include "exp/pretrained.pt"
+ln -s pretrained.pt epoch-99.pt
+ls -lh *.pt
+popd
+
+log "Export to torchscript model"
+./zipformer_mmi/export.py \
+  --exp-dir $repo/exp \
+  --use-averaged-model false \
+  --bpe-model $repo/data/lang_bpe_500/bpe.model \
+  --epoch 99 \
+  --avg 1 \
+  --jit 1
+
+ls -lh $repo/exp/*.pt
+
+log "Decode with models exported by torch.jit.script()"
+
+./zipformer_mmi/jit_pretrained.py \
+  --bpe-model $repo/data/lang_bpe_500/bpe.model \
+  --nn-model-filename $repo/exp/cpu_jit.pt \
+  --lang-dir $repo/data/lang_bpe_500 \
+  $repo/test_wavs/1089-134686-0001.wav \
+  $repo/test_wavs/1221-135766-0001.wav \
+  $repo/test_wavs/1221-135766-0002.wav
+
+for method in 1best nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescoring-4-gram; do
+  log "$method"
+
+  ./zipformer_mmi/pretrained.py \
+    --method $method \
+    --checkpoint $repo/exp/pretrained.pt \
+    --lang-dir $repo/data/lang_bpe_500 \
+    --bpe-model $repo/data/lang_bpe_500/bpe.model \
+    $repo/test_wavs/1089-134686-0001.wav \
+    $repo/test_wavs/1221-135766-0001.wav \
+    $repo/test_wavs/1221-135766-0002.wav
+done
+
+
+echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
+echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
+if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode"  ]]; then
+  mkdir -p zipformer_mmi/exp
+  ln -s $PWD/$repo/exp/pretrained.pt zipformer_mmi/exp/epoch-999.pt
+  ln -s $PWD/$repo/data/lang_bpe_500 data/
+
+  ls -lh data
+  ls -lh zipformer_mmi/exp
+
+  log "Decoding test-clean and test-other"
+
+  # use a small value for decoding with CPU
+  max_duration=100
+
+  for method in 1best nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescoring-4-gram; do
+    log "Decoding with $method"
+
+    ./zipformer_mmi/decode.py \
+      --decoding-method $method \
+      --epoch 999 \
+      --avg 1 \
+      --use-averaged-model 0 \
+      --nbest-scale 1.2 \
+      --hp-scale 1.0 \
+      --max-duration $max_duration \
+      --lang-dir $repo/data/lang_bpe_500 \
+      --exp-dir zipformer_mmi/exp
+  done
+
+  rm zipformer_mmi/exp/*.pt
+fi
diff --git a/.github/workflows/run-librispeech-2022-12-08-zipformer-mmi.yml b/.github/workflows/run-librispeech-2022-12-08-zipformer-mmi.yml
new file mode 100644
index 000000000..5472ca59b
--- /dev/null
+++ b/.github/workflows/run-librispeech-2022-12-08-zipformer-mmi.yml
@@ -0,0 +1,167 @@
+# Copyright      2022  Zengwei Yao
+
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: run-librispeech-2022-12-08-zipformer-mmi
+# zipformer
+
+on:
+  push:
+    branches:
+      - master
+  pull_request:
+    types: [labeled]
+
+  schedule:
+    # minute (0-59)
+    # hour (0-23)
+    # day of the month (1-31)
+    # month (1-12)
+    # day of the week (0-6)
+    # nightly build at 15:50 UTC time every day
+    - cron: "50 15 * * *"
+
+concurrency:
+  group: run_librispeech_2022_12_08_zipformer-${{ github.ref }}
+  cancel-in-progress: true
+
+jobs:
+  run_librispeech_2022_12_08_zipformer:
+    if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
+    runs-on: ${{ matrix.os }}
+    strategy:
+      matrix:
+        os: [ubuntu-latest]
+        python-version: [3.8]
+
+      fail-fast: false
+
+    steps:
+      - uses: actions/checkout@v2
+        with:
+          fetch-depth: 0
+
+      - name: Setup Python ${{ matrix.python-version }}
+        uses: actions/setup-python@v2
+        with:
+          python-version: ${{ matrix.python-version }}
+          cache: 'pip'
+          cache-dependency-path: '**/requirements-ci.txt'
+
+      - name: Install Python dependencies
+        run: |
+          grep -v '^#' ./requirements-ci.txt  | xargs -n 1 -L 1 pip install
+          pip uninstall -y protobuf
+          pip install --no-binary protobuf protobuf
+
+      - name: Cache kaldifeat
+        id: my-cache
+        uses: actions/cache@v2
+        with:
+          path: |
+            ~/tmp/kaldifeat
+          key: cache-tmp-${{ matrix.python-version }}-2022-09-25
+
+      - name: Install kaldifeat
+        if: steps.my-cache.outputs.cache-hit != 'true'
+        shell: bash
+        run: |
+          .github/scripts/install-kaldifeat.sh
+
+      - name: Cache LibriSpeech test-clean and test-other datasets
+        id: libri-test-clean-and-test-other-data
+        uses: actions/cache@v2
+        with:
+          path: |
+            ~/tmp/download
+          key: cache-libri-test-clean-and-test-other
+
+      - name: Download LibriSpeech test-clean and test-other
+        if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
+        shell: bash
+        run: |
+          .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
+
+      - name: Prepare manifests for LibriSpeech test-clean and test-other
+        shell: bash
+        run: |
+          .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
+
+      - name: Cache LibriSpeech test-clean and test-other fbank features
+        id: libri-test-clean-and-test-other-fbank
+        uses: actions/cache@v2
+        with:
+          path: |
+            ~/tmp/fbank-libri
+          key: cache-libri-fbank-test-clean-and-test-other-v2
+
+      - name: Compute fbank for LibriSpeech test-clean and test-other
+        if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
+        shell: bash
+        run: |
+          .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
+
+      - name: Inference with pre-trained model
+        shell: bash
+        env:
+          GITHUB_EVENT_NAME: ${{ github.event_name }}
+          GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
+        run: |
+          mkdir -p egs/librispeech/ASR/data
+          ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
+          ls -lh egs/librispeech/ASR/data/*
+
+          sudo apt-get -qq install git-lfs tree sox
+          export PYTHONPATH=$PWD:$PYTHONPATH
+          export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
+          export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
+
+          .github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh
+
+      - name: Display decoding results for librispeech zipformer-mmi
+        if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+        shell: bash
+        run: |
+          cd egs/librispeech/ASR/
+          tree ./zipformer-mmi/exp
+
+          cd zipformer-mmi
+          echo "results for zipformer-mmi"
+          echo "===1best==="
+          find exp/1best -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+          find exp/1best -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+          echo "===nbest==="
+          find exp/nbest -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+          find exp/nbest -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+          echo "===nbest-rescoring-LG==="
+          find exp/nbest-rescoring-LG -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+          find exp/nbest-rescoring-LG -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+          echo "===nbest-rescoring-3-gram==="
+          find exp/nbest-rescoring-3-gram -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+          find exp/nbest-rescoring-3-gram -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+          echo "===nbest-rescoring-4-gram==="
+          find exp/nbest-rescoring-4-gram -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+          find exp/nbest-rescoring-4-gram -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+      - name: Upload decoding results for librispeech zipformer-mmi
+        uses: actions/upload-artifact@v2
+        if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+        with:
+          name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-zipformer_mmi-2022-12-08
+          path: egs/librispeech/ASR/zipformer_mmi/exp/
diff --git a/docs/source/recipes/librispeech/index.rst b/docs/source/recipes/librispeech/index.rst
index 6c91b6750..568a8016f 100644
--- a/docs/source/recipes/librispeech/index.rst
+++ b/docs/source/recipes/librispeech/index.rst
@@ -7,3 +7,4 @@ LibriSpeech
    tdnn_lstm_ctc
    conformer_ctc
    lstm_pruned_stateless_transducer
+   zipformer_mmi
diff --git a/docs/source/recipes/librispeech/zipformer_mmi.rst b/docs/source/recipes/librispeech/zipformer_mmi.rst
new file mode 100644
index 000000000..db268dd02
--- /dev/null
+++ b/docs/source/recipes/librispeech/zipformer_mmi.rst
@@ -0,0 +1,422 @@
+Zipformer MMI
+===============
+
+.. hint::
+
+   Please scroll down to the bottom of this page to find download links
+   for pretrained models if you don't want to train a model from scratch.
+
+
+This tutorial shows you how to train an Zipformer MMI model
+with the `LibriSpeech `_ dataset.
+
+We use LF-MMI to compute the loss.
+
+.. note::
+
+   You can find the document about LF-MMI training at the following address:
+
+   ``_
+
+
+Data preparation
+----------------
+
+.. code-block:: bash
+
+  $ cd egs/librispeech/ASR
+  $ ./prepare.sh
+
+The script ``./prepare.sh`` handles the data preparation for you, **automagically**.
+All you need to do is to run it.
+
+.. note::
+
+   We encourage you to read ``./prepare.sh``.
+
+The data preparation contains several stages. You can use the following two
+options:
+
+  - ``--stage``
+  - ``--stop-stage``
+
+to control which stage(s) should be run. By default, all stages are executed.
+
+
+For example,
+
+.. code-block:: bash
+
+  $ cd egs/librispeech/ASR
+  $ ./prepare.sh --stage 0 --stop-stage 0
+
+means to run only stage 0.
+
+To run stage 2 to stage 5, use:
+
+.. code-block:: bash
+
+  $ ./prepare.sh --stage 2 --stop-stage 5
+
+.. hint::
+
+  If you have pre-downloaded the `LibriSpeech `_
+  dataset and the `musan `_ dataset, say,
+  they are saved in ``/tmp/LibriSpeech`` and ``/tmp/musan``, you can modify
+  the ``dl_dir`` variable in ``./prepare.sh`` to point to ``/tmp`` so that
+  ``./prepare.sh`` won't re-download them.
+
+.. note::
+
+  All generated files by ``./prepare.sh``, e.g., features, lexicon, etc,
+  are saved in ``./data`` directory.
+
+We provide the following YouTube video showing how to run ``./prepare.sh``.
+
+.. note::
+
+   To get the latest news of `next-gen Kaldi `_, please subscribe
+   the following YouTube channel by `Nadira Povey `_:
+
+      ``_
+
+..  youtube:: ofEIoJL-mGM
+
+Training
+--------
+
+For stability, it uses CTC loss for model warm-up and then switches to MMI loss.
+
+Configurable options
+~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: bash
+
+  $ cd egs/librispeech/ASR
+  $ ./zipformer_mmi/train.py --help
+
+shows you the training options that can be passed from the commandline.
+The following options are used quite often:
+
+  - ``--full-libri``
+
+    If it's True, the training part uses all the training data, i.e.,
+    960 hours. Otherwise, the training part uses only the subset
+    ``train-clean-100``, which has 100 hours of training data.
+
+    .. CAUTION::
+
+      The training set is perturbed by speed with two factors: 0.9 and 1.1.
+      If ``--full-libri`` is True, each epoch actually processes
+      ``3x960 == 2880`` hours of data.
+
+  - ``--num-epochs``
+
+    It is the number of epochs to train. For instance,
+    ``./zipformer_mmi/train.py --num-epochs 30`` trains for 30 epochs
+    and generates ``epoch-1.pt``, ``epoch-2.pt``, ..., ``epoch-30.pt``
+    in the folder ``./zipformer_mmi/exp``.
+
+  - ``--start-epoch``
+
+    It's used to resume training.
+    ``./zipformer_mmi/train.py --start-epoch 10`` loads the
+    checkpoint ``./zipformer_mmi/exp/epoch-9.pt`` and starts
+    training from epoch 10, based on the state from epoch 9.
+
+  - ``--world-size``
+
+    It is used for multi-GPU single-machine DDP training.
+
+      - (a) If it is 1, then no DDP training is used.
+
+      - (b) If it is 2, then GPU 0 and GPU 1 are used for DDP training.
+
+    The following shows some use cases with it.
+
+      **Use case 1**: You have 4 GPUs, but you only want to use GPU 0 and
+      GPU 2 for training. You can do the following:
+
+        .. code-block:: bash
+
+          $ cd egs/librispeech/ASR
+          $ export CUDA_VISIBLE_DEVICES="0,2"
+          $ ./zipformer_mmi/train.py --world-size 2
+
+      **Use case 2**: You have 4 GPUs and you want to use all of them
+      for training. You can do the following:
+
+        .. code-block:: bash
+
+          $ cd egs/librispeech/ASR
+          $ ./zipformer_mmi/train.py --world-size 4
+
+      **Use case 3**: You have 4 GPUs but you only want to use GPU 3
+      for training. You can do the following:
+
+        .. code-block:: bash
+
+          $ cd egs/librispeech/ASR
+          $ export CUDA_VISIBLE_DEVICES="3"
+          $ ./zipformer_mmi/train.py --world-size 1
+
+    .. caution::
+
+      Only multi-GPU single-machine DDP training is implemented at present.
+      Multi-GPU multi-machine DDP training will be added later.
+
+  - ``--max-duration``
+
+    It specifies the number of seconds over all utterances in a
+    batch, before **padding**.
+    If you encounter CUDA OOM, please reduce it.
+
+    .. HINT::
+
+      Due to padding, the number of seconds of all utterances in a
+      batch will usually be larger than ``--max-duration``.
+
+      A larger value for ``--max-duration`` may cause OOM during training,
+      while a smaller value may increase the training time. You have to
+      tune it.
+
+
+Pre-configured options
+~~~~~~~~~~~~~~~~~~~~~~
+
+There are some training options, e.g., weight decay,
+number of warmup steps, results dir, etc,
+that are not passed from the commandline.
+They are pre-configured by the function ``get_params()`` in
+`zipformer_mmi/train.py `_
+
+You don't need to change these pre-configured parameters. If you really need to change
+them, please modify ``./zipformer_mmi/train.py`` directly.
+
+Training logs
+~~~~~~~~~~~~~
+
+Training logs and checkpoints are saved in ``zipformer_mmi/exp``.
+You will find the following files in that directory:
+
+  - ``epoch-1.pt``, ``epoch-2.pt``, ...
+
+    These are checkpoint files saved at the end of each epoch, containing model
+    ``state_dict`` and optimizer ``state_dict``.
+    To resume training from some checkpoint, say ``epoch-10.pt``, you can use:
+
+      .. code-block:: bash
+
+        $ ./zipformer_mmi/train.py --start-epoch 11
+
+  - ``checkpoint-436000.pt``, ``checkpoint-438000.pt``, ...
+
+    These are checkpoint files saved every ``--save-every-n`` batches,
+    containing model ``state_dict`` and optimizer ``state_dict``.
+    To resume training from some checkpoint, say ``checkpoint-436000``, you can use:
+
+      .. code-block:: bash
+
+        $ ./zipformer_mmi/train.py --start-batch 436000
+
+  - ``tensorboard/``
+
+    This folder contains tensorBoard logs. Training loss, validation loss, learning
+    rate, etc, are recorded in these logs. You can visualize them by:
+
+      .. code-block:: bash
+
+        $ cd zipformer_mmi/exp/tensorboard
+        $ tensorboard dev upload --logdir . --description "Zipformer MMI training for LibriSpeech with icefall"
+
+    It will print something like below:
+
+      .. code-block::
+
+        TensorFlow installation not found - running with reduced feature set.
+        Upload started and will continue reading any new data as it's added to the logdir.
+
+        To stop uploading, press Ctrl-C.
+
+        New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/xyOZUKpEQm62HBIlUD4uPA/
+
+    Note there is a URL in the above output. Click it and you will see
+    tensorboard.
+
+  .. hint::
+
+    If you don't have access to google, you can use the following command
+    to view the tensorboard log locally:
+
+      .. code-block:: bash
+
+        cd zipformer_mmi/exp/tensorboard
+        tensorboard --logdir . --port 6008
+
+    It will print the following message:
+
+      .. code-block::
+
+        Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
+        TensorBoard 2.8.0 at http://localhost:6008/ (Press CTRL+C to quit)
+
+    Now start your browser and go to ``_ to view the tensorboard
+    logs.
+
+
+  - ``log/log-train-xxxx``
+
+    It is the detailed training log in text format, same as the one
+    you saw printed to the console during training.
+
+Usage example
+~~~~~~~~~~~~~
+
+You can use the following command to start the training using 8 GPUs:
+
+.. code-block:: bash
+
+  export CUDA_VISIBLE_DEVICES="0,1,2,3"
+  ./zipformer_mmi/train.py \
+    --world-size 4 \
+    --num-epochs 30 \
+    --start-epoch 1 \
+    --full-libri 1 \
+    --exp-dir zipformer_mmi/exp \
+    --max-duration 500 \
+    --use-fp16 1 \
+    --num-workers 2
+
+Decoding
+--------
+
+The decoding part uses checkpoints saved by the training part, so you have
+to run the training part first.
+
+.. hint::
+
+   There are two kinds of checkpoints:
+
+    - (1) ``epoch-1.pt``, ``epoch-2.pt``, ..., which are saved at the end
+      of each epoch. You can pass ``--epoch`` to
+      ``zipformer_mmi/decode.py`` to use them.
+
+    - (2) ``checkpoints-436000.pt``, ``epoch-438000.pt``, ..., which are saved
+      every ``--save-every-n`` batches. You can pass ``--iter`` to
+      ``zipformer_mmi/decode.py`` to use them.
+
+    We suggest that you try both types of checkpoints and choose the one
+    that produces the lowest WERs.
+
+.. code-block:: bash
+
+  $ cd egs/librispeech/ASR
+  $ ./zipformer_mmi/decode.py --help
+
+shows the options for decoding.
+
+The following shows the example using ``epoch-*.pt``:
+
+.. code-block:: bash
+
+  for m in nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescoring-4-gram; do
+    ./zipformer_mmi/decode.py \
+      --epoch 30 \
+      --avg 10 \
+      --exp-dir ./zipformer_mmi/exp/ \
+      --max-duration 100 \
+      --lang-dir data/lang_bpe_500 \
+      --nbest-scale 1.2 \
+      --hp-scale 1.0 \
+      --decoding-method $m
+  done
+
+
+Export models
+-------------
+
+`zipformer_mmi/export.py `_ supports exporting checkpoints from ``zipformer_mmi/exp`` in the following ways.
+
+Export ``model.state_dict()``
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Checkpoints saved by ``zipformer_mmi/train.py`` also include
+``optimizer.state_dict()``. It is useful for resuming training. But after training,
+we are interested only in ``model.state_dict()``. You can use the following
+command to extract ``model.state_dict()``.
+
+.. code-block:: bash
+
+  ./zipformer_mmi/export.py \
+    --exp-dir ./zipformer_mmi/exp \
+    --bpe-model data/lang_bpe_500/bpe.model \
+    --epoch 30 \
+    --avg 9 \
+    --jit 0
+
+It will generate a file ``./zipformer_mmi/exp/pretrained.pt``.
+
+.. hint::
+
+   To use the generated ``pretrained.pt`` for ``zipformer_mmi/decode.py``,
+   you can run:
+
+   .. code-block:: bash
+
+      cd zipformer_mmi/exp
+      ln -s pretrained epoch-9999.pt
+
+   And then pass ``--epoch 9999 --avg 1 --use-averaged-model 0`` to
+   ``./zipformer_mmi/decode.py``.
+
+To use the exported model with ``./zipformer_mmi/pretrained.py``, you
+can run:
+
+.. code-block:: bash
+
+  ./zipformer_mmi/pretrained.py \
+    --checkpoint ./zipformer_mmi/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method 1best \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+Export model using ``torch.jit.script()``
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: bash
+
+  ./zipformer_mmi/export.py \
+    --exp-dir ./zipformer_mmi/exp \
+    --bpe-model data/lang_bpe_500/bpe.model \
+    --epoch 30 \
+    --avg 9 \
+    --jit 1
+
+It will generate a file ``cpu_jit.pt`` in the given ``exp_dir``. You can later
+load it by ``torch.jit.load("cpu_jit.pt")``.
+
+Note ``cpu`` in the name ``cpu_jit.pt`` means the parameters when loaded into Python
+are on CPU. You can use ``to("cuda")`` to move them to a CUDA device.
+
+To use the generated files with ``./zipformer_mmi/jit_pretrained.py``:
+
+.. code-block:: bash
+
+  ./zipformer_mmi/jit_pretrained.py \
+    --nn-model-filename ./zipformer_mmi/exp/cpu_jit.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method 1best \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+Download pretrained models
+--------------------------
+
+If you don't want to train from scratch, you can download the pretrained models
+by visiting the following links:
+
+  - ``_
+
+  See ``_
+  for the details of the above pretrained models
diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md
index 9e5669f6d..092f77814 100644
--- a/egs/librispeech/ASR/RESULTS.md
+++ b/egs/librispeech/ASR/RESULTS.md
@@ -1,5 +1,62 @@
 ## Results
 
+### zipformer_mmi (zipformer with mmi loss)
+
+See  for more details.
+
+[zipformer_mmi](./zipformer_mmi)
+
+The tensorboard log can be found at
+
+
+You can find a pretrained model, training logs, decoding logs, and decoding
+results at:
+
+
+Number of model parameters: 69136519, i.e., 69.14 M
+
+|                          | test-clean | test-other  | comment             |
+|--------------------------|------------|-------------|---------------------|
+| 1best                    | 2.54       | 5.65        | --epoch 30 --avg 10 |
+| nbest                    | 2.54       | 5.66        | --epoch 30 --avg 10 |
+| nbest-rescoring-LG       | 2.49       | 5.42        | --epoch 30 --avg 10 |
+| nbest-rescoring-3-gram   | 2.52       | 5.62        | --epoch 30 --avg 10 |
+| nbest-rescoring-4-gram   | 2.5        | 5.51        | --epoch 30 --avg 10 |
+
+The training commands are:
+```bash
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./zipformer_mmi/train.py \
+  --world-size 4 \
+  --master-port 12345 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --lang-dir data/lang_bpe_500 \
+  --max-duration 500 \
+  --full-libri 1 \
+  --use-fp16 1 \
+  --exp-dir zipformer_mmi/exp
+```
+
+The decoding commands for the transducer branch are:
+```bash
+export CUDA_VISIBLE_DEVICES="5"
+
+for m in nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescoring-4-gram; do
+  ./zipformer_mmi/decode.py \
+    --epoch 30 \
+    --avg 10 \
+    --exp-dir ./zipformer_mmi/exp/ \
+    --max-duration 100 \
+    --lang-dir data/lang_bpe_500 \
+    --nbest-scale 1.2 \
+    --hp-scale 1.0 \
+    --decoding-method $m
+done
+```
+
+
 ### pruned_transducer_stateless7_ctc (zipformer with transducer loss and ctc loss)
 
 See  for more details.
diff --git a/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py b/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py
index 5be898e37..76db46cc8 100755
--- a/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py
+++ b/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py
@@ -291,7 +291,10 @@ def main():
 
     batch_size = nnet_output.shape[0]
     supervision_segments = torch.tensor(
-        [[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
+        [
+            [i, 0, feature_lengths[i] // params.subsampling_factor]
+            for i in range(batch_size)
+        ],
         dtype=torch.int32,
     )
 
diff --git a/egs/librispeech/ASR/conformer_ctc3/pretrained.py b/egs/librispeech/ASR/conformer_ctc3/pretrained.py
index 3628d6a5f..880945ea0 100755
--- a/egs/librispeech/ASR/conformer_ctc3/pretrained.py
+++ b/egs/librispeech/ASR/conformer_ctc3/pretrained.py
@@ -339,7 +339,10 @@ def main():
 
     batch_size = nnet_output.shape[0]
     supervision_segments = torch.tensor(
-        [[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
+        [
+            [i, 0, feature_lengths[i] // params.subsampling_factor]
+            for i in range(batch_size)
+        ],
         dtype=torch.int32,
     )
 
diff --git a/egs/librispeech/ASR/conformer_mmi/decode.py b/egs/librispeech/ASR/conformer_mmi/decode.py
index e3c7b685f..74f6e73fa 100755
--- a/egs/librispeech/ASR/conformer_mmi/decode.py
+++ b/egs/librispeech/ASR/conformer_mmi/decode.py
@@ -660,14 +660,22 @@ def main():
     # we need cut ids to display recognition results.
     args.return_cuts = True
     librispeech = LibriSpeechAsrDataModule(args)
+
+    test_clean_cuts = librispeech.test_clean_cuts()
+    test_other_cuts = librispeech.test_other_cuts()
+
+    test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
+    test_other_dl = librispeech.test_dataloaders(test_other_cuts)
+
     # CAUTION: `test_sets` is for displaying only.
     # If you want to skip test-clean, you have to skip
     # it inside the for loop. That is, use
     #
     #   if test_set == 'test-clean': continue
-    #
     test_sets = ["test-clean", "test-other"]
-    for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
+    test_dls = [test_clean_dl, test_other_dl]
+
+    for test_set, test_dl in zip(test_sets, test_dls):
         results_dict = decode_dataset(
             dl=test_dl,
             params=params,
diff --git a/egs/librispeech/ASR/conformer_mmi/train-with-attention.py b/egs/librispeech/ASR/conformer_mmi/train-with-attention.py
index f8c94cff9..100bc846a 100755
--- a/egs/librispeech/ASR/conformer_mmi/train-with-attention.py
+++ b/egs/librispeech/ASR/conformer_mmi/train-with-attention.py
@@ -30,6 +30,8 @@ import torch.multiprocessing as mp
 import torch.nn as nn
 from asr_datamodule import LibriSpeechAsrDataModule
 from conformer import Conformer
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
 from lhotse.utils import fix_random_seed
 from torch.nn.parallel import DistributedDataParallel as DDP
 from torch.nn.utils import clip_grad_norm_
@@ -100,6 +102,41 @@ def get_parser():
         """,
     )
 
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="conformer_mmi/exp-attn",
+        help="""The experiment dir.
+        It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        default="data/lang_bpe_500",
+        help="""The lang dir
+        It contains language related input files such as
+        "lexicon.txt"
+        """,
+    )
+
+    parser.add_argument(
+        "--seed",
+        type=int,
+        default=42,
+        help="The seed for random generators intended for reproducibility",
+    )
+
+    parser.add_argument(
+        "--use-pruned-intersect",
+        type=str2bool,
+        default=False,
+        help="""Whether to use `intersect_dense_pruned` to get denominator
+        lattice.""",
+    )
+
     return parser
 
 
@@ -114,12 +151,6 @@ def get_params() -> AttributeDict:
 
     Explanation of options saved in `params`:
 
-        - exp_dir: It specifies the directory where all training related
-                   files, e.g., checkpoints, log, etc, are saved
-
-        - lang_dir: It contains language related input files such as
-                    "lexicon.txt"
-
         - best_train_loss: Best training loss so far. It is used to select
                            the model that has the lowest training loss. It is
                            updated during the training.
@@ -164,8 +195,6 @@ def get_params() -> AttributeDict:
     """
     params = AttributeDict(
         {
-            "exp_dir": Path("conformer_mmi/exp_500_with_attention"),
-            "lang_dir": Path("data/lang_bpe_500"),
             "best_train_loss": float("inf"),
             "best_valid_loss": float("inf"),
             "best_train_epoch": -1,
@@ -184,15 +213,12 @@ def get_params() -> AttributeDict:
             "beam_size": 6,  # will change it to 8 after some batches (see code)
             "reduction": "sum",
             "use_double_scores": True,
-            #  "att_rate": 0.0,
-            #  "num_decoder_layers": 0,
             "att_rate": 0.7,
             "num_decoder_layers": 6,
             # parameters for Noam
             "weight_decay": 1e-6,
             "lr_factor": 5.0,
             "warm_step": 80000,
-            "use_pruned_intersect": False,
             "den_scale": 1.0,
             # use alignments before this number of batches
             "use_ali_until": 13000,
@@ -661,7 +687,7 @@ def run(rank, world_size, args):
     params = get_params()
     params.update(vars(args))
 
-    fix_random_seed(42)
+    fix_random_seed(params.seed)
     if world_size > 1:
         setup_dist(rank, world_size, params.master_port)
 
@@ -745,8 +771,29 @@ def run(rank, world_size, args):
         valid_ali = None
 
     librispeech = LibriSpeechAsrDataModule(args)
-    train_dl = librispeech.train_dataloaders()
-    valid_dl = librispeech.valid_dataloaders()
+    train_cuts = librispeech.train_clean_100_cuts()
+    if params.full_libri:
+        train_cuts += librispeech.train_clean_360_cuts()
+        train_cuts += librispeech.train_other_500_cuts()
+
+    def remove_short_and_long_utt(c: Cut):
+        # Keep only utterances with duration between 1 second and 20 seconds
+        #
+        # Caution: There is a reason to select 20.0 here. Please see
+        # ../local/display_manifest_statistics.py
+        #
+        # You should use ../local/display_manifest_statistics.py to get
+        # an utterance duration distribution for your dataset to select
+        # the threshold
+        return 1.0 <= c.duration <= 20.0
+
+    train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+    train_dl = librispeech.train_dataloaders(train_cuts)
+
+    valid_cuts = librispeech.dev_clean_cuts()
+    valid_cuts += librispeech.dev_other_cuts()
+    valid_dl = librispeech.valid_dataloaders(valid_cuts)
 
     for epoch in range(params.start_epoch, params.num_epochs):
         train_dl.sampler.set_epoch(epoch)
@@ -796,6 +843,7 @@ def main():
     parser = get_parser()
     LibriSpeechAsrDataModule.add_arguments(parser)
     args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
 
     world_size = args.world_size
     assert world_size >= 1
diff --git a/egs/librispeech/ASR/conformer_mmi/train.py b/egs/librispeech/ASR/conformer_mmi/train.py
index 5cfb2bfc7..f9f80632e 100755
--- a/egs/librispeech/ASR/conformer_mmi/train.py
+++ b/egs/librispeech/ASR/conformer_mmi/train.py
@@ -30,6 +30,8 @@ import torch.multiprocessing as mp
 import torch.nn as nn
 from asr_datamodule import LibriSpeechAsrDataModule
 from conformer import Conformer
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
 from lhotse.utils import fix_random_seed
 from torch.nn.parallel import DistributedDataParallel as DDP
 from torch.nn.utils import clip_grad_norm_
@@ -100,6 +102,26 @@ def get_parser():
         """,
     )
 
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="conformer_mmi/exp",
+        help="""The experiment dir.
+        It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        default="data/lang_bpe_500",
+        help="""The lang dir
+        It contains language related input files such as
+        "lexicon.txt"
+        """,
+    )
+
     parser.add_argument(
         "--seed",
         type=int,
@@ -107,6 +129,14 @@ def get_parser():
         help="The seed for random generators intended for reproducibility",
     )
 
+    parser.add_argument(
+        "--use-pruned-intersect",
+        type=str2bool,
+        default=False,
+        help="""Whether to use `intersect_dense_pruned` to get denominator
+        lattice.""",
+    )
+
     return parser
 
 
@@ -121,12 +151,6 @@ def get_params() -> AttributeDict:
 
     Explanation of options saved in `params`:
 
-        - exp_dir: It specifies the directory where all training related
-                   files, e.g., checkpoints, log, etc, are saved
-
-        - lang_dir: It contains language related input files such as
-                    "lexicon.txt"
-
         - best_train_loss: Best training loss so far. It is used to select
                            the model that has the lowest training loss. It is
                            updated during the training.
@@ -171,8 +195,6 @@ def get_params() -> AttributeDict:
     """
     params = AttributeDict(
         {
-            "exp_dir": Path("conformer_mmi/exp_500"),
-            "lang_dir": Path("data/lang_bpe_500"),
             "best_train_loss": float("inf"),
             "best_valid_loss": float("inf"),
             "best_train_epoch": -1,
@@ -193,13 +215,10 @@ def get_params() -> AttributeDict:
             "use_double_scores": True,
             "att_rate": 0.0,
             "num_decoder_layers": 0,
-            #  "att_rate": 0.7,
-            #  "num_decoder_layers": 6,
             # parameters for Noam
             "weight_decay": 1e-6,
             "lr_factor": 5.0,
             "warm_step": 80000,
-            "use_pruned_intersect": False,
             "den_scale": 1.0,
             # use alignments before this number of batches
             "use_ali_until": 13000,
@@ -752,8 +771,29 @@ def run(rank, world_size, args):
         valid_ali = None
 
     librispeech = LibriSpeechAsrDataModule(args)
-    train_dl = librispeech.train_dataloaders()
-    valid_dl = librispeech.valid_dataloaders()
+    train_cuts = librispeech.train_clean_100_cuts()
+    if params.full_libri:
+        train_cuts += librispeech.train_clean_360_cuts()
+        train_cuts += librispeech.train_other_500_cuts()
+
+    def remove_short_and_long_utt(c: Cut):
+        # Keep only utterances with duration between 1 second and 20 seconds
+        #
+        # Caution: There is a reason to select 20.0 here. Please see
+        # ../local/display_manifest_statistics.py
+        #
+        # You should use ../local/display_manifest_statistics.py to get
+        # an utterance duration distribution for your dataset to select
+        # the threshold
+        return 1.0 <= c.duration <= 20.0
+
+    train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+    train_dl = librispeech.train_dataloaders(train_cuts)
+
+    valid_cuts = librispeech.dev_clean_cuts()
+    valid_cuts += librispeech.dev_other_cuts()
+    valid_dl = librispeech.valid_dataloaders(valid_cuts)
 
     for epoch in range(params.start_epoch, params.num_epochs):
         fix_random_seed(params.seed + epoch)
@@ -804,6 +844,7 @@ def main():
     parser = get_parser()
     LibriSpeechAsrDataModule.add_arguments(parser)
     args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
 
     world_size = args.world_size
     assert world_size >= 1
diff --git a/egs/librispeech/ASR/generate-lm.sh b/egs/librispeech/ASR/generate-lm.sh
index 6baccd381..dacd276d1 100755
--- a/egs/librispeech/ASR/generate-lm.sh
+++ b/egs/librispeech/ASR/generate-lm.sh
@@ -2,7 +2,7 @@
 
 lang_dir=data/lang_bpe_500
 
-for ngram in 2 3 5; do
+for ngram in 2 3 4 5; do
   if [ ! -f $lang_dir/${ngram}gram.arpa ]; then
     ./shared/make_kn_lm.py \
       -ngram-order ${ngram} \
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py
index 59a393739..c1607699f 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py
@@ -72,14 +72,14 @@ Check ./pretrained.py for its usage.
 Note: If you don't want to train a model from scratch, we have
 provided one for you. You can get it at
 
-https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
+https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-ctc-2022-12-01
 
 with the following commands:
 
     sudo apt-get install git-lfs
     git lfs install
-    git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
-    # You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp
+    git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-ctc-2022-12-01
+    # You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-ctc-2022-12-01/exp
 """
 
 import argparse
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py
index d3343d34a..ad9cf08dc 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py
@@ -304,7 +304,10 @@ def main():
 
     batch_size = nnet_output.shape[0]
     supervision_segments = torch.tensor(
-        [[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
+        [
+            [i, 0, feature_lengths[i] // params.subsampling_factor]
+            for i in range(batch_size)
+        ],
         dtype=torch.int32,
     )
 
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py
index 74aef1bc7..5d460edb5 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py
@@ -322,7 +322,10 @@ def main():
 
     batch_size = nnet_output.shape[0]
     supervision_segments = torch.tensor(
-        [[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
+        [
+            [i, 0, feature_lengths[i] // params.subsampling_factor]
+            for i in range(batch_size)
+        ],
         dtype=torch.int32,
     )
 
diff --git a/egs/librispeech/ASR/zipformer_mmi/README.md b/egs/librispeech/ASR/zipformer_mmi/README.md
new file mode 100644
index 000000000..8ca844180
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/README.md
@@ -0,0 +1,26 @@
+This recipe implements Zipformer-MMI model.
+
+See https://k2-fsa.github.io/icefall/recipes/librispeech/zipformer_mmi.html for detailed tutorials.
+
+It uses **CTC loss for warm-up** and then switches to MMI loss during training.
+
+For decoding, it uses HP (H is ctc_topo, P is token-level bi-gram) as decoding graph. Supported decoding methods are:
+- **1best**. Extract the best path from the decoding lattice as the decoding result.
+- **nbest**. Extract n paths from the decoding lattice; the path with the highest score is the decoding result.
+- **nbest-rescoring-LG**. Extract n paths from the decoding lattice, rescore them with an word-level 3-gram LM, the path with the highest score is the decoding result.
+- **nbest-rescoring-3-gram**. Extract n paths from the decoding lattice, rescore them with an token-level 3-gram LM, the path with the highest score is the decoding result.
+- **nbest-rescoring-4-gram**. Extract n paths from the decoding lattice, rescore them with an token-level 4-gram LM, the path with the highest score is the decoding result.
+
+Experimental results training on train-clean-100 (epoch-30-avg-10):
+- 1best. 6.43 & 17.44
+- nbest, nbest-scale=1.2, 6.43 & 17.45
+- nbest-rescoring-LG, nbest-scale=1.2, 5.87 & 16.35
+- nbest-rescoring-3-gram,  nbest-scale=1.2, 6.19 & 16.57
+- nbest-rescoring-4-gram,  nbest-scale=1.2, 5.87 & 16.07
+
+Experimental results training on full librispeech (epoch-30-avg-10):
+- 1best. 2.54 & 5.65
+- nbest, nbest-scale=1.2, 2.54 & 5.66
+- nbest-rescoring-LG, nbest-scale=1.2, 2.49 & 5.42
+- nbest-rescoring-3-gram,  nbest-scale=1.2, 2.52 & 5.62
+- nbest-rescoring-4-gram,  nbest-scale=1.2, 2.5 & 5.51
diff --git a/egs/librispeech/ASR/zipformer_mmi/__init__.py b/egs/librispeech/ASR/zipformer_mmi/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/librispeech/ASR/zipformer_mmi/asr_datamodule.py b/egs/librispeech/ASR/zipformer_mmi/asr_datamodule.py
new file mode 120000
index 000000000..a074d6085
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/asr_datamodule.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/asr_datamodule.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_mmi/decode.py b/egs/librispeech/ASR/zipformer_mmi/decode.py
new file mode 100755
index 000000000..7d0ea78bb
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/decode.py
@@ -0,0 +1,736 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
+#                                                 Liyong Guo,
+#                                                 Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) 1best
+./zipformer_mmi/mmi_decode.py \
+    --epoch 30 \
+    --avg 15 \
+    --exp-dir ./zipformer_mmi/exp \
+    --max-duration 100 \
+    --decoding-method 1best
+(2) nbest
+./zipformer_mmi/mmi_decode.py \
+    --epoch 30 \
+    --avg 15 \
+    --exp-dir ./zipformer_mmi/exp \
+    --max-duration 100 \
+    --nbest-scale 1.0 \
+    --decoding-method nbest
+(3) nbest-rescoring-LG
+./zipformer_mmi/mmi_decode.py \
+    --epoch 30 \
+    --avg 15 \
+    --exp-dir ./zipformer_mmi/exp \
+    --max-duration 100 \
+    --nbest-scale 1.0 \
+    --decoding-method nbest-rescoring-LG
+(4) nbest-rescoring-3-gram
+./zipformer_mmi/mmi_decode.py \
+    --epoch 30 \
+    --avg 15 \
+    --exp-dir ./zipformer_mmi/exp \
+    --max-duration 100 \
+    --nbest-scale 1.0 \
+    --decoding-method nbest-rescoring-3-gram
+(5) nbest-rescoring-4-gram
+./zipformer_mmi/mmi_decode.py \
+    --epoch 30 \
+    --avg 15 \
+    --exp-dir ./zipformer_mmi/exp \
+    --max-duration 100 \
+    --nbest-scale 1.0 \
+    --decoding-method nbest-rescoring-4-gram
+"""
+
+
+import argparse
+import logging
+import math
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from train import add_model_arguments, get_ctc_model, get_params
+
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.decode import (
+    get_lattice,
+    nbest_decoding,
+    nbest_rescore_with_LM,
+    one_best_decoding,
+)
+from icefall.lexicon import Lexicon
+from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
+from icefall.utils import (
+    AttributeDict,
+    get_texts,
+    setup_logger,
+    store_transcripts,
+    str2bool,
+    write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=30,
+        help="""It specifies the checkpoint to use for decoding.
+        Note: Epoch counts from 1.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=15,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=True,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="zipformer_mmi/exp",
+        help="The experiment dir",
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_500/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=Path,
+        default="data/lang_bpe_500",
+        help="The lang dir containing word table and LG graph",
+    )
+
+    parser.add_argument(
+        "--decoding-method",
+        type=str,
+        default="1best",
+        help="""Decoding method. Use HP as decoding graph, where H is
+        ctc_topo and P is token-level bi-gram lm.
+        Supported values are:
+        - (1) 1best. Extract the best path from the decoding lattice as the
+          decoding result.
+        - (2) nbest. Extract n paths from the decoding lattice; the path
+          with the highest score is the decoding result.
+        - (4) nbest-rescoring-LG. Extract n paths from the decoding lattice,
+          rescore them with an word-level 3-gram LM, the path with the
+          highest score is the decoding result.
+        - (5) nbest-rescoring-3-gram. Extract n paths from the decoding
+          lattice, rescore them with an token-level 3-gram LM, the path with
+          the highest score is the decoding result.
+        - (6) nbest-rescoring-4-gram. Extract n paths from the decoding
+          lattice, rescore them with an token-level 4-gram LM, the path with
+          the highest score is the decoding result.
+        """,
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=100,
+        help="""Number of paths for n-best based decoding method.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, and nbest-oracle
+        """,
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=1.0,
+        help="""The scale to be applied to `lattice.scores`.
+        It's needed if you use any kinds of n-best based rescoring.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, and nbest-oracle
+        A smaller value results in more unique paths.
+        """,
+    )
+
+    parser.add_argument(
+        "--hp-scale",
+        type=float,
+        default=1.0,
+        help="""The scale to be applied to `ctc_topo_P.scores`.
+        """,
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def get_decoding_params() -> AttributeDict:
+    """Parameters for decoding."""
+    params = AttributeDict(
+        {
+            "frame_shift_ms": 10,
+            "search_beam": 20,
+            "output_beam": 8,
+            "min_active_states": 30,
+            "max_active_states": 10000,
+            "use_double_scores": True,
+        }
+    )
+    return params
+
+
+def decode_one_batch(
+    params: AttributeDict,
+    model: nn.Module,
+    HP: Optional[k2.Fsa],
+    bpe_model: Optional[spm.SentencePieceProcessor],
+    batch: dict,
+    G: Optional[k2.Fsa] = None,
+    LG: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+    """Decode one batch and return the result in a dict. The dict has the
+    following format:
+    - key: It indicates the setting used for decoding. For example,
+           if no rescoring is used, the key is the string `no_rescore`.
+           If LM rescoring is used, the key is the string `lm_scale_xxx`,
+           where `xxx` is the value of `lm_scale`. An example key is
+           `lm_scale_0.7`
+    - value: It contains the decoding result. `len(value)` equals to
+             batch size. `value[i]` is the decoding result for the i-th
+             utterance in the given batch.
+
+    Args:
+      params:
+        It's the return value of :func:`get_params`.
+
+        - params.decoding_method is "1best", it uses 1best decoding without LM rescoring.
+        - params.decoding_method is "nbest", it uses nbest decoding without LM rescoring.
+        - params.decoding_method is "nbest-rescoring-LG", it uses nbest rescoring with word-level 3-gram LM.
+        - params.decoding_method is "nbest-rescoring-3-gram", it uses nbest rescoring with token-level 3-gram LM.
+        - params.decoding_method is "nbest-rescoring-4-gram", it uses nbest rescoring with token-level 4-gram LM.
+
+      model:
+        The neural model.
+      HP:
+        The decoding graph. H is ctc_topo, P is token-level bi-gram LM.
+      bpe_model:
+        The BPE model.
+      batch:
+        It is the return value from iterating
+        `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+        for the format of the `batch`.
+      LG:
+        An LM. L is the lexicon, G is a word-level 3-gram LM.
+        It is used when params.decoding_method is "nbest-rescoring-LG".
+      G:
+        An LM. L is the lexicon, G is a token-level 3-gram or 4-gram LM.
+        It is used when params.decoding_method is "nbest-rescoring-3-gram"
+        or "nbest-rescoring-4-gram".
+    Returns:
+      Return the decoding result. See above description for the format of
+      the returned dict. Note: If it decodes to nothing, then return None.
+    """
+    device = HP.device
+    feature = batch["inputs"]
+    assert feature.ndim == 3, feature.shape
+    feature = feature.to(device)
+
+    # at entry, feature is (N, T, C)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    nnet_output, encoder_out_lens = model(x=feature, x_lens=feature_lens)
+    # nnet_output is (N, T, C)
+
+    supervision_segments = torch.stack(
+        (
+            supervisions["sequence_idx"],
+            supervisions["start_frame"] // params.subsampling_factor,
+            supervisions["num_frames"] // params.subsampling_factor,
+        ),
+        1,
+    ).to(torch.int32)
+
+    lattice = get_lattice(
+        nnet_output=nnet_output,
+        decoding_graph=HP,
+        supervision_segments=supervision_segments,
+        search_beam=params.search_beam,
+        output_beam=params.output_beam,
+        min_active_states=params.min_active_states,
+        max_active_states=params.max_active_states,
+        subsampling_factor=params.subsampling_factor,
+    )
+
+    method = params.decoding_method
+
+    if method in ["1best", "nbest"]:
+        if method == "1best":
+            best_path = one_best_decoding(
+                lattice=lattice, use_double_scores=params.use_double_scores
+            )
+            key = "no_rescore"
+        else:
+            best_path = nbest_decoding(
+                lattice=lattice,
+                num_paths=params.num_paths,
+                use_double_scores=params.use_double_scores,
+                nbest_scale=params.nbest_scale,
+            )
+            key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}"  # noqa
+
+        # Note: `best_path.aux_labels` contains token IDs, not word IDs
+        # since we are using HP, not HLG here.
+        #
+        # token_ids is a lit-of-list of IDs
+        token_ids = get_texts(best_path)
+        # hyps is a list of str, e.g., ['xxx yyy zzz', ...]
+        hyps = bpe_model.decode(token_ids)
+        # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
+        hyps = [s.split() for s in hyps]
+        return {key: hyps}
+
+    assert method in [
+        "nbest-rescoring-LG",  # word-level 3-gram lm
+        "nbest-rescoring-3-gram",  # token-level 3-gram lm
+        "nbest-rescoring-4-gram",  # token-level 4-gram lm
+    ]
+
+    lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
+    lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
+    lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
+
+    if method == "nbest-rescoring-LG":
+        assert LG is not None
+        LM = LG
+    else:
+        assert G is not None
+        LM = G
+    best_path_dict = nbest_rescore_with_LM(
+        lattice=lattice,
+        LM=LM,
+        num_paths=params.num_paths,
+        lm_scale_list=lm_scale_list,
+        nbest_scale=params.nbest_scale,
+    )
+
+    ans = dict()
+    suffix = f"-nbest-scale-{params.nbest_scale}-{params.num_paths}"
+    for lm_scale_str, best_path in best_path_dict.items():
+        token_ids = get_texts(best_path)
+        # hyps is a list of str, e.g., ['xxx yyy zzz', ...]
+        hyps = bpe_model.decode(token_ids)
+        # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
+        hyps = [s.split() for s in hyps]
+        ans[lm_scale_str + suffix] = hyps
+    return ans
+
+
+def decode_dataset(
+    dl: torch.utils.data.DataLoader,
+    params: AttributeDict,
+    model: nn.Module,
+    HP: k2.Fsa,
+    bpe_model: spm.SentencePieceProcessor,
+    G: Optional[k2.Fsa] = None,
+    LG: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+    """Decode dataset.
+
+    Args:
+      dl:
+        PyTorch's dataloader containing the dataset to decode.
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The neural model.
+      HP:
+        The decoding graph. H is ctc_topo, P is token-level bi-gram LM.
+      bpe_model:
+        The BPE model.
+      LG:
+        An LM. L is the lexicon, G is a word-level 3-gram LM.
+        It is used when params.decoding_method is "nbest-rescoring-LG".
+      G:
+        An LM. L is the lexicon, G is a token-level 3-gram or 4-gram LM.
+        It is used when params.decoding_method is "nbest-rescoring-3-gram"
+        or "nbest-rescoring-4-gram".
+
+    Returns:
+      Return a dict, whose key may be "no-rescore" if no LM rescoring
+      is used, or it may be "lm_scale_0.7" if LM rescoring is used.
+      Its value is a list of tuples. Each tuple contains two elements:
+      The first is the reference transcript, and the second is the
+      predicted result.
+    """
+    num_cuts = 0
+
+    try:
+        num_batches = len(dl)
+    except TypeError:
+        num_batches = "?"
+
+    results = defaultdict(list)
+    for batch_idx, batch in enumerate(dl):
+        texts = batch["supervisions"]["text"]
+        cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+        hyps_dict = decode_one_batch(
+            params=params,
+            model=model,
+            HP=HP,
+            bpe_model=bpe_model,
+            batch=batch,
+            G=G,
+            LG=LG,
+        )
+
+        for name, hyps in hyps_dict.items():
+            this_batch = []
+            assert len(hyps) == len(texts)
+            for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+                ref_words = ref_text.split()
+                this_batch.append((cut_id, ref_words, hyp_words))
+
+            results[name].extend(this_batch)
+
+        num_cuts += len(texts)
+
+        if batch_idx % 100 == 0:
+            batch_str = f"{batch_idx}/{num_batches}"
+
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+    return results
+
+
+def save_results(
+    params: AttributeDict,
+    test_set_name: str,
+    results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+    test_set_wers = dict()
+    for key, results in results_dict.items():
+        recog_path = (
+            params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+        )
+        results = sorted(results)
+        store_transcripts(filename=recog_path, texts=results)
+        logging.info(f"The transcripts are stored in {recog_path}")
+
+        # The following prints out WERs, per-word error statistics and aligned
+        # ref/hyp pairs.
+        errs_filename = (
+            params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+        )
+        with open(errs_filename, "w") as f:
+            wer = write_error_stats(f, f"{test_set_name}-{key}", results)
+            test_set_wers[key] = wer
+
+        logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+    test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+    errs_info = (
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+    )
+    with open(errs_info, "w") as f:
+        print("settings\tWER", file=f)
+        for key, val in test_set_wers:
+            print("{}\t{}".format(key, val), file=f)
+
+    s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+    note = "\tbest for {}".format(test_set_name)
+    for key, val in test_set_wers:
+        s += "{}\t{}{}\n".format(key, val, note)
+        note = ""
+    logging.info(s)
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    LibriSpeechAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+    args.lang_dir = Path(args.lang_dir)
+
+    params = get_params()
+    # add decoding params
+    params.update(get_decoding_params())
+    params.update(vars(args))
+
+    assert params.decoding_method in (
+        "1best",
+        "nbest",
+        "nbest-rescoring-LG",  # word-level 3-gram lm
+        "nbest-rescoring-3-gram",  # token-level 3-gram lm
+        "nbest-rescoring-4-gram",  # token-level 4-gram lm
+    ), params.decoding_method
+    params.res_dir = params.exp_dir / params.decoding_method
+
+    if params.iter > 0:
+        params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+    else:
+        params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+    if params.use_averaged_model:
+        params.suffix += "-use-averaged-model"
+
+    setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+    logging.info("decoding started")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+    logging.info(params)
+
+    lexicon = Lexicon(params.lang_dir)
+    max_token_id = max(lexicon.tokens)
+    num_classes = max_token_id + 1  # +1 for the blank
+
+    params.vocab_size = num_classes
+    #  and  are defined in local/train_bpe_model.py
+    params.blank_id = 0
+
+    bpe_model = spm.SentencePieceProcessor()
+    bpe_model.load(str(params.lang_dir / "bpe.model"))
+    mmi_graph_compiler = MmiTrainingGraphCompiler(
+        params.lang_dir,
+        uniq_filename="lexicon.txt",
+        device=device,
+        oov="",
+        sos_id=1,
+        eos_id=1,
+    )
+    HP = mmi_graph_compiler.ctc_topo_P
+    HP.scores *= params.hp_scale
+    if not hasattr(HP, "lm_scores"):
+        HP.lm_scores = HP.scores.clone()
+
+    LG = None
+    G = None
+
+    if params.decoding_method == "nbest-rescoring-LG":
+        lg_filename = params.lang_dir / "LG.pt"
+        logging.info(f"Loading {lg_filename}")
+        LG = k2.Fsa.from_dict(torch.load(lg_filename, map_location=device))
+        LG = k2.Fsa.from_fsas([LG]).to(device)
+        LG.lm_scores = LG.scores.clone()
+
+    elif params.decoding_method in ["nbest-rescoring-3-gram", "nbest-rescoring-4-gram"]:
+        order = params.decoding_method[-6]
+        assert order in ("3", "4"), (params.decoding_method, order)
+        order = int(order)
+        if not (params.lang_dir / f"{order}gram.pt").is_file():
+            logging.info(f"Loading {order}gram.fst.txt")
+            logging.warning("It may take a few minutes.")
+            with open(params.lang_dir / f"{order}gram.fst.txt") as f:
+                first_token_disambig_id = lexicon.token_table["#0"]
+
+                G = k2.Fsa.from_openfst(f.read(), acceptor=False)
+                # G.aux_labels is not needed in later computations, so
+                # remove it here.
+                del G.aux_labels
+                # CAUTION: The following line is crucial.
+                # Arcs entering the back-off state have label equal to #0.
+                # We have to change it to 0 here.
+                G.labels[G.labels >= first_token_disambig_id] = 0
+                G = k2.Fsa.from_fsas([G]).to(device)
+                # G = k2.remove_epsilon(G)
+                G = k2.arc_sort(G)
+                # Save a dummy value so that it can be loaded in C++.
+                # See https://github.com/pytorch/pytorch/issues/67902
+                # for why we need to do this.
+                G.dummy = 1
+
+                torch.save(G.as_dict(), params.lang_dir / f"{order}gram.pt")
+        else:
+            logging.info(f"Loading pre-compiled {order}gram.pt")
+            d = torch.load(params.lang_dir / f"{order}gram.pt", map_location=device)
+            G = k2.Fsa.from_dict(d)
+
+        G.lm_scores = G.scores.clone()
+
+    logging.info("About to create model")
+    model = get_ctc_model(params)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to(device)
+    model.eval()
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    # we need cut ids to display recognition results.
+    args.return_cuts = True
+    librispeech = LibriSpeechAsrDataModule(args)
+
+    test_clean_cuts = librispeech.test_clean_cuts()
+    test_other_cuts = librispeech.test_other_cuts()
+
+    test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
+    test_other_dl = librispeech.test_dataloaders(test_other_cuts)
+
+    test_sets = ["test-clean", "test-other"]
+    test_dl = [test_clean_dl, test_other_dl]
+
+    for test_set, test_dl in zip(test_sets, test_dl):
+        results_dict = decode_dataset(
+            dl=test_dl,
+            params=params,
+            model=model,
+            HP=HP,
+            bpe_model=bpe_model,
+            G=G,
+            LG=LG,
+        )
+
+        save_results(
+            params=params,
+            test_set_name=test_set,
+            results_dict=results_dict,
+        )
+
+    logging.info("Done!")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/librispeech/ASR/zipformer_mmi/encoder_interface.py b/egs/librispeech/ASR/zipformer_mmi/encoder_interface.py
new file mode 120000
index 000000000..b9aa0ae08
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/encoder_interface.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/encoder_interface.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_mmi/export.py b/egs/librispeech/ASR/zipformer_mmi/export.py
new file mode 100755
index 000000000..0af7bd367
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/export.py
@@ -0,0 +1,307 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script converts several saved checkpoints
+# to a single one using model averaging.
+"""
+
+Usage:
+
+(1) Export to torchscript model using torch.jit.script()
+
+./zipformer_mmi/export.py \
+  --exp-dir ./zipformer_mmi/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 30 \
+  --avg 9 \
+  --jit 1
+
+It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
+load it by `torch.jit.load("cpu_jit.pt")`.
+
+Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
+are on CPU. You can use `to("cuda")` to move them to a CUDA device.
+
+Check
+https://github.com/k2-fsa/sherpa
+for how to use the exported models outside of icefall.
+
+(2) Export `model.state_dict()`
+
+./zipformer_mmi/export.py \
+  --exp-dir ./zipformer_mmi/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 20 \
+  --avg 10
+
+It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
+load it by `icefall.checkpoint.load_checkpoint()`.
+
+To use the generated file with `zipformer_mmi/decode.py`,
+you can do:
+
+    cd /path/to/exp_dir
+    ln -s pretrained.pt epoch-9999.pt
+
+    cd /path/to/egs/librispeech/ASR
+    ./zipformer_mmi/decode.py \
+        --exp-dir ./zipformer_mmi/exp \
+        --epoch 9999 \
+        --avg 1 \
+        --max-duration 600 \
+        --decoding-method greedy_search \
+        --bpe-model data/lang_bpe_500/bpe.model
+
+Check ./pretrained.py for its usage.
+
+Note: If you don't want to train a model from scratch, we have
+provided one for you. You can get it at
+
+https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-mmi-2022-12-08
+
+with the following commands:
+
+    sudo apt-get install git-lfs
+    git lfs install
+    git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-mmi-2022-12-08
+    # You will find the pre-trained model in icefall-asr-librispeech-zipformer-mmi-2022-12-08/exp
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import sentencepiece as spm
+import torch
+from scaling_converter import convert_scaled_to_non_scaled
+from train import add_model_arguments, get_ctc_model, get_params
+
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.utils import str2bool
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=30,
+        help="""It specifies the checkpoint to use for decoding.
+        Note: Epoch counts from 1.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=9,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=True,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="zipformer_mmi/exp",
+        help="""It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_500/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--jit",
+        type=str2bool,
+        default=False,
+        help="""True to save a model after applying torch.jit.script.
+        It will generate a file named cpu_jit.pt
+
+        Check ./jit_pretrained.py for how to use it.
+        """,
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+@torch.no_grad()
+def main():
+    args = get_parser().parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    params = get_params()
+    params.update(vars(args))
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_ctc_model(params)
+
+    model.to(device)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to("cpu")
+    model.eval()
+
+    if params.jit is True:
+        convert_scaled_to_non_scaled(model, inplace=True)
+        logging.info("Using torch.jit.script()")
+        model = torch.jit.script(model)
+        filename = params.exp_dir / "cpu_jit.pt"
+        model.save(str(filename))
+        logging.info(f"Saved to {filename}")
+    else:
+        logging.info("Not using torchscript. Export model.state_dict()")
+        # Save it using a format so that it can be loaded
+        # by :func:`load_checkpoint`
+        filename = params.exp_dir / "pretrained.pt"
+        torch.save({"model": model.state_dict()}, str(filename))
+        logging.info(f"Saved to {filename}")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py b/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py
new file mode 100755
index 000000000..c9ef16ffa
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py
@@ -0,0 +1,391 @@
+#!/usr/bin/env python3
+# Copyright      2021-2022  Xiaomi Corp.   (authors: Fangjun Kuang,
+#                                                    Zengwei)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script loads torchscript models, exported by `torch.jit.script()`
+and uses them to decode waves.
+You can use the following command to get the exported models:
+
+./zipformer_mmi/export.py \
+  --exp-dir ./zipformer_mmi/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 20 \
+  --avg 10 \
+  --jit 1
+
+Usage of this script:
+
+(1) 1best
+./zipformer_mmi/jit_pretrained.py \
+    --nn-model-filename ./zipformer_mmi/exp/cpu_jit.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method 1best \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+(2) nbest
+./zipformer_mmi/jit_pretrained.py \
+    --nn-model-filename ./zipformer_mmi/exp/cpu_jit.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --nbest-scale 1.2 \
+    --method nbest \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+(3) nbest-rescoring-LG
+./zipformer_mmi/jit_pretrained.py \
+    --nn-model-filename ./zipformer_mmi/exp/cpu_jit.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --nbest-scale 1.2 \
+    --method nbest-rescoring-LG \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+(4) nbest-rescoring-3-gram
+./zipformer_mmi/jit_pretrained.py \
+    --nn-model-filename ./zipformer_mmi/exp/cpu_jit.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --nbest-scale 1.2 \
+    --method nbest-rescoring-3-gram \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+(5) nbest-rescoring-4-gram
+./zipformer_mmi/jit_pretrained.py \
+    --nn-model-filename ./zipformer_mmi/exp/cpu_jit.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --nbest-scale 1.2 \
+    --method nbest-rescoring-4-gram \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+"""
+
+import argparse
+import logging
+import math
+from pathlib import Path
+from typing import List
+
+import k2
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from decode import get_decoding_params
+from torch.nn.utils.rnn import pad_sequence
+from train import get_params
+
+from icefall.decode import (
+    get_lattice,
+    nbest_decoding,
+    nbest_rescore_with_LM,
+    one_best_decoding,
+)
+from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
+from icefall.utils import get_texts
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--nn-model-filename",
+        type=str,
+        required=True,
+        help="Path to the torchscript model cpu_jit.pt",
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        help="""Path to bpe.model.""",
+    )
+
+    parser.add_argument(
+        "--method",
+        type=str,
+        default="1best",
+        help="""Decoding method. Use HP as decoding graph, where H is
+        ctc_topo and P is token-level bi-gram lm.
+        Supported values are:
+        - (1) 1best. Extract the best path from the decoding lattice as the
+          decoding result.
+        - (2) nbest. Extract n paths from the decoding lattice; the path
+          with the highest score is the decoding result.
+        - (4) nbest-rescoring-LG. Extract n paths from the decoding lattice,
+          rescore them with an word-level 3-gram LM, the path with the
+          highest score is the decoding result.
+        - (5) nbest-rescoring-3-gram. Extract n paths from the decoding
+          lattice, rescore them with an token-level 3-gram LM, the path with
+          the highest score is the decoding result.
+        - (6) nbest-rescoring-4-gram. Extract n paths from the decoding
+          lattice, rescore them with an token-level 4-gram LM, the path with
+          the highest score is the decoding result.
+        """,
+    )
+
+    parser.add_argument(
+        "--sample-rate",
+        type=int,
+        default=16000,
+        help="The sample rate of the input sound file",
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=Path,
+        default="data/lang_bpe_500",
+        help="The lang dir containing word table and LG graph",
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=100,
+        help="""Number of paths for n-best based decoding method.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, and nbest-oracle
+        """,
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=1.2,
+        help="""The scale to be applied to `lattice.scores`.
+        It's needed if you use any kinds of n-best based rescoring.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, and nbest-oracle
+        A smaller value results in more unique paths.
+        """,
+    )
+
+    parser.add_argument(
+        "--ngram-lm-scale",
+        type=float,
+        default=0.1,
+        help="""
+        Used when method is nbest-rescoring-LG, nbest-rescoring-3-gram,
+        and nbest-rescoring-4-gram.
+        It specifies the scale for n-gram LM scores.
+        (Note: You need to tune it on a dataset.)
+        """,
+    )
+
+    parser.add_argument(
+        "--hp-scale",
+        type=float,
+        default=1.0,
+        help="""The scale to be applied to `ctc_topo_P.scores`.
+        """,
+    )
+
+    parser.add_argument(
+        "sound_files",
+        type=str,
+        nargs="+",
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
+    )
+
+    return parser
+
+
+def read_sound_files(
+    filenames: List[str], expected_sample_rate: float = 16000
+) -> List[torch.Tensor]:
+    """Read a list of sound files into a list 1-D float32 torch tensors.
+    Args:
+      filenames:
+        A list of sound filenames.
+      expected_sample_rate:
+        The expected sample rate of the sound files.
+    Returns:
+      Return a list of 1-D float32 torch tensors.
+    """
+    ans = []
+    for f in filenames:
+        wave, sample_rate = torchaudio.load(f)
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        # We use only the first channel
+        ans.append(wave[0])
+    return ans
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+    logging.info(vars(args))
+
+    params = get_params()
+    # add decoding params
+    params.update(get_decoding_params())
+    params.update(vars(args))
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    model = torch.jit.load(params.nn_model_filename)
+    model.eval()
+    model.to(device)
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(args.bpe_model)
+
+    logging.info("Constructing Fbank computer")
+    opts = kaldifeat.FbankOptions()
+    opts.device = device
+    opts.frame_opts.dither = 0
+    opts.frame_opts.snip_edges = False
+    opts.frame_opts.samp_freq = 16000
+    opts.mel_opts.num_bins = 80
+
+    fbank = kaldifeat.Fbank(opts)
+
+    logging.info(f"Reading sound files: {args.sound_files}")
+    waves = read_sound_files(
+        filenames=params.sound_files, expected_sample_rate=params.sample_rate
+    )
+    waves = [w.to(device) for w in waves]
+
+    logging.info("Decoding started")
+    features = fbank(waves)
+    feature_lengths = [f.size(0) for f in features]
+
+    features = pad_sequence(
+        features,
+        batch_first=True,
+        padding_value=math.log(1e-10),
+    )
+    feature_lengths = torch.tensor(feature_lengths, device=device)
+
+    bpe_model = spm.SentencePieceProcessor()
+    bpe_model.load(str(params.lang_dir / "bpe.model"))
+    mmi_graph_compiler = MmiTrainingGraphCompiler(
+        params.lang_dir,
+        uniq_filename="lexicon.txt",
+        device=device,
+        oov="",
+        sos_id=1,
+        eos_id=1,
+    )
+    HP = mmi_graph_compiler.ctc_topo_P
+    HP.scores *= params.hp_scale
+    if not hasattr(HP, "lm_scores"):
+        HP.lm_scores = HP.scores.clone()
+
+    method = params.method
+    assert method in (
+        "1best",
+        "nbest",
+        "nbest-rescoring-LG",  # word-level 3-gram lm
+        "nbest-rescoring-3-gram",  # token-level 3-gram lm
+        "nbest-rescoring-4-gram",  # token-level 4-gram lm
+    )
+    # loading language model for rescoring
+    LM = None
+    if method == "nbest-rescoring-LG":
+        lg_filename = params.lang_dir / "LG.pt"
+        logging.info(f"Loading {lg_filename}")
+        LG = k2.Fsa.from_dict(torch.load(lg_filename, map_location=device))
+        LG = k2.Fsa.from_fsas([LG]).to(device)
+        LG.lm_scores = LG.scores.clone()
+        LM = LG
+    elif method in ["nbest-rescoring-3-gram", "nbest-rescoring-4-gram"]:
+        order = method[-6]
+        assert order in ("3", "4")
+        order = int(order)
+        logging.info(f"Loading pre-compiled {order}gram.pt")
+        d = torch.load(params.lang_dir / f"{order}gram.pt", map_location=device)
+        G = k2.Fsa.from_dict(d)
+        G.lm_scores = G.scores.clone()
+        LM = G
+
+    # Encoder forward
+    nnet_output, encoder_out_lens = model(x=features, x_lens=feature_lengths)
+
+    batch_size = nnet_output.shape[0]
+    supervision_segments = torch.tensor(
+        [
+            [i, 0, feature_lengths[i] // params.subsampling_factor]
+            for i in range(batch_size)
+        ],
+        dtype=torch.int32,
+    )
+
+    lattice = get_lattice(
+        nnet_output=nnet_output,
+        decoding_graph=HP,
+        supervision_segments=supervision_segments,
+        search_beam=params.search_beam,
+        output_beam=params.output_beam,
+        min_active_states=params.min_active_states,
+        max_active_states=params.max_active_states,
+        subsampling_factor=params.subsampling_factor,
+    )
+
+    if method in ["1best", "nbest"]:
+        if method == "1best":
+            best_path = one_best_decoding(
+                lattice=lattice, use_double_scores=params.use_double_scores
+            )
+        else:
+            best_path = nbest_decoding(
+                lattice=lattice,
+                num_paths=params.num_paths,
+                use_double_scores=params.use_double_scores,
+                nbest_scale=params.nbest_scale,
+            )
+    else:
+        best_path_dict = nbest_rescore_with_LM(
+            lattice=lattice,
+            LM=LM,
+            num_paths=params.num_paths,
+            lm_scale_list=[params.ngram_lm_scale],
+            nbest_scale=params.nbest_scale,
+        )
+        best_path = next(iter(best_path_dict.values()))
+
+    # Note: `best_path.aux_labels` contains token IDs, not word IDs
+    # since we are using HP, not HLG here.
+    #
+    # token_ids is a lit-of-list of IDs
+    token_ids = get_texts(best_path)
+    # hyps is a list of str, e.g., ['xxx yyy zzz', ...]
+    hyps = bpe_model.decode(token_ids)
+    # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
+    hyps = [s.split() for s in hyps]
+    s = "\n"
+    for filename, hyp in zip(params.sound_files, hyps):
+        words = " ".join(hyp)
+        s += f"{filename}:\n{words}\n\n"
+    logging.info(s)
+
+    logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/librispeech/ASR/zipformer_mmi/model.py b/egs/librispeech/ASR/zipformer_mmi/model.py
new file mode 100644
index 000000000..4045c8b64
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/model.py
@@ -0,0 +1,75 @@
+# Copyright    2022  Xiaomi Corp.        (authors: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+from encoder_interface import EncoderInterface
+
+
+class CTCModel(nn.Module):
+    def __init__(
+        self,
+        encoder: EncoderInterface,
+        encoder_dim: int,
+        vocab_size: int,
+    ):
+        """
+        Args:
+          encoder:
+            It is the transcription network in the paper. Its accepts
+            two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
+            It returns two tensors: `logits` of shape (N, T, encoder_dm) and
+            `logit_lens` of shape (N,).
+        """
+        super().__init__()
+        assert isinstance(encoder, EncoderInterface), type(encoder)
+
+        self.encoder = encoder
+
+        self.ctc_output = nn.Sequential(
+            nn.Dropout(p=0.1),
+            nn.Linear(encoder_dim, vocab_size),
+            nn.LogSoftmax(dim=-1),
+        )
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        x_lens: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Args:
+          x:
+            A 3-D tensor of shape (N, T, C).
+          x_lens:
+            A 1-D tensor of shape (N,). It contains the number of frames in `x`
+            before padding.
+        Returns:
+          Return the ctc outputs and encoder output lengths.
+        """
+        assert x.ndim == 3, x.shape
+        assert x_lens.ndim == 1, x_lens.shape
+
+        encoder_out, x_lens = self.encoder(x, x_lens)
+        assert torch.all(x_lens > 0)
+
+        # compute ctc log-probs
+        ctc_output = self.ctc_output(encoder_out)
+
+        return ctc_output, x_lens
diff --git a/egs/librispeech/ASR/zipformer_mmi/optim.py b/egs/librispeech/ASR/zipformer_mmi/optim.py
new file mode 120000
index 000000000..81ac4a89a
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/optim.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7/optim.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_mmi/pretrained.py b/egs/librispeech/ASR/zipformer_mmi/pretrained.py
new file mode 100755
index 000000000..0e7fd0daf
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/pretrained.py
@@ -0,0 +1,410 @@
+#!/usr/bin/env python3
+# Copyright      2021-2022  Xiaomi Corp.   (authors: Fangjun Kuang,
+#                                                    Zengwei)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script loads a checkpoint and uses it to decode waves.
+You can generate the checkpoint with the following command:
+
+./zipformer_mmi/export.py \
+  --exp-dir ./zipformer_mmi/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 20 \
+  --avg 10
+
+Usage of this script:
+
+(1) 1best
+./zipformer_mmi/pretrained.py \
+    --checkpoint ./zipformer_mmi/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method 1best \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+(2) nbest
+./zipformer_mmi/pretrained.py \
+    --checkpoint ./zipformer_mmi/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --nbest-scale 1.2 \
+    --method nbest \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+(3) nbest-rescoring-LG
+./zipformer_mmi/pretrained.py \
+    --checkpoint ./zipformer_mmi/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --nbest-scale 1.2 \
+    --method nbest-rescoring-LG \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+(4) nbest-rescoring-3-gram
+./zipformer_mmi/pretrained.py \
+    --checkpoint ./zipformer_mmi/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --nbest-scale 1.2 \
+    --method nbest-rescoring-3-gram \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+(5) nbest-rescoring-4-gram
+./zipformer_mmi/pretrained.py \
+    --checkpoint ./zipformer_mmi/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --nbest-scale 1.2 \
+    --method nbest-rescoring-4-gram \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+
+You can also use `./zipformer_mmi/exp/epoch-xx.pt`.
+
+Note: ./zipformer_mmi/exp/pretrained.pt is generated by
+./zipformer_mmi/export.py
+"""
+
+
+import argparse
+import logging
+import math
+from pathlib import Path
+from typing import List
+
+import k2
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from decode import get_decoding_params
+from torch.nn.utils.rnn import pad_sequence
+from train import add_model_arguments, get_ctc_model, get_params
+
+from icefall.decode import (
+    get_lattice,
+    nbest_decoding,
+    nbest_rescore_with_LM,
+    one_best_decoding,
+)
+from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
+from icefall.utils import get_texts
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--checkpoint",
+        type=str,
+        required=True,
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        help="""Path to bpe.model.""",
+    )
+
+    parser.add_argument(
+        "--method",
+        type=str,
+        default="1best",
+        help="""Decoding method. Use HP as decoding graph, where H is
+        ctc_topo and P is token-level bi-gram lm.
+        Supported values are:
+        - (1) 1best. Extract the best path from the decoding lattice as the
+          decoding result.
+        - (2) nbest. Extract n paths from the decoding lattice; the path
+          with the highest score is the decoding result.
+        - (4) nbest-rescoring-LG. Extract n paths from the decoding lattice,
+          rescore them with an word-level 3-gram LM, the path with the
+          highest score is the decoding result.
+        - (5) nbest-rescoring-3-gram. Extract n paths from the decoding
+          lattice, rescore them with an token-level 3-gram LM, the path with
+          the highest score is the decoding result.
+        - (6) nbest-rescoring-4-gram. Extract n paths from the decoding
+          lattice, rescore them with an token-level 4-gram LM, the path with
+          the highest score is the decoding result.
+        """,
+    )
+
+    parser.add_argument(
+        "--sample-rate",
+        type=int,
+        default=16000,
+        help="The sample rate of the input sound file",
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=Path,
+        default="data/lang_bpe_500",
+        help="The lang dir containing word table and LG graph",
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=100,
+        help="""Number of paths for n-best based decoding method.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, and nbest-oracle
+        """,
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=1.2,
+        help="""The scale to be applied to `lattice.scores`.
+        It's needed if you use any kinds of n-best based rescoring.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, and nbest-oracle
+        A smaller value results in more unique paths.
+        """,
+    )
+
+    parser.add_argument(
+        "--ngram-lm-scale",
+        type=float,
+        default=0.1,
+        help="""
+        Used when method is nbest-rescoring-LG, nbest-rescoring-3-gram,
+        and nbest-rescoring-4-gram.
+        It specifies the scale for n-gram LM scores.
+        (Note: You need to tune it on a dataset.)
+        """,
+    )
+
+    parser.add_argument(
+        "--hp-scale",
+        type=float,
+        default=1.0,
+        help="""The scale to be applied to `ctc_topo_P.scores`.
+        """,
+    )
+
+    parser.add_argument(
+        "sound_files",
+        type=str,
+        nargs="+",
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def read_sound_files(
+    filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+    """Read a list of sound files into a list 1-D float32 torch tensors.
+    Args:
+      filenames:
+        A list of sound filenames.
+      expected_sample_rate:
+        The expected sample rate of the sound files.
+    Returns:
+      Return a list of 1-D float32 torch tensors.
+    """
+    ans = []
+    for f in filenames:
+        wave, sample_rate = torchaudio.load(f)
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        # We use only the first channel
+        ans.append(wave[0])
+    return ans
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+
+    params = get_params()
+    # add decoding params
+    params.update(get_decoding_params())
+    params.update(vars(args))
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.unk_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(f"{params}")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    logging.info("Creating model")
+    model = get_ctc_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    checkpoint = torch.load(args.checkpoint, map_location="cpu")
+    model.load_state_dict(checkpoint["model"], strict=False)
+    model.to(device)
+    model.eval()
+    model.device = device
+
+    logging.info("Constructing Fbank computer")
+    opts = kaldifeat.FbankOptions()
+    opts.device = device
+    opts.frame_opts.dither = 0
+    opts.frame_opts.snip_edges = False
+    opts.frame_opts.samp_freq = params.sample_rate
+    opts.mel_opts.num_bins = params.feature_dim
+
+    fbank = kaldifeat.Fbank(opts)
+
+    logging.info(f"Reading sound files: {params.sound_files}")
+    waves = read_sound_files(
+        filenames=params.sound_files, expected_sample_rate=params.sample_rate
+    )
+    waves = [w.to(device) for w in waves]
+
+    logging.info("Decoding started")
+    features = fbank(waves)
+    feature_lengths = [f.size(0) for f in features]
+
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    feature_lengths = torch.tensor(feature_lengths, device=device)
+
+    bpe_model = spm.SentencePieceProcessor()
+    bpe_model.load(str(params.lang_dir / "bpe.model"))
+    mmi_graph_compiler = MmiTrainingGraphCompiler(
+        params.lang_dir,
+        uniq_filename="lexicon.txt",
+        device=device,
+        oov="",
+        sos_id=1,
+        eos_id=1,
+    )
+    HP = mmi_graph_compiler.ctc_topo_P
+    HP.scores *= params.hp_scale
+    if not hasattr(HP, "lm_scores"):
+        HP.lm_scores = HP.scores.clone()
+
+    method = params.method
+    assert method in (
+        "1best",
+        "nbest",
+        "nbest-rescoring-LG",  # word-level 3-gram lm
+        "nbest-rescoring-3-gram",  # token-level 3-gram lm
+        "nbest-rescoring-4-gram",  # token-level 4-gram lm
+    )
+    # loading language model for rescoring
+    LM = None
+    if method == "nbest-rescoring-LG":
+        lg_filename = params.lang_dir / "LG.pt"
+        logging.info(f"Loading {lg_filename}")
+        LG = k2.Fsa.from_dict(torch.load(lg_filename, map_location=device))
+        LG = k2.Fsa.from_fsas([LG]).to(device)
+        LG.lm_scores = LG.scores.clone()
+        LM = LG
+    elif method in ["nbest-rescoring-3-gram", "nbest-rescoring-4-gram"]:
+        order = method[-6]
+        assert order in ("3", "4")
+        order = int(order)
+        logging.info(f"Loading pre-compiled {order}gram.pt")
+        d = torch.load(params.lang_dir / f"{order}gram.pt", map_location=device)
+        G = k2.Fsa.from_dict(d)
+        G.lm_scores = G.scores.clone()
+        LM = G
+
+    # Encoder forward
+    nnet_output, encoder_out_lens = model(x=features, x_lens=feature_lengths)
+
+    batch_size = nnet_output.shape[0]
+    supervision_segments = torch.tensor(
+        [
+            [i, 0, feature_lengths[i] // params.subsampling_factor]
+            for i in range(batch_size)
+        ],
+        dtype=torch.int32,
+    )
+
+    lattice = get_lattice(
+        nnet_output=nnet_output,
+        decoding_graph=HP,
+        supervision_segments=supervision_segments,
+        search_beam=params.search_beam,
+        output_beam=params.output_beam,
+        min_active_states=params.min_active_states,
+        max_active_states=params.max_active_states,
+        subsampling_factor=params.subsampling_factor,
+    )
+
+    if method in ["1best", "nbest"]:
+        if method == "1best":
+            best_path = one_best_decoding(
+                lattice=lattice, use_double_scores=params.use_double_scores
+            )
+        else:
+            best_path = nbest_decoding(
+                lattice=lattice,
+                num_paths=params.num_paths,
+                use_double_scores=params.use_double_scores,
+                nbest_scale=params.nbest_scale,
+            )
+    else:
+        best_path_dict = nbest_rescore_with_LM(
+            lattice=lattice,
+            LM=LM,
+            num_paths=params.num_paths,
+            lm_scale_list=[params.ngram_lm_scale],
+            nbest_scale=params.nbest_scale,
+        )
+        best_path = next(iter(best_path_dict.values()))
+
+    # Note: `best_path.aux_labels` contains token IDs, not word IDs
+    # since we are using HP, not HLG here.
+    #
+    # token_ids is a lit-of-list of IDs
+    token_ids = get_texts(best_path)
+    # hyps is a list of str, e.g., ['xxx yyy zzz', ...]
+    hyps = bpe_model.decode(token_ids)
+    # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
+    hyps = [s.split() for s in hyps]
+    s = "\n"
+    for filename, hyp in zip(params.sound_files, hyps):
+        words = " ".join(hyp)
+        s += f"{filename}:\n{words}\n\n"
+    logging.info(s)
+
+    logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/librispeech/ASR/zipformer_mmi/scaling.py b/egs/librispeech/ASR/zipformer_mmi/scaling.py
new file mode 120000
index 000000000..2428b74b9
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/scaling.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7/scaling.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_mmi/scaling_converter.py b/egs/librispeech/ASR/zipformer_mmi/scaling_converter.py
new file mode 120000
index 000000000..b8b8ba432
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/scaling_converter.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7/scaling_converter.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_mmi/test_model.py b/egs/librispeech/ASR/zipformer_mmi/test_model.py
new file mode 100755
index 000000000..7782845f4
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/test_model.py
@@ -0,0 +1,57 @@
+#!/usr/bin/env python3
+# Copyright    2022  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+To run this file, do:
+
+    cd icefall/egs/librispeech/ASR
+    python ./zipformer_mmi/test_model.py
+"""
+
+import torch
+from train import get_ctc_model, get_params
+
+
+def test_model():
+    params = get_params()
+    params.vocab_size = 500
+    params.num_encoder_layers = "2,4,3,2,4"
+    #  params.feedforward_dims = "1024,1024,1536,1536,1024"
+    params.feedforward_dims = "1024,1024,2048,2048,1024"
+    params.nhead = "8,8,8,8,8"
+    params.encoder_dims = "384,384,384,384,384"
+    params.attention_dims = "192,192,192,192,192"
+    params.encoder_unmasked_dims = "256,256,256,256,256"
+    params.zipformer_downsampling_factors = "1,2,4,8,2"
+    params.cnn_module_kernels = "31,31,31,31,31"
+    model = get_ctc_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    print(f"Number of model parameters: {num_param}")
+
+    features = torch.randn(2, 100, 80)
+    feature_lengths = torch.full((2,), 100)
+    model(x=features, x_lens=feature_lengths)
+
+
+def main():
+    test_model()
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/librispeech/ASR/zipformer_mmi/train.py b/egs/librispeech/ASR/zipformer_mmi/train.py
new file mode 100755
index 000000000..b2784e47c
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/train.py
@@ -0,0 +1,1198 @@
+#!/usr/bin/env python3
+# Copyright    2021-2022  Xiaomi Corp.        (authors: Fangjun Kuang,
+#                                                       Wei Kang,
+#                                                       Mingshuang Luo,)
+#                                                       Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./zipformer_mmi/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --exp-dir zipformer_mmi/exp \
+  --full-libri 1 \
+  --max-duration 300
+
+# For mix precision training:
+
+./zipformer_mmi/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --use-fp16 1 \
+  --exp-dir zipformer_mmi/exp \
+  --full-libri 1 \
+  --max-duration 500
+
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import CTCModel
+from optim import Eden, ScaledAdam
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer
+
+from icefall import diagnostics
+from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+    save_checkpoint_with_global_batch_idx,
+    update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.lexicon import Lexicon, UniqLexicon
+from icefall.mmi import LFMMILoss
+from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
+from icefall.utils import (
+    AttributeDict,
+    MetricsTracker,
+    encode_supervisions,
+    setup_logger,
+    str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+    if isinstance(model, DDP):
+        # get underlying nn.Module
+        model = model.module
+    for module in model.modules():
+        if hasattr(module, "batch_count"):
+            module.batch_count = batch_count
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+    parser.add_argument(
+        "--num-encoder-layers",
+        type=str,
+        default="2,4,3,2,4",
+        help="Number of zipformer encoder layers, comma separated.",
+    )
+
+    parser.add_argument(
+        "--feedforward-dims",
+        type=str,
+        default="1024,1024,2048,2048,1024",
+        help="Feedforward dimension of the zipformer encoder layers, comma separated.",
+    )
+
+    parser.add_argument(
+        "--nhead",
+        type=str,
+        default="8,8,8,8,8",
+        help="Number of attention heads in the zipformer encoder layers.",
+    )
+
+    parser.add_argument(
+        "--encoder-dims",
+        type=str,
+        default="384,384,384,384,384",
+        help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated",
+    )
+
+    parser.add_argument(
+        "--attention-dims",
+        type=str,
+        default="192,192,192,192,192",
+        help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
+        not the same as embedding dimension.""",
+    )
+
+    parser.add_argument(
+        "--encoder-unmasked-dims",
+        type=str,
+        default="256,256,256,256,256",
+        help="Unmasked dimensions in the encoders, relates to augmentation during training.  "
+        "Must be <= each of encoder_dims.  Empirically, less than 256 seems to make performance "
+        " worse.",
+    )
+
+    parser.add_argument(
+        "--zipformer-downsampling-factors",
+        type=str,
+        default="1,2,4,8,2",
+        help="Downsampling factor for each stack of encoder layers.",
+    )
+
+    parser.add_argument(
+        "--cnn-module-kernels",
+        type=str,
+        default="31,31,31,31,31",
+        help="Sizes of kernels in convolution modules",
+    )
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--world-size",
+        type=int,
+        default=1,
+        help="Number of GPUs for DDP training.",
+    )
+
+    parser.add_argument(
+        "--master-port",
+        type=int,
+        default=12354,
+        help="Master port to use for DDP training.",
+    )
+
+    parser.add_argument(
+        "--tensorboard",
+        type=str2bool,
+        default=True,
+        help="Should various information be logged in tensorboard.",
+    )
+
+    parser.add_argument(
+        "--num-epochs",
+        type=int,
+        default=30,
+        help="Number of epochs to train.",
+    )
+
+    parser.add_argument(
+        "--start-epoch",
+        type=int,
+        default=1,
+        help="""Resume training from this epoch. It should be positive.
+        If larger than 1, it will load checkpoint from
+        exp-dir/epoch-{start_epoch-1}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--start-batch",
+        type=int,
+        default=0,
+        help="""If positive, --start-epoch is ignored and
+        it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="zipformer_mmi/exp",
+        help="""The experiment dir.
+        It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        default="data/lang_bpe_500",
+        help="""The lang dir
+        It contains language related input files such as
+        "lexicon.txt"
+        """,
+    )
+
+    parser.add_argument(
+        "--base-lr", type=float, default=0.05, help="The base learning rate."
+    )
+
+    parser.add_argument(
+        "--lr-batches",
+        type=float,
+        default=5000,
+        help="""Number of steps that affects how rapidly the learning rate
+        decreases. We suggest not to change this.""",
+    )
+
+    parser.add_argument(
+        "--lr-epochs",
+        type=float,
+        default=3.5,
+        help="""Number of epochs that affects how rapidly the learning rate decreases.
+        """,
+    )
+
+    parser.add_argument(
+        "--seed",
+        type=int,
+        default=42,
+        help="The seed for random generators intended for reproducibility",
+    )
+
+    parser.add_argument(
+        "--use-pruned-intersect",
+        type=str2bool,
+        default=False,
+        help="""Whether to use `intersect_dense_pruned` to get denominator
+        lattice.""",
+    )
+
+    parser.add_argument(
+        "--print-diagnostics",
+        type=str2bool,
+        default=False,
+        help="Accumulate stats on activations, print them and exit.",
+    )
+
+    parser.add_argument(
+        "--inf-check",
+        type=str2bool,
+        default=False,
+        help="Add hooks to check for infinite module outputs and gradients.",
+    )
+
+    parser.add_argument(
+        "--save-every-n",
+        type=int,
+        default=2000,
+        help="""Save checkpoint after processing this number of batches"
+        periodically. We save checkpoint to exp-dir/ whenever
+        params.batch_idx_train % save_every_n == 0. The checkpoint filename
+        has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+        Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+        end of each epoch where `xxx` is the epoch number counting from 0.
+        """,
+    )
+
+    parser.add_argument(
+        "--keep-last-k",
+        type=int,
+        default=30,
+        help="""Only keep this number of checkpoints on disk.
+        For instance, if it is 3, there are only 3 checkpoints
+        in the exp-dir with filenames `checkpoint-xxx.pt`.
+        It does not affect checkpoints with name `epoch-xxx.pt`.
+        """,
+    )
+
+    parser.add_argument(
+        "--average-period",
+        type=int,
+        default=200,
+        help="""Update the averaged model, namely `model_avg`, after processing
+        this number of batches. `model_avg` is a separate version of model,
+        in which each floating-point parameter is the average of all the
+        parameters from the start of training. Each time we take the average,
+        we do: `model_avg = model * (average_period / batch_idx_train) +
+            model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+        """,
+    )
+
+    parser.add_argument(
+        "--use-fp16",
+        type=str2bool,
+        default=False,
+        help="Whether to use half precision training.",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    """Return a dict containing training parameters.
+
+    All training related parameters that are not passed from the commandline
+    are saved in the variable `params`.
+
+    Commandline options are merged into `params` after they are parsed, so
+    you can also access them via `params`.
+
+    Explanation of options saved in `params`:
+
+        - best_train_loss: Best training loss so far. It is used to select
+                           the model that has the lowest training loss. It is
+                           updated during the training.
+
+        - best_valid_loss: Best validation loss so far. It is used to select
+                           the model that has the lowest validation loss. It is
+                           updated during the training.
+
+        - best_train_epoch: It is the epoch that has the best training loss.
+
+        - best_valid_epoch: It is the epoch that has the best validation loss.
+
+        - batch_idx_train: Used to writing statistics to tensorboard. It
+                           contains number of batches trained so far across
+                           epochs.
+
+        - log_interval:  Print training loss if batch_idx % log_interval` is 0
+
+        - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+        - valid_interval:  Run validation if batch_idx % valid_interval is 0
+
+        - feature_dim: The model input dim. It has to match the one used
+                       in computing features.
+
+        - subsampling_factor:  The subsampling factor for the model.
+
+        - encoder_dim: Hidden dim for multi-head attention model.
+
+        - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+        - warm_step: The warmup period that dictates the decay of the
+              scale on "simple" (un-pruned) loss.
+    """
+    params = AttributeDict(
+        {
+            "best_train_loss": float("inf"),
+            "best_valid_loss": float("inf"),
+            "best_train_epoch": -1,
+            "best_valid_epoch": -1,
+            "batch_idx_train": 0,
+            "log_interval": 50,
+            "reset_interval": 200,
+            "valid_interval": 3000,  # For the 100h subset, use 800
+            # parameters for zipformer
+            "feature_dim": 80,
+            "subsampling_factor": 4,  # not passed in, this is fixed.
+            # parameters for mmi loss
+            "mmi_beam_size": 6,
+            "den_scale": 1.0,
+            # parameters for mmi loss
+            "ctc_beam_size": 10,
+            "reduction": "sum",
+            "use_double_scores": True,
+            "warm_step": 2000,
+            "env_info": get_env_info(),
+        }
+    )
+
+    return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+    # TODO: We can add an option to switch between Zipformer and Transformer
+    def to_int_tuple(s: str):
+        return tuple(map(int, s.split(",")))
+
+    encoder = Zipformer(
+        num_features=params.feature_dim,
+        output_downsampling_factor=2,
+        zipformer_downsampling_factors=to_int_tuple(
+            params.zipformer_downsampling_factors
+        ),
+        encoder_dims=to_int_tuple(params.encoder_dims),
+        attention_dim=to_int_tuple(params.attention_dims),
+        encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
+        nhead=to_int_tuple(params.nhead),
+        feedforward_dim=to_int_tuple(params.feedforward_dims),
+        cnn_module_kernels=to_int_tuple(params.cnn_module_kernels),
+        num_encoder_layers=to_int_tuple(params.num_encoder_layers),
+    )
+    return encoder
+
+
+def get_ctc_model(params: AttributeDict) -> nn.Module:
+    encoder = get_encoder_model(params)
+
+    model = CTCModel(
+        encoder=encoder,
+        encoder_dim=int(params.encoder_dims.split(",")[-1]),
+        vocab_size=params.vocab_size,
+    )
+    return model
+
+
+def load_checkpoint_if_available(
+    params: AttributeDict,
+    model: nn.Module,
+    model_avg: nn.Module = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+    """Load checkpoint from file.
+
+    If params.start_batch is positive, it will load the checkpoint from
+    `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+    params.start_epoch is larger than 1, it will load the checkpoint from
+    `params.start_epoch - 1`.
+
+    Apart from loading state dict for `model` and `optimizer` it also updates
+    `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+    and `best_valid_loss` in `params`.
+
+    Args:
+      params:
+        The return value of :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer that we are using.
+      scheduler:
+        The scheduler that we are using.
+    Returns:
+      Return a dict containing previously saved training info.
+    """
+    if params.start_batch > 0:
+        filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+    elif params.start_epoch > 1:
+        filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+    else:
+        return None
+
+    assert filename.is_file(), f"{filename} does not exist!"
+
+    saved_params = load_checkpoint(
+        filename,
+        model=model,
+        model_avg=model_avg,
+        optimizer=optimizer,
+        scheduler=scheduler,
+    )
+
+    keys = [
+        "best_train_epoch",
+        "best_valid_epoch",
+        "batch_idx_train",
+        "best_train_loss",
+        "best_valid_loss",
+    ]
+    for k in keys:
+        params[k] = saved_params[k]
+
+    if params.start_batch > 0:
+        if "cur_epoch" in saved_params:
+            params["start_epoch"] = saved_params["cur_epoch"]
+
+        if "cur_batch_idx" in saved_params:
+            params["cur_batch_idx"] = saved_params["cur_batch_idx"]
+
+    return saved_params
+
+
+def save_checkpoint(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    model_avg: Optional[nn.Module] = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+    sampler: Optional[CutSampler] = None,
+    scaler: Optional[GradScaler] = None,
+    rank: int = 0,
+) -> None:
+    """Save model, optimizer, scheduler and training stats to file.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer used in the training.
+      sampler:
+       The sampler for the training dataset.
+      scaler:
+        The scaler used for mix precision training.
+    """
+    if rank != 0:
+        return
+    filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+    save_checkpoint_impl(
+        filename=filename,
+        model=model,
+        model_avg=model_avg,
+        params=params,
+        optimizer=optimizer,
+        scheduler=scheduler,
+        sampler=sampler,
+        scaler=scaler,
+        rank=rank,
+    )
+
+    if params.best_train_epoch == params.cur_epoch:
+        best_train_filename = params.exp_dir / "best-train-loss.pt"
+        copyfile(src=filename, dst=best_train_filename)
+
+    if params.best_valid_epoch == params.cur_epoch:
+        best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+        copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    ctc_graph_compiler: BpeCtcTrainingGraphCompiler,
+    mmi_graph_compiler: MmiTrainingGraphCompiler,
+    batch: dict,
+    is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+    """
+    Compute ctc loss given the model and its inputs.
+
+    Args:
+      params:
+        Parameters for training. See :func:`get_params`.
+      model:
+        The model for training. It is an instance of Zipformer in our case.
+      graph_compiler:
+        It is used to build a decoding graph from a ctc topo and training
+        transcript. The training transcript is contained in the given `batch`,
+        while the ctc topo is built when this compiler is instantiated.
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      is_training:
+        True for training. False for validation. When it is True, this
+        function enables autograd during computation; when it is False, it
+        disables autograd.
+    """
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+    feature = batch["inputs"]
+    # at entry, feature is (N, T, C)
+    assert feature.ndim == 3
+    feature = feature.to(device)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    batch_idx_train = params.batch_idx_train
+    warm_step = params.warm_step
+
+    with torch.set_grad_enabled(is_training):
+        nnet_output, encoder_out_lens = model(x=feature, x_lens=feature_lens)
+
+    # NOTE: We need `encode_supervisions` to sort sequences with
+    # different duration in decreasing order, required by
+    # `k2.intersect_dense` called in `LFMMILoss.forward()`
+    with warnings.catch_warnings():
+        warnings.simplefilter("ignore")
+        supervision_segments, texts = encode_supervisions(
+            supervisions, subsampling_factor=params.subsampling_factor
+        )
+
+    dense_fsa_vec = k2.DenseFsaVec(
+        nnet_output,
+        supervision_segments,
+        allow_truncate=params.subsampling_factor - 1,
+    )
+
+    info = MetricsTracker()
+    if batch_idx_train < warm_step:
+        # Training with ctc loss
+        # Works with a BPE model
+        token_ids = ctc_graph_compiler.texts_to_ids(texts)
+        decoding_graph = ctc_graph_compiler.compile(token_ids)
+        loss = k2.ctc_loss(
+            decoding_graph=decoding_graph,
+            dense_fsa_vec=dense_fsa_vec,
+            output_beam=params.ctc_beam_size,
+            reduction=params.reduction,
+            use_double_scores=params.use_double_scores,
+        )
+        info["ctc_loss"] = loss.detach().cpu().item()
+        info["mmi_loss"] = 0
+    else:
+        # Training with mmi loss
+        loss_fn = LFMMILoss(
+            graph_compiler=mmi_graph_compiler,
+            use_pruned_intersect=params.use_pruned_intersect,
+            den_scale=params.den_scale,
+            beam_size=params.mmi_beam_size,
+        )
+        loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts)
+        info["ctc_loss"] = 0
+        info["mmi_loss"] = loss.detach().cpu().item()
+
+    assert loss.requires_grad == is_training
+
+    info["frames"] = encoder_out_lens.sum().cpu().item()
+    # Note: We use reduction=sum while computing the loss.
+    info["loss"] = loss.detach().cpu().item()
+
+    return loss, info
+
+
+def compute_validation_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    ctc_graph_compiler: BpeCtcTrainingGraphCompiler,
+    mmi_graph_compiler: MmiTrainingGraphCompiler,
+    valid_dl: torch.utils.data.DataLoader,
+    world_size: int = 1,
+) -> MetricsTracker:
+    """Run the validation process."""
+    model.eval()
+
+    tot_loss = MetricsTracker()
+
+    for batch_idx, batch in enumerate(valid_dl):
+        loss, loss_info = compute_loss(
+            params=params,
+            model=model,
+            ctc_graph_compiler=ctc_graph_compiler,
+            mmi_graph_compiler=mmi_graph_compiler,
+            batch=batch,
+            is_training=False,
+        )
+        assert loss.requires_grad is False
+        tot_loss = tot_loss + loss_info
+
+    if world_size > 1:
+        tot_loss.reduce(loss.device)
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    if loss_value < params.best_valid_loss:
+        params.best_valid_epoch = params.cur_epoch
+        params.best_valid_loss = loss_value
+
+    return tot_loss
+
+
+def train_one_epoch(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    optimizer: torch.optim.Optimizer,
+    scheduler: LRSchedulerType,
+    ctc_graph_compiler: BpeCtcTrainingGraphCompiler,
+    mmi_graph_compiler: MmiTrainingGraphCompiler,
+    train_dl: torch.utils.data.DataLoader,
+    valid_dl: torch.utils.data.DataLoader,
+    scaler: GradScaler,
+    model_avg: Optional[nn.Module] = None,
+    tb_writer: Optional[SummaryWriter] = None,
+    world_size: int = 1,
+    rank: int = 0,
+) -> None:
+    """Train the model for one epoch.
+
+    The training loss from the mean of all frames is saved in
+    `params.train_loss`. It runs the validation process every
+    `params.valid_interval` batches.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The model for training.
+      optimizer:
+        The optimizer we are using.
+      scheduler:
+        The learning rate scheduler, we call step() every step.
+      graph_compiler:
+        It is used to convert transcripts to FSAs.
+      train_dl:
+        Dataloader for the training dataset.
+      valid_dl:
+        Dataloader for the validation dataset.
+      scaler:
+        The scaler used for mix precision training.
+      model_avg:
+        The stored model averaged from the start of training.
+      tb_writer:
+        Writer to write log messages to tensorboard.
+      world_size:
+        Number of nodes in DDP training. If it is 1, DDP is disabled.
+      rank:
+        The rank of the node in DDP training. If no DDP is used, it should
+        be set to 0.
+    """
+    model.train()
+
+    tot_loss = MetricsTracker()
+
+    cur_batch_idx = params.get("cur_batch_idx", 0)
+
+    for batch_idx, batch in enumerate(train_dl):
+        if batch_idx < cur_batch_idx:
+            continue
+        cur_batch_idx = batch_idx
+
+        params.batch_idx_train += 1
+        batch_size = len(batch["supervisions"]["text"])
+
+        try:
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, loss_info = compute_loss(
+                    params=params,
+                    model=model,
+                    ctc_graph_compiler=ctc_graph_compiler,
+                    mmi_graph_compiler=mmi_graph_compiler,
+                    batch=batch,
+                    is_training=True,
+                )
+            # summary stats
+            tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+            # NOTE: We use reduction==sum and loss is computed over utterances
+            # in the batch and there is no normalization to it so far.
+            scaler.scale(loss).backward()
+            set_batch_count(model, params.batch_idx_train)
+            scheduler.step_batch(params.batch_idx_train)
+
+            scaler.step(optimizer)
+            scaler.update()
+            optimizer.zero_grad()
+        except:  # noqa
+            display_and_save_batch(
+                batch, params=params, graph_compiler=mmi_graph_compiler
+            )
+            raise
+
+        if params.print_diagnostics and batch_idx == 5:
+            return
+
+        if (
+            rank == 0
+            and params.batch_idx_train > 0
+            and params.batch_idx_train % params.average_period == 0
+        ):
+            update_averaged_model(
+                params=params,
+                model_cur=model,
+                model_avg=model_avg,
+            )
+
+        if (
+            params.batch_idx_train > 0
+            and params.batch_idx_train % params.save_every_n == 0
+        ):
+            params.cur_batch_idx = batch_idx
+            save_checkpoint_with_global_batch_idx(
+                out_dir=params.exp_dir,
+                global_batch_idx=params.batch_idx_train,
+                model=model,
+                model_avg=model_avg,
+                params=params,
+                optimizer=optimizer,
+                scheduler=scheduler,
+                sampler=train_dl.sampler,
+                scaler=scaler,
+                rank=rank,
+            )
+            del params.cur_batch_idx
+            remove_checkpoints(
+                out_dir=params.exp_dir,
+                topk=params.keep_last_k,
+                rank=rank,
+            )
+
+        if batch_idx % 100 == 0 and params.use_fp16:
+            # If the grad scale was less than 1, try increasing it.    The _growth_interval
+            # of the grad scaler is configurable, but we can't configure it to have different
+            # behavior depending on the current grad scale.
+            cur_grad_scale = scaler._scale.item()
+            if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
+                scaler.update(cur_grad_scale * 2.0)
+            if cur_grad_scale < 0.01:
+                logging.warning(f"Grad scale is small: {cur_grad_scale}")
+            if cur_grad_scale < 1.0e-05:
+                raise RuntimeError(
+                    f"grad_scale is too small, exiting: {cur_grad_scale}"
+                )
+
+        if batch_idx % params.log_interval == 0:
+            cur_lr = scheduler.get_last_lr()[0]
+            cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+            logging.info(
+                f"Epoch {params.cur_epoch}, "
+                f"batch {batch_idx}, loss[{loss_info}], "
+                f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+                f"lr: {cur_lr:.2e}, "
+                + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+            )
+
+            if tb_writer is not None:
+                tb_writer.add_scalar(
+                    "train/learning_rate", cur_lr, params.batch_idx_train
+                )
+
+                loss_info.write_summary(
+                    tb_writer, "train/current_", params.batch_idx_train
+                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                if params.use_fp16:
+                    tb_writer.add_scalar(
+                        "train/grad_scale",
+                        cur_grad_scale,
+                        params.batch_idx_train,
+                    )
+
+        if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+            logging.info("Computing validation loss")
+            valid_info = compute_validation_loss(
+                params=params,
+                model=model,
+                ctc_graph_compiler=ctc_graph_compiler,
+                mmi_graph_compiler=mmi_graph_compiler,
+                valid_dl=valid_dl,
+                world_size=world_size,
+            )
+            model.train()
+            logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+            logging.info(
+                f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+            )
+            if tb_writer is not None:
+                valid_info.write_summary(
+                    tb_writer, "train/valid_", params.batch_idx_train
+                )
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    params.train_loss = loss_value
+    if params.train_loss < params.best_train_loss:
+        params.best_train_epoch = params.cur_epoch
+        params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+    """
+    Args:
+      rank:
+        It is a value between 0 and `world_size-1`, which is
+        passed automatically by `mp.spawn()` in :func:`main`.
+        The node with rank 0 is responsible for saving checkpoint.
+      world_size:
+        Number of GPUs for DDP training.
+      args:
+        The return value of get_parser().parse_args()
+    """
+    params = get_params()
+    params.update(vars(args))
+    if params.full_libri is False:
+        params.valid_interval = 1600
+
+    fix_random_seed(params.seed)
+    if world_size > 1:
+        setup_dist(rank, world_size, params.master_port)
+
+    setup_logger(f"{params.exp_dir}/log/log-train")
+    logging.info("Training started")
+
+    if args.tensorboard and rank == 0:
+        tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+    else:
+        tb_writer = None
+
+    lexicon = Lexicon(params.lang_dir)
+    max_token_id = max(lexicon.tokens)
+    num_classes = max_token_id + 1  # +1 for the blank
+    params.vocab_size = num_classes
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", rank)
+    logging.info(f"Device: {device}")
+
+    assert "lang_bpe" in str(params.lang_dir)
+    ctc_graph_compiler = BpeCtcTrainingGraphCompiler(
+        params.lang_dir,
+        device=device,
+        sos_token="",
+        eos_token="",
+    )
+    mmi_graph_compiler = MmiTrainingGraphCompiler(
+        params.lang_dir,
+        uniq_filename="lexicon.txt",
+        device=device,
+        oov="",
+        sos_id=1,
+        eos_id=1,
+    )
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_ctc_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    assert params.save_every_n >= params.average_period
+    model_avg: Optional[nn.Module] = None
+    if rank == 0:
+        # model_avg is only used with rank 0
+        model_avg = copy.deepcopy(model).to(torch.float64)
+
+    assert params.start_epoch > 0, params.start_epoch
+    checkpoints = load_checkpoint_if_available(
+        params=params, model=model, model_avg=model_avg
+    )
+
+    model.to(device)
+    if world_size > 1:
+        logging.info("Using DDP")
+        model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+    parameters_names = []
+    parameters_names.append(
+        [name_param_pair[0] for name_param_pair in model.named_parameters()]
+    )
+    optimizer = ScaledAdam(
+        model.parameters(),
+        lr=params.base_lr,
+        clipping_scale=2.0,
+        parameters_names=parameters_names,
+    )
+
+    scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+    if checkpoints and "optimizer" in checkpoints:
+        logging.info("Loading optimizer state dict")
+        optimizer.load_state_dict(checkpoints["optimizer"])
+
+    if (
+        checkpoints
+        and "scheduler" in checkpoints
+        and checkpoints["scheduler"] is not None
+    ):
+        logging.info("Loading scheduler state dict")
+        scheduler.load_state_dict(checkpoints["scheduler"])
+
+    if params.print_diagnostics:
+        opts = diagnostics.TensorDiagnosticOptions(
+            2**22
+        )  # allow 4 megabytes per sub-module
+        diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+    if params.inf_check:
+        register_inf_check_hooks(model)
+
+    librispeech = LibriSpeechAsrDataModule(args)
+
+    # train_cuts = librispeech.train_clean_100_cuts()
+    if params.full_libri:
+        # train_cuts += librispeech.train_clean_360_cuts()
+        # train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
+
+    def remove_short_and_long_utt(c: Cut):
+        # Keep only utterances with duration between 1 second and 20 seconds
+        #
+        # Caution: There is a reason to select 20.0 here. Please see
+        # ../local/display_manifest_statistics.py
+        #
+        # You should use ../local/display_manifest_statistics.py to get
+        # an utterance duration distribution for your dataset to select
+        # the threshold
+        return 1.0 <= c.duration <= 20.0
+
+    train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+    if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+        # We only load the sampler's state dict when it loads a checkpoint
+        # saved in the middle of an epoch
+        sampler_state_dict = checkpoints["sampler"]
+    else:
+        sampler_state_dict = None
+
+    train_dl = librispeech.train_dataloaders(
+        train_cuts, sampler_state_dict=sampler_state_dict
+    )
+
+    valid_cuts = librispeech.dev_clean_cuts()
+    valid_cuts += librispeech.dev_other_cuts()
+    valid_dl = librispeech.valid_dataloaders(valid_cuts)
+
+    if not params.print_diagnostics:
+        scan_pessimistic_batches_for_oom(
+            model=model,
+            train_dl=train_dl,
+            optimizer=optimizer,
+            ctc_graph_compiler=ctc_graph_compiler,
+            mmi_graph_compiler=mmi_graph_compiler,
+            params=params,
+        )
+
+    scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+    if checkpoints and "grad_scaler" in checkpoints:
+        logging.info("Loading grad scaler state dict")
+        scaler.load_state_dict(checkpoints["grad_scaler"])
+
+    for epoch in range(params.start_epoch, params.num_epochs + 1):
+        scheduler.step_epoch(epoch - 1)
+        fix_random_seed(params.seed + epoch - 1)
+        train_dl.sampler.set_epoch(epoch - 1)
+
+        if tb_writer is not None:
+            tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+        params.cur_epoch = epoch
+
+        train_one_epoch(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            ctc_graph_compiler=ctc_graph_compiler,
+            mmi_graph_compiler=mmi_graph_compiler,
+            train_dl=train_dl,
+            valid_dl=valid_dl,
+            scaler=scaler,
+            tb_writer=tb_writer,
+            world_size=world_size,
+            rank=rank,
+        )
+
+        if params.print_diagnostics:
+            diagnostic.print_diagnostics()
+            break
+
+        save_checkpoint(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sampler=train_dl.sampler,
+            scaler=scaler,
+            rank=rank,
+        )
+
+    logging.info("Done!")
+
+    if world_size > 1:
+        torch.distributed.barrier()
+        cleanup_dist()
+
+
+def display_and_save_batch(
+    batch: dict,
+    params: AttributeDict,
+    graph_compiler: MmiTrainingGraphCompiler,
+) -> None:
+    """Display the batch statistics and save the batch into disk.
+
+    Args:
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      params:
+        Parameters for training. See :func:`get_params`.
+      sp:
+        The BPE model.
+    """
+    from lhotse.utils import uuid4
+
+    filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+    logging.info(f"Saving batch to {filename}")
+    torch.save(batch, filename)
+
+    supervisions = batch["supervisions"]
+    features = batch["inputs"]
+
+    logging.info(f"features shape: {features.shape}")
+    y = graph_compiler.texts_to_ids(supervisions["text"])
+    num_tokens = sum(len(i) for i in y)
+    logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+    model: Union[nn.Module, DDP],
+    train_dl: torch.utils.data.DataLoader,
+    optimizer: torch.optim.Optimizer,
+    ctc_graph_compiler: BpeCtcTrainingGraphCompiler,
+    mmi_graph_compiler: MmiTrainingGraphCompiler,
+    params: AttributeDict,
+):
+    from lhotse.dataset import find_pessimistic_batches
+
+    logging.info(
+        "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+    )
+    batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+    for criterion, cuts in batches.items():
+        batch = train_dl.dataset[cuts]
+        try:
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, _ = compute_loss(
+                    params=params,
+                    model=model,
+                    ctc_graph_compiler=ctc_graph_compiler,
+                    mmi_graph_compiler=mmi_graph_compiler,
+                    batch=batch,
+                    is_training=True,
+                )
+            loss.backward()
+            optimizer.zero_grad()
+        except Exception as e:
+            if "CUDA out of memory" in str(e):
+                logging.error(
+                    "Your GPU ran out of memory with the current "
+                    "max_duration setting. We recommend decreasing "
+                    "max_duration and trying again.\n"
+                    f"Failing criterion: {criterion} "
+                    f"(={crit_values[criterion]}) ..."
+                )
+            display_and_save_batch(
+                batch, params=params, graph_compiler=mmi_graph_compiler
+            )
+            raise
+        logging.info(
+            f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+        )
+
+
+def main():
+    parser = get_parser()
+    LibriSpeechAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    world_size = args.world_size
+    assert world_size >= 1
+    if world_size > 1:
+        mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+    else:
+        run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/librispeech/ASR/zipformer_mmi/zipformer.py b/egs/librispeech/ASR/zipformer_mmi/zipformer.py
new file mode 120000
index 000000000..79b076556
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/zipformer.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7/zipformer.py
\ No newline at end of file
diff --git a/icefall/decode.py b/icefall/decode.py
index e4c614c4e..68e490c5e 100644
--- a/icefall/decode.py
+++ b/icefall/decode.py
@@ -717,6 +717,107 @@ def rescore_with_n_best_list(
     return ans
 
 
+def nbest_rescore_with_LM(
+    lattice: k2.Fsa,
+    LM: k2.Fsa,
+    num_paths: int,
+    lm_scale_list: List[float],
+    nbest_scale: float = 1.0,
+    use_double_scores: bool = True,
+) -> Dict[str, k2.Fsa]:
+    """Rescore an n-best list with an n-gram LM.
+    The path with the maximum score is used as the decoding output.
+
+    Args:
+      lattice:
+        An FsaVec with axes [utt][state][arc]. It must have the following
+        attributes: ``aux_labels`` and ``lm_scores``. They are both token
+        IDs.
+      LM:
+        An FsaVec containing only a single FSA. It is one of follows:
+        - LG, L is lexicon and G is word-level n-gram LM.
+        - G, token-level n-gram LM.
+      num_paths:
+        Size of nbest list.
+      lm_scale_list:
+        A list of floats representing LM score scales.
+      nbest_scale:
+        Scale to be applied to ``lattice.score`` when sampling paths
+        using ``k2.random_paths``.
+      use_double_scores:
+        True to use double precision during computation. False to use
+        single precision.
+    Returns:
+      A dict of FsaVec, whose key is an lm_scale and the value is the
+      best decoding path for each utterance in the lattice.
+    """
+    device = lattice.device
+
+    assert len(lattice.shape) == 3
+    assert hasattr(lattice, "aux_labels")
+    assert hasattr(lattice, "lm_scores")
+
+    assert LM.shape == (1, None, None)
+    assert LM.device == device
+
+    nbest = Nbest.from_lattice(
+        lattice=lattice,
+        num_paths=num_paths,
+        use_double_scores=use_double_scores,
+        nbest_scale=nbest_scale,
+    )
+    # nbest.fsa.scores contains 0s
+
+    nbest = nbest.intersect(lattice)
+
+    # Now nbest.fsa has its scores set
+    assert hasattr(nbest.fsa, "lm_scores")
+
+    # am scores + bi-gram scores
+    hp_scores = nbest.tot_scores()
+
+    # Now start to intersect nbest with LG or G
+    inv_fsa = k2.invert(nbest.fsa)
+    if hasattr(LM, "aux_labels"):
+        # LM is LG here
+        # delete token IDs as it is not needed
+        del inv_fsa.aux_labels
+    inv_fsa.scores.zero_()
+    inv_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(inv_fsa)
+    path_to_utt_map = nbest.shape.row_ids(1)
+
+    LM = k2.arc_sort(LM)
+    path_lattice = k2.intersect_device(
+        LM,
+        inv_fsa_with_epsilon_loops,
+        b_to_a_map=torch.zeros_like(path_to_utt_map),
+        sorted_match_a=True,
+    )
+
+    # Its labels are token IDs.
+    # If LM is G, its aux_labels are tokens IDs;
+    # If LM is LG, its aux_labels are words IDs.
+    path_lattice = k2.top_sort(k2.connect(path_lattice))
+    one_best = k2.shortest_path(path_lattice, use_double_scores=use_double_scores)
+
+    lm_scores = one_best.get_tot_scores(
+        use_double_scores=use_double_scores,
+        log_semiring=True,  # Note: we always use True
+    )
+    # If LM is LG, we might get empty paths
+    lm_scores[lm_scores == float("-inf")] = -1e9
+
+    ans = dict()
+    for lm_scale in lm_scale_list:
+        tot_scores = hp_scores.values / lm_scale + lm_scores
+        tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
+        max_indexes = tot_scores.argmax()
+        best_path = k2.index_fsa(nbest.fsa, max_indexes)
+        key = f"lm_scale_{lm_scale}"
+        ans[key] = best_path
+    return ans
+
+
 def rescore_with_whole_lattice(
     lattice: k2.Fsa,
     G_with_epsilon_loops: k2.Fsa,
diff --git a/icefall/mmi.py b/icefall/mmi.py
index 16ed6e032..b7777b434 100644
--- a/icefall/mmi.py
+++ b/icefall/mmi.py
@@ -112,8 +112,12 @@ def _compute_mmi_loss_exact_non_optimized(
     num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True)
 
     # TODO: pass output_beam as function argument
-    num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=beam_size)
-    den_lats = k2.intersect_dense(den_graphs, dense_fsa_vec, output_beam=beam_size)
+    num_lats = k2.intersect_dense(
+        num_graphs, dense_fsa_vec, output_beam=beam_size, max_arcs=2147483600
+    )
+    den_lats = k2.intersect_dense(
+        den_graphs, dense_fsa_vec, output_beam=beam_size, max_arcs=2147483600
+    )
 
     num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
 
@@ -144,7 +148,7 @@ def _compute_mmi_loss_pruned(
     """
     num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=False)
 
-    num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=10.0)
+    num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=8.0)
 
     # the values for search_beam/output_beam/min_active_states/max_active_states
     # are not tuned. You may want to tune them.

From 0470bbae66d2c9ebc91ee5d0dfa37dfb4df3a9cb Mon Sep 17 00:00:00 2001
From: Zengwei Yao 
Date: Tue, 13 Dec 2022 15:47:30 +0800
Subject: [PATCH 074/120] minor fix for zipformer recipe (#758)

* minor fix

* add CI test
---
 .github/workflows/test.yml                    |  3 +++
 .../pruned_transducer_stateless7/export.py    |  1 -
 .../test_model.py                             | 20 +++++++++++++++----
 .../pruned_transducer_stateless7/zipformer.py | 16 +++++----------
 4 files changed, 24 insertions(+), 16 deletions(-)

diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 4dbe99827..c062a2a3d 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -113,6 +113,9 @@ jobs:
           cd ../pruned_transducer_stateless4
           pytest -v -s
 
+          cd ../pruned_transducer_stateless7
+          pytest -v -s
+
           cd ../transducer_stateless
           pytest -v -s
 
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py
index 9a6f3ed37..3e3160e7e 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py
@@ -294,7 +294,6 @@ def main():
 
     if params.jit is True:
         convert_scaled_to_non_scaled(model, inplace=True)
-        logging.info("Using torch.jit.script()")
         # We won't use the forward() method of the model in C++, so just ignore
         # it here.
         # Otherwise, one of its arguments is a ragged tensor and is not
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/test_model.py
index db7fb7b3e..cdf914df3 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/test_model.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/test_model.py
@@ -20,19 +20,21 @@
 To run this file, do:
 
     cd icefall/egs/librispeech/ASR
-    python ./pruned_transducer_stateless4/test_model.py
+    python ./pruned_transducer_stateless7/test_model.py
 """
 
+import torch
+
+from scaling_converter import convert_scaled_to_non_scaled
 from train import get_params, get_transducer_model
 
 
-def test_model_1():
+def test_model():
     params = get_params()
     params.vocab_size = 500
     params.blank_id = 0
     params.context_size = 2
     params.num_encoder_layers = "2,4,3,2,4"
-    #  params.feedforward_dims = "1024,1024,1536,1536,1024"
     params.feedforward_dims = "1024,1024,2048,2048,1024"
     params.nhead = "8,8,8,8,8"
     params.encoder_dims = "384,384,384,384,384"
@@ -47,9 +49,19 @@ def test_model_1():
     num_param = sum([p.numel() for p in model.parameters()])
     print(f"Number of model parameters: {num_param}")
 
+    # Test jit script
+    convert_scaled_to_non_scaled(model, inplace=True)
+    # We won't use the forward() method of the model in C++, so just ignore
+    # it here.
+    # Otherwise, one of its arguments is a ragged tensor and is not
+    # torch scriptabe.
+    model.__class__.forward = torch.jit.ignore(model.__class__.forward)
+    print("Using torch.jit.script")
+    model = torch.jit.script(model)
+
 
 def main():
-    test_model_1()
+    test_model()
 
 
 if __name__ == "__main__":
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py
index e8fd89abd..ed1e2efa2 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py
@@ -1,5 +1,5 @@
 #!/usr/bin/env python3
-# Copyright (c)  2021  University of Chinese Academy of Sciences (author: Han Zhu)
+# Copyright    2022  Xiaomi Corp.        (authors: Daniel Povey)
 #
 # See ../../../../LICENSE for clarification regarding multiple authors
 #
@@ -454,7 +454,7 @@ class ZipformerEncoderLayer(nn.Module):
         # pooling module
         if torch.jit.is_scripting():
             src = src + self.pooling(src, key_padding_mask=src_key_padding_mask)
-        elif random.random() > dynamic_dropout:
+        elif random.random() >= dynamic_dropout:
             src = src + self.pooling(src, key_padding_mask=src_key_padding_mask)
 
         if torch.jit.is_scripting():
@@ -478,7 +478,7 @@ class ZipformerEncoderLayer(nn.Module):
                 src, src_key_padding_mask=src_key_padding_mask
             )
         else:
-            use_self_attn = random.random() > dynamic_dropout
+            use_self_attn = random.random() >= dynamic_dropout
             if use_self_attn:
                 src_att, attn_weights = self.self_attn(
                     src,
@@ -488,7 +488,7 @@ class ZipformerEncoderLayer(nn.Module):
                 )
                 src = src + src_att
 
-            if random.random() > dynamic_dropout:
+            if random.random() >= dynamic_dropout:
                 src = src + self.conv_module1(
                     src, src_key_padding_mask=src_key_padding_mask
                 )
@@ -497,7 +497,7 @@ class ZipformerEncoderLayer(nn.Module):
             if use_self_attn:
                 src = src + self.self_attn.forward2(src, attn_weights)
 
-            if random.random() > dynamic_dropout:
+            if random.random() >= dynamic_dropout:
                 src = src + self.conv_module2(
                     src, src_key_padding_mask=src_key_padding_mask
                 )
@@ -1289,12 +1289,6 @@ class RelPositionMultiheadAttention(nn.Module):
             bsz * num_heads, seq_len, seq_len
         )
 
-        assert list(attn_output_weights.size()) == [
-            bsz * num_heads,
-            seq_len,
-            seq_len,
-        ]
-
         if attn_mask is not None:
             if attn_mask.dtype == torch.bool:
                 attn_output_weights.masked_fill_(attn_mask, float("-inf"))

From b293db4baf1606cfe95066cf28ffde56173a7ddb Mon Sep 17 00:00:00 2001
From: Daniil 
Date: Tue, 13 Dec 2022 03:13:26 -0500
Subject: [PATCH 075/120] Tedlium3 conformer ctc2 (#696)

* modify preparation

* small refacor

* add tedlium3 conformer_ctc2

* modify decode

* filter unk in decode

* add scaling converter

* address comments

* fix lambda function lhotse

* add implicit manifest shuffle

* refactor ctc_greedy_search

* import model arguments from train.py

* style fix

* fix ci test and last style issues

* update RESULTS

* fix RESULTS numbers

* fix label smoothing loss

* update model parameters number in RESULTS
---
 .../ASR/conformer_ctc/label_smoothing.py      |    3 +-
 .../ASR/conformer_ctc2/subsampling.py         |    5 +-
 .../emformer2.py                              |    4 +-
 egs/librispeech/ASR/local/compile_hlg.py      |    2 +-
 .../ASR/local/compute_fbank_musan.py          |    8 +-
 egs/librispeech/ASR/local/prepare_lang_bpe.py |   23 +-
 .../pruned_transducer_stateless2/scaling.py   |   20 +-
 .../scaling_converter.py                      |    2 +-
 egs/tedlium3/ASR/RESULTS.md                   |   83 ++
 egs/tedlium3/ASR/conformer_ctc2/__init__.py   |    0
 .../ASR/conformer_ctc2/asr_datamodule.py      |    1 +
 egs/tedlium3/ASR/conformer_ctc2/attention.py  |  201 +++
 egs/tedlium3/ASR/conformer_ctc2/combiner.py   |  244 ++++
 egs/tedlium3/ASR/conformer_ctc2/conformer.py  | 1033 ++++++++++++++++
 egs/tedlium3/ASR/conformer_ctc2/decode.py     |  899 ++++++++++++++
 egs/tedlium3/ASR/conformer_ctc2/export.py     |  294 +++++
 .../ASR/conformer_ctc2/label_smoothing.py     |    1 +
 egs/tedlium3/ASR/conformer_ctc2/lstmp.py      |    1 +
 egs/tedlium3/ASR/conformer_ctc2/optim.py      |    1 +
 egs/tedlium3/ASR/conformer_ctc2/scaling.py    |    1 +
 .../ASR/conformer_ctc2/scaling_converter.py   |    1 +
 .../ASR/conformer_ctc2/subsampling.py         |    1 +
 egs/tedlium3/ASR/conformer_ctc2/train.py      | 1061 ++++++++++++++++
 .../ASR/conformer_ctc2/transformer.py         | 1093 +++++++++++++++++
 .../convert_transcript_words_to_bpe_ids.py    |   42 +-
 .../convert_transcript_words_to_tokens.py     |    1 -
 .../ASR/local/generate_unique_lexicon.py      |    1 -
 egs/tedlium3/ASR/local/prepare_lang.py        |    1 -
 egs/tedlium3/ASR/local/prepare_lexicon.py     |   94 --
 egs/tedlium3/ASR/local/prepare_transcripts.py |   66 +-
 egs/tedlium3/ASR/local/prepare_words.py       |   83 ++
 egs/tedlium3/ASR/local/test_prepare_lang.py   |    1 -
 egs/tedlium3/ASR/prepare.sh                   |   98 +-
 icefall/decode.py                             |    2 -
 test/test_lexicon.py                          |    2 +-
 35 files changed, 5158 insertions(+), 215 deletions(-)
 create mode 100755 egs/tedlium3/ASR/conformer_ctc2/__init__.py
 create mode 120000 egs/tedlium3/ASR/conformer_ctc2/asr_datamodule.py
 create mode 100644 egs/tedlium3/ASR/conformer_ctc2/attention.py
 create mode 100644 egs/tedlium3/ASR/conformer_ctc2/combiner.py
 create mode 100644 egs/tedlium3/ASR/conformer_ctc2/conformer.py
 create mode 100755 egs/tedlium3/ASR/conformer_ctc2/decode.py
 create mode 100755 egs/tedlium3/ASR/conformer_ctc2/export.py
 create mode 120000 egs/tedlium3/ASR/conformer_ctc2/label_smoothing.py
 create mode 120000 egs/tedlium3/ASR/conformer_ctc2/lstmp.py
 create mode 120000 egs/tedlium3/ASR/conformer_ctc2/optim.py
 create mode 120000 egs/tedlium3/ASR/conformer_ctc2/scaling.py
 create mode 120000 egs/tedlium3/ASR/conformer_ctc2/scaling_converter.py
 create mode 120000 egs/tedlium3/ASR/conformer_ctc2/subsampling.py
 create mode 100755 egs/tedlium3/ASR/conformer_ctc2/train.py
 create mode 100644 egs/tedlium3/ASR/conformer_ctc2/transformer.py
 delete mode 120000 egs/tedlium3/ASR/local/convert_transcript_words_to_tokens.py
 delete mode 120000 egs/tedlium3/ASR/local/generate_unique_lexicon.py
 delete mode 120000 egs/tedlium3/ASR/local/prepare_lang.py
 delete mode 100755 egs/tedlium3/ASR/local/prepare_lexicon.py
 create mode 100755 egs/tedlium3/ASR/local/prepare_words.py
 delete mode 120000 egs/tedlium3/ASR/local/test_prepare_lang.py

diff --git a/egs/librispeech/ASR/conformer_ctc/label_smoothing.py b/egs/librispeech/ASR/conformer_ctc/label_smoothing.py
index cb0d6e04d..52d2eda3b 100644
--- a/egs/librispeech/ASR/conformer_ctc/label_smoothing.py
+++ b/egs/librispeech/ASR/conformer_ctc/label_smoothing.py
@@ -44,7 +44,8 @@ class LabelSmoothingLoss(torch.nn.Module):
             mean of the output is taken. (3) "sum": the output will be summed.
         """
         super().__init__()
-        assert 0.0 <= label_smoothing < 1.0
+        assert 0.0 <= label_smoothing < 1.0, f"{label_smoothing}"
+        assert reduction in ("none", "sum", "mean"), reduction
         self.ignore_index = ignore_index
         self.label_smoothing = label_smoothing
         self.reduction = reduction
diff --git a/egs/librispeech/ASR/conformer_ctc2/subsampling.py b/egs/librispeech/ASR/conformer_ctc2/subsampling.py
index 3fcb4196f..85a4dc8df 100644
--- a/egs/librispeech/ASR/conformer_ctc2/subsampling.py
+++ b/egs/librispeech/ASR/conformer_ctc2/subsampling.py
@@ -24,10 +24,9 @@ from scaling import (
     ScaledConv2d,
     ScaledLinear,
 )
-from torch import nn
 
 
-class Conv2dSubsampling(nn.Module):
+class Conv2dSubsampling(torch.nn.Module):
     """Convolutional 2D subsampling (to 1/4 length).
 
     Convert an input of shape (N, T, idim) to an output
@@ -61,7 +60,7 @@ class Conv2dSubsampling(nn.Module):
         assert in_channels >= 7
         super().__init__()
 
-        self.conv = nn.Sequential(
+        self.conv = torch.nn.Sequential(
             ScaledConv2d(
                 in_channels=1,
                 out_channels=layer1_channels,
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py
index 65a7efa77..188059044 100644
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py
@@ -1435,7 +1435,7 @@ class EmformerEncoder(nn.Module):
         self,
         x: torch.Tensor,
         states: List[torch.Tensor],
-    ) -> Tuple[torch.Tensor, List[torch.Tensor],]:
+    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
         """Forward pass for streaming inference.
 
         B: batch size;
@@ -1640,7 +1640,7 @@ class Emformer(EncoderInterface):
         self,
         x: torch.Tensor,
         states: List[torch.Tensor],
-    ) -> Tuple[torch.Tensor, List[torch.Tensor],]:
+    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
         """Forward pass for streaming inference.
 
         B: batch size;
diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py
index df6c609bb..08dac6a7b 100755
--- a/egs/librispeech/ASR/local/compile_hlg.py
+++ b/egs/librispeech/ASR/local/compile_hlg.py
@@ -24,7 +24,7 @@ This script takes as input lang_dir and generates HLG from
 
         Caution: We use a lexicon that contains disambiguation symbols
 
-    - G, the LM, built from data/lm/G_3_gram.fst.txt
+    - G, the LM, built from data/lm/G_n_gram.fst.txt
 
 The generated HLG is saved in $lang_dir/HLG.pt
 """
diff --git a/egs/librispeech/ASR/local/compute_fbank_musan.py b/egs/librispeech/ASR/local/compute_fbank_musan.py
index 4a4093ae4..62036467e 100755
--- a/egs/librispeech/ASR/local/compute_fbank_musan.py
+++ b/egs/librispeech/ASR/local/compute_fbank_musan.py
@@ -28,7 +28,7 @@ import os
 from pathlib import Path
 
 import torch
-from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, combine
+from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, MonoCut, combine
 from lhotse.recipes.utils import read_manifests_if_cached
 
 from icefall.utils import get_executor
@@ -41,6 +41,10 @@ torch.set_num_threads(1)
 torch.set_num_interop_threads(1)
 
 
+def is_cut_long(c: MonoCut) -> bool:
+    return c.duration > 5
+
+
 def compute_fbank_musan():
     src_dir = Path("data/manifests")
     output_dir = Path("data/fbank")
@@ -86,7 +90,7 @@ def compute_fbank_musan():
                 recordings=combine(part["recordings"] for part in manifests.values())
             )
             .cut_into_windows(10.0)
-            .filter(lambda c: c.duration > 5)
+            .filter(is_cut_long)
             .compute_and_store_features(
                 extractor=extractor,
                 storage_path=f"{output_dir}/musan_feats",
diff --git a/egs/librispeech/ASR/local/prepare_lang_bpe.py b/egs/librispeech/ASR/local/prepare_lang_bpe.py
index e121aefa9..2a2d9c219 100755
--- a/egs/librispeech/ASR/local/prepare_lang_bpe.py
+++ b/egs/librispeech/ASR/local/prepare_lang_bpe.py
@@ -127,7 +127,7 @@ def lexicon_to_fst_no_sil(
 
 
 def generate_lexicon(
-    model_file: str, words: List[str]
+    model_file: str, words: List[str], oov: str
 ) -> Tuple[Lexicon, Dict[str, int]]:
     """Generate a lexicon from a BPE model.
 
@@ -136,6 +136,8 @@ def generate_lexicon(
         Path to a sentencepiece model.
       words:
         A list of strings representing words.
+      oov:
+        The out of vocabulary word in lexicon.
     Returns:
       Return a tuple with two elements:
         - A dict whose keys are words and values are the corresponding
@@ -156,12 +158,9 @@ def generate_lexicon(
     for word, pieces in zip(words, words_pieces):
         lexicon.append((word, pieces))
 
-    # The OOV word is 
-    lexicon.append(("", [sp.id_to_piece(sp.unk_id())]))
+    lexicon.append((oov, ["▁", sp.id_to_piece(sp.unk_id())]))
 
-    token2id: Dict[str, int] = dict()
-    for i in range(sp.vocab_size()):
-        token2id[sp.id_to_piece(i)] = i
+    token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())}
 
     return lexicon, token2id
 
@@ -176,6 +175,13 @@ def get_args():
         """,
     )
 
+    parser.add_argument(
+        "--oov",
+        type=str,
+        default="",
+        help="The out of vocabulary word in lexicon.",
+    )
+
     parser.add_argument(
         "--debug",
         type=str2bool,
@@ -202,12 +208,13 @@ def main():
 
     words = word_sym_table.symbols
 
-    excluded = ["", "!SIL", "", "", "#0", "", ""]
+    excluded = ["", "!SIL", "", args.oov, "#0", "", ""]
+
     for w in excluded:
         if w in words:
             words.remove(w)
 
-    lexicon, token_sym_table = generate_lexicon(model_file, words)
+    lexicon, token_sym_table = generate_lexicon(model_file, words, args.oov)
 
     lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
 
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py
index c802ecf89..963ebdc2d 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py
@@ -652,16 +652,16 @@ class ActivationBalancer(torch.nn.Module):
     def forward(self, x: Tensor) -> Tensor:
         if random.random() >= self.balance_prob:
             return x
-        else:
-            return ActivationBalancerFunction.apply(
-                x,
-                self.channel_dim,
-                self.min_positive,
-                self.max_positive,
-                self.max_factor / self.balance_prob,
-                self.min_abs,
-                self.max_abs,
-            )
+
+        return ActivationBalancerFunction.apply(
+            x,
+            self.channel_dim,
+            self.min_positive,
+            self.max_positive,
+            self.max_factor / self.balance_prob,
+            self.min_abs,
+            self.max_abs,
+        )
 
 
 class DoubleSwishFunction(torch.autograd.Function):
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py
index b712eeda0..a6540c584 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py
@@ -282,7 +282,7 @@ def convert_scaled_to_non_scaled(
     if not inplace:
         model = copy.deepcopy(model)
 
-    excluded_patterns = r"self_attn\.(in|out)_proj"
+    excluded_patterns = r"(self|src)_attn\.(in|out)_proj"
     p = re.compile(excluded_patterns)
 
     d = {}
diff --git a/egs/tedlium3/ASR/RESULTS.md b/egs/tedlium3/ASR/RESULTS.md
index 511b19f73..38eaa8f44 100644
--- a/egs/tedlium3/ASR/RESULTS.md
+++ b/egs/tedlium3/ASR/RESULTS.md
@@ -1,5 +1,88 @@
 ## Results
 
+### TedLium3 BPE training results (Conformer-CTC 2)
+
+#### [conformer_ctc2](./conformer_ctc2)
+
+See  for more details.
+
+The tensorboard log can be found at
+
+
+You can find a pretrained model and decoding results at:
+
+
+Number of model parameters: 101141699, i.e., 101.14 M
+
+The WERs are
+
+|                          | dev        | test        | comment             |
+|--------------------------|------------|-------------|---------------------|
+| ctc decoding             | 6.45       | 5.96        | --epoch 38 --avg 26 |
+| 1best                    | 5.92       | 5.51        | --epoch 38 --avg 26 |
+| whole lattice rescoring  | 5.96       | 5.47        | --epoch 38 --avg 26 |
+| attention decoder        | 5.60       | 5.33        | --epoch 38 --avg 26 |
+
+The training command for reproducing is given below:
+
+```
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./conformer_ctc2/train.py \
+    --world-size 4 \
+    --num-epochs 40 \
+    --exp-dir conformer_ctc2/exp \
+    --max-duration 350 \
+    --use-fp16 true
+```
+
+The decoding command is:
+```
+epoch=38
+avg=26
+
+## ctc decoding
+./conformer_ctc2/decode.py \
+  --method ctc-decoding \
+  --exp-dir conformer_ctc2/exp \
+  --lang-dir data/lang_bpe_500 \
+  --result-dir conformer_ctc2/exp \
+  --max-duration 500 \
+  --epoch $epoch \
+  --avg $avg
+
+## 1best
+./conformer_ctc2/decode.py \
+  --method 1best \
+  --exp-dir conformer_ctc2/exp \
+  --lang-dir data/lang_bpe_500 \
+  --result-dir conformer_ctc2/exp \
+  --max-duration 500 \
+  --epoch $epoch \
+  --avg $avg
+
+## whole lattice rescoring
+./conformer_ctc2/decode.py \
+  --method whole-lattice-rescoring \
+  --exp-dir conformer_ctc2/exp \
+  --lm-path data/lm/G_4_gram_big.pt \
+  --lang-dir data/lang_bpe_500 \
+  --result-dir conformer_ctc2/exp \
+  --max-duration 500 \
+  --epoch $epoch \
+  --avg $avg
+
+## attention decoder
+./conformer_ctc2/decode.py \
+  --method attention-decoder \
+  --exp-dir conformer_ctc2/exp \
+  --lang-dir data/lang_bpe_500 \
+  --result-dir conformer_ctc2/exp \
+  --max-duration 500 \
+  --epoch $epoch \
+  --avg $avg
+```
+
 ### TedLium3 BPE training results (Pruned Transducer)
 
 #### 2022-03-21
diff --git a/egs/tedlium3/ASR/conformer_ctc2/__init__.py b/egs/tedlium3/ASR/conformer_ctc2/__init__.py
new file mode 100755
index 000000000..e69de29bb
diff --git a/egs/tedlium3/ASR/conformer_ctc2/asr_datamodule.py b/egs/tedlium3/ASR/conformer_ctc2/asr_datamodule.py
new file mode 120000
index 000000000..49b2ee483
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/asr_datamodule.py
@@ -0,0 +1 @@
+../transducer_stateless/asr_datamodule.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/conformer_ctc2/attention.py b/egs/tedlium3/ASR/conformer_ctc2/attention.py
new file mode 100644
index 000000000..178cd7e62
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/attention.py
@@ -0,0 +1,201 @@
+# Copyright    2022  Behavox LLC.        (author: Daniil Kulko)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Tuple, Union
+
+import torch
+from scaling import ScaledLinear
+
+
+class MultiheadAttention(torch.nn.Module):
+    """Allows the model to jointly attend to information
+    from different representation subspaces. This is a modified
+    version of the original version of multihead attention
+    (see Attention Is All You Need )
+    with replacement of input / output projection layers
+    with newly introduced ScaleLinear layer
+    (see https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py).
+
+    Args:
+        embed_dim:
+          total dimension of the model.
+        num_heads:
+          number of parallel attention heads. Note that embed_dim will be split
+          across num_heads, i.e. each head will have dimension (embed_dim // num_heads).
+        dropout:
+          dropout probability on attn_output_weights. (default=0.0).
+        bias:
+          if specified, adds bias to input / output projection layers (default=True).
+        add_bias_kv:
+          if specified, adds bias to the key and value sequences at dim=0 (default=False).
+        add_zero_attn:
+          if specified, adds a new batch of zeros to the key and value sequences
+          at dim=1 (default=False).
+        batch_first:
+          if True, then the input and output tensors are provided as
+          (batch, seq, feature), otherwise (seq, batch, feature) (default=False).
+
+    Examples::
+        >>> multihead_attn = MultiheadAttention(embed_dim, num_heads)
+        >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
+    """
+
+    def __init__(
+        self,
+        embed_dim: int,
+        num_heads: int,
+        dropout: float = 0.0,
+        bias: bool = True,
+        add_bias_kv: bool = False,
+        add_zero_attn: bool = False,
+        batch_first: bool = False,
+        device: Union[torch.device, str, None] = None,
+        dtype: Union[torch.dtype, str, None] = None,
+    ) -> None:
+
+        super().__init__()
+
+        self.embed_dim = embed_dim
+        self.num_heads = num_heads
+        self.dropout = dropout
+        self.batch_first = batch_first
+
+        if embed_dim % num_heads != 0:
+            raise ValueError(
+                f"embed_dim must be divisible by num_heads. "
+                "Got embedding dim vs number 0f heads: "
+                f"{embed_dim} vs {num_heads}"
+            )
+
+        self.head_dim = embed_dim // num_heads
+
+        self.in_proj = ScaledLinear(
+            embed_dim,
+            3 * embed_dim,
+            bias=bias,
+            device=device,
+            dtype=dtype,
+        )
+        self.out_proj = ScaledLinear(
+            embed_dim,
+            embed_dim,
+            bias=bias,
+            initial_scale=0.25,
+            device=device,
+            dtype=dtype,
+        )
+
+        if add_bias_kv:
+            self.bias_k = torch.nn.Parameter(
+                torch.empty((1, 1, embed_dim), device=device, dtype=dtype)
+            )
+            self.bias_v = torch.nn.Parameter(
+                torch.empty((1, 1, embed_dim), device=device, dtype=dtype)
+            )
+        else:
+            self.register_parameter("bias_k", None)
+            self.register_parameter("bias_v", None)
+
+        self.add_zero_attn = add_zero_attn
+
+        self._reset_parameters()
+
+    def _reset_parameters(self) -> None:
+        if self.bias_k is not None:
+            torch.nn.init.xavier_normal_(self.bias_k)
+        if self.bias_v is not None:
+            torch.nn.init.xavier_normal_(self.bias_v)
+
+    def forward(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        key_padding_mask: Optional[torch.Tensor] = None,
+        need_weights: bool = True,
+        attn_mask: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+        """
+        Args:
+            query:
+              Query embeddings of shape (L, N, E_q) when batch_first=False or (N, L, E_q)
+              when batch_first=True, where L is the target sequence length, N is the batch size,
+              and E_q is the query embedding dimension embed_dim. Queries are compared against
+              key-value pairs to produce the output. See "Attention Is All You Need" for more details.
+            key:
+              Key embeddings of shape (S, N, E_k) when batch_first=False or (N, S, E_k) when
+              batch_first=True, where S is the source sequence length, N is the batch size, and
+              E_k is the key embedding dimension kdim. See "Attention Is All You Need" for more details.
+            value:
+              Value embeddings of shape (S, N, E_v) when batch_first=False or (N, S, E_v) when
+              batch_first=True, where S is the source sequence length, N is the batch size, and
+              E_v is the value embedding dimension vdim. See "Attention Is All You Need" for more details.
+            key_padding_mask:
+              If specified, a mask of shape (N, S) indicating which elements within key
+              to ignore for the purpose of attention (i.e. treat as "padding").
+              Binary and byte masks are supported. For a binary mask, a True value indicates
+              that the corresponding key value will be ignored for the purpose of attention.
+              For a byte mask, a non-zero value indicates that the corresponding key value will be ignored.
+            need_weights:
+              If specifid, returns attn_output_weights in addition to attn_outputs (default=True).
+            attn_mask:
+              If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
+              (L, S) or (N * num_heads, L, S), where N is the batch size, L is the target sequence length,
+              and S is the source sequence length. A 2D mask will be broadcasted across the batch while
+              a 3D mask allows for a different mask for each entry in the batch.
+              Binary, byte, and float masks are supported. For a binary mask, a True value indicates
+              that the corresponding position is not allowed to attend. For a byte mask, a non-zero
+              value indicates that the corresponding position is not allowed to attend. For a float mask,
+              the mask values will be added to the attention weight.
+
+        Returns:
+            attn_output:
+              Attention outputs of shape (L, N, E) when batch_first=False or (N, L, E) when batch_first=True,
+              where L is the target sequence length, N is the batch size, and E is the embedding dimension
+              embed_dim.
+            attn_output_weights:
+              Attention output weights of shape (N, L, S), where N is the batch size, L is the target sequence
+              length, and S is the source sequence length. Only returned when need_weights=True.
+        """
+        if self.batch_first:
+            query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
+
+        (
+            attn_output,
+            attn_output_weights,
+        ) = torch.nn.functional.multi_head_attention_forward(
+            query,
+            key,
+            value,
+            self.embed_dim,
+            self.num_heads,
+            in_proj_weight=self.in_proj.get_weight(),
+            in_proj_bias=self.in_proj.get_bias(),
+            bias_k=self.bias_k,
+            bias_v=self.bias_v,
+            add_zero_attn=self.add_zero_attn,
+            dropout_p=self.dropout,
+            out_proj_weight=self.out_proj.get_weight(),
+            out_proj_bias=self.out_proj.get_bias(),
+            training=self.training,
+            key_padding_mask=key_padding_mask,
+            need_weights=need_weights,
+            attn_mask=attn_mask,
+        )
+
+        if self.batch_first:
+            return attn_output.transpose(1, 0), attn_output_weights
+        return attn_output, attn_output_weights
diff --git a/egs/tedlium3/ASR/conformer_ctc2/combiner.py b/egs/tedlium3/ASR/conformer_ctc2/combiner.py
new file mode 100644
index 000000000..ff526029d
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/combiner.py
@@ -0,0 +1,244 @@
+# Copyright    2022  Behavox LLC.        (author: Daniil Kulko)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List
+
+import torch
+
+
+class RandomCombine(torch.nn.Module):
+    """
+    This module combines a list of Tensors, all with the same shape, to
+    produce a single output of that same shape which, in training time,
+    is a random combination of all the inputs; but which in test time
+    will be just the last input.
+    The idea is that the list of Tensors will be a list of outputs of multiple
+    conformer layers.  This has a similar effect as iterated loss. (See:
+    DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER
+    NETWORKS).
+    """
+
+    def __init__(
+        self,
+        num_inputs: int,
+        final_weight: float = 0.5,
+        pure_prob: float = 0.5,
+        stddev: float = 2.0,
+    ) -> None:
+        """
+        Args:
+          num_inputs:
+            The number of tensor inputs, which equals the number of layers'
+            outputs that are fed into this module.  E.g. in an 18-layer neural
+            net if we output layers 16, 12, 18, num_inputs would be 3.
+          final_weight:
+            The amount of weight or probability we assign to the
+            final layer when randomly choosing layers or when choosing
+            continuous layer weights.
+          pure_prob:
+            The probability, on each frame, with which we choose
+            only a single layer to output (rather than an interpolation)
+          stddev:
+            A standard deviation that we add to log-probs for computing
+            randomized weights.
+        The method of choosing which layers, or combinations of layers, to use,
+        is conceptually as follows::
+            With probability `pure_prob`::
+               With probability `final_weight`: choose final layer,
+               Else: choose random non-final layer.
+            Else::
+               Choose initial log-weights that correspond to assigning
+               weight `final_weight` to the final layer and equal
+               weights to other layers; then add Gaussian noise
+               with variance `stddev` to these log-weights, and normalize
+               to weights (note: the average weight assigned to the
+               final layer here will not be `final_weight` if stddev>0).
+        """
+        super().__init__()
+        assert 0 <= pure_prob <= 1, pure_prob
+        assert 0 < final_weight < 1, final_weight
+        assert num_inputs >= 1, num_inputs
+
+        self.num_inputs = num_inputs
+        self.final_weight = final_weight
+        self.pure_prob = pure_prob
+        self.stddev = stddev
+
+        self.final_log_weight = (
+            torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1))
+            .log()
+            .item()
+        )
+
+    def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
+        """Forward function.
+        Args:
+          inputs:
+            A list of Tensor, e.g. from various layers of a transformer.
+            All must be the same shape, of (*, num_channels)
+        Returns:
+          A Tensor of shape (*, num_channels). In test mode
+          this is just the final input.
+        """
+        num_inputs = self.num_inputs
+        assert len(inputs) == num_inputs, f"{len(inputs)}, {num_inputs}"
+        if not self.training or torch.jit.is_scripting() or len(inputs) == 1:
+            return inputs[-1]
+
+        # Shape of weights: (*, num_inputs)
+        num_channels = inputs[0].shape[-1]
+        num_frames = inputs[0].numel() // num_channels
+
+        ndim = inputs[0].ndim
+        # stacked_inputs: (num_frames, num_channels, num_inputs)
+        stacked_inputs = torch.stack(inputs, dim=ndim).reshape(
+            (num_frames, num_channels, num_inputs)
+        )
+
+        # weights: (num_frames, num_inputs)
+        weights = self._get_random_weights(
+            inputs[0].dtype, inputs[0].device, num_frames
+        )
+
+        weights = weights.reshape(num_frames, num_inputs, 1)
+        # ans: (num_frames, num_channels, 1)
+        ans = torch.matmul(stacked_inputs, weights)
+        # ans: (*, num_channels)
+
+        ans = ans.reshape(inputs[0].shape[:-1] + (num_channels,))
+
+        return ans
+
+    def _get_random_weights(
+        self, dtype: torch.dtype, device: torch.device, num_frames: int
+    ) -> torch.Tensor:
+        """Return a tensor of random weights, of shape
+        `(num_frames, self.num_inputs)`,
+        Args:
+          dtype:
+            The data-type desired for the answer, e.g. float, double.
+          device:
+            The device needed for the answer.
+          num_frames:
+            The number of sets of weights desired
+        Returns:
+          A tensor of shape (num_frames, self.num_inputs), such that
+          `ans.sum(dim=1)` is all ones.
+        """
+        pure_prob = self.pure_prob
+        if pure_prob == 0.0:
+            return self._get_random_mixed_weights(dtype, device, num_frames)
+        elif pure_prob == 1.0:
+            return self._get_random_pure_weights(dtype, device, num_frames)
+        else:
+            p = self._get_random_pure_weights(dtype, device, num_frames)
+            m = self._get_random_mixed_weights(dtype, device, num_frames)
+            return torch.where(
+                torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m
+            )
+
+    def _get_random_pure_weights(
+        self, dtype: torch.dtype, device: torch.device, num_frames: int
+    ) -> torch.Tensor:
+        """Return a tensor of random one-hot weights, of shape
+        `(num_frames, self.num_inputs)`,
+        Args:
+          dtype:
+            The data-type desired for the answer, e.g. float, double.
+          device:
+            The device needed for the answer.
+          num_frames:
+            The number of sets of weights desired.
+        Returns:
+          A one-hot tensor of shape `(num_frames, self.num_inputs)`, with
+          exactly one weight equal to 1.0 on each frame.
+        """
+        final_prob = self.final_weight
+
+        # final contains self.num_inputs - 1 in all elements
+        final = torch.full((num_frames,), self.num_inputs - 1, device=device)
+        # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights.
+        nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device)
+
+        indexes = torch.where(
+            torch.rand(num_frames, device=device) < final_prob, final, nonfinal
+        )
+        ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to(
+            dtype=dtype
+        )
+        return ans
+
+    def _get_random_mixed_weights(
+        self, dtype: torch.dtype, device: torch.device, num_frames: int
+    ) -> torch.Tensor:
+        """Return a tensor of random one-hot weights, of shape
+        `(num_frames, self.num_inputs)`,
+        Args:
+          dtype:
+            The data-type desired for the answer, e.g. float, double.
+          device:
+            The device needed for the answer.
+          num_frames:
+            The number of sets of weights desired.
+        Returns:
+          A tensor of shape (num_frames, self.num_inputs), which elements
+          in [0..1] that sum to one over the second axis, i.e.
+          `ans.sum(dim=1)` is all ones.
+        """
+        logprobs = (
+            torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device)
+            * self.stddev
+        )
+        logprobs[:, -1] += self.final_log_weight
+        return logprobs.softmax(dim=1)
+
+
+def _test_random_combine(
+    final_weight: float,
+    pure_prob: float,
+    stddev: float,
+) -> None:
+    print(
+        f"_test_random_combine: final_weight={final_weight}, "
+        f"pure_prob={pure_prob}, stddev={stddev}"
+    )
+    num_inputs = 3
+    num_channels = 50
+    m = RandomCombine(
+        num_inputs=num_inputs,
+        final_weight=final_weight,
+        pure_prob=pure_prob,
+        stddev=stddev,
+    )
+
+    x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)]
+
+    y = m(x)
+    assert y.shape == x[0].shape
+    assert torch.allclose(y, x[0])  # .. since actually all ones.
+
+
+def _test_random_combine_main() -> None:
+    _test_random_combine(0.999, 0, 0.0)
+    _test_random_combine(0.5, 0, 0.0)
+    _test_random_combine(0.999, 0, 0.0)
+    _test_random_combine(0.5, 0, 0.3)
+    _test_random_combine(0.5, 1, 0.3)
+    _test_random_combine(0.5, 0.5, 0.3)
+
+
+if __name__ == "__main__":
+    _test_random_combine_main()
diff --git a/egs/tedlium3/ASR/conformer_ctc2/conformer.py b/egs/tedlium3/ASR/conformer_ctc2/conformer.py
new file mode 100644
index 000000000..fad2f371f
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/conformer.py
@@ -0,0 +1,1033 @@
+#!/usr/bin/env python3
+# Copyright (c)  2021  University of Chinese Academy of Sciences (author: Han Zhu)
+#                2022  Xiaomi Corp.                              (author: Quandong Wang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import math
+import warnings
+from typing import List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+from combiner import RandomCombine
+from scaling import (
+    ActivationBalancer,
+    BasicNorm,
+    DoubleSwish,
+    ScaledConv1d,
+    ScaledLinear,
+)
+from subsampling import Conv2dSubsampling
+from transformer import Supervisions, Transformer, encoder_padding_mask
+
+
+class Conformer(Transformer):
+    def __init__(
+        self,
+        num_features: int,
+        num_classes: int,
+        subsampling_factor: int = 4,
+        d_model: int = 256,
+        nhead: int = 4,
+        dim_feedforward: int = 2048,
+        num_encoder_layers: int = 12,
+        num_decoder_layers: int = 6,
+        dropout: float = 0.1,
+        layer_dropout: float = 0.075,
+        cnn_module_kernel: int = 31,
+        aux_layer_period: int = 3,
+    ) -> None:
+        """
+        Args:
+          num_features (int):
+            number of input features.
+          num_classes (int):
+            number of output classes.
+          subsampling_factor (int):
+            subsampling factor of encoder;
+            currently, subsampling_factor MUST be 4.
+          d_model (int):
+            attention dimension, also the output dimension.
+          nhead (int):
+            number of heads in multi-head attention;
+            must satisfy d_model // nhead == 0.
+          dim_feedforward (int):
+            feedforward dimention.
+          num_encoder_layers (int):
+            number of encoder layers.
+          num_decoder_layers (int):
+            number of decoder layers.
+          dropout (float):
+            dropout rate.
+          layer_dropout (float):
+            layer-dropout rate.
+          cnn_module_kernel (int):
+            kernel size of convolution module.
+          aux_layer_period (int):
+            determines the auxiliary encoder layers.
+        """
+
+        super().__init__(
+            num_features=num_features,
+            num_classes=num_classes,
+            subsampling_factor=subsampling_factor,
+            d_model=d_model,
+            nhead=nhead,
+            dim_feedforward=dim_feedforward,
+            num_encoder_layers=num_encoder_layers,
+            num_decoder_layers=num_decoder_layers,
+            dropout=dropout,
+            layer_dropout=layer_dropout,
+        )
+
+        self.num_features = num_features
+        self.subsampling_factor = subsampling_factor
+        if subsampling_factor != 4:
+            raise NotImplementedError("Support only 'subsampling_factor=4'.")
+
+        # self.encoder_embed converts the input of shape (N, T, num_features)
+        # to the shape (N, T//subsampling_factor, d_model).
+        # That is, it does two things simultaneously:
+        #   (1) subsampling: T -> T//subsampling_factor
+        #   (2) embedding: num_features -> d_model
+        self.encoder_embed = Conv2dSubsampling(num_features, d_model)
+
+        self.encoder_pos = RelPositionalEncoding(d_model, dropout)
+
+        encoder_layer = ConformerEncoderLayer(
+            d_model=d_model,
+            nhead=nhead,
+            dim_feedforward=dim_feedforward,
+            dropout=dropout,
+            layer_dropout=layer_dropout,
+            cnn_module_kernel=cnn_module_kernel,
+        )
+
+        # aux_layers from 1/3
+        self.encoder = ConformerEncoder(
+            encoder_layer=encoder_layer,
+            num_layers=num_encoder_layers,
+            aux_layers=list(
+                range(
+                    num_encoder_layers // 3,
+                    num_encoder_layers - 1,
+                    aux_layer_period,
+                )
+            ),
+        )
+
+    def run_encoder(
+        self,
+        x: torch.Tensor,
+        supervisions: Optional[Supervisions] = None,
+        warmup: float = 1.0,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+        """
+        Args:
+          x:
+            the input tensor. Its shape is (batch_size, seq_len, feature_dim).
+          supervisions:
+            Supervision in lhotse format.
+            See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32  # noqa
+            CAUTION: It contains length information, i.e., start and number of
+            frames, before subsampling
+            It is read directly from the batch, without any sorting. It is used
+            to compute encoder padding mask, which is used as memory key padding
+            mask for the decoder.
+          warmup:
+            a floating point value that gradually increases from 0 throughout
+            training; when it is >= 1.0 we are "fully warmed up".  It is used
+            to turn modules on sequentially.
+
+        Returns:
+          torch.Tensor: Predictor tensor of dimension (S, N, C).
+          torch.Tensor: Mask tensor of dimension (N, S)
+        """
+        x = self.encoder_embed(x)
+        x, pos_emb = self.encoder_pos(x)
+        x = x.permute(1, 0, 2)  # (N, S, C) -> (S, N, C)
+        mask = encoder_padding_mask(x.size(0), supervisions)
+        mask = mask.to(x.device) if mask is not None else None
+
+        x = self.encoder(
+            x, pos_emb, src_key_padding_mask=mask, warmup=warmup
+        )  # (S, N, C)
+
+        return x, mask
+
+
+class ConformerEncoderLayer(nn.Module):
+    """
+    ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks.
+    See: "Conformer: Convolution-augmented Transformer for Speech Recognition"
+
+    Examples:
+        >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
+        >>> src = torch.rand(10, 32, 512)
+        >>> pos_emb = torch.rand(32, 19, 512)
+        >>> out = encoder_layer(src, pos_emb)
+    """
+
+    def __init__(
+        self,
+        d_model: int,
+        nhead: int,
+        dim_feedforward: int = 2048,
+        dropout: float = 0.1,
+        bypass_scale: float = 0.1,
+        layer_dropout: float = 0.075,
+        cnn_module_kernel: int = 31,
+    ) -> None:
+        """
+        Args:
+          d_model:
+            the number of expected features in the input (required).
+          nhead:
+            the number of heads in the multiheadattention models (required).
+          dim_feedforward:
+            the dimension of the feedforward network model (default=2048).
+          dropout:
+            the dropout value (default=0.1).
+          bypass_scale:
+            a scale on the layer's output, used in bypass (resnet-type) skip-connection;
+            when the layer is bypassed the final output will be a
+            weighted sum of the layer's input and layer's output with weights
+            (1.0-bypass_scale) and bypass_scale correspondingly (default=0.1).
+          layer_dropout:
+            the probability to bypass the layer (default=0.075).
+          cnn_module_kernel (int):
+            kernel size of convolution module (default=31).
+        """
+        super().__init__()
+
+        if bypass_scale < 0.0 or bypass_scale > 1.0:
+            raise ValueError("bypass_scale should be between 0.0 and 1.0")
+
+        if layer_dropout < 0.0 or layer_dropout > 1.0:
+            raise ValueError("layer_dropout should be between 0.0 and 1.0")
+
+        self.bypass_scale = bypass_scale
+        self.layer_dropout = layer_dropout
+
+        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
+
+        self.feed_forward = nn.Sequential(
+            ScaledLinear(d_model, dim_feedforward),
+            ActivationBalancer(channel_dim=-1),
+            DoubleSwish(),
+            nn.Dropout(dropout),
+            ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
+        )
+
+        self.feed_forward_macaron = nn.Sequential(
+            ScaledLinear(d_model, dim_feedforward),
+            ActivationBalancer(channel_dim=-1),
+            DoubleSwish(),
+            nn.Dropout(dropout),
+            ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
+        )
+
+        self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
+
+        self.norm_final = BasicNorm(d_model)
+
+        # try to ensure the output is close to zero-mean (or at least, zero-median).
+        self.balancer = ActivationBalancer(
+            channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
+        )
+
+        self.dropout = nn.Dropout(dropout)
+
+    def forward(
+        self,
+        src: torch.Tensor,
+        pos_emb: torch.Tensor,
+        src_mask: Optional[torch.Tensor] = None,
+        src_key_padding_mask: Optional[torch.Tensor] = None,
+        warmup: float = 1.0,
+    ) -> torch.Tensor:
+        """
+        Pass the input through the encoder layer.
+
+        Args:
+          src:
+            the sequence to the encoder layer of shape (S, N, C) (required).
+          pos_emb:
+            positional embedding tensor of shape (N, 2*S-1, C) (required).
+          src_mask:
+            the mask for the src sequence of shape (S, S) (optional).
+          src_key_padding_mask:
+            the mask for the src keys per batch of shape (N, S) (optional).
+          warmup:
+            controls selective bypass of of layers; if < 1.0, we will
+            bypass layers more frequently.
+
+        Returns:
+            Output tensor of the shape (S, N, C), where
+            S is the source sequence length,
+            N is the batch size,
+            C is the feature number
+        """
+        src_orig = src
+
+        warmup_scale = min(self.bypass_scale + warmup, 1.0)
+        # alpha = 1.0 means fully use this encoder layer, 0.0 would mean
+        # completely bypass it.
+        if self.training:
+            alpha = (
+                warmup_scale
+                if torch.rand(()).item() <= (1.0 - self.layer_dropout)
+                else self.bypass_scale
+            )
+        else:
+            alpha = 1.0
+
+        # macaron style feed forward module
+        src = src + self.dropout(self.feed_forward_macaron(src))
+
+        # multi-headed self-attention module
+        src_att = self.self_attn(
+            src,
+            src,
+            src,
+            pos_emb=pos_emb,
+            attn_mask=src_mask,
+            key_padding_mask=src_key_padding_mask,
+        )[0]
+
+        src = src + self.dropout(src_att)
+
+        # convolution module
+        src = src + self.dropout(self.conv_module(src))
+
+        # feed forward module
+        src = src + self.dropout(self.feed_forward(src))
+
+        src = self.norm_final(self.balancer(src))
+
+        if alpha != 1.0:
+            src = alpha * src + (1 - alpha) * src_orig
+
+        return src
+
+
+class ConformerEncoder(nn.Module):
+    """
+    ConformerEncoder is a stack of N encoder layers
+
+    Examples:
+        >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
+        >>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6)
+        >>> src = torch.rand(10, 32, 512)
+        >>> pos_emb = torch.rand(32, 19, 512)
+        >>> out = conformer_encoder(src, pos_emb)
+    """
+
+    def __init__(
+        self,
+        encoder_layer: nn.Module,
+        num_layers: int,
+        aux_layers: List[int],
+    ) -> None:
+
+        """
+        Args:
+          encoder_layer:
+            an instance of the ConformerEncoderLayer() class (required).
+          num_layers:
+            the number of sub-encoder-layers in the encoder (required).
+          aux_layers:
+            list of indexes of sub-encoder-layers outputs to be combined (required).
+        """
+
+        super().__init__()
+        self.layers = nn.ModuleList(
+            [copy.deepcopy(encoder_layer) for i in range(num_layers)]
+        )
+        self.num_layers = num_layers
+
+        assert len(set(aux_layers)) == len(aux_layers)
+
+        assert num_layers - 1 not in aux_layers
+        self.aux_layers = aux_layers + [num_layers - 1]
+
+        self.combiner = RandomCombine(
+            num_inputs=len(self.aux_layers),
+            final_weight=0.5,
+            pure_prob=0.333,
+            stddev=2.0,
+        )
+
+    def forward(
+        self,
+        src: torch.Tensor,
+        pos_emb: torch.Tensor,
+        mask: Optional[torch.Tensor] = None,
+        src_key_padding_mask: Optional[torch.Tensor] = None,
+        warmup: float = 1.0,
+    ) -> torch.Tensor:
+        """
+        Pass the input through the encoder layers in turn.
+
+        Args:
+          src:
+            the sequence to the encoder of shape (S, N, C) (required).
+          pos_emb:
+            positional embedding tensor of shape (N, 2*S-1, C) (required).
+          mask:
+            the mask for the src sequence of shape (S, S) (optional).
+          src_key_padding_mask:
+            the mask for the src keys per batch of shape (N, S) (optional).
+          warmup:
+            controls selective bypass of layer; if < 1.0, we will
+            bypass the layer more frequently (default=1.0).
+
+        Returns:
+          Output tensor of the shape (S, N, C), where
+          S is the source sequence length,
+          N is the batch size,
+          C is the feature number.
+
+        """
+        output = src
+
+        outputs = []
+        for i, mod in enumerate(self.layers):
+            output = mod(
+                output,
+                pos_emb,
+                src_mask=mask,
+                src_key_padding_mask=src_key_padding_mask,
+                warmup=warmup,
+            )
+
+            if i in self.aux_layers:
+                outputs.append(output)
+
+        output = self.combiner(outputs)
+
+        return output
+
+
+class RelPositionalEncoding(torch.nn.Module):
+    """
+    Relative positional encoding module.
+
+    See: Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
+    Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py
+
+    """
+
+    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
+        """
+        Construct an PositionalEncoding object.
+
+        Args:
+          d_model: Embedding dimension.
+          dropout_rate: Dropout rate.
+          max_len: Maximum input length.
+
+        """
+        super().__init__()
+        self.d_model = d_model
+        self.dropout = torch.nn.Dropout(p=dropout_rate)
+        self.pe = None
+        self.extend_pe(torch.tensor(0.0).expand(1, max_len))
+
+    def extend_pe(self, x: torch.Tensor) -> None:
+        """
+        Reset the positional encodings.
+
+        Args:
+          x:
+            input tensor (N, T, C), where
+            T is the source sequence length,
+            N is the batch size.
+            C is the feature number.
+
+        """
+        if self.pe is not None:
+            # self.pe contains both positive and negative parts
+            # the length of self.pe is 2 * input_len - 1
+            if self.pe.size(1) >= x.size(1) * 2 - 1:
+                # Note: TorchScript doesn't implement operator== for torch.Device
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
+                    self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+                return
+        # Suppose `i` means to the position of query vecotr and `j` means the
+        # position of key vector. We use position relative positions when keys
+        # are to the left (i>j) and negative relative positions otherwise (i Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Add positional encoding.
+
+        Args:
+          x:
+            input tensor (N, T, C).
+
+        Returns:
+          torch.Tensor: Encoded tensor (N, T, C).
+          torch.Tensor: Encoded tensor (N, 2*T-1, C), where
+          T is the source sequence length,
+          N is the batch size.
+          C is the feature number.
+
+        """
+        self.extend_pe(x)
+        pos_emb = self.pe[
+            :,
+            self.pe.size(1) // 2
+            - x.size(1)
+            + 1 : self.pe.size(1) // 2  # noqa E203
+            + x.size(1),
+        ]
+        return self.dropout(x), self.dropout(pos_emb)
+
+
+class RelPositionMultiheadAttention(nn.Module):
+    """
+    Multi-Head Attention layer with relative position encoding
+    See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context".
+
+    """
+
+    def __init__(
+        self,
+        embed_dim: int,
+        num_heads: int,
+        dropout: float = 0.0,
+    ) -> None:
+        """
+        Args:
+          embed_dim:
+            total dimension of the model.
+          num_heads:
+            parallel attention heads.
+          dropout:
+            a Dropout layer on attn_output_weights. Default: 0.0.
+        """
+        super().__init__()
+        self.embed_dim = embed_dim
+        self.num_heads = num_heads
+        self.dropout = dropout
+        self.head_dim = embed_dim // num_heads
+        assert (
+            self.head_dim * num_heads == self.embed_dim
+        ), "embed_dim must be divisible by num_heads"
+
+        self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True)
+        self.out_proj = ScaledLinear(
+            embed_dim, embed_dim, bias=True, initial_scale=0.25
+        )
+
+        # linear transformation for positional encoding.
+        self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
+        # these two learnable bias are used in matrix c and matrix d
+        # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
+        self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
+        self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
+        self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach())
+        self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach())
+        self._reset_parameters()
+
+    def _pos_bias_u(self):
+        return self.pos_bias_u * self.pos_bias_u_scale.exp()
+
+    def _pos_bias_v(self):
+        return self.pos_bias_v * self.pos_bias_v_scale.exp()
+
+    def _reset_parameters(self) -> None:
+        nn.init.normal_(self.pos_bias_u, std=0.01)
+        nn.init.normal_(self.pos_bias_v, std=0.01)
+
+    def forward(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        pos_emb: torch.Tensor,
+        key_padding_mask: Optional[torch.Tensor] = None,
+        need_weights: bool = False,
+        attn_mask: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+        """
+        Args:
+          query, key, value: map a query and a set of key-value pairs to an output.
+          pos_emb: Positional embedding tensor
+          key_padding_mask: if provided, specified padding elements in the key will
+                            be ignored by the attention. When given a binary mask
+                            and a value is True, the corresponding value on the attention
+                            layer will be ignored. When given a byte mask and a value is
+                            non-zero, the corresponding value on the attention layer will be ignored.
+          need_weights: output attn_output_weights.
+          attn_mask: 2D or 3D mask that prevents attention to certain positions.
+                     A 2D mask will be broadcasted for all the batches while a 3D
+                     mask allows to specify a different mask for the entries of each batch.
+
+        Shape:
+          - Inputs:
+          - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
+            the embedding dimension.
+          - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
+            the embedding dimension.
+          - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
+            the embedding dimension.
+          - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
+            the embedding dimension.
+          - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
+            If a ByteTensor is provided, the non-zero positions will be ignored while the position
+            with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
+            value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
+          - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
+            3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
+            S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
+            positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
+            while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
+            is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
+            is provided, it will be added to the attention weight.
+
+          - Outputs:
+          - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
+            E is the embedding dimension.
+          - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
+            L is the target sequence length, S is the source sequence length.
+        """
+        return self.multi_head_attention_forward(
+            query,
+            key,
+            value,
+            pos_emb,
+            self.embed_dim,
+            self.num_heads,
+            self.in_proj.get_weight(),
+            self.in_proj.get_bias(),
+            self.dropout,
+            self.out_proj.get_weight(),
+            self.out_proj.get_bias(),
+            training=self.training,
+            key_padding_mask=key_padding_mask,
+            need_weights=need_weights,
+            attn_mask=attn_mask,
+        )
+
+    def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
+        """
+        Compute relative positional encoding.
+
+        Args:
+          x:
+            input tensor (batch, head, time1, 2*time1-1).
+            time1 means the length of query vector.
+
+        Returns:
+          torch.Tensor: tensor of shape (batch, head, time1, time2)
+          (note: time2 has the same value as time1, but it is for
+          the key, while time1 is for the query).
+        """
+        (batch_size, num_heads, time1, n) = x.shape
+        assert n == 2 * time1 - 1
+        # Note: TorchScript requires explicit arg for stride()
+        batch_stride = x.stride(0)
+        head_stride = x.stride(1)
+        time1_stride = x.stride(2)
+        n_stride = x.stride(3)
+        return x.as_strided(
+            (batch_size, num_heads, time1, time1),
+            (batch_stride, head_stride, time1_stride - n_stride, n_stride),
+            storage_offset=n_stride * (time1 - 1),
+        )
+
+    def multi_head_attention_forward(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        pos_emb: torch.Tensor,
+        embed_dim_to_check: int,
+        num_heads: int,
+        in_proj_weight: torch.Tensor,
+        in_proj_bias: torch.Tensor,
+        dropout_p: float,
+        out_proj_weight: torch.Tensor,
+        out_proj_bias: torch.Tensor,
+        training: bool = True,
+        key_padding_mask: Optional[torch.Tensor] = None,
+        need_weights: bool = False,
+        attn_mask: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+        """
+        Args:
+          query, key, value: map a query and a set of key-value pairs to an output.
+          pos_emb: Positional embedding tensor
+          embed_dim_to_check: total dimension of the model.
+          num_heads: parallel attention heads.
+          in_proj_weight, in_proj_bias: input projection weight and bias.
+          dropout_p: probability of an element to be zeroed.
+          out_proj_weight, out_proj_bias: the output projection weight and bias.
+          training: apply dropout if is ``True``.
+          key_padding_mask: if provided, specified padding elements in the key will
+                            be ignored by the attention. This is an binary mask.
+                            When the value is True, the corresponding value on the
+                            attention layer will be filled with -inf.
+          need_weights: output attn_output_weights.
+          attn_mask: 2D or 3D mask that prevents attention to certain positions.
+                     A 2D mask will be broadcasted for all the batches while a 3D
+                     mask allows to specify a different mask for the entries of each batch.
+
+        Shape:
+          Inputs:
+          - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
+            the embedding dimension.
+          - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
+            the embedding dimension.
+          - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
+            the embedding dimension.
+          - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence
+            length, N is the batch size, E is the embedding dimension.
+          - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
+            If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
+            will be unchanged. If a BoolTensor is provided, the positions with the
+            value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
+          - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
+            3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
+            S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
+            positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
+            while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
+            are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
+            is provided, it will be added to the attention weight.
+
+          Outputs:
+          - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
+            E is the embedding dimension.
+          - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
+            L is the target sequence length, S is the source sequence length.
+        """
+
+        tgt_len, bsz, embed_dim = query.size()
+        assert embed_dim == embed_dim_to_check
+        assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
+
+        head_dim = embed_dim // num_heads
+        assert (
+            head_dim * num_heads == embed_dim
+        ), "embed_dim must be divisible by num_heads"
+
+        scaling = float(head_dim) ** -0.5
+
+        if torch.equal(query, key) and torch.equal(key, value):
+            # self-attention
+            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
+                3, dim=-1
+            )
+
+        elif torch.equal(key, value):
+            # encoder-decoder attention
+            # This is inline in_proj function with in_proj_weight and in_proj_bias
+            _b = in_proj_bias
+            _start = 0
+            _end = embed_dim
+            _w = in_proj_weight[_start:_end, :]
+            if _b is not None:
+                _b = _b[_start:_end]
+            q = nn.functional.linear(query, _w, _b)
+
+            # This is inline in_proj function with in_proj_weight and in_proj_bias
+            _b = in_proj_bias
+            _start = embed_dim
+            _end = None
+            _w = in_proj_weight[_start:, :]
+            if _b is not None:
+                _b = _b[_start:]
+            k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1)
+
+        else:
+            # This is inline in_proj function with in_proj_weight and in_proj_bias
+            _b = in_proj_bias
+            _start = 0
+            _end = embed_dim
+            _w = in_proj_weight[_start:_end, :]
+            if _b is not None:
+                _b = _b[_start:_end]
+            q = nn.functional.linear(query, _w, _b)
+
+            # This is inline in_proj function with in_proj_weight and in_proj_bias
+            _b = in_proj_bias
+            _start = embed_dim
+            _end = embed_dim * 2
+            _w = in_proj_weight[_start:_end, :]
+            if _b is not None:
+                _b = _b[_start:_end]
+            k = nn.functional.linear(key, _w, _b)
+
+            # This is inline in_proj function with in_proj_weight and in_proj_bias
+            _b = in_proj_bias
+            _start = embed_dim * 2
+            _end = None
+            _w = in_proj_weight[_start:, :]
+            if _b is not None:
+                _b = _b[_start:]
+            v = nn.functional.linear(value, _w, _b)
+
+        if attn_mask is not None:
+            assert (
+                attn_mask.dtype == torch.float32
+                or attn_mask.dtype == torch.float64
+                or attn_mask.dtype == torch.float16
+                or attn_mask.dtype == torch.uint8
+                or attn_mask.dtype == torch.bool
+            ), "Only float, byte, and bool types are supported for attn_mask, not {}".format(
+                attn_mask.dtype
+            )
+            if attn_mask.dtype == torch.uint8:
+                warnings.warn(
+                    "Byte tensor for attn_mask is deprecated. Use bool tensor instead."
+                )
+                attn_mask = attn_mask.to(torch.bool)
+
+            if attn_mask.dim() == 2:
+                attn_mask = attn_mask.unsqueeze(0)
+                if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
+                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
+            elif attn_mask.dim() == 3:
+                if list(attn_mask.size()) != [
+                    bsz * num_heads,
+                    query.size(0),
+                    key.size(0),
+                ]:
+                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
+            else:
+                raise RuntimeError(
+                    f"attn_mask's dimension {attn_mask.dim()} is not supported"
+                )
+            # attn_mask's dim is 3 now.
+
+        # convert ByteTensor key_padding_mask to bool
+        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
+            warnings.warn(
+                "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
+            )
+            key_padding_mask = key_padding_mask.to(torch.bool)
+
+        q = (q * scaling).contiguous().view(tgt_len, bsz, num_heads, head_dim)
+        k = k.contiguous().view(-1, bsz, num_heads, head_dim)
+        v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
+
+        src_len = k.size(0)
+
+        if key_padding_mask is not None:
+            assert key_padding_mask.size(0) == bsz, "{} == {}".format(
+                key_padding_mask.size(0), bsz
+            )
+            assert key_padding_mask.size(1) == src_len, "{} == {}".format(
+                key_padding_mask.size(1), src_len
+            )
+
+        q = q.transpose(0, 1)  # (batch, time1, head, d_k)
+
+        pos_emb_bsz = pos_emb.size(0)
+        assert pos_emb_bsz in (1, bsz)  # actually it is 1
+        p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
+        p = p.transpose(1, 2)  # (batch, head, 2*time1-1, d_k)
+
+        q_with_bias_u = (q + self._pos_bias_u()).transpose(
+            1, 2
+        )  # (batch, head, time1, d_k)
+
+        q_with_bias_v = (q + self._pos_bias_v()).transpose(
+            1, 2
+        )  # (batch, head, time1, d_k)
+
+        # compute attention score
+        # first compute matrix a and matrix c
+        # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
+        k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
+        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
+
+        # compute matrix b and matrix d
+        matrix_bd = torch.matmul(
+            q_with_bias_v, p.transpose(-2, -1)
+        )  # (batch, head, time1, 2*time1-1)
+        matrix_bd = self.rel_shift(matrix_bd)
+
+        attn_output_weights = matrix_ac + matrix_bd  # (batch, head, time1, time2)
+        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
+
+        assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
+
+        if attn_mask is not None:
+            if attn_mask.dtype == torch.bool:
+                attn_output_weights.masked_fill_(attn_mask, float("-inf"))
+            else:
+                attn_output_weights += attn_mask
+
+        if key_padding_mask is not None:
+            attn_output_weights = attn_output_weights.view(
+                bsz, num_heads, tgt_len, src_len
+            )
+            attn_output_weights = attn_output_weights.masked_fill(
+                key_padding_mask.unsqueeze(1).unsqueeze(2),
+                float("-inf"),
+            )
+            attn_output_weights = attn_output_weights.view(
+                bsz * num_heads, tgt_len, src_len
+            )
+
+        attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
+        attn_output_weights = nn.functional.dropout(
+            attn_output_weights, p=dropout_p, training=training
+        )
+
+        attn_output = torch.bmm(attn_output_weights, v)
+        assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
+        attn_output = (
+            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+        )
+        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
+
+        if need_weights:
+            # average attention weights over heads
+            attn_output_weights = attn_output_weights.view(
+                bsz, num_heads, tgt_len, src_len
+            )
+            return attn_output, attn_output_weights.sum(dim=1) / num_heads
+        else:
+            return attn_output, None
+
+
+class ConvolutionModule(nn.Module):
+    def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None:
+        """
+        ConvolutionModule in Conformer model.
+        Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
+        Construct a ConvolutionModule object.
+
+        Args:
+          channels (int):
+            the number of channels of conv layers.
+          kernel_size (int):
+            kernerl size of conv layers.
+          bias (bool):
+            whether to use bias in conv layers (default=True).
+        """
+        super().__init__()
+        # kernerl_size should be a odd number for 'SAME' padding
+        assert (kernel_size - 1) % 2 == 0
+
+        self.pointwise_conv1 = ScaledConv1d(
+            channels,
+            2 * channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            bias=bias,
+        )
+
+        # after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu).
+        # For most layers the normal rms value of channels of x seems to be in the range 1 to 4,
+        # but sometimes, for some reason, for layer 0 the rms ends up being very large,
+        # between 50 and 100 for different channels.  This will cause very peaky and
+        # sparse derivatives for the sigmoid gating function, which will tend to make
+        # the loss function not learn effectively.  (for most layers the average absolute values
+        # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion,
+        # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different
+        # layers, which likely breaks down as 0.5 for the "linear" half and
+        # 0.2 to 0.3 for the part that goes into the sigmoid.  The idea is that if we
+        # constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
+        # it will be in a better position to start learning something, i.e. to latch onto
+        # the correct range.
+        self.deriv_balancer1 = ActivationBalancer(
+            channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
+        )
+
+        self.depthwise_conv = ScaledConv1d(
+            channels,
+            channels,
+            kernel_size,
+            stride=1,
+            padding=(kernel_size - 1) // 2,
+            groups=channels,
+            bias=bias,
+        )
+
+        self.deriv_balancer2 = ActivationBalancer(
+            channel_dim=1, min_positive=0.05, max_positive=1.0
+        )
+
+        self.activation = DoubleSwish()
+
+        self.pointwise_conv2 = ScaledConv1d(
+            channels,
+            channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            bias=bias,
+            initial_scale=0.25,
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """Compute convolution module.
+
+        Args:
+          x:
+            input tensor of shape (T, N, C).
+
+        Returns:
+          torch.Tensor: Output tensor (T, N, C), where
+          T is the source sequence length,
+          N is the batch size,
+          C is the feature number.
+
+        """
+        # exchange the temporal dimension and the feature dimension
+        x = x.permute(1, 2, 0)  # (#batch, channels, time).
+
+        # GLU mechanism
+        x = self.pointwise_conv1(x)  # (batch, 2*channels, time)
+
+        x = self.deriv_balancer1(x)
+        x = nn.functional.glu(x, dim=1)  # (batch, channels, time)
+
+        # 1D Depthwise Conv
+        x = self.depthwise_conv(x)
+
+        x = self.deriv_balancer2(x)
+        x = self.activation(x)
+
+        x = self.pointwise_conv2(x)  # (batch, channel, time)
+
+        return x.permute(2, 0, 1)
diff --git a/egs/tedlium3/ASR/conformer_ctc2/decode.py b/egs/tedlium3/ASR/conformer_ctc2/decode.py
new file mode 100755
index 000000000..ce4dcd142
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/decode.py
@@ -0,0 +1,899 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo,
+#                                            Fangjun Kuang,
+#                                            Quandong Wang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import logging
+import shutil
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import TedLiumAsrDataModule
+from conformer import Conformer
+from train import add_model_arguments
+
+from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.decode import (
+    get_lattice,
+    nbest_decoding,
+    nbest_oracle,
+    one_best_decoding,
+    rescore_with_attention_decoder,
+    rescore_with_n_best_list,
+    rescore_with_whole_lattice,
+)
+from icefall.env import get_env_info
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+    AttributeDict,
+    get_texts,
+    load_averaged_model,
+    setup_logger,
+    store_transcripts,
+    str2bool,
+    write_error_stats,
+)
+
+
+def get_parser() -> argparse.ArgumentParser:
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=30,
+        help="""It specifies the checkpoint to use for decoding.
+        Note: Epoch counts from 1.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=15,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--method",
+        type=str,
+        default="attention-decoder",
+        help="""Decoding method.
+        Supported values are:
+            - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece
+              model, i.e., lang_dir/bpe.model, to convert word pieces to words.
+              It needs neither a lexicon nor an n-gram LM.
+            - (1) ctc-greedy-search. It only use CTC output and a sentence piece
+              model for decoding. It produces the same results with ctc-decoding.
+            - (2) 1best. Extract the best path from the decoding lattice as the
+              decoding result.
+            - (3) nbest. Extract n paths from the decoding lattice; the path
+              with the highest score is the decoding result.
+            - (4) nbest-rescoring. Extract n paths from the decoding lattice,
+              rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
+              the highest score is the decoding result.
+            - (5) whole-lattice-rescoring. Rescore the decoding lattice with an
+              n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
+              is the decoding result.
+            - (6) attention-decoder. Extract n paths from the LM rescored
+              lattice, the path with the highest score is the decoding result.
+            - (7) nbest-oracle. Its WER is the lower bound of any n-best
+              rescoring method can achieve. Useful for debugging n-best
+              rescoring method.
+        """,
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=True,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=100,
+        help="""Number of paths for n-best based decoding method.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, attention-decoder, and nbest-oracle
+        """,
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=0.5,
+        help="""The scale to be applied to `lattice.scores`.
+        It's needed if you use any kinds of n-best based rescoring.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, attention-decoder, and nbest-oracle
+        A smaller value results in more unique paths.
+        """,
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="conformer_ctc2/exp",
+        help="The experiment dir",
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        default="data/lang_bpe_500",
+        help="The lang dir",
+    )
+
+    parser.add_argument(
+        "--lm-path",
+        type=str,
+        default="data/lm/G_4_gram.pt",
+        help="""The n-gram LM dir for rescoring.
+        It should contain either lm_fname.pt or lm_fname.fst.txt
+        """,
+    )
+
+    parser.add_argument(
+        "--result-dir",
+        type=str,
+        default="conformer_ctc2/exp",
+        help="Directory to store results.",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    """Return a dict containing training parameters.
+
+    All training related parameters that are not passed from the commandline
+    are saved in the variable `params`.
+
+    Commandline options are merged into `params` after they are parsed, so
+    you can also access them via `params`.
+
+    Explanation of options saved in `params`:
+
+        - feature_dim: The model input dim. It has to match the one used
+                       in computing features.
+
+        - subsampling_factor:  The subsampling factor for the model.
+    """
+    params = AttributeDict(
+        {
+            # parameters for conformer
+            "subsampling_factor": 4,
+            "feature_dim": 80,
+            # parameters for decoding
+            "search_beam": 15,
+            "output_beam": 8,
+            "min_active_states": 10,
+            "max_active_states": 7000,
+            "use_double_scores": True,
+            "env_info": get_env_info(),
+        }
+    )
+    return params
+
+
+def ctc_greedy_search(
+    ctc_probs: torch.Tensor,
+    mask: torch.Tensor,
+) -> List[List[int]]:
+    """Apply CTC greedy search
+    Args:
+      ctc_probs (torch.Tensor): (batch, max_len, num_bpe)
+      mask (torch.Tensor): (batch, max_len)
+    Returns:
+      best path result
+    """
+
+    _, max_index = ctc_probs.max(2)  # (B, maxlen)
+    max_index = max_index.masked_fill_(mask, 0)  # (B, maxlen)
+
+    ret_hyps = []
+    for hyp in max_index:
+        hyp = torch.unique_consecutive(hyp)
+        hyp = hyp[hyp > 0].tolist()
+        ret_hyps.append(hyp)
+    return ret_hyps
+
+
+def decode_one_batch(
+    params: AttributeDict,
+    model: nn.Module,
+    HLG: Optional[k2.Fsa],
+    H: Optional[k2.Fsa],
+    bpe_model: Optional[spm.SentencePieceProcessor],
+    batch: dict,
+    word_table: k2.SymbolTable,
+    sos_id: int,
+    eos_id: int,
+    G: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+    """Decode one batch and return the result in a dict. The dict has the
+    following format:
+
+        - key: It indicates the setting used for decoding. For example,
+               if no rescoring is used, the key is the string `no_rescore`.
+               If LM rescoring is used, the key is the string `lm_scale_xxx`,
+               where `xxx` is the value of `lm_scale`. An example key is
+               `lm_scale_0.7`
+        - value: It contains the decoding result. `len(value)` equals to
+                 batch size. `value[i]` is the decoding result for the i-th
+                 utterance in the given batch.
+    Args:
+      params:
+        It's the return value of :func:`get_params`.
+
+        - params.method is "1best", it uses 1best decoding without LM rescoring.
+        - params.method is "nbest", it uses nbest decoding without LM rescoring.
+        - params.method is "nbest-rescoring", it uses nbest LM rescoring.
+        - params.method is "whole-lattice-rescoring", it uses whole lattice LM
+          rescoring.
+
+      model:
+        The neural model.
+      HLG:
+        The decoding graph. Used only when params.method is NOT ctc-decoding.
+      H:
+        The ctc topo. Used only when params.method is ctc-decoding.
+      bpe_model:
+        The BPE model. Used only when params.method is ctc-decoding.
+      batch:
+        It is the return value from iterating
+        `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+        for the format of the `batch`.
+      word_table:
+        The word symbol table.
+      sos_id:
+        The token ID of the SOS.
+      eos_id:
+        The token ID of the EOS.
+      G:
+        An LM. It is not None when params.method is "nbest-rescoring"
+        or "whole-lattice-rescoring". In general, the G in HLG
+        is a 3-gram LM, while this G is a 4-gram LM.
+    Returns:
+      Return the decoding result. See above description for the format of
+      the returned dict. Note: If it decodes to nothing, then return None.
+    """
+    if HLG is not None:
+        device = HLG.device
+    else:
+        device = H.device
+    feature = batch["inputs"]
+    assert feature.ndim == 3
+    feature = feature.to(device)
+    # at entry, feature is (N, T, C)
+
+    supervisions = batch["supervisions"]
+
+    nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
+    # nnet_output is (N, T, C)
+
+    supervision_segments = torch.stack(
+        (
+            supervisions["sequence_idx"],
+            torch.div(
+                supervisions["start_frame"],
+                params.subsampling_factor,
+                rounding_mode="floor",
+            ),
+            torch.div(
+                supervisions["num_frames"],
+                params.subsampling_factor,
+                rounding_mode="floor",
+            ),
+        ),
+        1,
+    ).to(torch.int32)
+
+    if H is None:
+        assert HLG is not None
+        decoding_graph = HLG
+    else:
+        assert HLG is None
+        assert bpe_model is not None
+        decoding_graph = H
+
+    lattice = get_lattice(
+        nnet_output=nnet_output,
+        decoding_graph=decoding_graph,
+        supervision_segments=supervision_segments,
+        search_beam=params.search_beam,
+        output_beam=params.output_beam,
+        min_active_states=params.min_active_states,
+        max_active_states=params.max_active_states,
+        subsampling_factor=params.subsampling_factor,
+    )
+
+    if params.method == "ctc-decoding":
+        best_path = one_best_decoding(
+            lattice=lattice, use_double_scores=params.use_double_scores
+        )
+        # Note: `best_path.aux_labels` contains token IDs, not word IDs
+        # since we are using H, not HLG here.
+        #
+        # token_ids is a lit-of-list of IDs
+        token_ids = get_texts(best_path)
+
+        # hyps is a list of str, e.g., ['xxx yyy zzz', ...]
+        hyps = bpe_model.decode(token_ids)
+
+        # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
+        unk = bpe_model.decode(bpe_model.unk_id()).strip()
+        hyps = [[w for w in s.split() if w != unk] for s in hyps]
+        key = "ctc-decoding"
+
+        return {key: hyps}
+
+    if params.method == "ctc-greedy-search":
+        hyps = ctc_greedy_search(nnet_output, memory_key_padding_mask)
+
+        # hyps is a list of str, e.g., ['xxx yyy zzz', ...]
+        hyps = bpe_model.decode(hyps)
+
+        # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
+        unk = bpe_model.decode(bpe_model.unk_id()).strip()
+        hyps = [[w for w in s.split() if w != unk] for s in hyps]
+        key = "ctc-greedy-search"
+
+        return {key: hyps}
+
+    if params.method == "nbest-oracle":
+        # Note: You can also pass rescored lattices to it.
+        # We choose the HLG decoded lattice for speed reasons
+        # as HLG decoding is faster and the oracle WER
+        # is only slightly worse than that of rescored lattices.
+        best_path = nbest_oracle(
+            lattice=lattice,
+            num_paths=params.num_paths,
+            ref_texts=supervisions["text"],
+            word_table=word_table,
+            nbest_scale=params.nbest_scale,
+            oov="",
+        )
+        hyps = get_texts(best_path)
+        hyps = [
+            [word_table[i] for i in ids if word_table[i] != ""] for ids in hyps
+        ]
+        key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}"  # noqa
+        return {key: hyps}
+
+    if params.method == "nbest":
+        best_path = nbest_decoding(
+            lattice=lattice,
+            num_paths=params.num_paths,
+            use_double_scores=params.use_double_scores,
+            nbest_scale=params.nbest_scale,
+        )
+        key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}"  # noqa
+
+        hyps = get_texts(best_path)
+        hyps = [
+            [word_table[i] for i in ids if word_table[i] != ""] for ids in hyps
+        ]
+        return {key: hyps}
+
+    assert params.method in [
+        "1best",
+        "nbest-rescoring",
+        "whole-lattice-rescoring",
+        "attention-decoder",
+    ]
+
+    lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
+    lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
+    lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
+
+    if params.method == "1best":
+        best_path_dict = one_best_decoding(
+            lattice=lattice,
+            lm_scale_list=lm_scale_list,
+        )
+    elif params.method == "nbest-rescoring":
+        best_path_dict = rescore_with_n_best_list(
+            lattice=lattice,
+            G=G,
+            num_paths=params.num_paths,
+            lm_scale_list=lm_scale_list,
+            nbest_scale=params.nbest_scale,
+        )
+    elif params.method == "whole-lattice-rescoring":
+        best_path_dict = rescore_with_whole_lattice(
+            lattice=lattice,
+            G_with_epsilon_loops=G,
+            lm_scale_list=lm_scale_list,
+        )
+    elif params.method == "attention-decoder":
+        best_path_dict = rescore_with_attention_decoder(
+            lattice=lattice,
+            num_paths=params.num_paths,
+            model=model,
+            memory=memory,
+            memory_key_padding_mask=memory_key_padding_mask,
+            sos_id=sos_id,
+            eos_id=eos_id,
+            nbest_scale=params.nbest_scale,
+        )
+    else:
+        raise ValueError(f"Unsupported decoding method: {params.method}")
+
+    ans = dict()
+    if best_path_dict is not None:
+        for lm_scale_str, best_path in best_path_dict.items():
+            hyps = get_texts(best_path)
+            hyps = [
+                [word_table[i] for i in ids if word_table[i] != ""] for ids in hyps
+            ]
+            ans[lm_scale_str] = hyps
+    else:
+        ans = None
+    return ans
+
+
+def decode_dataset(
+    dl: torch.utils.data.DataLoader,
+    params: AttributeDict,
+    model: nn.Module,
+    HLG: Optional[k2.Fsa],
+    H: Optional[k2.Fsa],
+    bpe_model: Optional[spm.SentencePieceProcessor],
+    word_table: k2.SymbolTable,
+    sos_id: int,
+    eos_id: int,
+    G: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+    """Decode dataset.
+
+    Args:
+      dl:
+        PyTorch's dataloader containing the dataset to decode.
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The neural model.
+      HLG:
+        The decoding graph. Used only when params.method is NOT ctc-decoding.
+      H:
+        The ctc topo. Used only when params.method is ctc-decoding.
+      bpe_model:
+        The BPE model. Used only when params.method is ctc-decoding.
+      word_table:
+        It is the word symbol table.
+      sos_id:
+        The token ID for SOS.
+      eos_id:
+        The token ID for EOS.
+      G:
+        An LM. It is not None when params.method is "nbest-rescoring"
+        or "whole-lattice-rescoring". In general, the G in HLG
+        is a 3-gram LM, while this G is a 4-gram LM.
+    Returns:
+      Return a dict, whose key may be "no-rescore" if no LM rescoring
+      is used, or it may be "lm_scale_0.7" if LM rescoring is used.
+      Its value is a list of tuples. Each tuple contains two elements:
+      The first is the reference transcript, and the second is the
+      predicted result.
+    """
+    num_cuts = 0
+
+    try:
+        num_batches = len(dl)
+    except TypeError:
+        num_batches = "?"
+
+    results = defaultdict(list)
+    for batch_idx, batch in enumerate(dl):
+        texts = batch["supervisions"]["text"]
+        cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+        hyps_dict = decode_one_batch(
+            params=params,
+            model=model,
+            HLG=HLG,
+            H=H,
+            bpe_model=bpe_model,
+            batch=batch,
+            word_table=word_table,
+            G=G,
+            sos_id=sos_id,
+            eos_id=eos_id,
+        )
+
+        if hyps_dict is not None:
+            for lm_scale, hyps in hyps_dict.items():
+                this_batch = []
+                assert len(hyps) == len(texts)
+                for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+                    ref_words = ref_text.split()
+                    this_batch.append((cut_id, ref_words, hyp_words))
+
+                results[lm_scale].extend(this_batch)
+        else:
+            assert len(results) > 0, "It should not decode to empty in the first batch!"
+            this_batch = []
+            hyp_words = []
+            for ref_text in texts:
+                ref_words = ref_text.split()
+                this_batch.append((ref_words, hyp_words))
+
+            for lm_scale in results.keys():
+                results[lm_scale].extend(this_batch)
+
+        num_cuts += len(texts)
+
+        if batch_idx % 100 == 0:
+            batch_str = f"{batch_idx}/{num_batches}"
+
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+    return results
+
+
+def save_results(
+    params: AttributeDict,
+    test_set_name: str,
+    results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+) -> None:
+    if params.method == "attention-decoder":
+        # Set it to False since there are too many logs.
+        enable_log = False
+    else:
+        enable_log = True
+    test_set_wers = dict()
+    for key, results in results_dict.items():
+        recog_path = params.result_dir / f"recogs-{test_set_name}-{key}.txt"
+        results = sorted(results)
+        store_transcripts(filename=recog_path, texts=results)
+        if enable_log:
+            logging.info(f"The transcripts are stored in {recog_path}")
+
+        # The following prints out WERs, per-word error statistics and aligned
+        # ref/hyp pairs.
+        errs_filename = params.result_dir / f"errs-{test_set_name}-{key}.txt"
+        with open(errs_filename, "w") as f:
+            wer = write_error_stats(
+                f, f"{test_set_name}-{key}", results, enable_log=enable_log
+            )
+            test_set_wers[key] = wer
+
+        if enable_log:
+            logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+    test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+    errs_info = params.result_dir / f"wer-summary-{test_set_name}.txt"
+    with open(errs_info, "w") as f:
+        print("settings\tWER", file=f)
+        for key, val in test_set_wers:
+            print("{}\t{}".format(key, val), file=f)
+
+    s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+    note = "\tbest for {}".format(test_set_name)
+    for key, val in test_set_wers:
+        s += "{}\t{}{}\n".format(key, val, note)
+        note = ""
+    logging.info(s)
+
+
+@torch.no_grad()
+def main() -> None:
+    parser = get_parser()
+    TedLiumAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+    args.lang_dir = Path(args.lang_dir)
+    args.lm_path = Path(args.lm_path)
+    args.result_dir = Path(args.result_dir)
+
+    if args.result_dir.is_dir():
+        shutil.rmtree(args.result_dir)
+    args.result_dir.mkdir()
+
+    params = get_params()
+    params.update(vars(args))
+
+    setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
+    logging.info("Decoding started")
+    logging.info(params)
+
+    lexicon = Lexicon(params.lang_dir)
+    max_token_id = max(lexicon.tokens)
+    num_classes = max_token_id + 1  # +1 for the blank
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    graph_compiler = BpeCtcTrainingGraphCompiler(
+        params.lang_dir,
+        device=device,
+        sos_token="",
+        eos_token="",
+    )
+    sos_id = graph_compiler.sos_id
+    eos_id = graph_compiler.eos_id
+
+    if params.method in ("ctc-decoding", "ctc-greedy-search"):
+        HLG = None
+        H = k2.ctc_topo(
+            max_token=max_token_id,
+            modified=False,
+            device=device,
+        )
+        bpe_model = spm.SentencePieceProcessor()
+        bpe_model.load(str(params.lang_dir / "bpe.model"))
+    else:
+        H = None
+        bpe_model = None
+        HLG = k2.Fsa.from_dict(
+            torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
+        )
+        assert HLG.requires_grad is False
+
+        if not hasattr(HLG, "lm_scores"):
+            HLG.lm_scores = HLG.scores.clone()
+
+    if params.method in ("nbest-rescoring", "whole-lattice-rescoring"):
+        assert params.lm_path.suffix in (".pt", ".txt")
+
+        if params.lm_path.is_file() and params.lm_path.suffix == ".pt":
+            logging.info(f"Loading pre-compiled {params.lm_path.name}")
+            d = torch.load(params.lm_path, map_location=device)
+            G = k2.Fsa.from_dict(d)
+        elif not params.lm_path.is_file() and params.lm_path.suffix == ".txt":
+            raise FileNotFoundError(f"No such language model file: '{params.lm_path}'")
+        else:
+            # here we pass only if LM filename ends with '.pt' and doesn't exist
+            # or if LM filename ends '.txt' and exists.
+            if (
+                not params.lm_path.is_file()
+                and params.lm_path.suffix == ".pt"
+                and not (
+                    params.lm_path.parent / f"{params.lm_path.stem}.fst.txt"
+                ).is_file()
+            ):
+                raise FileNotFoundError(
+                    f"No such language model file: '{params.lm_path}'\n"
+                    "'.fst.txt' representation of the language model was "
+                    "not found either."
+                )
+            else:
+                # whatever params.lm_path.name we got lm_name.pt or lm_name.fst.txt
+                # we are going to load lm_name.fst.txt here
+                params.lm_path = params.lm_path.parent / params.lm_path.name.replace(
+                    ".pt", ".fst.txt"
+                )
+                logging.info(f"Loading {params.lm_path.name}")
+                logging.warning("It may take 8 minutes.")
+                with open(params.lm_path) as f:
+                    first_word_disambig_id = lexicon.word_table["#0"]
+
+                    G = k2.Fsa.from_openfst(f.read(), acceptor=False)
+                    # G.aux_labels is not needed in later computations, so
+                    # remove it here.
+                    del G.aux_labels
+                    # CAUTION: The following line is crucial.
+                    # Arcs entering the back-off state have label equal to #0.
+                    # We have to change it to 0 here.
+                    G.labels[G.labels >= first_word_disambig_id] = 0
+                    # See https://github.com/k2-fsa/k2/issues/874
+                    # for why we need to set G.properties to None
+                    G.__dict__["_properties"] = None
+                    G = k2.Fsa.from_fsas([G]).to(device)
+                    G = k2.arc_sort(G)
+                    # Save a dummy value so that it can be loaded in C++.
+                    # See https://github.com/pytorch/pytorch/issues/67902
+                    # for why we need to do this.
+                    G.dummy = 1
+
+                    torch.save(
+                        G.as_dict(),
+                        params.lm_path.parent
+                        / params.lm_path.name.replace(".fst.txt", ".pt"),
+                    )
+
+        if params.method == "whole-lattice-rescoring":
+            # Add epsilon self-loops to G as we will compose
+            # it with the whole lattice later
+            G = k2.add_epsilon_self_loops(G)
+            G = k2.arc_sort(G)
+            G = G.to(device)
+
+        # G.lm_scores is used to replace HLG.lm_scores during
+        # LM rescoring.
+        G.lm_scores = G.scores.clone()
+    else:
+        G = None
+
+    model = Conformer(
+        num_features=params.feature_dim,
+        num_classes=num_classes,
+        subsampling_factor=params.subsampling_factor,
+        d_model=params.dim_model,
+        nhead=params.nhead,
+        dim_feedforward=params.dim_feedforward,
+        num_encoder_layers=params.num_encoder_layers,
+        num_decoder_layers=params.num_decoder_layers,
+    )
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to(device)
+    model.eval()
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    # we need cut ids to display recognition results.
+    args.return_cuts = True
+    tedlium = TedLiumAsrDataModule(args)
+
+    valid_cuts = tedlium.dev_cuts()
+    test_cuts = tedlium.test_cuts()
+
+    valid_dl = tedlium.valid_dataloaders(valid_cuts)
+    test_dl = tedlium.test_dataloaders(test_cuts)
+
+    test_sets = ["dev", "test"]
+    test_dls = [valid_dl, test_dl]
+
+    for test_set, test_dl in zip(test_sets, test_dls):
+        results_dict = decode_dataset(
+            dl=test_dl,
+            params=params,
+            model=model,
+            HLG=HLG,
+            H=H,
+            bpe_model=bpe_model,
+            word_table=lexicon.word_table,
+            G=G,
+            sos_id=sos_id,
+            eos_id=eos_id,
+        )
+
+        save_results(params=params, test_set_name=test_set, results_dict=results_dict)
+
+    logging.info("Done!")
+
+
+torch.set_num_threads(1)
+# when we import add_model_arguments from train.py
+# we enforce torch.set_num_interop_threads(1) in it,
+# so we ended up with setting num_interop_threads to one
+# two times: in train.py and decode.py which cause an error,
+# that is why added an additional if statement.
+if torch.get_num_interop_threads() != 1:
+    torch.set_num_interop_threads(1)
+
+# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
+# in PyTorch 1.12 and later.
+torch.backends.cuda.matmul.allow_tf32 = True
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/tedlium3/ASR/conformer_ctc2/export.py b/egs/tedlium3/ASR/conformer_ctc2/export.py
new file mode 100755
index 000000000..009bea230
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/export.py
@@ -0,0 +1,294 @@
+#!/usr/bin/env python3
+#
+# Copyright 2022 Behavox LLC (Author: Daniil Kulko)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script converts several saved checkpoints
+# to a single one using model averaging.
+"""
+Usage:
+./conformer_ctc2/export.py \
+  --exp-dir ./conformer_ctc2/exp \
+  --epoch 20 \
+  --avg 10
+
+It will generate a file exp_dir/pretrained.pt
+
+To use the generated file with `conformer_ctc2/decode.py`,
+you can do:
+
+    cd /path/to/exp_dir
+    ln -s pretrained.pt epoch-9999.pt
+
+    cd /path/to/egs/tedlium3/ASR
+    ./conformer_ctc2/decode.py \
+        --exp-dir ./conformer_ctc2/exp \
+        --epoch 9999 \
+        --avg 1 \
+        --max-duration 100
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import torch
+from conformer import Conformer
+from scaling_converter import convert_scaled_to_non_scaled
+from train import add_model_arguments
+
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import AttributeDict, str2bool
+
+
+def get_parser() -> argparse.ArgumentParser:
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=30,
+        help="""It specifies the checkpoint to use for averaging.
+        Note: Epoch counts from 0.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=15,
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch' and '--iter'"
+        ),
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=True,
+        help=(
+            "Whether to load averaged model. Currently it only supports "
+            "using --epoch. If True, it would decode with the averaged model "
+            "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+            "Actually only the models with epoch number of `epoch-avg` and "
+            "`epoch` are loaded for averaging. "
+        ),
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="conformer_ctc2/exp",
+        help="""It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        default="data/lang_bpe_500",
+        help="The lang dir",
+    )
+
+    parser.add_argument(
+        "--jit",
+        type=str2bool,
+        default=True,
+        help="""True to save a model after applying torch.jit.script.
+        """,
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    """Return a dict containing training parameters.
+
+    All training related parameters that are not passed from the commandline
+    are saved in the variable `params`.
+
+    Commandline options are merged into `params` after they are parsed, so
+    you can also access them via `params`.
+
+    Explanation of options saved in `params`:
+
+        - feature_dim: The model input dim. It has to match the one used
+                       in computing features.
+
+        - subsampling_factor:  The subsampling factor for the model.
+    """
+    # parameters for conformer
+    params = AttributeDict({"subsampling_factor": 4, "feature_dim": 80})
+    return params
+
+
+def main():
+    args = get_parser().parse_args()
+    args.exp_dir = Path(args.exp_dir)
+    args.lang_dir = Path(args.lang_dir)
+
+    params = get_params()
+    params.update(vars(args))
+
+    lexicon = Lexicon(params.lang_dir)
+    max_token_id = max(lexicon.tokens)
+    num_classes = max_token_id + 1  # +1 for the blank
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    logging.info(params)
+
+    logging.info("About to create model")
+
+    model = Conformer(
+        num_features=params.feature_dim,
+        num_classes=num_classes,
+        subsampling_factor=params.subsampling_factor,
+        d_model=params.dim_model,
+        nhead=params.nhead,
+        dim_feedforward=params.dim_feedforward,
+        num_encoder_layers=params.num_encoder_layers,
+        num_decoder_layers=params.num_decoder_layers,
+    )
+
+    model.to(device)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                "Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to("cpu")
+    model.eval()
+
+    if params.jit:
+        convert_scaled_to_non_scaled(model, inplace=True)
+        logging.info("Using torch.jit.script")
+        model = torch.jit.script(model)
+        filename = params.exp_dir / "cpu_jit.pt"
+        model.save(str(filename))
+        logging.info(f"Saved to {filename}")
+    else:
+        logging.info("Not using torch.jit.script")
+        # Save it using a format so that it can be loaded
+        # by :func:`load_checkpoint`
+        filename = params.exp_dir / "pretrained.pt"
+        torch.save({"model": model.state_dict()}, str(filename))
+        logging.info(f"Saved to {filename}")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/tedlium3/ASR/conformer_ctc2/label_smoothing.py b/egs/tedlium3/ASR/conformer_ctc2/label_smoothing.py
new file mode 120000
index 000000000..e9d239fff
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/label_smoothing.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/conformer_ctc/label_smoothing.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/conformer_ctc2/lstmp.py b/egs/tedlium3/ASR/conformer_ctc2/lstmp.py
new file mode 120000
index 000000000..b82e115fc
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/lstmp.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/lstm_transducer_stateless2/lstmp.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/conformer_ctc2/optim.py b/egs/tedlium3/ASR/conformer_ctc2/optim.py
new file mode 120000
index 000000000..0a2f285aa
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/optim.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/conformer_ctc2/scaling.py b/egs/tedlium3/ASR/conformer_ctc2/scaling.py
new file mode 120000
index 000000000..c10cdfe12
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/scaling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/scaling.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/conformer_ctc2/scaling_converter.py b/egs/tedlium3/ASR/conformer_ctc2/scaling_converter.py
new file mode 120000
index 000000000..db93d155b
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/scaling_converter.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/conformer_ctc2/subsampling.py b/egs/tedlium3/ASR/conformer_ctc2/subsampling.py
new file mode 120000
index 000000000..8c91f2336
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/subsampling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/conformer_ctc2/subsampling.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/conformer_ctc2/train.py b/egs/tedlium3/ASR/conformer_ctc2/train.py
new file mode 100755
index 000000000..42e4c010a
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/train.py
@@ -0,0 +1,1061 @@
+#!/usr/bin/env python3
+# Copyright    2022  Behavox LLC.        (authors: Daniil Kulko)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./conformer_ctc/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --exp-dir conformer_ctc/exp \
+  --max-duration 300
+
+# For mix precision training:
+
+./conformer_ctc/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --use-fp16 1 \
+  --exp-dir conformer_ctc/exp \
+  --max-duration 550
+
+"""
+
+
+import argparse
+import copy
+import logging
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+from asr_datamodule import TedLiumAsrDataModule
+from conformer import Conformer
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from local.convert_transcript_words_to_bpe_ids import convert_texts_into_ids
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+
+from icefall import diagnostics
+from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+    save_checkpoint_with_global_batch_idx,
+    update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+    AttributeDict,
+    MetricsTracker,
+    display_and_save_batch,
+    encode_supervisions,
+    setup_logger,
+    str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def add_model_arguments(parser: argparse.ArgumentParser) -> None:
+    parser.add_argument(
+        "--num-encoder-layers",
+        type=int,
+        default=24,
+        help="Number of conformer encoder layers..",
+    )
+
+    parser.add_argument(
+        "--num-decoder-layers",
+        type=int,
+        default=6,
+        help="""Number of decoder layer of transformer decoder.
+        Setting this to 0 will not create the decoder at all (pure CTC model)
+        """,
+    )
+
+    parser.add_argument(
+        "--att-rate",
+        type=float,
+        default=0.8,
+        help="""The attention rate.
+        The total loss is (1 -  att_rate) * ctc_loss + att_rate * att_loss
+        """,
+    )
+
+    parser.add_argument(
+        "--dim-feedforward",
+        type=int,
+        default=1536,
+        help="Feedforward module dimension of the conformer model.",
+    )
+
+    parser.add_argument(
+        "--nhead",
+        type=int,
+        default=8,
+        help="Number of attention heads in the conformer multiheadattention modules.",
+    )
+
+    parser.add_argument(
+        "--dim-model",
+        type=int,
+        default=384,
+        help="Attention dimension in the conformer model.",
+    )
+
+
+def get_parser() -> argparse.ArgumentParser:
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--world-size",
+        type=int,
+        default=1,
+        help="Number of GPUs for DDP training.",
+    )
+
+    parser.add_argument(
+        "--master-port",
+        type=int,
+        default=12354,
+        help="Master port to use for DDP training.",
+    )
+
+    parser.add_argument(
+        "--tensorboard",
+        type=str2bool,
+        default=True,
+        help="Should various information be logged in tensorboard.",
+    )
+
+    parser.add_argument(
+        "--num-epochs",
+        type=int,
+        default=30,
+        help="Number of epochs to train.",
+    )
+
+    parser.add_argument(
+        "--start-epoch",
+        type=int,
+        default=1,
+        help="""Resume training from this epoch. It should be positive.
+        If larger than 1, it will load checkpoint from
+        exp-dir/epoch-{start_epoch-1}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--start-batch",
+        type=int,
+        default=0,
+        help="""If positive, --start-epoch is ignored and
+        it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="conformer_ctc/exp",
+        help="""The experiment dir.
+        It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        default="data/lang_bpe_500",
+        help="""The lang dir
+        It contains language related input files such as
+        "lexicon.txt" and "bpe.model"
+        """,
+    )
+
+    parser.add_argument(
+        "--initial-lr",
+        type=float,
+        default=0.003,
+        help="The initial learning rate.  This value should not need to be changed.",
+    )
+
+    parser.add_argument(
+        "--lr-batches",
+        type=float,
+        default=5000,
+        help="""Number of steps that affects how rapidly the learning rate
+        decreases. We suggest not to change this.""",
+    )
+
+    parser.add_argument(
+        "--lr-epochs",
+        type=float,
+        default=6,
+        help="Number of epochs that affects how rapidly the learning rate decreases.",
+    )
+
+    parser.add_argument(
+        "--seed",
+        type=int,
+        default=42,
+        help="The seed for random generators intended for reproducibility",
+    )
+
+    parser.add_argument(
+        "--print-diagnostics",
+        type=str2bool,
+        default=False,
+        help="Accumulate stats on activations, print them and exit.",
+    )
+
+    parser.add_argument(
+        "--save-every-n",
+        type=int,
+        default=4000,
+        help="""Save checkpoint after processing this number of batches"
+        periodically. We save checkpoint to exp-dir/ whenever
+        params.batch_idx_train % save_every_n == 0. The checkpoint filename
+        has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+        Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+        end of each epoch where `xxx` is the epoch number counting from 0.
+        """,
+    )
+
+    parser.add_argument(
+        "--keep-last-k",
+        type=int,
+        default=30,
+        help="""Only keep this number of checkpoints on disk.
+        For instance, if it is 3, there are only 3 checkpoints
+        in the exp-dir with filenames `checkpoint-xxx.pt`.
+        It does not affect checkpoints with name `epoch-xxx.pt`.
+        """,
+    )
+
+    parser.add_argument(
+        "--average-period",
+        type=int,
+        default=100,
+        help="""Update the averaged model, namely `model_avg`, after processing
+        this number of batches. `model_avg` is a separate version of model,
+        in which each floating-point parameter is the average of all the
+        parameters from the start of training. Each time we take the average,
+        we do: `model_avg = model * (average_period / batch_idx_train) +
+            model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+        """,
+    )
+
+    parser.add_argument(
+        "--use-fp16",
+        type=str2bool,
+        default=False,
+        help="Whether to use half precision training.",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    """Return a dict containing training parameters.
+
+    All training related parameters that are not passed from the commandline
+    are saved in the variable `params`.
+
+    Commandline options are merged into `params` after they are parsed, so
+    you can also access them via `params`.
+
+    Explanation of options saved in `params`:
+
+        - best_train_loss: Best training loss so far. It is used to select
+                           the model that has the lowest training loss. It is
+                           updated during the training.
+
+        - best_valid_loss: Best validation loss so far. It is used to select
+                           the model that has the lowest validation loss. It is
+                           updated during the training.
+
+        - best_train_epoch: It is the epoch that has the best training loss.
+
+        - best_valid_epoch: It is the epoch that has the best validation loss.
+
+        - batch_idx_train: Used to writing statistics to tensorboard. It
+                           contains number of batches trained so far across
+                           epochs.
+
+        - log_interval:  Print training loss if batch_idx % log_interval` is 0
+
+        - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+        - valid_interval:  Run validation if batch_idx % valid_interval is 0
+
+        - feature_dim: The model input dim. It has to match the one used
+                       in computing features.
+
+        - subsampling_factor:  The subsampling factor for the model.
+
+        - warm_step: The warm_step for Noam optimizer.
+    """
+    params = AttributeDict(
+        {
+            "best_train_loss": float("inf"),
+            "best_valid_loss": float("inf"),
+            "best_train_epoch": -1,
+            "best_valid_epoch": -1,
+            "batch_idx_train": 0,
+            "log_interval": 10,
+            "reset_interval": 200,
+            "valid_interval": 1000,
+            # parameters for conformer
+            "feature_dim": 80,
+            "subsampling_factor": 4,
+            # parameters for ctc loss
+            "beam_size": 10,
+            "reduction": "none",
+            "use_double_scores": True,
+            # parameters for Noam
+            "model_warm_step": 3000,  # arg given to model, not for lrate
+            "env_info": get_env_info(),
+        }
+    )
+
+    return params
+
+
+def load_checkpoint_if_available(
+    params: AttributeDict,
+    model: torch.nn.Module,
+    model_avg: torch.nn.Module = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+    """Load checkpoint from file.
+
+    If params.start_batch is positive, it will load the checkpoint from
+    `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+    params.start_epoch is larger than 1, it will load the checkpoint from
+    `params.start_epoch - 1`.
+
+    Apart from loading state dict for `model` and `optimizer` it also updates
+    `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+    and `best_valid_loss` in `params`.
+
+    Args:
+      params:
+        The return value of :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer that we are using.
+      scheduler:
+        The scheduler that is used for training.
+    Returns:
+      Return a dict containing previously saved training info.
+    """
+    if params.start_batch > 0:
+        filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+    elif params.start_epoch > 1:
+        filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+    else:
+        return None
+
+    assert filename.is_file(), f"{filename} does not exist!"
+
+    saved_params = load_checkpoint(
+        filename,
+        model=model,
+        model_avg=model_avg,
+        optimizer=optimizer,
+        scheduler=scheduler,
+    )
+
+    keys = [
+        "best_train_epoch",
+        "best_valid_epoch",
+        "batch_idx_train",
+        "best_train_loss",
+        "best_valid_loss",
+    ]
+    for k in keys:
+        params[k] = saved_params[k]
+
+    if params.start_batch > 0:
+        if "cur_epoch" in saved_params:
+            params["start_epoch"] = saved_params["cur_epoch"]
+
+    return saved_params
+
+
+def save_checkpoint(
+    params: AttributeDict,
+    model: Union[torch.nn.Module, DDP],
+    model_avg: Optional[torch.nn.Module] = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+    sampler: Optional[CutSampler] = None,
+    scaler: Optional[GradScaler] = None,
+    rank: int = 0,
+) -> None:
+    """Save model, optimizer, scheduler and training stats to file.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer used for training.
+      scheduler:
+        The learning rate scheduler used for training.
+      sampler:
+       The sampler for the training dataset.
+      scaler:
+        The scaler used for mix precision training.
+    """
+    if rank != 0:
+        return
+    filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+    save_checkpoint_impl(
+        filename=filename,
+        model=model,
+        model_avg=model_avg,
+        params=params,
+        optimizer=optimizer,
+        scheduler=scheduler,
+        sampler=sampler,
+        scaler=scaler,
+        rank=rank,
+    )
+
+    if params.best_train_epoch == params.cur_epoch:
+        best_train_filename = params.exp_dir / "best-train-loss.pt"
+        copyfile(src=filename, dst=best_train_filename)
+
+    if params.best_valid_epoch == params.cur_epoch:
+        best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+        copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+    params: AttributeDict,
+    model: Union[torch.nn.Module, DDP],
+    graph_compiler: BpeCtcTrainingGraphCompiler,
+    batch: dict,
+    is_training: bool,
+    warmup: float = 1.0,
+) -> Tuple[Tensor, MetricsTracker]:
+    """
+    Compute CTC loss given the model and its inputs.
+    Args:
+      params:
+        Parameters for training. See :func:`get_params`.
+      model:
+        The model for training. It is an instance of Conformer in our case.
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      graph_compiler:
+        It is used to build a decoding graph from a ctc topo and training
+        transcript. The training transcript is contained in the given `batch`,
+        while the ctc topo is built when this compiler is instantiated.
+      is_training:
+        True for training. False for validation. When it is True, this
+        function enables autograd during computation; when it is False, it
+        disables autograd.
+     warmup: a floating point value which increases throughout training;
+        values >= 1.0 are fully warmed up and have all modules present.
+    """
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+    feature = batch["inputs"]
+    # at entry, feature is (N, T, C)
+    assert feature.ndim == 3
+    feature = feature.to(device)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    with torch.set_grad_enabled(is_training):
+        nnet_output, encoder_memory, memory_mask = model(
+            feature, supervisions, warmup=warmup
+        )
+
+        supervision_segments, texts = encode_supervisions(
+            supervisions, subsampling_factor=params.subsampling_factor
+        )
+
+        token_ids = convert_texts_into_ids(texts, graph_compiler.sp)
+        decoding_graph = graph_compiler.compile(token_ids)
+
+        dense_fsa_vec = k2.DenseFsaVec(
+            nnet_output,
+            supervision_segments,
+            allow_truncate=params.subsampling_factor - 1,
+        )
+
+        ctc_loss = k2.ctc_loss(
+            decoding_graph=decoding_graph,
+            dense_fsa_vec=dense_fsa_vec,
+            output_beam=params.beam_size,
+            reduction=params.reduction,
+            use_double_scores=params.use_double_scores,
+        )
+
+        if params.att_rate > 0.0:
+            with torch.set_grad_enabled(is_training):
+                mmodel = model.module if hasattr(model, "module") else model
+                # Note: We need to generate an unsorted version of token_ids
+                # `encode_supervisions()` called above sorts text, but
+                # encoder_memory and memory_mask are not sorted, so we
+                # use an unsorted version `supervisions["text"]` to regenerate
+                # the token_ids
+                #
+                # See https://github.com/k2-fsa/icefall/issues/97
+                # for more details
+                unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"])
+                att_loss = mmodel.decoder_forward(
+                    encoder_memory,
+                    memory_mask,
+                    token_ids=unsorted_token_ids,
+                    sos_id=graph_compiler.sos_id,
+                    eos_id=graph_compiler.eos_id,
+                    warmup=warmup,
+                )
+        else:
+            att_loss = torch.tensor([0])
+
+        ctc_loss_is_finite = torch.isfinite(ctc_loss)
+        att_loss_is_finite = torch.isfinite(att_loss)
+        if torch.any(~ctc_loss_is_finite) or torch.any(~att_loss_is_finite):
+            logging.info(
+                "Not all losses are finite!\n"
+                f"ctc_loss: {ctc_loss}\n"
+                f"att_loss: {att_loss}"
+            )
+            display_and_save_batch(batch, params=params, sp=graph_compiler.sp)
+            ctc_loss = ctc_loss[ctc_loss_is_finite]
+            att_loss = att_loss[att_loss_is_finite]
+
+            # If the batch contains more than 10 utterances AND
+            # if either all ctc_loss or att_loss is inf or nan,
+            # we stop the training process by raising an exception
+            if torch.all(~ctc_loss_is_finite) or torch.all(~att_loss_is_finite):
+                raise ValueError(
+                    "There are too many utterances in this batch "
+                    "leading to inf or nan losses."
+                )
+
+        ctc_loss = ctc_loss.sum()
+        att_loss = att_loss.sum()
+
+        if params.att_rate > 0.0:
+            loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
+        else:
+            loss = ctc_loss
+
+    assert loss.requires_grad == is_training
+
+    info = MetricsTracker()
+    # info["frames"] is an approximate number for two reasons:
+    # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2
+    # (2) If some utterances in the batch lead to inf/nan loss, they
+    #     are filtered out.
+    info["frames"] = (
+        torch.div(feature_lens, params.subsampling_factor, rounding_mode="floor")
+        .sum()
+        .item()
+    )
+
+    # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances`  # noqa
+    info["utterances"] = feature.size(0)
+    # averaged input duration in frames over utterances
+    info["utt_duration"] = feature_lens.sum().item()
+    # averaged padding proportion over utterances
+    info["utt_pad_proportion"] = (
+        ((feature.size(1) - feature_lens) / feature.size(1)).sum().item()
+    )
+
+    # Note: We use reduction=sum while computing the loss.
+    info["loss"] = loss.detach().cpu().item()
+    info["ctc_loss"] = ctc_loss.detach().cpu().item()
+    if params.att_rate > 0.0:
+        info["att_loss"] = att_loss.detach().cpu().item()
+
+    return loss, info
+
+
+def compute_validation_loss(
+    params: AttributeDict,
+    model: Union[torch.nn.Module, DDP],
+    graph_compiler: BpeCtcTrainingGraphCompiler,
+    valid_dl: torch.utils.data.DataLoader,
+    world_size: int = 1,
+) -> MetricsTracker:
+    """Run the validation process."""
+    model.eval()
+
+    tot_loss = MetricsTracker()
+
+    for batch in valid_dl:
+        loss, loss_info = compute_loss(
+            params=params,
+            model=model,
+            graph_compiler=graph_compiler,
+            batch=batch,
+            is_training=False,
+        )
+        assert loss.requires_grad is False
+        tot_loss = tot_loss + loss_info
+
+    if world_size > 1:
+        tot_loss.reduce(loss.device)
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    if loss_value < params.best_valid_loss:
+        params.best_valid_epoch = params.cur_epoch
+        params.best_valid_loss = loss_value
+
+    return tot_loss
+
+
+def train_one_epoch(
+    params: AttributeDict,
+    model: Union[torch.nn.Module, DDP],
+    optimizer: torch.optim.Optimizer,
+    scheduler: LRSchedulerType,
+    graph_compiler: BpeCtcTrainingGraphCompiler,
+    train_dl: torch.utils.data.DataLoader,
+    valid_dl: torch.utils.data.DataLoader,
+    scaler: GradScaler,
+    model_avg: Optional[torch.nn.Module] = None,
+    tb_writer: Optional[SummaryWriter] = None,
+    world_size: int = 1,
+    rank: int = 0,
+) -> None:
+    """Train the model for one epoch.
+
+    The training loss from the mean of all frames is saved in
+    `params.train_loss`. It runs the validation process every
+    `params.valid_interval` batches.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The model for training.
+      optimizer:
+        The optimizer we are using.
+      scheduler:
+        The learning rate scheduler, we call step() every step.
+      graph_compiler:
+        It is used to convert transcripts to FSAs.
+      train_dl:
+        Dataloader for the training dataset.
+      valid_dl:
+        Dataloader for the validation dataset.
+      scaler:
+        The scaler used for mix precision training.
+      model_avg:
+        The stored model averaged from the start of training.
+      tb_writer:
+        Writer to write log messages to tensorboard.
+      world_size:
+        Number of nodes in DDP training. If it is 1, DDP is disabled.
+      rank:
+        The rank of the node in DDP training. If no DDP is used, it should
+        be set to 0.
+    """
+    model.train()
+
+    tot_loss = MetricsTracker()
+
+    for batch_idx, batch in enumerate(train_dl):
+        params.batch_idx_train += 1
+        batch_size = len(batch["supervisions"]["text"])
+
+        try:
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, loss_info = compute_loss(
+                    params=params,
+                    model=model,
+                    graph_compiler=graph_compiler,
+                    batch=batch,
+                    is_training=True,
+                    warmup=(params.batch_idx_train / params.model_warm_step),
+                )
+            # summary stats
+            tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+            # NOTE: We use reduction==sum and loss is computed over utterances
+            # in the batch and there is no normalization to it so far.
+            scaler.scale(loss).backward()
+            scheduler.step_batch(params.batch_idx_train)
+            scaler.step(optimizer)
+            scaler.update()
+            optimizer.zero_grad()
+        except:  # noqa
+            display_and_save_batch(batch, params=params, sp=graph_compiler.sp)
+            raise
+
+        if params.print_diagnostics and batch_idx == 5:
+            return
+
+        if (
+            rank == 0
+            and params.batch_idx_train > 0
+            and params.batch_idx_train % params.average_period == 0
+        ):
+            update_averaged_model(
+                params=params,
+                model_cur=model,
+                model_avg=model_avg,
+            )
+
+        if (
+            params.batch_idx_train > 0
+            and params.batch_idx_train % params.save_every_n == 0
+        ):
+            save_checkpoint_with_global_batch_idx(
+                out_dir=params.exp_dir,
+                global_batch_idx=params.batch_idx_train,
+                model=model,
+                model_avg=model_avg,
+                params=params,
+                optimizer=optimizer,
+                scheduler=scheduler,
+                sampler=train_dl.sampler,
+                scaler=scaler,
+                rank=rank,
+            )
+            remove_checkpoints(
+                out_dir=params.exp_dir,
+                topk=params.keep_last_k,
+                rank=rank,
+            )
+
+        if batch_idx % params.log_interval == 0:
+            cur_lr = scheduler.get_last_lr()[0]
+            logging.info(
+                f"Epoch {params.cur_epoch}, "
+                f"batch {batch_idx}, loss[{loss_info}], "
+                f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+                f"lr: {cur_lr:.2e}"
+            )
+
+            if tb_writer is not None:
+                tb_writer.add_scalar(
+                    "train/learning_rate", cur_lr, params.batch_idx_train
+                )
+
+                loss_info.write_summary(
+                    tb_writer, "train/current_", params.batch_idx_train
+                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+
+        if batch_idx > 0 and batch_idx % params.valid_interval == 0:
+            logging.info("Computing validation loss")
+            valid_info = compute_validation_loss(
+                params=params,
+                model=model,
+                graph_compiler=graph_compiler,
+                valid_dl=valid_dl,
+                world_size=world_size,
+            )
+            model.train()
+            logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+            if tb_writer is not None:
+                valid_info.write_summary(
+                    tb_writer, "train/valid_", params.batch_idx_train
+                )
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    params.train_loss = loss_value
+    if params.train_loss < params.best_train_loss:
+        params.best_train_epoch = params.cur_epoch
+        params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+    """
+    Args:
+      rank:
+        It is a value between 0 and `world_size-1`, which is
+        passed automatically by `mp.spawn()` in :func:`main`.
+        The node with rank 0 is responsible for saving checkpoint.
+      world_size:
+        Number of GPUs for DDP training.
+      args:
+        The return value of get_parser().parse_args()
+    """
+    params = get_params()
+    params.update(vars(args))
+
+    fix_random_seed(params.seed)
+    if world_size > 1:
+        setup_dist(rank, world_size, params.master_port)
+
+    setup_logger(f"{params.exp_dir}/log/log-train")
+    logging.info("Training started")
+    logging.info(params)
+
+    if args.tensorboard and rank == 0:
+        tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+    else:
+        tb_writer = None
+
+    lexicon = Lexicon(params.lang_dir)
+    max_token_id = max(lexicon.tokens)
+    num_classes = max_token_id + 1  # +1 for the blank
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", rank)
+    logging.info(f"Device: {device}")
+
+    if "lang_bpe" not in str(params.lang_dir):
+        raise ValueError(
+            f"Unsupported type of lang dir (we expected it to have "
+            f"'lang_bpe' in its name): {params.lang_dir}"
+        )
+
+    graph_compiler = BpeCtcTrainingGraphCompiler(
+        params.lang_dir,
+        device=device,
+        sos_token="",
+        eos_token="",
+    )
+
+    logging.info("About to create model")
+    model = Conformer(
+        num_features=params.feature_dim,
+        num_classes=num_classes,
+        subsampling_factor=params.subsampling_factor,
+        d_model=params.dim_model,
+        nhead=params.nhead,
+        dim_feedforward=params.dim_feedforward,
+        num_encoder_layers=params.num_encoder_layers,
+        num_decoder_layers=params.num_decoder_layers,
+    )
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    assert params.save_every_n >= params.average_period
+    model_avg: Optional[torch.nn.Module] = None
+    if rank == 0:
+        # model_avg is only used with rank 0
+        model_avg = copy.deepcopy(model)
+
+    assert params.start_epoch > 0, params.start_epoch
+    checkpoints = load_checkpoint_if_available(
+        params=params, model=model, model_avg=model_avg
+    )
+
+    model.to(device)
+    if world_size > 1:
+        logging.info("Using DDP")
+        model = DDP(model, device_ids=[rank])
+
+    optimizer = optim.Eve(model.parameters(), lr=params.initial_lr)
+    scheduler = optim.Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+    if checkpoints and checkpoints.get("optimizer") is not None:
+        logging.info("Loading optimizer state dict")
+        optimizer.load_state_dict(checkpoints["optimizer"])
+
+    if checkpoints and checkpoints.get("scheduler") is not None:
+        logging.info("Loading scheduler state dict")
+        scheduler.load_state_dict(checkpoints["scheduler"])
+
+    if params.print_diagnostics:
+        opts = diagnostics.TensorDiagnosticOptions(
+            2**22
+        )  # allow 4 megabytes per sub-module
+        diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+    tedlium = TedLiumAsrDataModule(args)
+
+    train_cuts = tedlium.train_cuts()
+
+    if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+        # We only load the sampler's state dict when it loads a checkpoint
+        # saved in the middle of an epoch
+        sampler_state_dict = checkpoints["sampler"]
+    else:
+        sampler_state_dict = None
+
+    train_dl = tedlium.train_dataloaders(
+        train_cuts, sampler_state_dict=sampler_state_dict
+    )
+
+    valid_cuts = tedlium.dev_cuts()
+    valid_dl = tedlium.valid_dataloaders(valid_cuts)
+
+    if (
+        params.start_epoch <= 1
+        and params.start_batch <= 0
+        and not params.print_diagnostics
+    ):
+        scan_pessimistic_batches_for_oom(
+            model=model,
+            train_dl=train_dl,
+            optimizer=optimizer,
+            graph_compiler=graph_compiler,
+            params=params,
+            warmup=0.0 if params.start_epoch == 1 else 1.0,
+        )
+
+    scaler = GradScaler(enabled=params.use_fp16)
+    if checkpoints and "grad_scaler" in checkpoints:
+        logging.info("Loading grad scaler state dict")
+        scaler.load_state_dict(checkpoints["grad_scaler"])
+
+    for epoch in range(params.start_epoch, params.num_epochs + 1):
+        scheduler.step_epoch(epoch - 1)
+        fix_random_seed(params.seed + epoch - 1)
+        train_dl.sampler.set_epoch(epoch - 1)
+        train_dl.dataset.epoch = epoch - 1
+
+        if tb_writer is not None:
+            tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+        params.cur_epoch = epoch
+
+        train_one_epoch(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            graph_compiler=graph_compiler,
+            train_dl=train_dl,
+            valid_dl=valid_dl,
+            scaler=scaler,
+            tb_writer=tb_writer,
+            world_size=world_size,
+            rank=rank,
+        )
+
+        if params.print_diagnostics:
+            diagnostic.print_diagnostics()
+            break
+
+        save_checkpoint(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sampler=train_dl.sampler,
+            scaler=scaler,
+            rank=rank,
+        )
+
+    logging.info("Done!")
+
+    if world_size > 1:
+        torch.distributed.barrier()
+        cleanup_dist()
+
+
+def scan_pessimistic_batches_for_oom(
+    model: Union[torch.nn.Module, DDP],
+    train_dl: torch.utils.data.DataLoader,
+    optimizer: torch.optim.Optimizer,
+    graph_compiler: BpeCtcTrainingGraphCompiler,
+    params: AttributeDict,
+    warmup: float,
+):
+    from lhotse.dataset import find_pessimistic_batches
+
+    logging.info(
+        "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+    )
+    batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+    for criterion, cuts in batches.items():
+        batch = train_dl.dataset[cuts]
+        try:
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, _ = compute_loss(
+                    params=params,
+                    model=model,
+                    graph_compiler=graph_compiler,
+                    batch=batch,
+                    is_training=True,
+                    warmup=warmup,
+                )
+            loss.backward()
+            optimizer.step()
+            optimizer.zero_grad()
+        except Exception as e:
+            if "CUDA out of memory" in str(e):
+                logging.error(
+                    "Your GPU ran out of memory with the current "
+                    "max_duration setting. We recommend decreasing "
+                    "max_duration and trying again.\n"
+                    f"Failing criterion: {criterion} "
+                    f"(={crit_values[criterion]}) ..."
+                )
+            display_and_save_batch(batch, params=params, sp=graph_compiler.sp)
+            raise
+
+
+def main():
+    parser = get_parser()
+    TedLiumAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    world_size = args.world_size
+    assert world_size >= 1
+    if world_size > 1:
+        mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+    else:
+        run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
+# in PyTorch 1.12 and later.
+torch.backends.cuda.matmul.allow_tf32 = True
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/tedlium3/ASR/conformer_ctc2/transformer.py b/egs/tedlium3/ASR/conformer_ctc2/transformer.py
new file mode 100644
index 000000000..9dbf32e48
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/transformer.py
@@ -0,0 +1,1093 @@
+# Copyright    2021  University of Chinese Academy of Sciences (author: Han Zhu)
+# Copyright    2022  Xiaomi Corp.                              (author: Quandong Wang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import math
+from typing import Dict, List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+from attention import MultiheadAttention
+from combiner import RandomCombine
+from label_smoothing import LabelSmoothingLoss
+from scaling import (
+    ActivationBalancer,
+    BasicNorm,
+    DoubleSwish,
+    ScaledEmbedding,
+    ScaledLinear,
+)
+from subsampling import Conv2dSubsampling
+from torch.nn.utils.rnn import pad_sequence
+
+# Note: TorchScript requires Dict/List/etc. to be fully typed.
+Supervisions = Dict[str, torch.Tensor]
+
+
+class Transformer(nn.Module):
+    def __init__(
+        self,
+        num_features: int,
+        num_classes: int,
+        subsampling_factor: int = 4,
+        d_model: int = 256,
+        nhead: int = 4,
+        dim_feedforward: int = 2048,
+        num_encoder_layers: int = 12,
+        num_decoder_layers: int = 6,
+        dropout: float = 0.1,
+        layer_dropout: float = 0.075,
+        aux_layer_period: int = 3,
+    ) -> None:
+        """
+        Args:
+          num_features:
+            the input dimension of the model.
+          num_classes:
+            the output dimension of the model.
+          subsampling_factor:
+            number of output frames is num_in_frames // subsampling_factor;
+            currently, subsampling_factor MUST be 4.
+          d_model:
+            attention dimension.
+          nhead:
+            number of heads in multi-head attention;
+            must satisfy d_model // nhead == 0.
+          dim_feedforward:
+            the output dimension of the feedforward layers in encoder/decoder.
+          num_encoder_layers:
+            number of encoder layers.
+          num_decoder_layers:
+            number of decoder layers.
+          dropout:
+            dropout in encoder/decoder.
+          layer_dropout:
+            layer-dropout rate.
+          aux_layer_period:
+            determines the auxiliary encoder layers.
+        """
+        super().__init__()
+
+        self.num_features = num_features
+        self.num_classes = num_classes
+        self.subsampling_factor = subsampling_factor
+        if subsampling_factor != 4:
+            raise NotImplementedError("Support only 'subsampling_factor=4'.")
+
+        # self.encoder_embed converts the input of shape (N, T, num_classes)
+        # to the shape (N, T//subsampling_factor, d_model).
+        # That is, it does two things simultaneously:
+        #   (1) subsampling: T -> T//subsampling_factor
+        #   (2) embedding: num_classes -> d_model
+        self.encoder_embed = Conv2dSubsampling(num_features, d_model)
+
+        self.encoder_pos = PositionalEncoding(d_model, dropout)
+
+        encoder_layer = TransformerEncoderLayer(
+            d_model=d_model,
+            nhead=nhead,
+            dim_feedforward=dim_feedforward,
+            dropout=dropout,
+            layer_dropout=layer_dropout,
+        )
+        # aux_layers from 1/3
+        self.encoder = TransformerEncoder(
+            encoder_layer=encoder_layer,
+            num_layers=num_encoder_layers,
+            aux_layers=list(
+                range(
+                    num_encoder_layers // 3,
+                    num_encoder_layers - 1,
+                    aux_layer_period,
+                )
+            ),
+        )
+
+        # TODO(fangjun): remove dropout
+        self.encoder_output_layer = nn.Sequential(
+            nn.Dropout(p=dropout), ScaledLinear(d_model, num_classes, bias=True)
+        )
+
+        if num_decoder_layers > 0:
+            self.decoder_num_class = (
+                self.num_classes
+            )  # bpe model already has sos/eos symbol
+
+            self.decoder_embed = ScaledEmbedding(
+                num_embeddings=self.decoder_num_class, embedding_dim=d_model
+            )
+            self.decoder_pos = PositionalEncoding(d_model, dropout)
+
+            decoder_layer = TransformerDecoderLayer(
+                d_model=d_model,
+                nhead=nhead,
+                dim_feedforward=dim_feedforward,
+                dropout=dropout,
+            )
+
+            self.decoder = TransformerDecoder(
+                decoder_layer=decoder_layer,
+                num_layers=num_decoder_layers,
+                aux_layers=[],
+            )
+
+            self.decoder_output_layer = ScaledLinear(
+                d_model, self.decoder_num_class, bias=True
+            )
+
+            self.decoder_criterion = LabelSmoothingLoss(reduction="none")
+        else:
+            self.decoder_criterion = None
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        supervision: Optional[Supervisions] = None,
+        warmup: float = 1.0,
+    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+        """
+        Args:
+          x:
+            The input tensor. Its shape is (N, S, C).
+          supervision:
+            Supervision in lhotse format.
+            See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32  # noqa
+            (CAUTION: It contains length information, i.e., start and number of
+             frames, before subsampling)
+          warmup:
+            a floating point value that gradually increases from 0 throughout
+            training; when it is >= 1.0 we are "fully warmed up". It is used
+            to turn modules on sequentially.
+
+        Returns:
+          Return a tuple containing 3 tensors:
+            - CTC output for ctc decoding. Its shape is (N, S, C)
+            - Encoder output with shape (S, N, C). It can be used as key and
+              value for the decoder.
+            - Encoder output padding mask. It can be used as
+              memory_key_padding_mask for the decoder. Its shape is (N, S).
+              It is None if `supervision` is None.
+        """
+
+        encoder_memory, memory_key_padding_mask = self.run_encoder(
+            x, supervision, warmup
+        )
+
+        x = self.ctc_output(encoder_memory)
+        return x, encoder_memory, memory_key_padding_mask
+
+    def run_encoder(
+        self,
+        x: torch.Tensor,
+        supervisions: Optional[Supervisions] = None,
+        warmup: float = 1.0,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+        """Run the transformer encoder.
+
+        Args:
+          x:
+            The model input. Its shape is (N, S, C).
+          supervisions:
+            Supervision in lhotse format.
+            See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32  # noqa
+            CAUTION: It contains length information, i.e., start and number of
+            frames, before subsampling
+            It is read directly from the batch, without any sorting. It is used
+            to compute the encoder padding mask, which is used as memory key
+            padding mask for the decoder.
+          warmup:
+            a floating point value that gradually increases from 0 throughout
+            training; when it is >= 1.0 we are "fully warmed up". It is used
+            to turn modules on sequentially.
+
+        Returns:
+          Return a tuple with two tensors:
+            - The encoder output, with shape (S, N, C)
+            - encoder padding mask, with shape (N, S).
+              The mask is None if `supervisions` is None.
+              It is used as memory key padding mask in the decoder.
+        """
+        x = self.encoder_embed(x)
+        x = self.encoder_pos(x)
+        x = x.permute(1, 0, 2)  # (N, S, C) -> (S, N, C)
+        mask = encoder_padding_mask(x.size(0), supervisions)
+        mask = mask.to(x.device) if mask is not None else None
+        x = self.encoder(x, src_key_padding_mask=mask, warmup=warmup)  # (S, N, C)
+
+        return x, mask
+
+    def ctc_output(self, x: torch.Tensor) -> torch.Tensor:
+        """
+        Args:
+          x:
+            the output tensor from the transformer encoder;
+            its shape is (S, N, C)
+
+        Returns:
+          Return a tensor that can be used for CTC decoding.
+          Its shape is (N, S, C)
+        """
+        x = self.encoder_output_layer(x)
+        x = x.permute(1, 0, 2)  # (S, N, C) -> (N, S, C)
+        x = nn.functional.log_softmax(x, dim=-1)  # (N, S, C)
+        return x
+
+    @torch.jit.export
+    def decoder_forward(
+        self,
+        memory: torch.Tensor,
+        memory_key_padding_mask: torch.Tensor,
+        token_ids: List[List[int]],
+        sos_id: int,
+        eos_id: int,
+        warmup: float = 1.0,
+    ) -> torch.Tensor:
+        """
+        Args:
+          memory:
+            It's the output of the encoder of shape (S, N, C)
+          memory_key_padding_mask:
+            The padding mask from the encoder of shape (N, S).
+          token_ids:
+            A list-of-list IDs. Each sublist contains IDs for an utterance.
+            The IDs can be either phone IDs or word piece IDs.
+          sos_id:
+            sos token id
+          eos_id:
+            eos token id
+          warmup:
+            a floating point value that gradually increases from 0 throughout
+            training; when it is >= 1.0 we are "fully warmed up". It is used
+            to turn modules on sequentially.
+
+        Returns:
+          A scalar, the **sum** of label smoothing loss over utterances
+          in the batch without any normalization.
+        """
+        ys_in = add_sos(token_ids, sos_id=sos_id)
+        ys_in = [torch.tensor(y) for y in ys_in]
+        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
+
+        ys_out = add_eos(token_ids, eos_id=eos_id)
+        ys_out = [torch.tensor(y) for y in ys_out]
+        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
+
+        device = memory.device
+        ys_in_pad = ys_in_pad.to(device)
+        ys_out_pad = ys_out_pad.to(device)
+
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
+
+        tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
+        # TODO: Use length information to create the decoder padding mask
+        # We set the first column to False since the first column in ys_in_pad
+        # contains sos_id, which is the same as eos_id in our current setting.
+        tgt_key_padding_mask[:, 0] = False
+
+        tgt = self.decoder_embed(ys_in_pad)  # (N, T) -> (N, T, C)
+        tgt = self.decoder_pos(tgt)
+        tgt = tgt.permute(1, 0, 2)  # (N, T, C) -> (T, N, C)
+        pred_pad = self.decoder(
+            tgt=tgt,
+            memory=memory,
+            tgt_mask=tgt_mask,
+            tgt_key_padding_mask=tgt_key_padding_mask,
+            memory_key_padding_mask=memory_key_padding_mask,
+            warmup=warmup,
+        )  # (T, N, C)
+        pred_pad = pred_pad.permute(1, 0, 2)  # (T, N, C) -> (N, T, C)
+        pred_pad = self.decoder_output_layer(pred_pad)  # (N, T, C)
+
+        decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad)
+
+        return decoder_loss
+
+    @torch.jit.export
+    def decoder_nll(
+        self,
+        memory: torch.Tensor,
+        memory_key_padding_mask: torch.Tensor,
+        token_ids: List[torch.Tensor],
+        sos_id: int,
+        eos_id: int,
+        warmup: float = 1.0,
+    ) -> torch.Tensor:
+        """
+        Args:
+          memory:
+            It's the output of the encoder of shape (S, N, C).
+          memory_key_padding_mask:
+            The padding mask from the encoder of shape (N, S).
+          token_ids:
+            A list-of-list IDs (e.g., word piece IDs).
+            Each sublist represents an utterance.
+          sos_id:
+            The token ID for SOS.
+          eos_id:
+            The token ID for EOS.
+          warmup:
+            a floating point value that gradually increases from 0 throughout
+            training; when it is >= 1.0 we are "fully warmed up". It is used
+            to turn modules on sequentially.
+
+        Returns:
+          A 2-D tensor of shape (len(token_ids), max_token_length)
+          representing the cross entropy loss (i.e., negative log-likelihood).
+        """
+        # The common part between this function and decoder_forward could be
+        # extracted as a separate function.
+        if isinstance(token_ids[0], torch.Tensor):
+            # This branch is executed by torchscript in C++.
+            # See https://github.com/k2-fsa/k2/pull/870
+            # https://github.com/k2-fsa/k2/blob/3c1c18400060415b141ccea0115fd4bf0ad6234e/k2/torch/bin/attention_rescore.cu#L286
+            token_ids = [tolist(t) for t in token_ids]
+
+        ys_in = add_sos(token_ids, sos_id=sos_id)
+        ys_in = [torch.tensor(y) for y in ys_in]
+        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
+
+        ys_out = add_eos(token_ids, eos_id=eos_id)
+        ys_out = [torch.tensor(y) for y in ys_out]
+        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
+
+        device = memory.device
+        ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
+        ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
+
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
+
+        tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
+        # TODO: Use length information to create the decoder padding mask
+        # We set the first column to False since the first column in ys_in_pad
+        # contains sos_id, which is the same as eos_id in our current setting.
+        tgt_key_padding_mask[:, 0] = False
+
+        tgt = self.decoder_embed(ys_in_pad)  # (N, T) -> (N, T, C)
+        tgt = self.decoder_pos(tgt)
+        tgt = tgt.permute(1, 0, 2)  # (N, T, С) -> (T, N, C)
+        pred_pad = self.decoder(
+            tgt=tgt,
+            memory=memory,
+            tgt_mask=tgt_mask,
+            tgt_key_padding_mask=tgt_key_padding_mask,
+            memory_key_padding_mask=memory_key_padding_mask,
+            warmup=warmup,
+        )  # (T, B, F)
+        pred_pad = pred_pad.permute(1, 0, 2)  # (T, N, C) -> (N, T, C)
+        pred_pad = self.decoder_output_layer(pred_pad)  # (N, T, C)
+        # nll: negative log-likelihood
+        nll = torch.nn.functional.cross_entropy(
+            pred_pad.view(-1, self.decoder_num_class),
+            ys_out_pad.view(-1),
+            ignore_index=-1,
+            reduction="none",
+        )
+
+        nll = nll.view(pred_pad.shape[0], -1)
+
+        return nll
+
+
+class TransformerEncoderLayer(nn.Module):
+    """
+    Modified from torch.nn.TransformerEncoderLayer.
+
+    Example:
+        >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
+        >>> src = torch.rand(10, 32, 512)
+        >>> out = encoder_layer(src)
+    """
+
+    def __init__(
+        self,
+        d_model: int,
+        nhead: int,
+        dim_feedforward: int = 2048,
+        dropout: float = 0.1,
+        bypass_scale: float = 0.1,
+        layer_dropout: float = 0.075,
+    ) -> None:
+        """
+        Args:
+          d_model:
+            the number of expected features in the input (required).
+          nhead:
+            the number of heads in the multiheadattention models (required).
+          dim_feedforward:
+            the dimension of the feedforward network model (default=2048).
+          dropout:
+            the dropout value (default=0.1).
+          bypass_scale:
+            a scale on the layer's output, used in bypass (resnet-type) skip-connection;
+            when the layer is bypassed the final output will be a
+            weighted sum of the layer's input and layer's output with weights
+            (1.0-bypass_scale) and bypass_scale correspondingly (default=0.1).
+          layer_dropout:
+            the probability to bypass the layer (default=0.075).
+        """
+
+        super().__init__()
+
+        if bypass_scale < 0.0 or bypass_scale > 1.0:
+            raise ValueError("bypass_scale should be between 0.0 and 1.0")
+
+        if layer_dropout < 0.0 or layer_dropout > 1.0:
+            raise ValueError("layer_dropout should be between 0.0 and 1.0")
+
+        self.bypass_scale = bypass_scale
+        self.layer_dropout = layer_dropout
+
+        self.self_attn = MultiheadAttention(d_model, nhead)
+        # Implementation of Feedforward model
+
+        self.feed_forward = nn.Sequential(
+            ScaledLinear(d_model, dim_feedforward),
+            ActivationBalancer(channel_dim=-1),
+            DoubleSwish(),
+            nn.Dropout(dropout),
+            ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
+        )
+
+        self.norm_final = BasicNorm(d_model)
+
+        # try to ensure the output is close to zero-mean (or at least, zero-median).
+        self.balancer = ActivationBalancer(
+            channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
+        )
+
+        self.dropout = nn.Dropout(dropout)
+
+    def forward(
+        self,
+        src: torch.Tensor,
+        src_mask: Optional[torch.Tensor] = None,
+        src_key_padding_mask: Optional[torch.Tensor] = None,
+        warmup: float = 1.0,
+    ) -> torch.Tensor:
+        """
+        Pass the input through the encoder layer.
+
+        Args:
+          src:
+            the sequence to the encoder layer of shape (S, N, C) (required).
+          src_mask:
+            the mask for the src sequence of shape (S, S) (optional).
+          src_key_padding_mask:
+            the mask for the src keys per batch of shape (N, S) (optional)
+          warmup:
+            controls selective bypass of layers; if < 1.0, we will
+            bypass the layer more frequently (default=1.0).
+
+        Returns:
+          Output tensor of the shape (S, N, C), where
+          S is the source sequence length,
+          N is the batch size,
+          C is the feature number.
+
+        """
+        src_orig = src
+
+        warmup_scale = min(self.bypass_scale + warmup, 1.0)
+        # alpha = 1.0 means fully use this encoder layer, 0.0 would mean
+        # completely bypass it.
+        if self.training:
+            alpha = (
+                warmup_scale
+                if torch.rand(()).item() <= (1.0 - self.layer_dropout)
+                else self.bypass_scale
+            )
+        else:
+            alpha = 1.0
+
+        src_att = self.self_attn(
+            src,
+            src,
+            src,
+            attn_mask=src_mask,
+            key_padding_mask=src_key_padding_mask,
+        )[0]
+        src = src + self.dropout(src_att)
+
+        src = src + self.dropout(self.feed_forward(src))
+
+        src = self.norm_final(self.balancer(src))
+
+        if alpha != 1.0:
+            src = alpha * src + (1.0 - alpha) * src_orig
+
+        return src
+
+
+class TransformerDecoderLayer(nn.Module):
+    """Modified from torch.nn.TransformerDecoderLayer.
+
+    Example:
+        >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
+        >>> memory = torch.rand(10, 32, 512)
+        >>> tgt = torch.rand(20, 32, 512)
+        >>> out = decoder_layer(tgt, memory)
+    """
+
+    def __init__(
+        self,
+        d_model: int,
+        nhead: int,
+        dim_feedforward: int = 2048,
+        dropout: float = 0.1,
+        bypass_scale: float = 0.1,
+        layer_dropout: float = 0.075,
+    ) -> None:
+
+        """
+        Args:
+          d_model:
+            the number of expected features in the input (required).
+          nhead:
+            the number of heads in the multiheadattention models (required).
+          dim_feedforward:
+            the dimension of the feedforward network model (default=2048).
+          dropout:
+            the dropout value (default=0.1).
+          bypass_scale:
+            a scale on the layer's output, used in bypass (resnet-type) skip-connection;
+            when the layer is bypassed, the final output will be a
+            weighted sum of the layer's input and layer's output with weights
+            (1.0-bypass_scale) and bypass_scale correspondingly (default=0.1).
+          layer_dropout:
+            the probability to bypass the layer (default=0.075).
+        """
+
+        super().__init__()
+
+        if bypass_scale < 0.0 or bypass_scale > 1.0:
+            raise ValueError("bypass_scale should be between 0.0 and 1.0")
+
+        if layer_dropout < 0.0 or layer_dropout > 1.0:
+            raise ValueError("layer_dropout should be between 0.0 and 1.0")
+
+        self.bypass_scale = bypass_scale
+        self.layer_dropout = layer_dropout
+
+        self.self_attn = MultiheadAttention(d_model, nhead)
+        self.src_attn = MultiheadAttention(d_model, nhead)
+
+        # Implementation of Feedforward model
+        self.feed_forward = nn.Sequential(
+            ScaledLinear(d_model, dim_feedforward),
+            ActivationBalancer(channel_dim=-1),
+            DoubleSwish(),
+            nn.Dropout(dropout),
+            ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
+        )
+
+        self.norm_final = BasicNorm(d_model)
+
+        # try to ensure the output is close to zero-mean (or at least, zero-median).
+        self.balancer = ActivationBalancer(
+            channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
+        )
+
+        self.dropout = nn.Dropout(dropout)
+
+    def forward(
+        self,
+        tgt: torch.Tensor,
+        memory: torch.Tensor,
+        tgt_mask: Optional[torch.Tensor] = None,
+        memory_mask: Optional[torch.Tensor] = None,
+        tgt_key_padding_mask: Optional[torch.Tensor] = None,
+        memory_key_padding_mask: Optional[torch.Tensor] = None,
+        warmup: float = 1.0,
+    ) -> torch.Tensor:
+        """Pass the inputs (and mask) through the decoder layer.
+
+        Args:
+          tgt:
+            the sequence to the decoder layer of shape (T, N, C) (required).
+          memory:
+            the sequence from the last layer of the encoder of shape (S, N, C) (required).
+          tgt_mask:
+            the mask for the tgt sequence of shape (T, T) (optional).
+          memory_mask:
+            the mask for the memory sequence of shape (T, S) (optional).
+          tgt_key_padding_mask:
+            the mask for the tgt keys per batch of shape (N, T) (optional).
+          memory_key_padding_mask:
+            the mask for the memory keys per batch of shape (N, S) (optional).
+          warmup: controls selective bypass of layers; if < 1.0, we will
+            bypass the layer more frequently (default=1.0).
+
+        Returns:
+          Output tensor of the shape (T, N, C), where
+          S is the source sequence length,
+          T is the target sequence length,
+          N is the batch size,
+          C is the feature number.
+
+        """
+        tgt_orig = tgt
+
+        warmup_scale = min(self.bypass_scale + warmup, 1.0)
+        # alpha = 1.0 means fully use this encoder layer, 0.0 would mean
+        # completely bypass it.
+        if self.training:
+            alpha = (
+                warmup_scale
+                if torch.rand(()).item() <= (1.0 - self.layer_dropout)
+                else self.bypass_scale
+            )
+        else:
+            alpha = 1.0
+
+        tgt_att = self.self_attn(
+            tgt,
+            tgt,
+            tgt,
+            attn_mask=tgt_mask,
+            key_padding_mask=tgt_key_padding_mask,
+        )[0]
+        tgt = tgt + self.dropout(tgt_att)
+
+        src_att = self.src_attn(
+            tgt,
+            memory,
+            memory,
+            attn_mask=memory_mask,
+            key_padding_mask=memory_key_padding_mask,
+        )[0]
+        tgt = tgt + self.dropout(src_att)
+
+        tgt = tgt + self.dropout(self.feed_forward(tgt))
+
+        tgt = self.norm_final(self.balancer(tgt))
+
+        if alpha != 1.0:
+            tgt = alpha * tgt + (1.0 - alpha) * tgt_orig
+
+        return tgt
+
+
+class TransformerEncoder(nn.Module):
+    """TransformerEncoder is a stack of N encoder layers
+
+    Examples:
+        >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
+        >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
+        >>> src = torch.rand(10, 32, 512)
+        >>> out = transformer_encoder(src)
+    """
+
+    def __init__(
+        self,
+        encoder_layer: nn.Module,
+        num_layers: int,
+        aux_layers: List[int],
+    ) -> None:
+        """
+        Args:
+          encoder_layer:
+            an instance of the TransformerEncoderLayer() class (required).
+          num_layers:
+            the number of sub-encoder-layers in the encoder (required).
+          aux_layers:
+            list of indexes of sub-encoder-layers outputs to be combined (required).
+        """
+
+        super().__init__()
+        self.layers = nn.ModuleList(
+            [copy.deepcopy(encoder_layer) for i in range(num_layers)]
+        )
+        self.num_layers = num_layers
+
+        assert len(set(aux_layers)) == len(aux_layers)
+
+        assert num_layers - 1 not in aux_layers
+        self.aux_layers = aux_layers + [num_layers - 1]
+
+        self.combiner = RandomCombine(
+            num_inputs=len(self.aux_layers),
+            final_weight=0.5,
+            pure_prob=0.333,
+            stddev=2.0,
+        )
+
+    def forward(
+        self,
+        src: torch.Tensor,
+        mask: Optional[torch.Tensor] = None,
+        src_key_padding_mask: Optional[torch.Tensor] = None,
+        warmup: float = 1.0,
+    ) -> torch.Tensor:
+        """Pass the input through the encoder layers in turn.
+
+        Args:
+          src:
+            the input to the encoder of shape (S, N, C) (required).
+          mask:
+            the mask for the src sequence of shape (S, S) (optional).
+          src_key_padding_mask:
+            the mask for the src keys per batch of shape (N, S) (optional).
+          warmup:
+            controls selective bypass of layer; if < 1.0, we will
+            bypass the layer more frequently (default=1.0).
+
+        Returns:
+          Output tensor of the shape (S, N, C), where
+          S is the source sequence length,
+          N is the batch size,
+          C is the feature number.
+
+        """
+        output = src
+
+        outputs = []
+        for i, mod in enumerate(self.layers):
+            output = mod(
+                output,
+                src_mask=mask,
+                src_key_padding_mask=src_key_padding_mask,
+                warmup=warmup,
+            )
+
+            if i in self.aux_layers:
+                outputs.append(output)
+
+        output = self.combiner(outputs)
+
+        return output
+
+
+class TransformerDecoder(nn.Module):
+    """TransformerDecoder is a stack of N decoder layers
+
+    Examples:
+        >>> decoder_layer = TransformerDecoderLayer(d_model=512, nhead=8)
+        >>> transformer_decoder = TransformerDecoder(decoder_layer, num_layers=6)
+        >>> memory = torch.rand(10, 32, 512)
+        >>> tgt = torch.rand(20, 32, 512)
+        >>> out = transformer_decoder(tgt, memory)
+    """
+
+    def __init__(
+        self,
+        decoder_layer: nn.Module,
+        num_layers: int,
+        aux_layers: List[int],
+    ) -> None:
+        """
+        Args:
+          decoder_layer:
+            an instance of the TransformerDecoderLayer() class (required).
+          num_layers:
+            the number of decoder layers in the decoder (required).
+          aux_layers:
+            list of indexes of decoder layer outputs to be combined (required).
+        """
+
+        super().__init__()
+        self.layers = nn.ModuleList(
+            [copy.deepcopy(decoder_layer) for i in range(num_layers)]
+        )
+        self.num_layers = num_layers
+
+        assert len(set(aux_layers)) == len(aux_layers)
+
+        assert num_layers - 1 not in aux_layers
+        self.aux_layers = aux_layers + [num_layers - 1]
+
+        self.combiner = RandomCombine(
+            num_inputs=len(self.aux_layers),
+            final_weight=0.5,
+            pure_prob=0.333,
+            stddev=2.0,
+        )
+
+    def forward(
+        self,
+        tgt: torch.Tensor,
+        memory: torch.Tensor,
+        tgt_mask: Optional[torch.Tensor] = None,
+        memory_mask: Optional[torch.Tensor] = None,
+        tgt_key_padding_mask: Optional[torch.Tensor] = None,
+        memory_key_padding_mask: Optional[torch.Tensor] = None,
+        warmup: float = 1.0,
+    ) -> torch.Tensor:
+        """Pass the input (and mask) through the decoder layers in turn.
+
+        Args:
+          tgt:
+            the sequence to the decoder of shape (T, N, C) (required).
+          memory:
+            the sequence from the last layer of the encoder of shape (S, N, C) (required).
+          tgt_mask:
+            the mask for the tgt sequence of shape (T, T) (optional).
+          memory_mask:
+            the mask for the memory sequence of shape (T, S) (optional).
+          tgt_key_padding_mask:
+            the mask for the tgt keys per batch of shape (N, T)  (optional).
+          memory_key_padding_mask:
+            the mask for the memory keys per batch of shape (N, S) (optional).
+          warmup:
+            controls selective bypass of layer; if < 1.0, we will
+            bypass the layer more frequently (default=1.0).
+
+        Returns:
+          Output tensor of the shape (T, N, C), where
+          S is the source sequence length,
+          T is the target sequence length,
+          N is the batch size,
+          C is the feature number.
+
+        """
+        output = tgt
+
+        outputs = []
+        for i, mod in enumerate(self.layers):
+            output = mod(
+                output,
+                memory,
+                tgt_mask=tgt_mask,
+                memory_mask=memory_mask,
+                tgt_key_padding_mask=tgt_key_padding_mask,
+                memory_key_padding_mask=memory_key_padding_mask,
+                warmup=warmup,
+            )
+
+            if i in self.aux_layers:
+                outputs.append(output)
+
+        output = self.combiner(outputs)
+
+        return output
+
+
+class PositionalEncoding(nn.Module):
+    """This class implements the positional encoding
+    proposed in the following paper:
+
+    - Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf
+
+        PE(pos, 2i) = sin(pos / (10000^(2i/d_modle))
+        PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle))
+
+    Note:
+
+      1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model)))
+                               = exp(-1* 2i / d_model * log(100000))
+                               = exp(2i * -(log(10000) / d_model))
+    """
+
+    def __init__(self, d_model: int, dropout: float = 0.1) -> None:
+        """
+        Args:
+          d_model: Embedding dimension.
+          dropout: Dropout probability to be applied to the output of this module.
+        """
+        super().__init__()
+        self.d_model = d_model
+        self.xscale = math.sqrt(self.d_model)
+        self.dropout = nn.Dropout(p=dropout)
+        # not doing: self.pe = None because of errors thrown by torchscript
+        self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32)
+
+    def extend_pe(self, x: torch.Tensor) -> None:
+        """Extend the time t in the positional encoding if required.
+        The shape of `self.pe` is (1, T1, d_model). The shape of the input x
+        is (N, T, d_model). If T > T1, then we change the shape of self.pe
+        to (N, T, d_model). Otherwise, nothing is done.
+
+        Args:
+          x:
+            It is a tensor of shape (N, T, C).
+            T is the target sequence length,
+            N is the batch size,
+            C is the feature number.
+        """
+        if self.pe is not None:
+            if self.pe.size(1) >= x.size(1):
+                self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+                return
+        pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32)
+        position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
+        div_term = torch.exp(
+            torch.arange(0, self.d_model, 2, dtype=torch.float32)
+            * -(math.log(10000.0) / self.d_model)
+        )
+        pe[:, 0::2] = torch.sin(position * div_term)
+        pe[:, 1::2] = torch.cos(position * div_term)
+        pe = pe.unsqueeze(0)
+        # Now pe is of shape (1, T, d_model), where T is x.size(1)
+        self.pe = pe.to(device=x.device, dtype=x.dtype)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """
+        Add positional encoding.
+
+        Args:
+          x: Input of shape is (N, T, C)
+
+        Returns:
+          A tensor of the same shape (N, T, C),
+          T is the target sequence length,
+          N is the batch size,
+          C is the feature number.
+
+        """
+        self.extend_pe(x)
+        x = x + self.pe[:, : x.size(1), :]
+        return self.dropout(x)
+
+
+def encoder_padding_mask(
+    max_len: int, supervisions: Optional[Supervisions] = None
+) -> Optional[torch.Tensor]:
+    """Make mask tensor containing indexes of padded part.
+
+    TODO:
+      This function **assumes** that the model uses
+      a subsampling factor of 4. We should remove that
+      assumption later.
+
+    Args:
+      max_len:
+        Maximum length of input features.
+        CAUTION: It is the length after subsampling.
+      supervisions:
+        Supervision in lhotse format.
+        See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32  # noqa
+        (CAUTION: It contains length information, i.e., start and number of
+         frames, before subsampling)
+
+    Returns:
+      Mask tensor of dimension (batch_size, input_length),
+      True denotes the masked indices.
+    """
+    if supervisions is None:
+        return None
+
+    supervision_segments = torch.stack(
+        (
+            supervisions["sequence_idx"],
+            supervisions["start_frame"],
+            supervisions["num_frames"],
+        ),
+        1,
+    ).to(torch.int32)
+
+    lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)]
+    for idx in range(supervision_segments.size(0)):
+        # Note: TorchScript doesn't allow to unpack tensors as tuples
+        sequence_idx = supervision_segments[idx, 0].item()
+        start_frame = supervision_segments[idx, 1].item()
+        num_frames = supervision_segments[idx, 2].item()
+        lengths[sequence_idx] = start_frame + num_frames
+
+    lengths = [((i - 1) // 2 - 1) // 2 for i in lengths]
+    bs = int(len(lengths))
+    seq_range = torch.arange(0, max_len, dtype=torch.int64)
+    seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len)
+    # Note: TorchScript doesn't implement Tensor.new()
+    seq_length_expand = torch.tensor(
+        lengths, device=seq_range_expand.device, dtype=seq_range_expand.dtype
+    ).unsqueeze(-1)
+    mask = seq_range_expand >= seq_length_expand
+
+    return mask
+
+
+def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor:
+    """Generate a length mask for input.
+
+    The masked position are filled with True,
+    Unmasked positions are filled with False.
+
+    Args:
+      ys_pad:
+        padded tensor of dimension (batch_size, input_length).
+      ignore_id:
+        the ignored number (the padding number) in ys_pad
+
+    Returns:
+        A bool tensor of the same shape as the input tensor.
+    """
+    ys_mask = ys_pad == ignore_id
+    return ys_mask
+
+
+def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
+    """Generate a square mask for the sequence. The masked positions are
+    filled with float('-inf'). Unmasked positions are filled with float(0.0).
+    The mask can be used for masked self-attention.
+
+    For instance, if sz is 3, it returns::
+
+        tensor([[0., -inf, -inf],
+                [0., 0., -inf],
+                [0., 0., 0]])
+
+    Args:
+      sz: mask size
+
+    Returns:
+      A square mask tensor of dimension (sz, sz)
+    """
+    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
+    mask = (
+        mask.float()
+        .masked_fill(mask == 0, float("-inf"))
+        .masked_fill(mask == 1, float(0.0))
+    )
+    return mask
+
+
+def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]:
+    """Prepend sos_id to each utterance.
+
+    Args:
+      token_ids:
+        A list-of-list of token IDs. Each sublist contains
+        token IDs (e.g., word piece IDs) of an utterance.
+      sos_id:
+        The ID of the SOS token.
+
+    Return:
+      Return a new list-of-list, where each sublist starts
+      with SOS ID.
+    """
+    return [[sos_id] + utt for utt in token_ids]
+
+
+def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]:
+    """Append eos_id to each utterance.
+
+    Args:
+      token_ids:
+        A list-of-lists of token IDs. Each sublist contains
+        token IDs (e.g., word piece IDs) of an utterance.
+      eos_id:
+        The ID of the EOS token.
+
+    Return:
+      Return a new list-of-lists, where each sublist ends
+      with EOS ID.
+    """
+    return [utt + [eos_id] for utt in token_ids]
+
+
+def tolist(t: torch.Tensor) -> List[int]:
+    """Used by jit"""
+    return torch.jit.annotate(List[int], t.tolist())
diff --git a/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py b/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py
index 9dbcc9d9e..19ba8d24b 100644
--- a/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py
+++ b/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py
@@ -4,16 +4,18 @@
 """
 Convert a transcript based on words to a list of BPE ids.
 
-For example, if we use 2 as the encoding id of :
+For example, if we use 2 as the encoding id of 
+Note: it, inserts a space token before each 
 
 texts = ['this is a  day']
-spm_ids = [[38, 33, 6, 2, 316]]
+spm_ids = [[38, 33, 6, 15, 2, 316]]
 
 texts = [' this is a sunny day']
-spm_ids = [[2, 38, 33, 6, 118, 11, 11, 21, 316]]
+spm_ids = [[15, 2, 38, 33, 6, 118, 11, 11, 21, 316]]
 
 texts = ['']
-spm_ids = [[2]]
+spm_ids = [[15, 2]]
+
 """
 
 import argparse
@@ -38,29 +40,27 @@ def get_args():
 
 def convert_texts_into_ids(
     texts: List[str],
-    unk_id: int,
     sp: spm.SentencePieceProcessor,
 ) -> List[List[int]]:
     """
     Args:
       texts:
         A string list of transcripts, such as ['Today is Monday', 'It's sunny'].
-      unk_id:
-        A number id for the token ''.
+      sp:
+        A sentencepiece BPE model.
     Returns:
       Return an integer list of bpe ids.
     """
     y = []
     for text in texts:
-        y_ids = []
         if "" in text:
-            text_segments = text.split("")
-            id_segments = sp.encode(text_segments, out_type=int)
+            id_segments = sp.encode(text.split(""), out_type=int)
+
+            y_ids = []
             for i in range(len(id_segments)):
-                if i != len(id_segments) - 1:
-                    y_ids.extend(id_segments[i] + [unk_id])
-                else:
-                    y_ids.extend(id_segments[i])
+                y_ids += id_segments[i]
+                if i < len(id_segments) - 1:
+                    y_ids += [sp.piece_to_id("▁"), sp.unk_id()]
         else:
             y_ids = sp.encode(text, out_type=int)
         y.append(y_ids)
@@ -70,19 +70,13 @@ def convert_texts_into_ids(
 
 def main():
     args = get_args()
-    texts = args.texts
-    bpe_model = args.bpe_model
 
     sp = spm.SentencePieceProcessor()
-    sp.load(bpe_model)
-    unk_id = sp.piece_to_id("")
+    sp.load(args.bpe_model)
 
-    y = convert_texts_into_ids(
-        texts=texts,
-        unk_id=unk_id,
-        sp=sp,
-    )
-    logging.info(f"The input texts: {texts}")
+    y = convert_texts_into_ids(texts=args.texts, sp=sp)
+
+    logging.info(f"The input texts: {args.texts}")
     logging.info(f"The encoding ids: {y}")
 
 
diff --git a/egs/tedlium3/ASR/local/convert_transcript_words_to_tokens.py b/egs/tedlium3/ASR/local/convert_transcript_words_to_tokens.py
deleted file mode 120000
index 2ce13fd69..000000000
--- a/egs/tedlium3/ASR/local/convert_transcript_words_to_tokens.py
+++ /dev/null
@@ -1 +0,0 @@
-../../../librispeech/ASR/local/convert_transcript_words_to_tokens.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/local/generate_unique_lexicon.py b/egs/tedlium3/ASR/local/generate_unique_lexicon.py
deleted file mode 120000
index c0aea1403..000000000
--- a/egs/tedlium3/ASR/local/generate_unique_lexicon.py
+++ /dev/null
@@ -1 +0,0 @@
-../../../librispeech/ASR/local/generate_unique_lexicon.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/local/prepare_lang.py b/egs/tedlium3/ASR/local/prepare_lang.py
deleted file mode 120000
index 747f2ab39..000000000
--- a/egs/tedlium3/ASR/local/prepare_lang.py
+++ /dev/null
@@ -1 +0,0 @@
-../../../librispeech/ASR/local/prepare_lang.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/local/prepare_lexicon.py b/egs/tedlium3/ASR/local/prepare_lexicon.py
deleted file mode 100755
index b9160b6d4..000000000
--- a/egs/tedlium3/ASR/local/prepare_lexicon.py
+++ /dev/null
@@ -1,94 +0,0 @@
-#!/usr/bin/env python3
-# Copyright    2022  Xiaomi Corp.        (authors: Mingshuang Luo)
-#
-# See ../../../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-"""
-This script takes as input supervisions json dir "data/manifests"
-consisting of supervisions_train.json and does the following:
-
-1. Generate lexicon_words.txt.
-
-"""
-import argparse
-import logging
-from pathlib import Path
-
-import lhotse
-
-
-def get_args():
-    parser = argparse.ArgumentParser()
-    parser.add_argument(
-        "--manifests-dir",
-        type=str,
-        help="""Input directory.
-        """,
-    )
-    parser.add_argument(
-        "--lang-dir",
-        type=str,
-        help="""Output directory.
-        """,
-    )
-
-    return parser.parse_args()
-
-
-def prepare_lexicon(manifests_dir: str, lang_dir: str):
-    """
-    Args:
-      manifests_dir:
-        The manifests directory, e.g., data/manifests.
-      lang_dir:
-        The language directory, e.g., data/lang_phone.
-
-    Return:
-      The lexicon_words.txt file.
-    """
-    words = set()
-
-    lexicon = Path(lang_dir) / "lexicon_words.txt"
-    sups = lhotse.load_manifest(f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz")
-    for s in sups:
-        # list the words units and filter the empty item
-        words_list = list(filter(None, s.text.split()))
-
-        for word in words_list:
-            if word not in words and word != "":
-                words.add(word)
-
-    with open(lexicon, "w") as f:
-        for word in sorted(words):
-            f.write(word + "  " + word)
-            f.write("\n")
-
-
-def main():
-    args = get_args()
-    manifests_dir = Path(args.manifests_dir)
-    lang_dir = Path(args.lang_dir)
-
-    logging.info("Generating lexicon_words.txt")
-    prepare_lexicon(manifests_dir, lang_dir)
-
-
-if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-
-    logging.basicConfig(format=formatter, level=logging.INFO)
-
-    main()
diff --git a/egs/tedlium3/ASR/local/prepare_transcripts.py b/egs/tedlium3/ASR/local/prepare_transcripts.py
index 7ea4e89a4..d4ccdd1e3 100755
--- a/egs/tedlium3/ASR/local/prepare_transcripts.py
+++ b/egs/tedlium3/ASR/local/prepare_transcripts.py
@@ -1,5 +1,6 @@
 #!/usr/bin/env python3
-# Copyright    2021  Xiaomi Corp.        (authors: Mingshuang Luo)
+# Copyright    2021  Xiaomi Corp.        (author: Mingshuang Luo)
+# Copyright    2022  Behavox LLC.        (author: Daniil Kulko)
 #
 # See ../../../../LICENSE for clarification regarding multiple authors
 #
@@ -17,68 +18,67 @@
 
 
 """
-This script takes as input supervisions json dir "data/manifests"
-consisting of supervisions_train.json and does the following:
-
-1. Generate train.text.
+This script takes input text file and removes all words
+that iclude any character out of English alphabet.
 
 """
 import argparse
 import logging
+import re
 from pathlib import Path
 
-import lhotse
-
 
 def get_args():
     parser = argparse.ArgumentParser()
     parser.add_argument(
-        "--manifests-dir",
+        "--input-text-path",
         type=str,
-        help="""Input directory.
-        """,
+        help="Input text file path.",
     )
     parser.add_argument(
-        "--lang-dir",
+        "--output-text-path",
         type=str,
-        help="""Output directory.
-        """,
+        help="Output text file path.",
     )
 
     return parser.parse_args()
 
 
-def prepare_transcripts(manifests_dir: str, lang_dir: str):
+def prepare_transcripts(input_text_path: Path, output_text_path: Path) -> None:
     """
     Args:
-      manifests_dir:
-        The manifests directory, e.g., data/manifests.
-      lang_dir:
-        The language directory, e.g., data/lang_phone.
+      input_text_path:
+        The input data text file path, e.g., data/lang/train_orig.txt.
+      output_text_path:
+        The output data text file path, e.g., data/lang/train.txt.
 
     Return:
-      The train.text in lang_dir.
+      Saved text file in output_text_path.
     """
-    texts = []
 
-    train_text = Path(lang_dir) / "train.text"
-    sups = lhotse.load_manifest(f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz")
-    for s in sups:
-        texts.append(s.text)
+    foreign_chr_check = re.compile(r"[^a-z']")
 
-    with open(train_text, "w") as f:
-        for text in texts:
-            f.write(text)
-            f.write("\n")
+    logging.info(f"Loading {input_text_path.name}")
+    with open(input_text_path, "r", encoding="utf8") as f:
+        texts = {t.rstrip("\n") for t in f}
+
+    texts = {
+        " ".join([w for w in t.split() if foreign_chr_check.search(w) is None])
+        for t in texts
+    }
+
+    with open(output_text_path, "w+", encoding="utf8") as f:
+        for t in texts:
+            f.write(f"{t}\n")
 
 
-def main():
+def main() -> None:
     args = get_args()
-    manifests_dir = Path(args.manifests_dir)
-    lang_dir = Path(args.lang_dir)
+    input_text_path = Path(args.input_text_path)
+    output_text_path = Path(args.output_text_path)
 
-    logging.info("Generating train.text")
-    prepare_transcripts(manifests_dir, lang_dir)
+    logging.info(f"Generating {output_text_path.name}")
+    prepare_transcripts(input_text_path, output_text_path)
 
 
 if __name__ == "__main__":
diff --git a/egs/tedlium3/ASR/local/prepare_words.py b/egs/tedlium3/ASR/local/prepare_words.py
new file mode 100755
index 000000000..a37d0f08f
--- /dev/null
+++ b/egs/tedlium3/ASR/local/prepare_words.py
@@ -0,0 +1,83 @@
+#!/usr/bin/env python3
+# Copyright    2022  Behavox LLC.        (authors: Daniil Kulko)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This script takes as input supervisions json dir "data/manifests"
+consisting of tedlium_supervisions_train.json and does the following:
+
+1. Generate words.txt.
+
+"""
+import argparse
+import logging
+import re
+from pathlib import Path
+
+
+def get_args():
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        help="Output directory.",
+    )
+
+    return parser.parse_args()
+
+
+def prepare_words(lang_dir: str) -> None:
+    """
+    Args:
+      lang_dir:
+        The language directory, e.g., data/lang.
+
+    Return:
+      The words.txt file.
+    """
+
+    words_orig_path = Path(lang_dir) / "words_orig.txt"
+    words_path = Path(lang_dir) / "words.txt"
+
+    foreign_chr_check = re.compile(r"[^a-z']")
+
+    logging.info(f"Loading {words_orig_path.name}")
+    with open(words_orig_path, "r", encoding="utf8") as f:
+        words = {w for w_compl in f for w in w_compl.strip("-\n").split("_")}
+    words = {w for w in words if foreign_chr_check.search(w) is None and w != ""}
+    words.add("")
+    words = ["", "!SIL"] + sorted(words) + ["#0", "", ""]
+
+    with open(words_path, "w+", encoding="utf8") as f:
+        for idx, word in enumerate(words):
+            f.write(f"{word} {idx}\n")
+
+
+def main() -> None:
+    args = get_args()
+    lang_dir = Path(args.lang_dir)
+
+    logging.info("Generating words.txt")
+    prepare_words(lang_dir)
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+
+    main()
diff --git a/egs/tedlium3/ASR/local/test_prepare_lang.py b/egs/tedlium3/ASR/local/test_prepare_lang.py
deleted file mode 120000
index f0f864998..000000000
--- a/egs/tedlium3/ASR/local/test_prepare_lang.py
+++ /dev/null
@@ -1 +0,0 @@
-../../../librispeech/ASR/local/test_prepare_lang.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/prepare.sh b/egs/tedlium3/ASR/prepare.sh
index 272cf7aed..3d90436ff 100755
--- a/egs/tedlium3/ASR/prepare.sh
+++ b/egs/tedlium3/ASR/prepare.sh
@@ -5,7 +5,6 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
 
 set -eou pipefail
 
-nj=15
 stage=0
 stop_stage=100
 
@@ -63,6 +62,13 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
     mv $dl_dir/TEDLIUM_release-3 $dl_dir/tedlium3
   fi
 
+  # Download big and small 4 gram lanuage models
+  if [ ! -d $dl_dir/lm ]; then
+    wget --continue http://kaldi-asr.org/models/5/4gram_small.arpa.gz -P $dl_dir/lm
+    wget --continue http://kaldi-asr.org/models/5/4gram_big.arpa.gz -P $dl_dir/lm
+    gzip -d $dl_dir/lm/4gram_small.arpa.gz $dl_dir/lm/4gram_big.arpa.gz
+  fi
+
   # If you have pre-downloaded it to /path/to/musan,
   # you can create a symlink
   #
@@ -100,7 +106,14 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
 
   if [ ! -e data/fbank/.tedlium3.done ]; then
     mkdir -p data/fbank
+
     python3 ./local/compute_fbank_tedlium.py
+
+    gunzip -c data/fbank/tedlium_cuts_train.jsonl.gz | shuf | \
+    gzip -c > data/fbank/tedlium_cuts_train-shuf.jsonl.gz
+    mv data/fbank/tedlium_cuts_train-shuf.jsonl.gz \
+       data/fbank/tedlium_cuts_train.jsonl.gz
+
     touch data/fbank/.tedlium3.done
   fi
 fi
@@ -115,28 +128,24 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
 fi
 
 if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
-  log "Stage 5: Prepare phone based lang"
-  lang_dir=data/lang_phone
+  log "Stage 5: Prepare BPE train data and set of words"
+  lang_dir=data/lang
   mkdir -p $lang_dir
 
-  if [ ! -f $lang_dir/train.text ]; then
+  if [ ! -f $lang_dir/train.txt ]; then
+    gunzip -c $dl_dir/tedlium3/LM/*.en.gz | sed 's: <\/s>::g' > $lang_dir/train_orig.txt
+
     ./local/prepare_transcripts.py \
-      --lang-dir $lang_dir \
-      --manifests-dir data/manifests
+      --input-text-path $lang_dir/train_orig.txt \
+      --output-text-path $lang_dir/train.txt
   fi
 
-  if [ ! -f $lang_dir/lexicon_words.txt ]; then
-    ./local/prepare_lexicon.py \
-      --lang-dir $lang_dir \
-      --manifests-dir data/manifests
-  fi
+  if [ ! -f $lang_dir/words.txt ]; then
 
-  (echo '!SIL SIL'; echo ' '; ) |
-    cat - $lang_dir/lexicon_words.txt |
-    sort | uniq > $lang_dir/lexicon.txt
+    awk '{print $1}' $dl_dir/tedlium3/TEDLIUM.152k.dic |
+    sed 's:([0-9])::g' | sort | uniq > $lang_dir/words_orig.txt
 
-  if [ ! -f $lang_dir/L_disambig.pt ]; then
-    ./local/prepare_lang.py --lang-dir $lang_dir
+    ./local/prepare_words.py --lang-dir $lang_dir
   fi
 fi
 
@@ -148,25 +157,56 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
     mkdir -p $lang_dir
     # We reuse words.txt from phone based lexicon
     # so that the two can share G.pt later.
-    cp data/lang_phone/words.txt $lang_dir
-
-    if [ ! -f $lang_dir/transcript_words.txt ]; then
-      log "Generate data for BPE training"
-      cat data/lang_phone/train.text |
-      cut -d " " -f 2- > $lang_dir/transcript_words.txt
-      # remove the  for transcript_words.txt
-      sed -i 's/ //g' $lang_dir/transcript_words.txt
-      sed -i 's/ //g' $lang_dir/transcript_words.txt
-      sed -i 's///g' $lang_dir/transcript_words.txt
-    fi
+    cp data/lang/words.txt $lang_dir
 
     ./local/train_bpe_model.py \
       --lang-dir $lang_dir \
       --vocab-size $vocab_size \
-      --transcript $lang_dir/transcript_words.txt
+      --transcript data/lang/train.txt
 
     if [ ! -f $lang_dir/L_disambig.pt ]; then
-      ./local/prepare_lang_bpe.py --lang-dir $lang_dir
+      ./local/prepare_lang_bpe.py --lang-dir $lang_dir --oov ""
+    fi
+  done
+fi
+
+if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
+  log "Stage 7: Prepare G"
+  # We assume you have install kaldilm, if not, please install
+  # it using: pip install kaldilm
+
+  mkdir -p data/lm
+  if [ ! -f data/lm/G_4_gram_small.fst.txt ]; then
+    # It is used in building HLG
+    python3 -m kaldilm \
+      --read-symbol-table="data/lang/words.txt" \
+      --disambig-symbol='#0' \
+      --max-order=4 \
+      --max-arpa-warnings=-1 \
+      $dl_dir/lm/4gram_small.arpa > data/lm/G_4_gram_small.fst.txt
+  fi
+
+  if [ ! -f data/lm/G_4_gram_big.fst.txt ]; then
+    # It is used for LM rescoring
+    python3 -m kaldilm \
+      --read-symbol-table="data/lang/words.txt" \
+      --disambig-symbol='#0' \
+      --max-order=4 \
+      --max-arpa-warnings=-1 \
+      $dl_dir/lm/4gram_big.arpa > data/lm/G_4_gram_big.fst.txt
+  fi
+fi
+
+if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
+  log "Stage 8: Compile HLG"
+
+  for vocab_size in ${vocab_sizes[@]}; do
+    lang_dir=data/lang_bpe_${vocab_size}
+
+    if [ ! -f $lang_dir/HLG.pt ]; then
+      ./local/compile_hlg.py \
+        --lang-dir $lang_dir \
+        --lm G_4_gram_small
     fi
   done
 fi
diff --git a/icefall/decode.py b/icefall/decode.py
index 68e490c5e..23f9fb9b3 100644
--- a/icefall/decode.py
+++ b/icefall/decode.py
@@ -466,9 +466,7 @@ def one_best_decoding(
     Return:
       An FsaVec containing linear paths.
     """
-
     if lm_scale_list is not None:
-
         ans = dict()
         saved_am_scores = lattice.scores - lattice.lm_scores
         for lm_scale in lm_scale_list:
diff --git a/test/test_lexicon.py b/test/test_lexicon.py
index 69867efc7..b1beab3f6 100755
--- a/test/test_lexicon.py
+++ b/test/test_lexicon.py
@@ -112,7 +112,7 @@ def uniq_lexicon_test():
     # But there is no word "ca" in the lexicon, so our
     # implementation returns the id of ""
     print(token_ids, expected_token_ids)
-    assert token_ids.tolist() == [[sp.unk_id()]]
+    assert token_ids.tolist() == [[sp.piece_to_id("▁"), sp.unk_id()]]
 
     # case 3: With OOV
     texts = ["foo"]

From fbc88948044278b687a57309248eb2ae6df0a415 Mon Sep 17 00:00:00 2001
From: Fangjun Kuang 
Date: Wed, 14 Dec 2022 13:47:23 +0800
Subject: [PATCH 076/120] Add comment for compile_hlg_using_openfst.py (#762)

---
 egs/librispeech/ASR/prepare.sh | 11 +++++++++--
 1 file changed, 9 insertions(+), 2 deletions(-)

diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh
index 11c8e1066..59bed8389 100755
--- a/egs/librispeech/ASR/prepare.sh
+++ b/egs/librispeech/ASR/prepare.sh
@@ -302,13 +302,20 @@ fi
 if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
   log "Stage 9: Compile HLG"
   ./local/compile_hlg.py --lang-dir data/lang_phone
-  ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone
+
+  # Note If ./local/compile_hlg.py throws OOM,
+  # please switch to the following command
+  #
+  # ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone
 
   for vocab_size in ${vocab_sizes[@]}; do
     lang_dir=data/lang_bpe_${vocab_size}
     ./local/compile_hlg.py --lang-dir $lang_dir
 
-    ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir
+    # Note If ./local/compile_hlg.py throws OOM,
+    # please switch to the following command
+    #
+    # ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir
   done
 fi
 

From ad475ec10dec864373099ba541cad5f743a4726b Mon Sep 17 00:00:00 2001
From: Wei Kang 
Date: Thu, 15 Dec 2022 19:07:28 +0800
Subject: [PATCH 077/120] Add documents for pruned_transducer_stateless (#526)

* begin to add documents for pruned_transducer_stateless

* Move lstm docs to Streaming folder

* Add documents for pruned transducer stateless models

* Move zipformer mmi to non-streaming recipe

* Add more docs for streaming decoding

* Fix typo
---
 docs/source/index.rst                         |   8 +
 .../aishell/conformer_ctc.rst                 |   0
 .../aishell-conformer-ctc-tensorboard-log.jpg | Bin
 .../aishell-tdnn-lstm-ctc-tensorboard-log.jpg | Bin
 ...cer_stateless_modified-tensorboard-log.png | Bin
 .../{ => Non-streaming-ASR}/aishell/index.rst |   0
 .../aishell/stateless_transducer.rst          |   0
 .../aishell/tdnn_lstm_ctc.rst                 |   0
 .../recipes/Non-streaming-ASR/index.rst       |  10 +
 .../librispeech/conformer_ctc.rst             |   0
 ...rispeech-conformer-ctc-tensorboard-log.png | Bin
 ...eech-pruned-transducer-tensorboard-log.jpg | Bin 0 -> 566971 bytes
 .../librispeech/index.rst                     |   1 +
 .../pruned_transducer_stateless.rst           | 545 +++++++++++++
 .../librispeech/tdnn_lstm_ctc.rst             |   0
 .../librispeech/zipformer_mmi.rst             |   0
 .../{ => Non-streaming-ASR}/timit/index.rst   |   0
 .../timit/tdnn_ligru_ctc.rst                  |   0
 .../timit/tdnn_lstm_ctc.rst                   |   0
 .../yesno/images/tdnn-tensorboard-log.png     | Bin
 .../{ => Non-streaming-ASR}/yesno/index.rst   |   0
 .../{ => Non-streaming-ASR}/yesno/tdnn.rst    |   0
 docs/source/recipes/Streaming-ASR/index.rst   |  12 +
 .../recipes/Streaming-ASR/introduction.rst    |  52 ++
 ...speech-lstm-transducer-tensorboard-log.png | Bin
 ...eech-pruned-transducer-tensorboard-log.jpg | Bin 0 -> 560358 bytes
 .../Streaming-ASR/librispeech/index.rst       |   9 +
 .../lstm_pruned_stateless_transducer.rst      |   0
 .../pruned_transducer_stateless.rst           | 735 ++++++++++++++++++
 docs/source/recipes/index.rst                 |   6 +-
 30 files changed, 1374 insertions(+), 4 deletions(-)
 rename docs/source/recipes/{ => Non-streaming-ASR}/aishell/conformer_ctc.rst (100%)
 rename docs/source/recipes/{ => Non-streaming-ASR}/aishell/images/aishell-conformer-ctc-tensorboard-log.jpg (100%)
 rename docs/source/recipes/{ => Non-streaming-ASR}/aishell/images/aishell-tdnn-lstm-ctc-tensorboard-log.jpg (100%)
 rename docs/source/recipes/{ => Non-streaming-ASR}/aishell/images/aishell-transducer_stateless_modified-tensorboard-log.png (100%)
 rename docs/source/recipes/{ => Non-streaming-ASR}/aishell/index.rst (100%)
 rename docs/source/recipes/{ => Non-streaming-ASR}/aishell/stateless_transducer.rst (100%)
 rename docs/source/recipes/{ => Non-streaming-ASR}/aishell/tdnn_lstm_ctc.rst (100%)
 create mode 100644 docs/source/recipes/Non-streaming-ASR/index.rst
 rename docs/source/recipes/{ => Non-streaming-ASR}/librispeech/conformer_ctc.rst (100%)
 rename docs/source/recipes/{ => Non-streaming-ASR}/librispeech/images/librispeech-conformer-ctc-tensorboard-log.png (100%)
 create mode 100644 docs/source/recipes/Non-streaming-ASR/librispeech/images/librispeech-pruned-transducer-tensorboard-log.jpg
 rename docs/source/recipes/{ => Non-streaming-ASR}/librispeech/index.rst (82%)
 create mode 100644 docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst
 rename docs/source/recipes/{ => Non-streaming-ASR}/librispeech/tdnn_lstm_ctc.rst (100%)
 rename docs/source/recipes/{ => Non-streaming-ASR}/librispeech/zipformer_mmi.rst (100%)
 rename docs/source/recipes/{ => Non-streaming-ASR}/timit/index.rst (100%)
 rename docs/source/recipes/{ => Non-streaming-ASR}/timit/tdnn_ligru_ctc.rst (100%)
 rename docs/source/recipes/{ => Non-streaming-ASR}/timit/tdnn_lstm_ctc.rst (100%)
 rename docs/source/recipes/{ => Non-streaming-ASR}/yesno/images/tdnn-tensorboard-log.png (100%)
 rename docs/source/recipes/{ => Non-streaming-ASR}/yesno/index.rst (100%)
 rename docs/source/recipes/{ => Non-streaming-ASR}/yesno/tdnn.rst (100%)
 create mode 100644 docs/source/recipes/Streaming-ASR/index.rst
 create mode 100644 docs/source/recipes/Streaming-ASR/introduction.rst
 rename docs/source/recipes/{ => Streaming-ASR}/librispeech/images/librispeech-lstm-transducer-tensorboard-log.png (100%)
 create mode 100644 docs/source/recipes/Streaming-ASR/librispeech/images/streaming-librispeech-pruned-transducer-tensorboard-log.jpg
 create mode 100644 docs/source/recipes/Streaming-ASR/librispeech/index.rst
 rename docs/source/recipes/{ => Streaming-ASR}/librispeech/lstm_pruned_stateless_transducer.rst (100%)
 create mode 100644 docs/source/recipes/Streaming-ASR/librispeech/pruned_transducer_stateless.rst

diff --git a/docs/source/index.rst b/docs/source/index.rst
index be9977ca9..4ea446259 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -22,6 +22,14 @@ speech recognition recipes using `k2 `_.
 
    installation/index
    model-export/index
+
+.. toctree::
+   :maxdepth: 3
+
    recipes/index
+
+.. toctree::
+   :maxdepth: 2
+
    contributing/index
    huggingface/index
diff --git a/docs/source/recipes/aishell/conformer_ctc.rst b/docs/source/recipes/Non-streaming-ASR/aishell/conformer_ctc.rst
similarity index 100%
rename from docs/source/recipes/aishell/conformer_ctc.rst
rename to docs/source/recipes/Non-streaming-ASR/aishell/conformer_ctc.rst
diff --git a/docs/source/recipes/aishell/images/aishell-conformer-ctc-tensorboard-log.jpg b/docs/source/recipes/Non-streaming-ASR/aishell/images/aishell-conformer-ctc-tensorboard-log.jpg
similarity index 100%
rename from docs/source/recipes/aishell/images/aishell-conformer-ctc-tensorboard-log.jpg
rename to docs/source/recipes/Non-streaming-ASR/aishell/images/aishell-conformer-ctc-tensorboard-log.jpg
diff --git a/docs/source/recipes/aishell/images/aishell-tdnn-lstm-ctc-tensorboard-log.jpg b/docs/source/recipes/Non-streaming-ASR/aishell/images/aishell-tdnn-lstm-ctc-tensorboard-log.jpg
similarity index 100%
rename from docs/source/recipes/aishell/images/aishell-tdnn-lstm-ctc-tensorboard-log.jpg
rename to docs/source/recipes/Non-streaming-ASR/aishell/images/aishell-tdnn-lstm-ctc-tensorboard-log.jpg
diff --git a/docs/source/recipes/aishell/images/aishell-transducer_stateless_modified-tensorboard-log.png b/docs/source/recipes/Non-streaming-ASR/aishell/images/aishell-transducer_stateless_modified-tensorboard-log.png
similarity index 100%
rename from docs/source/recipes/aishell/images/aishell-transducer_stateless_modified-tensorboard-log.png
rename to docs/source/recipes/Non-streaming-ASR/aishell/images/aishell-transducer_stateless_modified-tensorboard-log.png
diff --git a/docs/source/recipes/aishell/index.rst b/docs/source/recipes/Non-streaming-ASR/aishell/index.rst
similarity index 100%
rename from docs/source/recipes/aishell/index.rst
rename to docs/source/recipes/Non-streaming-ASR/aishell/index.rst
diff --git a/docs/source/recipes/aishell/stateless_transducer.rst b/docs/source/recipes/Non-streaming-ASR/aishell/stateless_transducer.rst
similarity index 100%
rename from docs/source/recipes/aishell/stateless_transducer.rst
rename to docs/source/recipes/Non-streaming-ASR/aishell/stateless_transducer.rst
diff --git a/docs/source/recipes/aishell/tdnn_lstm_ctc.rst b/docs/source/recipes/Non-streaming-ASR/aishell/tdnn_lstm_ctc.rst
similarity index 100%
rename from docs/source/recipes/aishell/tdnn_lstm_ctc.rst
rename to docs/source/recipes/Non-streaming-ASR/aishell/tdnn_lstm_ctc.rst
diff --git a/docs/source/recipes/Non-streaming-ASR/index.rst b/docs/source/recipes/Non-streaming-ASR/index.rst
new file mode 100644
index 000000000..67123a648
--- /dev/null
+++ b/docs/source/recipes/Non-streaming-ASR/index.rst
@@ -0,0 +1,10 @@
+Non Streaming ASR
+=================
+
+.. toctree::
+   :maxdepth: 2
+
+   aishell/index
+   librispeech/index
+   timit/index
+   yesno/index
diff --git a/docs/source/recipes/librispeech/conformer_ctc.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/conformer_ctc.rst
similarity index 100%
rename from docs/source/recipes/librispeech/conformer_ctc.rst
rename to docs/source/recipes/Non-streaming-ASR/librispeech/conformer_ctc.rst
diff --git a/docs/source/recipes/librispeech/images/librispeech-conformer-ctc-tensorboard-log.png b/docs/source/recipes/Non-streaming-ASR/librispeech/images/librispeech-conformer-ctc-tensorboard-log.png
similarity index 100%
rename from docs/source/recipes/librispeech/images/librispeech-conformer-ctc-tensorboard-log.png
rename to docs/source/recipes/Non-streaming-ASR/librispeech/images/librispeech-conformer-ctc-tensorboard-log.png
diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/images/librispeech-pruned-transducer-tensorboard-log.jpg b/docs/source/recipes/Non-streaming-ASR/librispeech/images/librispeech-pruned-transducer-tensorboard-log.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..800835749a7c7b2ca62398ba2e75d3091d06a6a7
GIT binary patch
literal 566971
zcmeFabzD_l_b<9Od(+)r(j`cDN=uit2uOFgf`A|)AgF+JBOpjicZf(!N;gQe>ATS9
zdEWQD_jm8-oO{ptoIj2e=UU%A)|fHonlZ+hb8m0H-^>FzN^**F00aU6ufRXxW&v=Q
z^|pBe04gc~D*yl}01QF~Ac6>F0mLEX|6n-?GXVWXM*x6u8vynfk1BY-4Pb7+()_+7
zWFY*_fi9W>{Tn0QmI2)C0D|h)PVP=_)=theJa@SOL1{%5*lmtr_=OpML7Y*_sL5EM
z0=DrK_3buXI`C!@2Srgv##CKXO-}KF>@PvE;bzXx_6Rrt;Na-)swpo`^HAS_26+{%
z4JJSUO2uzx?&d6|uKwUQ&%f`#^8e>@GXAUXzzD}}UjLB)`v9JWrJFe@yb;L$*uu@+
z0mLl;082G@c6A2;q+2?jm%H;VehOj&SFnO0etnCre&d2${P;I+_{F2GDFgC+178V-
znVG8<0AS64bQ(`{D=;6dOAvE8TG%)O05&^_#m%kFEI@o)cXkH{r(66M#LOVSGl*M2
zEcp1}^0fLlY-aZOpZv|ttp35jg5YiAuBZ)S5)i+zbd%Nkizma?N%OC=5UQWJ%1Hf1
z?=f=)(eJY6JnVFC>9_eGSX(Ool|K~A*IHW-q@w`A!efJv49YPY=>HEgj{hK@7?Q^?c$gcPslXf24yOhFtsBUXP;Fs(Ulzom?9g>fN1a=FZ}WjW
zyG!4e1*V0eIotmlCmQT|S$_v?ux
zc+~=YfgQjRRj>pHKpC(HKYM`K5=`@V
zX}Z8NNU{0n^WUYLfq9yP^|c4%?eqVd{x|hsoKm2?5B^H)@pp+d*ng)-l0~|Yq<|!c
z!~!FOiNi!-55TJ=7`b7BFrmNY_!mFoD&i30BH|e0JmRFOO|a{Kw2=xh4@&gA4K2a;
z{kOJ&F@kyW!K7hgVD4bfFd`Tkzyae0%aMXffF<*RRGGgdyls`g-^jnq^p6CmZ=dMf|(HvXys
z^ab=2v>VzA?S(c#YXBN(Gqe`^1=@Iv|0-YZ*W0!Fr!>u9^|1rri_Jegf6IIO#&65K
z?H_Jj#atU)m%npyxAb-g`!gWp7=I0F|`h@Spz|jcFL1Cct^RC(hi%)$@0Lu-h9V
zI7?uIBaIZG0_Xr{P#$i89}oc~0a-u^Py@69eZU0luhw8ccLlrvf8ZGq2D}7U<3u1C
zNC&ck0-yw_0BV2+pcUu>`hjo27%&a?o)usV*aJ?03kU>)48erpLr5Xi5C#Yvgc~9V
z5r@b^9zZl9`jAHuD~Kb+1L6+}fxLvoLy{qxkU~f~qz=*w>4gkICLs%u4afoH907`e
zfk1#jiNJ`!i6Dp|g`kX}jbMymh2VnVi|`yF2H`D27QzRFYJ^sVeuOcEd4w&5V<-Sc
zhY~?)p&U>_s0>sMY5;u#b%h2(BcVyqENCgT9(>Ou(0S-K^bCds!-vtpIKUpP0Mmh)
z!(3p2uqaqEtN>OG>wt~G7GQg@YeY;$3Pd(UVMIm5hlo~)UWj3cNr-ufRfwJ7SXx0m
zK|(?zLSjM^0LP*}k}Z-SQWVmAq%x!yq#>jwq+?`cWKv`{WHDqlWHV%UC`l+qC`~BCC~GJes5qz$sKTgfs1~R`sIjOysP(9W
zs4J)!Xt-!hXkuvEXtrq2(B7hzqIIIppdF&4qu)UnLf1gIL4S(=7X2f75BdW783rB(
zD~1e)5rzjw3`Rai3&teIAtoj!Bc>##0j4`<3}zu_JLViF919=oE|wye1y&GN3RX4N
zFxDJoYrrVus~&J#h27>N{#?1^HDDu~93&WS0BrHCzwBZy0hhlx)~$VntgEJ-3s
z%1FjY&Pi!VWl8Nw<49{rXUSk>tYjKwUSw%xon+hO1mq&*=H!v&735PC2o$UoniPH%
z*%V(Wjwz`qy-7nzBSm9R^OmNA
zW{;MFR)N-oHk0-%?Zq9YJ34oQ@08t{r9-0=q<9oe(kCpa)Tq&Pe|ia8eU65V}pH~4PN-Ca(4P6N(3
z&Th_YE2docVAJnTLs3@rP+e`>Yp``G}JXxG*&h5Y5Hq+YvF5|Y87dnXiICyYESDh>v-t2>0;{|=@#mq>fP6i
z*PDOH`7q#Nzdo71jee~GvVpEap24x9tYMetn9;Z~ld-pPuL+rny-A}fwyBxv
zr$>m7^d1#Gx-?TW%Q8EDtoS(f@t(Pi`5W^s3rUOD78{mgmhqNrPsE`B}v&*-;vDdRNb3k!;)0efNd;%fv?#Jh;=NQ=n6ms~HCUS39-NA*SDi+&q@
z6Z0hIYb<|k+AE}2j<3e!B;pF$P2yUKUWpD)lX=qMB@EG(iZiY-Pc_9))?VD#Zz$^DYrQjXHh
zGNQ7VAAye^A2-V%m5)~_SG0W+{8U=WP?=hVUlmymsrISftFf+`uhp*|s#B_Ks~4%S
z`h53uegkboN+V%oY!h13^QN0--{#{M=a%hOtJbA9)3)jMhwY;s8XaFdl{@!T8q<5=1W`4Zp)V|&sQ;5lh&x$3fH;U8#d%NzHJ(Au53APU2KQ#;O?aEGVfOY
zl>GT+@8RCczU%(YLDV7XVg3>CQQNWl@!W~S$;D|ToD^PgCV19;u6Mq6;dP08`SyzC
zs{UH}dgjLA=Em8~)$ErpaBFuXJOQorgM0wM)CU0kZ=n8Q_}doz#R2(kzk(R@SG*8OAqyzxudH}d`y}7w~e{*w{1KJal0MO#{+XlU5M#2Gr46$1@Mq=^w_WtkTw|xpy
z{+97S@1W#7JiNC_{+s(v3xJD)#ELEoh1>xUa3N4!$V~@84b}}2RA|8PcOei2C=3w^
z83h#$9b~A&0T3WiC;|+Mh9TBzz4s=*3s1iRXKAD%O_UWHnwi=
z9-dy_KE5H(L&L&fL`23XyiQDd^Y&eGR(4KqUVcGgQN^dqs_L5By85>Ej?S*`p5DIU
zkoU4+k9;9p`^I+{}Uws>+)w00RmEI};QakN{3i;R<{~XFqR%7*4px4S9yYtu=MhpAgE2M_y!mZx^xL@D=)~tavrz=*17YqkvR{s
zkCj=$ER=%au%MpF8{m0}-VKm;1Hcz=fKckaSGG66;g1`D!gTM0E%Oav#1nLmQFsG9
z>%VOOBi)~O_y05${`Aa0K_~7kB#uhM)+eR{IL=K*a&}Ygg-XI9~t=?;xyadl>rY|C2Av7O|
zQePTV?-gFt1bt%z&quc20G&hE#ZQ8c&_QBf*9?UJ
z>5(<@Yo>$4!Dq*y-53=crKr5KN2Omex_R5@H-x_zMdBJ8)*M~FtkmPXTf;+f{~2;K
zW6^tE_)!A&Q*i1zVcEOBqqN(Jtjs;|Cw8L+w)V(8-R7KI3Q>_9bSAr
z6Lb_1poXgRrlwHtdb@D6e*Mg*PkA?9#g~fTjUZd82>yCw@}yyj_%v@sR5Oyf)O`DY
zY}R1Yr9?NZjj*a2DP@T6(n(^nq#MQIvRzQs9A;fs%`yM*KB|bdbL{=>>n7Q4Y$ryZ
zYFXy*VtdZEjSsR$KCyDE7$XvBD459-s6+nCNcggNj^xeuP%2OC+jH#eM)w7-;Svq4
zSgp226PoPA`L(FBN>5|M^*3dgc#YRZ58pbdB1_LTE46O*W<_2hqlSc%9JEq*7%6v>
zG)&;^lN6d(Y!FUzZZu`e66k3hf>(9Kmv)6ErvCDqB4k?*cb7iL`Z~*NY#8Z0tg7?V
ze_jZW)i9B>N$hIy$=;{V_x4(3@Xu?o9#2R_71KL7>pnIPP+g3)aAhy3pzZtElRwc~
zIr2%OV&nH5bSx!Hpvlgh6$#u7|L61rT~`YYB!}NChP&F8y)Y>IUPbG=8L>ukh;pCR
zrSb;Ab+nrm71@)Y*Lz8qY>GKISin$+P8+%WKDQ3IY}_}6PRvMDNL&W{%ijRTJ2$`)
zaoH#wpFFg4(dH4GnN5|IkKvrBN_vvZ95UhacAf5_PHolAq2~J1VHP8A=!r#DnzCKA
zcQ-q`#^@|XF-t4*%C_AKHp$ewt_kyKi71c26mLPEWle(T&d`a_5+8D8(6SZy)t{|{
zr5iww&u^kNHuzjEk~iYgZ7Cv(jIik-F6^MAnE1Tgi{gV_uwW}am7LK2mfMr#adgzM
z=VT7EHd(tDf%)a~Z%ClnR#l+$5ht(;Y0JR@jaePd|2+t?(JD1pmW
z5iMMrD68sl*NfUiVEFW6C+OW?C*jX@gljwlf5lNw&Z;^SF*cR&lyo8{Hq`XY_7u!e
zcHaYFD{jkEK;je9r_#>(eih1)L8dsu!JkR|aWP{EYCK<-kLb|P
zm$yS@WyM%I+{XFqf&2r?~)hJ
zxgR9^xeo5!tvI^#QQ>NSs>44^!T81ZA9D}fW|j!SZUt+A=GmFzd0?Y?|4=*ZjIO-4
zz}QWB5_#Wfg;cRi&7Z9Pk}x9Lmp5CmkIy)p_b2gT)x?Y}7GJwksBOjbY*wX%5Z!o_
zH=iyBfwIy-41bC}&S-zMYx;fWR29mOo&k!_IPxS~sgI6&OK$+3J?!nny+fTkC!z*r
zjtx912Rap|I9!M`F8V#hXCmq1-#cwC5ndSi(~Z%T8!sxStVGEBXP_b3Z3|k3(1m>w
zi{Wd_UP-S{=S!|CEy?|vxqqqB<=U|_Hx0{Iag&*|OWK$_32<;wB^0N;{!+7BT|LKe
zPhkzXD%VT+dB>rQ2j0BL8Wf(SxX3}YUqwkdL-!8AYct$+WNT-0p8?jO6cdzod&-ON
zXEq$uC~KvcF^3(Z1)hrhNzN|ZnHwOd!_ZMCCt^i*2i9khCy&*
zan!7)7&zUY$zV7CDHY4Ub_r6rbtj=tHR(Gfqc_(C8ly4!qETc=7qk}~@Rz!3s8bJL
z2>+yzIa>gred=3t9Gi7EKSvz0lSKyKME?#_Rof
z;l-Nb)`y8Eq0Fk>(WYUFHN+0*rfM>1Ms{I|XvbEla@tMIm9G;|wj2V~(yhxHJ*W|`
zzuo}Yd~9$14e1$IcfQ0K_Kox|OJA}+k1Fpbow;1gM5tAJo{l6+KZoP=A|3Cmy)oIt
zOfTzf1?_pI6_h6&v5E39^I(N;qf7D5qWyN(SFWe2C7jRXrutPl1F#-?hdg&@h;U!c
z-YuAuAWuAssh^rkW7u(++9%lK~VIW#-1!h;-XfhV`rZyu%#qYAKVy2^7`A(q3~Hit0e42MM_G{<$O}%b=}1%FY}Z7bozN8h$A?RylS?toZU>+3
zY2T+u7YVjPVQJ#TcY5qr35{9#b@d!R)4$Oj{5*D
zvac*Na%}jp%;AXu`NBIQg=vR|DDDp^f;6~s^{xm{x}*JR9(vX#Xw_^A71cAuD@|zER(m{N$D7{TDU3l@8=)7WDp^HU2r&8Zl{7
z7b^NQn-YZfycf)0CWv))hvY(vMf&d)P%WqG&6;!bg)N-rGLJeIRc8J4uqX*QUH50i
z7WGb}32{l}Efnc8u=D)sP`&BjD;rQNPU=qsm&!1(rN&nr;}dB@z}(JlK}=%)R=`>Z
zd`j=9(GfFWkZ9+OFu|(dmv56CuNn6{bVQG%E0|s3b;Z+oagnP!=XXrSeksvePn=y!5dRkq?b75;)mo`KU!!eTa2)ztO6rm3-%#4j(9$-QU0*gCE;;lk71)Hu9XNlf18V;PwFau{L5Mmwv`L#n9NP
zG1ic+yPGgPjPah0qA^or+O!dsBLrn%Oy9aIF3&gwwk^u*Jd%OBHS^0O=q>++o+{Fp~&__HT+9*
z&2v$`jF!6|jbw*vp&y|$hP`;pN96A$2dmtwjt1L*#!j`4t&TJ?&D{XD*2Mb?3GA=v
z2p*%7JRHR%-;L{86Mfu3t?{wEdH=z=Rj7cGfzj;S4v=b+JQu~f$XDByG7wtvHC!TvP0G1+CM%mDGqd@yaz1WZK
zGpEX#sr9%o@$OWyC(VM32b9pYZ^=_Fags58sdBo=gn
z3+2%K`+Nna6sMF|+>zA8ce@_=^QmXS83b
zR(T|11PzI)N;j1c91S|q1QnJt6p14iwl%VSo|!tcdx_3NUK$QYCkO1Ei+k?n+`6uc
z-$zGy#heVeqsi53o4J5tbni6dr$-zgDX>c{CuY)-?(zs>A_JS2E9zdcuKusdba08%Nbx`UlM(`W-Gfg>MU^62`8Uy+BDner`!hqY*S
z9-3&@a9^8PdS7i&sp@oW;l~y~U#T&=kgZTGVh3B~lPfTB4+AKQ3E;0A`K=1WWm8ioU
zToIUjpAY-@_F}9IO>er6`tHuyP;svCMY7x(0}J+((`dY|`!DgEWGM*SP`xb0c_LyXq*P
zVNS&eM{YxH&}RzfVWIA<0{Zc!vIQr6Ycy)3VaKCm`$scU^pb1eJ|Z-g!r9-9HzQq2
z75a-svNRYmO-%<^K7&YSAk9ao^g&zgkGt67_D;J7{FG})F9r6ChSzjQN*7qiho|Kd
z=u^VPbr1Gx2Wx2vyZO4Q$l9bsNf&MV?iSU*;h3+fa6D~Fp0xJtk8p7-Ga#|mh|k$x
z8n9_B9j+=VedGK6EG#m4ZTkkWezTp@5WTbVvuo$wnU{IgQYkALJnX}EK!ukcX!pOm
zou3eaaui=~N$a7U)#sntzn3lG!?W&iI$T*$64WEu>
zv98v>(+wr#9d?c7S#>ZpG`v5Uo*SuPVVx&Kp4qBPLUWE|Yo{#Ie>Oheg?jgi{{o%n
zwllXgk{Uwe+G3l0AzY$;j=8zUFu-7)wFj8e*GH#_sfA%!<+urGfTaib~BlloA_+JjV|E?+h?UZK640I@zWQd5LNPK
zs(QDxYi?~Bm1yCt^ExIYq=zrtGGiG$dr*=dp
z%pXZBCusC#utLY=4>j|*k~bTOi%hM5ARRlXJ$Tbo7rYWF2+Swoh;q?o9u;NB$!~c$
zA6x9eW+${0C+0#8#nQ)E9<#{0sT52eIZR%C*|VF)HW*1580EcqseOG^E{r|!w1C}x
zR$eR7Y)k9;m4E2niUO-u*iLy?}X@9D;lp9
zT9IP{>`WK(-rMB(tL`(Vn0GcdxUyobSs&u;Tng*Lje=wDXy?i~#cR5F>DuH5?(n1&
zrjNIGeXx^q>ta>ute9DojwD3Iuk9=+H%^Bf
z<6VMY0yP%U!p)m`W!${-^8AYJdlN$^y#!_h9YK$|yFxn~2F`stnh+Bca#puogPh$u
z!jm6{RHbZu>cw)Y_Ghz=ymgBPo)I55FPbYpST#(x``qJwd%lBAXt3LLs+O{SdU9Z>
zBs9rN;MDt+@8|}2dAC6&+gtg*Q#gv;o8?u#8rgVNNt(DYI?*J$)MLI?vcV2?34CwC
z4JxLZ_yE`08^FcAAQI^5F@GgRNAt9H-FG50wwLP&Kc6*Pg>AOXW5bWh?{dfPM06X;
z>qY69!iXy7RPRN

eU{ir;kIa;1i>f8Ful-YLY5%0|_>uVT`EirH0fpC<6y(>-=| z{vd}f^X;UdgnjqtnxP>H)vDFk(P-uJmfjXpQ7pS34H{NnetUN9F8Y$ESd7t;UAaBv z45Tf{ktn1i8kHJPx3NfjZ{*zS@Q{&tUdmEFLJPa#Dh{*H(aGE zX;+p!vMrhp@wR_*+f7&fq_|Ys>yAhcu=4k05$XukJj0XyoYf@McRHJHu*~;^QpRa9 zqCsY!mND_9R{8;z;F;^>W8yV+EQ0ejz8Fr<){!8J#P`WIEWRZ{B4i6Z7?!eova`m# z)3Mnj2X%cxZFGyx=Ln#RGS8cSsiM&|#P^VQrlbxHvJi(A(c<11NOH zhObBGA}3;I#0+#J5vCh3P%@A@;NXMa%I(~HC-G50>#|VV=jo}q}TVU_sGLNwolP@%&{6RGxH;}l0(QZA?BY;dgmi6i9G0!Aw=~8k^0W2<- zv0kp7(QLVF51JHT>z#Vl(~$vNf=><7+FQDLi9(f0^~KwBNfW*6{#Q;;E*{*|Qyn8p zm#lY^%#WbS!%cg;&ADxN`8fDj$LV6z9vTy;rG7+cENeTEomUIPgZGWo2k2?N>cp=P zSt&huifzv(P}k*67`hPjvc7>Vqqk>wo(?5DxlLDt^c5pNk6D@q^9>a}?xNdAkWNDj z+k#yTG~uQl-)Q3l7i>65~mFT1| z#>VW3>Wd{d{+P;A?vCL2qVax{o)rftv+)@+uX^9der&R_)vFD`QtjlC(shqRBjpaI zXcSIQR<(DP$4x2X-7j^vI5@RN`=f=2?Xq$D&on;mAlFS{v5f}`+SCVxM=qrG5((Yi z=}2JRe>2lx8>@-Dk-Qs7-fubYg!*u^K7J>?!>hcbZvKN6>+=gN!)qAp$4k*Pes6^%GNHGHH$VHNOW>ANO8ntO?vR6LJK-cFqFbkdY&E+_ULhpW5x zqWx;h#}upnxL6(Z;soA3QC@B5U1X4sCOcXhAiYMLa9WOiS*&VItY7oVSk4dq#SR4v zKCdbrZPNyZ9G9Qlfa#FGh*yz!+K+ZkmKdwr0}IM;v{#@d`1GmTf~ktXM*XK+xKF|K zcjkxorFW`HHr~FIf?karTocCV$6fe~*&Zg=;aBVqe>UYDo!}V^2)h%F!gg2{L_eEw zb{*bFBEIk z`M6gQHXU~00ludO;p+THdzP;c{WCD|YhJNM)WBak@p7oO2jR>opUYFP@7|H3*L-v^ z$r79oqp&Ivbuk17#L)txZnoxYpM;vasvW-{YYty*th^%=QgT%UyNpUqf6Pr(l^is6 zM-9mm9;}Uo5@<+*8xN9X;|5rDg)%46BC2)uFYrDuhR(VQM-|ZV8Jwz&x=-M4@>fiK za>vW2l?Z{DN_ul@Hw`smZce4W?d%Bv@+6oKa!W(P={l83#17&1tEEl(TzT8 zioUF2F4$bpZGH~!($d>s(HnjrMRjZRW)&Dg_bmoNU;xgmVdU6Uixuhs?xv+#ve!C_V1 z*LjPVCFcI}ILyK&a{E!*Z{F|v&^M9nv-95ob}5RV6@-6UQ`{AEi%FF0RvP1zb01+} z{cPRBP8*%Wiv0jwpm)g9V*?qz%{#K3a8dCFp0vxZaK`Uj#gRX99T-yy<;SU0Yb05{%;0_|LQ zqRWIKVF7b3!sO_^5=$`qG(N5WsFA~TO(tEgk4fm_Ph9WpE)z;)&%If6zn(8X_l-Zl zb34`Xo%bk>dZHDQw;|1wxSBN{Iqo>M)ja1me`s$htLysQ0h#{Tt6pzJrFz?C=RZ~C zae{J1jObfEN07B0MlFpQ0NrO>Bz}1*w!~H0QGy}+{M#?(706Ihwpxs1N;U;=wNCcTH*$@va#9zf zZ@aG!O&a&huBFllKs(hZ=!m5FBmCVWHpzNKR(@zsB9ZmO9=cG=UUJVh*6*8kK6}Sx zZly&^y7|2oDovJK-?}JdOmdioU7|PL9H*2M&&39l93Ad{$mKGT-cu4{yOZeDGe0_F z_UnsV=@x=j!s1$8))!r$EA`w`|;`FR9(y65yhn9j-RE>9}=Du{8l=PK!;WqB67|WHjK4ZMZCMLH90+psLY8Y2>)=)#*g}>bJpr zYN(1&_P^(m3gLEB3U|Erwq0e z6Gjj4(BeOE$-~GR{c>%SZM&wTQ^&LHUvqc$kW+f@LdE2jPSjo+c!Xulv+wx_Hjxa<-XA@MzFYgz5cEK03+jldG;Z6@hwxRN(f4RCjsn zPZOO*n$X;iCMUWw)qbsq9+sN+yiuD&7+HyC9T zr8of`+*LVxr$h;|4*a5ZCj1Y7(9zl9hP3P$3Rj1eiWmF#<{MSK%b|H;3Ud@6+X@tq zd?=3R{g`wpJB~hz`y)-rN;8aXquuhZa|w&FD_e=ccsD;2zxsQUuEF=y0hM)2>Y&y) zfy>a@6z9XP$!y_K6PNfpDj3BGM_xw%wB9j0uQpbrDrqcMij^Bb_6b5$?7mmT@ed4t zvAx&lltWcMJDGa4Iu$Th$+Ks9aB|~OaXRlHqhV1KqFr+9_S{xk=_ncV-KX5OJv3IO zZYB-2epp-V?5xUDd>-y~o{eRG*T`{HG##v*gu*T-Hj|T04m*AZKOzntN^-}JHta0X zXH*zHb8L-ckW_(QQ8$&^>e^55cS;w%8KV~BMvGaFXb#w>#BmMUbEfS~M4KM8=Subx zdC$^eyuvQ!0yp=X2!=dCbVSJh4g4vd~uh-A$K8SS-k}stc{zB+sS`=f82fa zkdGYwrS@?Wlhoj6h{^ump+>`SAf9)K=E(Aaqrs{+MxTQ;lt1XP%Qg| zs#`Uind%tnub&EgbhimQTZKGOB)ss+*i&oTDGi8rwG@JO2$#Pu>CZY$sx}P~{$zAz zHpU_+usU`ouiV^8P2NqBpz!M;u8s~kW&yy1vGm6VxPSi9)>BLWtZs5mZfLjL$Pp^o zGhniJ=_<1rz)z~3ox9B1atu>|AT?#cYs`>e_zM-#qo>S5mQrI@S%_C99$0`UYFRAn zsK>dg>{txcC%uMh;EbaEqUTMF(Zu9keoa$3%|x+rk0bS06lmn)6-ii&g7x2~0>E;- zFIr6yII$Vt+4OWXAm!FpA(wlIRXc-ZHynt)!8Pk0W9)f>Zj~iHfz_R1k5hN<`eTo5 ztBGID$;)q{IOUZlYL=C=@bkyduW2;o=`>rENm=xNfYvPA)eZ1UcaXwmzqLz1ckFEP zbHI;xi$r8dpNqPljPP!~teprRy1$zdGpR0XHDhQrr8Iu(BFbH+?ZMi56MP#}Y8<|z{b)C#`~2*#18;F%<||8Mnz+R90-8rd zE(>M2lP#wLzB)5FUEE*y6pFG(Fgh#@9EU` zP%Jhi81BF6@GIdr6!Y}yeBbhYQQ0Hfut~t4>U}i3r%q@l;+NIs#k2@Roq%1?2^3kz zxR{ZZGhXsT-r<9bK(bs(ekaB0iPp5dz>JN#h3=9l5kMuz!p>WwrEO)*{nhRh6*`K! z4O+;Eh*YOvO|TQkH4U75t3^2bS22C-98VLrb1a}->Zzb&~5sdW;r zjju`8qpA>E9C^~N#fJxdYMOo%wf>Lx(r5rvtvnh4gxU4y4+%w-C2Io^bQbOo}nKj{4Ie?+G_wy&&7?3h!*OSeBE z9;d2^N(gX@#($nPo?2s+AlznFtt#tFsqtfbE8mjFOJq8QEYnB?|N4{m+~e!LRKKHA z*Ws_b*Q+8TBqocCT2*w>$|BT~+3|K85+vR=i>D#}#xBLC>=ibsA|ylVe8sJwS2X*? zHTCLfKbiQ<4?Nns0f>1AEK9N~?LN1xGE1T4grp*<@FlLMn&+$djjL-0x}ghN@e zG0f?YkITgOcI(Q(x)Jb_|G?`zG%1;Fu&77*&TRiwNBaC+>&t;?eHw!qb6>uMaENP^ zw$~?{{2&!5B*olHYSFg9_27^`#mTOCHel_&t$}WS_Dn#6Ze?;`KxRW-T((1gW;@{0 zDLgGdR>qz`>cGC5JUnq_Ur5diezT*Yx<;t^cVyYQ%H;i{TmT34=Xc5za%9vcGu1Vnq);H+`M5YV|sZI%X(7WpXHtIYt z!OIF~^H7T3?hJI~O)QA#`ek48?1vTR^)JsFnk_j)rdkCBX=1Jru2^^-*@kt3JtXMQD^B!`{-h{4Dgk`4QL}q23Qu!`yiY-;DbnGI&f*4fy7} z`;PK>kGBT6#t#onp6b;HMn439pBx1{CTQJcRAP_F^qR3^4fJB#70BioZPeT|RfWVL zhmPwZ(-2B?yM=u7Ec1uI*+*aGk5~#&Qyy}%4C_OR?UxPa2Y3~_5nm^Lg(Zw=s&9fyDMLuOU71G@^fF=@pKo= zqiIjceeS;LS|eIgp`A`Y&~n74^w%10uq?7U=j;69son!qs_xk&C=k6Dl`t{haM51B z+FW+2{>{l>nQg$g#&bV@o$gU)1w;pZYeVbZvX6(TQkzDiYM=+w3Q57wCnKuY$%Ib_ z%0Ff7^JxqD_{6K!RE_O69liyf%*fpmN1V3QaCA;^@_sl`PgRvS*O|9TggnCr|Ervh5hc1E4r@wwD%b^M)GFp@++`vd36k1Rgx~n<+S7FxUIQ6 z``9r@cFhD24oy1U?1+Ejnm-CwZ*kBu{fR%)DCZHsGRh{(L8`r`wbGDHNJ;rV+3V=r zhhpBY@too@^%a-Zq3VQR6)Xvh-$7L;8wsA5;XE7{Gmwv-8z z*ZhE8D;jxjzBtiQM(%CZbxLLuQAlq0a49rNf|ZHX5_Yhz#x$4nag3vh4=%4+KJJhm z+mX!6XnmJXz3Fk(E3s#Ec37!l6nkPArl^mYB8#84j=!l6=T_3tO^Du2P?_Aw9nU8s zf3Y8zsYV)+OKr(!NA#1BBSpobwAzR~|J2z5J7sUI3+h>v-V_-}*PHp{y&r9U@UfT@bD__Jm&dQI zNOwujBE^Fn(g+u<%{RxQZByf-6Hgm*&K@_Ce?F4W?4XchrxbS0#j;V_oi@0qmv9zn z5BDsyIe(H{Y!r~7T$yz6n5TlkhMcNaU6oD7=-b^RQ(R-XtE7RS%@?-o2|-PXITE@p zQ%QJxhs~asa^;>xNF3dmWvFFC6LW5_{Sk~x@hi`(;-^AF6h3;`U4c`j{Eh@0Hf(IY zz16QZNA#>dS5)Wgnx*!M_oWeu9u^jj_w`LA_-+&(dWlkN@9DjCH*~7Dt(P6Egno#{ zwL8x)pOqI0`WYHq^I!4c=uc->&)KYf80NO`pk18Y27lW(rtc z?~MZTBZ7Rzd#vVe`y7s>L_F*QSULo(nC*f(6o3fU*jQK}_qGWSwhXGKXKI~6yq=+L zvZvP-mhV}T0Crf3AwJ3++D7MRv7=9V@)61euX^^r2f7~CfeS}eN&*S4YKqZx*4IC$cF*7S;xr{`;PDE9f8@(rLWxv#{WRW9DZoK?1w zY?57K+Mc_?KWscZLs{O5m*gqt4gKb-X4YHuoFX|&XPVJXLsipLAyd}(N!gy$qUVc` z;AxWZ(tED0_`TCEhXRsw%HlIOKvm=kHHJTA*k2@{GqQ=>)@9@WV(zWqn*PK0Zxoc0 zRHQ*k>FyE{a3CPvpfpUnbAr;{N=>90FuJ?DLuocbdLzbwf#1D9-{W&1_i_IP_YXVx zb>kJ!>$=YKab7-L5zm1fjEUw|hbaKu6V(j4D6B9_!~bk(sdWFAliq2yV?NX8&pCRX zr~5;3d7wVe-?ZS}4Vz?Dlr=&=7Ut}$yCgQ|-KN3#>5L`WSH<>w)1=~74BjSFc<8~S zz!Pw(9|*n?*>J~fIL$WORlC}Qw}j&wk61A`bxN~g-P3%H_wbz;tRgI!$eB!ac21Sr!8*@Xa^j^5%||c` z-tcCgrZl-fbk<|_{M8r&q=y>y+1*p4({lv^jtUwsTCN-IlXtHpL=Hm+dNJ+sQznkC zewC|Djg9hFzhZchN~r2z|3i%&1W{}aB7I})!K4rMT~8<&BlQ!dk`mp%B%6Nvt5*4L zHmofb5KzaU>^ED&UFV~%(e*aKwHF$iwbB1Bq9IrSWfxJQnx|snT=IlPckr>G5rXV} z((B9pAii53w2RhJhpEo3_b4Ww6YS1UU#xBtO2c}_c#wEZ4{xb|R);p;@W7lEdv%J>Xk?%L&GZ>ejKtJ_m5ABEUo`>C zl+Di?T=}M={^6K71^beG?S(U%xD+xjr`e<*!8vWmb1^6dA?&CPgL8-ZulwcsiRcQy z1k)V)+%r#qIzKr9baN7w7x0d+rID5(XeFtpmR3P^dy$_I(7p~Wn3+Lk*@18J_{%Vg zRxSRusI&PeIm3skN7RXygrU<|P#wahUVnEv&QR07 z0n>^}WS#qmV~?doEaSXl<#aXFJETtXMOGDA+F7XFZ1|cXWNAzUw10(xe-&ak7zOU> z&>0(jtrkBX!5Ydnk6DF}?d?DOd^L#k0YrD*JBCU`7kwf*ZeMa+mFo$P7k$yuCNy<| zOlXC(+F^tUGKl*E5-ZtwSax1iwu1i0$>fCtnqc@#! z3++M02&i9il-Fww!eB{fn9#M&fnfw|)SiZJm?PWrt}AqG>`bl>&KY?LirF9u5`ck8 z_lr-pT{-Myc%Q#Ju^RvKyyvA25!v4KQ%f;)>ZX0a0+FA4y9Q)gh40t0U4qoxl>Q3$ zmay%Q3^rwwhjSdQJg#&Y3-nXJd)j%o>3aK-h&_?+XBR1t4QOys__ViUStzY98x-x@ zNXlDt#8o_Nu)*K~%dRdtc|Xl(2U7o;qi-b~q%A_X);wX2&0xM5bp~El9tJ?C-TOpr%LG20Dm2Wp16e7GLZaE{`ImYY?AKe}KY`i+H zm0Tq&2GJp2@ryy@v0at(%LV&aHx;)y7RnvG>E;t9vRYL9~m;p=O`D4pSyNS}Gjyz>{fF zcVq8CNC!`Uhydo{={P)E5H$%VF6Qqv%a4xzfr;0F+h{$h28k@ul zvG>FcrZ2gaJYvEG_gSy~Z7kE?r34iwcBH1b@tC55kzN*3i2D*m!NA6TQc)W{HmV5A zrcxEG`j6W{ULMHD@W}emNgipTTs{^Y+q9G|#59R5ED0avt!ugLerI@JZRC2)+cEY> z`E0(3GsILz-f*T8vv`*nQ(DI_OP2Ml$$ijPruy;oTlk|y#R<2ir=sVw6ByNtV{@OMJG@1o_7kUKV26j2 zE-^dlTFIC>_&Si&kStW|ZO#V7$I=_TP9@jSFmc>(;NvcPKH0;XAjW%gG=}fKqR)5!0+*DW2=2MXZV5w|bX!oO z6N`7Y*!cC_cPv!c-mDE0eGF{ywyYgYwZ_IAL~J0^oC2+~`J7Ok!`L1jepIfZt9=QiBQAp)(4z1S&5s>< z62PU+pL@)RJ_*?xzM)&G(o*$k$_KR*+;`2kj-q+bVj^CyRF$Ayze!cp7_*B>_-Zr< zWk60e1d_1Va+|s2Kd%Avzq$V}{)Yo?2HAkF@z9C)Z7bN~OTd!M@OM}3RWwFy>rTWn z(PE)ebTP|;i1-KmTQkugZpH}r4XP^(`h{;I0Xo^W^-ce9h~=sKV|4Sh4FnB+6E1=Xt zPQpiL`Nvgpq2jVq#-LRlLaXBttzL^e==&s*CwYjap&^kTtuSOsA4FV=*11>@22aFT zS*7r`i@dmlRRuJu6@nj!;2E3cI-L#*O%?OFk_Xp55yY1vE*bA@q{hB|&({me( zUb&XN+*t2;I8kKpq{zDc=Ne6QM7jgvV22V|_D~ zhPWzOUWapS?e-3%d98n%0_9G;-Jg`s4jGg#sEP4bzOv9RE6y}%SX`8(m8U;YAPo^h zYiTU;|HG;0bvgYRXgsL`W7BR(OKte-JdBs1IQ41nesRt$H}G|6^sV%Y+Nj7N%8fqp zr^$}I490j4^9g3o-pJA%oN=`6Ul44nzNJP(os2+f@Xb)!b;!(;MA`E6rHjc zU!;qhZCxfYMyujuGI&*=W3Z%Uu@So6RXd-lj%4xM;v5YmY@bdS z{lYUMkMFlQ3y1Ai!ls@dJV)VcH#YmE_#4plwI?UVOupWH@*ro^(9V?GkKm)n!2-(x zcl*hN7?oYc`ttf7hc(^zYj_=<=yqAFS21&$iGSGe023dK57&@T zMK>Vc1u1pZ#LO8t)u(K~|4P#O=ZDSvJT>vbWkShexqvhH7XHzp)7x{u-|wSGw2tYO z+Exg%ySO1pdpLgBu`Fc>eUN1T;wfGC*{|w|SJWw`6JaO6cUl-yQ-DyC%Q$}Or+0PH zaa*s%Xj;yj_~TuHpq{@TRw2-sflpmE)wAE=qC4!QZwnUWpdd7(<`^Iq`=e_sQXI{m zesZIM;_v1zT7V6pVaY#eq$w*3$yMcBs_^l9usp3r*$rc$R*hgPWLHHc0AEwPz>g zikf65G$hvHNg<-tC>mWd^|_GaZ4z$}Y_hac3KQ?{^Aq*sn`LR4hkOX3&vkU2j@`v(Q|5{v4`t?SFX32HDpR~eAU_Xx< z9yx%P^@jfL*~0ePwhogPoZc^{gjj{CdBe_afs!ikvq*tcGXu!BH^JG-@Ui z_1Q^{aL{A-Lv2jQtwTHJ#Q%O!_kK2Aa-!^wY?ZHvPgAgEW_ltWdK-X-!dNsxA>cxQ zr545_spnS40qNL(m-SAxUlJ}vFkb~CZ#aVv%G@&Q!t6Rg#jC2m&Z46_MrHu|L_QV$ zv?TuvX#PD_5CEYSE6*&^M;6rt0B#R#v!yo2NWIi~5@;hygiON0pQgwIZr8ror-j!s zuoeG0^ek<>=vdD@Z@Fw)e_kOd<-;%Nq36$#VzNAD>TiV53hg@6$H-t!Q`mKu%5nh; z$Cm$a*zq;~;Z&Ph-pD&$Xgxs#H|=|&PyaYo!t@u!Xl>7&`?QFMoOj)9l8UL^2(NQO zzxM97H=&Xj+vQ-)QFQu~KgtvXZ5qjkco%Eg4!~V$)uWUWk6ys#hAL?)8a8W*noa6r zvrBW+QX8XzSC6TMjaKh7xYO8vmNoc3he85^_a+vunuf}rFynd|)x$z2Sbx79qK$ep z-hz3v6ppb)O!aCItIyhD{iH<;tNqa_GKuxLOsl6MnuAj`a9MguBv!+xtbD4Y)1%7$lSB z&LnJ9?kBq#&n2mw?GbBn+1Sxu;b(b-71!AV`~KmWg6}voI>V5?rhr>5w7v_n+Vrgo zs8jX~j7D(>65nxyv-hd*QztNDD=Vb&0?NV;L@<-3P!?c{#t< zx%=l;*)Uf^C-JNJPr%Q}30D>knhYtM&y?R1-zO@WUpQ1iUwwc|wy4zG%sT$9EV%7@ zI#p@B`=c$F0HC60OU043uwglbkpGGv!S4x-4;TsIuvj%+FRB(~v+~H(_Sw67Hn z65NUOm&vp>E&7GmHt|vLuGsJZz$2}SrREqUwI za+L+l4%X-b6%Y}+uZM{ZM1by6L$XN`Tf0?ldc0TNNd`S*Gh7BSb#nTB#K^VGwhKX) zE!Pey*@TIW;pvM4?MK+qy@>;Q3_lFA2twf|A;hmwQ`gDvKeR~_RJ9O?KJA) znc5amhf>okAShKf<|D_}?T`M~A^e|9`2YC?4$pFa3C4fGLnj-86p>$}T_xF7beqtno%o)pP3gTjT+u<)DI9igEB9qcf zZAmb5J++TDx&U4oTZSZv{xkl?ytI|0H?iOUK`V;efcuR2dg5; z=?-M+Jf*$nTTJFTPaVv4g5$Dndg6Snza)|h~ylw zB!$|vB942b_PVnO4TFT?#XG}NOh3P_NJ6Wj1CjuP?b6lxvS!VoGVztxbSRWBWnrxV z{@1OMU(|oU2%j#-EmFjhtpXnCIN@_)@Lnr21ZN?H)i2bGk*^;|;pF4xDJ9bc6#4#8R~X$SQ# z#dbal(W+2(&=&6vujW(&>pOmzu3auP1~g;jon!M&NM3msQPHoDbOmS{n2Q3KwM;B9 zTwh^8j5?D7i7#)(ATBY%_~$&MxJGHbZv5YL=dLj!Pnl@Z$Hf+=#VZ<4sKqm3X;=R= zHTkw4*8-yJj1FBhbQ?^I!qm@$QFD1_2QfB?n){YI0pb;)Onr$bKEU^-JweEvN)kdA z-8Y#V;6$mbYt}w_*3}W%)Ox(YaDg_Yi0FS)M3msZ?H06JKGW0&WUo7&$z3|ewju2N z3O9(aSjjP^U7`1eWGN;{xx_DE>dqx!`YoAxe&YDh9~A(K>SXqQK14dVX`;~KOyYe* zd;fjilAJs8YCwXCeXd<^Su>+4%S?BGz{pmQDFj#O;pYsTkDhT)ht#H|^7Nwz$EYtO z9RU(ds&xT}A?qqFq9X6bi6pf?LTT!YCY#hX_%_CGaAE7|h$fS^Ia#3!bi0F6pIel5 z{kyTF)lx-HMxRGRC2+H*W5S~SJZ^-1IFjsjhq*p4+3&p;#Y7>RO&Q!d$}pjW%QK6% zN=xi>+OgEYLE;S(nc!Frf3hd}0np8eGxe~IWY^yP|ju>dy=dMq))^=nMTvvFsusmmPlkh|U-^vX_Y>;lqr5jKJ zG_hpNOgq&jTQQGFs7df6^U35ASsZS;1 zrK4H#vx~w2;2H!X7OX3n-{Z~oM)JF&^i1gvWLDdG|BsAJi8Gx-;6(oaFEec$gwBqi3jjRn6^N z^u~wT-C3_TQ$|vmq_?$3y(6qHG@nI89!-_fM_QZxcpe8?zuS?<0QJ+?%J^U|VC%D#S|=-iKW}Vl_9v1uip}pTA-*AM_y$(4tsLUb zk=?4G660u$eF5tTVucUfzs%t;y}^(Cz`fHHsFCD6F;4oXq56;qY4&q0%?gjv;q5X$ z%Y5t6uEAmxy{yFFn*g+vvp*@4FD|QOR2RMF9^T7o0dgx(SpN3qANlRfH;6}dY9;oPh*Tz-E)`K@dgmFku&135eA$nFZJJ^S8X+wP=~a?=f0 zm+|zdpKWs1UOC&EowvLkk zt-&#KJB6nMa;vAF)}Ay#Ma(Nn(l4YUISiLCVlFgH-y#;1H_Zilr0=UmqZ#=6yZKSc zwd%M>3<%I4Dp(?t*%+)ody&0jYFaCD(m z%j4h=Dz!&Oz5N8r`cpVU*7K=LWktHuy}jOtJ)UYs>A4qL#ylIHvBTy5{M>H&Z(K7( zPfjW!Cd$ie4I%cO2RE9L&>#UYU|Q@YWaMYaHQP$Vmq!V+3ah;MRgZKK&r6=iBVzUD zck2_bU$j_#O0#9=cscyyVlu!m4vjTCHhtx60l!b6%{a#MP?fHGiL=eX!!^9^{Z|9v z3J=y<>b=UNj39WFeiWf1V>zy)mcz9cM^aF#U@{9PpWaJgS*_bn`<%XTKY4Y%QIo65 z1Rdd(37X!Et6IzLTwJ_1ZTSO0#@sB^vLf!|tqVt9^`yM&`E=2Zd;Ykg6HkM0DCSaX z&yy|11RF&#VkgYS2G#-g04orYao!4#HZ)2d-IaJ#hjs*i7-zaDxi@6y#GIXB?n@T; z4F)ut2kti50{kPCv<}i8wmzm6V(+?a>Da{BdxZp;(USHDs_$g0%G1nOz$SEouO|de z)|=T>%awON=FbohX@TLThIWq}@$C6CjIq93GfQzlHMeVkYv`I!*Kd%mD-)USJ3*eZ zRC`eQko4E!C%)z`wO?RY>5x?hsQ+;ES#Q$X`3cz;gXztQ*>V0QSp2uS{b@<47!&kHAvjIl0j68ku4k)vz>@Cv*) zSkz8>`}{zgkdbhN4L&D{s+u}B<9#*%&7Sjd#m|Ln%fO9-rMK=NTdNY2oE?I)5?|3{ z@8!jV`h^^JtY}RajGd`I`%k!jaIsqeGq>*DAi&6ge!YK@r~2+k49^bHaQ~-JMm+<= zzQkxI0VIALZ$0*jvTOLRX8cyW%4cv~N*})z$~lh#mLGJ|)F5*EKU5ay9UxmFgoGcu zUD%`US*M965F1efK>`(umHYb)sVk90wP7DlbxK~UI3adCWz{afMo_fArCjtcfPkb^3) zCF^Xv7*cDX@#dR6)6y1Qh466WnX)LTlUC;4q`Z|J4a1Tg(z>tJH3DO_xU3WPvnAaA zZ0OyeAEDZR<;_A@fbV6b0!J1LxH4%No76JX%bj5@qYTo;b43hY0Lb z2l&Rj^6$h&382QF!XvI53qaI@!d+-HUPnpF=Nose z?Y8hS!^vXm+XyZW)0*agMd3v35KF^L#%b)CuH`RB(U8H`t2j=8x4&xU^PpyHdh7R> zFKkkRM>skNcO|90xnZRvEgCY)(D1Qf*UDBVFX1CwHLvQ1Uuho~w!!@cjTy6dTE=H= zNQoW-zKaDBGe?odUaOsJQ9VozYi2<}q!Mc^y*dvq8?$KgTi?{f zb0OyMVpUHi`&wq9Qq7}W%2j&vDaTcr2evUINFp(ZH$>X09$G)D3~nJnmq_5|tsh^A z7*HHFQz69W5&rgp9$D&72BHPTR6kICicxFp_E>yzx_(aS39^c@(BQ>m&f9M0A*O~+ zd1hpB-PW;>D@9N33Ci2sr};@wQxZVNSkv#af;e7f7p{|PfIr%tVh83*ZWwd94oLnw zm&N6-TJz1Ec`q%iNpKr!z`B^f!#>lwNfIvr>rJW*Z$ z$wV|$LAG<@o~ryH_w*Wj@D3Q{(&`80B;Sgo^Af!lDzEdfZaNpZV5mjzU?o|6%9uL9a^w{^$2=OL}^4i|xo;3!!CNj8OmAM5uC1U%a zWgTRkJz>Bf{)P4p&4aK9wD?&rHE?r769g)EtF1I&-CV*oW4MA%Mz3^LonjjGu&EL} zlUim$r|E9V#nDB~N<{m}q1=nJ;-v)Efmr_!hraIbN-n1+;-sMhqZ=shNARnc{cf*Y%k+^`7saGrLIpaIlP#bvQ&$2_>su6uMqKRr z2}e}4a2&5^8)ywNOY~dN7V%mBEkW=tswq47_U2?I4%?&6*7v)UGHE&y&7@)TYC|dnX%G12$I??dm!$UmP1y0+U$O@ecRwdke+>jHi58m>yYPcI%~kB+Nt@(r90Y^`WSP)|0Zcp$&w=4Z z#GYfx5NT&~373Q$p2I-175X$-`6MQJPxIiHkE3Ndmh*SDo^bdTgYG-}UoSaX2;v_O zCTYU+zLV@eoa$o`?ewXutm)%=Rvz0D-ng^mZ>z%}Udf4^+cSmLwDTb)%n=!B%kp4n zpez}yTF`~h=W1MLp3)(vv{UjReWY*H=M~KiqoZRRdJ7qwSG2)|?Ac7+n^(}<2wQAh zXAhN<;C5F}=UaBJbZr^|Gs-s6-p{amDtQLEprnRDNKyV4_`~z+8D4K;_CN=^j~a&0UhDY8FSy^8-;5)4+^> z5D#%o9^jpzSC0NFd~8Ly_208v%G<79;h43Z_j1neR%|ulM2#%%RLYjyGaV&`IZWQ2 z=hi8q+wTg|Nw43ilg`k$$Tdm%IeW!Qvtb#54R@Pw=R&FvP&yf`(<7e{Lrhmk8&6WO z1akAS6r3bsE0lBIkR7Sow+xEFn2Kd&h?|bly=TQ|5xyQx{2Ukh7w8xC87chrRZaD+ zHqALEV()t~4N` zZ5GFEcl|s5&PG*vf)n{=;$Rso9%nH>8AX_kmL;I6>aP zI?Er!wcN;}2%X}Rktwg`ruK_9k&*=)SNb6CW~5eh<$@RwKF_0W^Xxef(yw7OHJ0z= zVoG4&f|WLlL){&U-!`Pn&Ppvm1JH7=_~64WoU#983wERz7qDzJc{%^d%Wxl#t?4{~ z}q#Lo2-_t>@_65K@+=jAGT)v8f#qV(TUp&e==F|I3z~HFL)Iq zHb2L?mi#AjI%nw@aD%EHP`N+rNTKRlVxmUi)%YfS{t_?6PYxM(A-3xc-}P(d59bwo zyd1ycY;wF}TeT<`S4U`U2#Kyr@u&sJ{NXvlnjSharr3K^wWsUy7^{>1ssM$8&!n-u zu|*2jCjc{R-rc zb=3wPjwpL{mK5Mc;4#}h)mBzMR))tP3Vf z()(5#h;8^=0w%o9{s!A_$vATw&3A1n%tR?KIKTBATC-DO3?3wqZ^rY6W?*=>6eyff zOY1KY$FoY}N$Tl_-pvYJpQ^cYwjbiN%5im)qobi{Zu(efzE4^75l?iy?esql8UpzV z8AHOVF$%C?``)WP?U|A@n@^E=wv1SW4gvapYyv+Y=V$-4_O8Ljp2rc=2-$^NGdxy} z?UO#pN$ui}znrmDJ@aWf3c*LGR-zO28WG^=W2=QqS3rEMibT}|sgLqHkEJ%5uCb%v z^=QYfjpt6%a-kgS_&Q&5(n<4Cdn=Mv6Zc|%-&J;(M+cX8Un^-baE|YF@v4OUOmsQ} zwzXv+{%{pwb1eB0+g3ATCz}F{lo&75?;FgNDHA;%U&u08@mX)p{i9-PkinuKoK^C; zax{0&cf%$liwbL*#7dW8|4~*QO`8&!1ZvPFobx_Drd_QSOz797gRm>&C@UIzRy+3e z&D)+IK6N~l;iRkKUO~1X)q!v**>3%duvmyKBfu)?;!k zUJrM36$rMnGD56HhM!o4JA1*3#g#w3TrN_986`wkQ|)U{9jN67cps>#sf`|YxM01l zkR;uBHL6#-ac1UgDJI7EB^~Im{fp~#2>fu)mf6ZnZLtcX=UYm6{oEeHWE|~aGc=^a zDL^laNxdcM-!3xIQ+#x!0iA49zKh+px=5q1aFRl++J9A=_Uhjo(l*5C zd0!{Q(C`=RJt1v~9VsC9uocolg`BW65s;&g9p8JlC1pBu*FQYch6CU+_ zPgmAMf!lZY8gts0PJze8d$kzR`GiCS)O^05%gy{hbg^jYwFn=8s2lW zkQd5eBH!?E$KcPKQ_TX#p3hWMR*DmTa^M7Sbuuz13);q}CQ~_0!}NQ@)$rc9mm6Pj zwBy2p4<&l!{5KnVSJM?* zBx(hf`bCr5MfAr-^+0FMtIgfWy8_tIf~(4lS@Q9nRzlJ@d-6F=n8$6mf#3%Na64ov z@*^;s&B{tBH=EW6Z&LbE_f#WrcnPLiLsC`e=xFcwS!~cZw?Op?an&M$deVdKGcvYY z>nUJDg0CuxN-npcu49Y%Ny9!+Y_4U=kgYPhQm4{Uk|RyWAi%q0?(jEiiqV1jch|Ec zmzQRy2Zn-X#nLX0I9M}&u+2S72dD1h%efGNc>Z$|$4a9iXBKd)QN|xI6k7^~8ipN! z$-Q%}7@1^#^ZmI{QHl5f-4&l`m)Zy;oz}$Jt!oFsNeugM=1uNkL$2OLze{PI=Pp7$s=oBK}71j&&lPah=T zHoef(vBbY^-hGa8n8e@xX#;9>(uxx&0fDwOgsIki{}C^GCPe@2Ri0YyHiP%-#1vzV zFwJeTgU)I3U zW>M)W3{Cro<5q1zBQrr60L4vZmp5P#i3{O z@WX)HD$}#FJ0OovXSUwFOM}c82;VAHBz*Y8tdP}XfaC}~3&joph!E_9MgzK1-y!!u z7T3)6D})aSWPQGo(MEFAZGpapFm}ywK0hDce=4LLly6KEC&*eJgLDQBVNGmuDf`O{ zn|W(y@17~s1$C2?>MQ=$@Cr~EIHhoCRj!JCl$$nf!apxOc+?+|^X{?#Me!l?^M+$r z^_E$e&%7)#`$QtDPa|<|D7WnqP*0%=F2mw9cgh%T^^D1hB5e9K*;D?}gXz-MYGkU1 zYu3TO#D1`xX6%-DtZSjkdXUQENWz?!s*wz}hjMcV)PEL9az3DrjP4e7u}-FosPiS~t#?@7zlTY`HQX zGCgf*uda*1b|G8QFPja#_QPHqSQLCc5YpBX@8#yfbzN(KG`8Pv%sBtmJIgVxKrY?E z0(kaE8agblPTciRyO=@SB1NF3I?!M5%{V2YkNJvH-IMrDQeunfuEDI5c>cDNkCQ11^tR~veX-8+#G_`h+2 z|G)TaQL`y<{VWJept>;Mh(SgU7iAy)!(mB`RXE0aCsY~&jj54NprzKX8z=iIq{MCc zKO9DZJDG>M*oiEW*Z+v)-c00(yaKqKt*8e7yI;BKU=i_!#ta1} z3{pXV-2#*MZ87R~7SZIGi~!vvn#Y`9GGNxEZPJH68=MsV+_zvt-xe)95>)qy*1WgU zd@kWF^+5(VE_B*HW&le-53ul^mKcd4%)|DuQqQ1B{2EBh>Fm~(2Iz(^SD@8mSvvp- zFt;EsDC)j7#n0`D4bDOWsW04#uV`Ih0tTkWiqxif2c5`u?YlI zhzU?>Di>~8Sh#HQfYehK>3t;4Av#Jtrh>om@VXNP-|S#PO-87~LI?}cWGJ=*WAE_+clXQkk*e`rhQtjB0bRs`kmFfU2gg2^nPs zXIz6EsdBJ@S-8N9KQq3G-^oy(VR8?KoUWPsJ707l4O+aR2ZRj1)VAu_+JIi~|8}EQ z8ACA3s0vv|_f2mjySWqGQb($|8c0t18yXyu)X}`+`myFE5<2wol18sfVOblHT6AfR z?eb`rmA-djN1cRQgKI!h!^??Jp+l@B)&3_&8VjKT;Giqj`%nvkpMIW3Ncn>2nZH)V z`e&sfO1C#a;@3-_Kgg=o^^axyP27%bZf2k5^~Vd%lkLSp0Y|?6~MD;HUAo^vJ6p+OOp!3BmbPyZ>-V6}E$?6%J8k zcLatn<@&F5M?5`QHAE|n*GBB@3% zLAtkGxJ#7}P0H5n#x|njKBt_|Fcnk3@)ExK7TZ?eaI)j`59fI>EAlB^ATCG<&Zqun zPJRLfG>GFs$4X4l)r_plq*^#{k&z!Taid#`fXl^G4Q8uDU1R2uGu7&kHYjZ~d~C=j zc8=h2XgRMVdBQ)MX%tEzUaI(IAW`v9+ktLjDN-z3!()JAug;x?_e&aE!2f)+OHU?Y zX!Llx+JY^&5M(m-Z(9A_C?_~`CdUbRL4;!d}Io*Mx15~xE1u{*f!zgh1o7^UZx$sa*;3v?8#1tb{)*hi!Rb2 zGFh(wg!{H4#9<);aK+K?v}8FSq>aSYToSW0Y-?$+U)O(=n4FA{`$sw8<=wA#2Te_I zaq2%Dv!Z`E@P@q*R^)vR%$3sE@7E2nUsc6hMM`$txNA)t^X?ki)JXzZ9qjXc0!K+h zg19&9)c8->tlFvlI7)Nu%^5>pB|)G>cpNh)PgaU^o3|On2ime2^6hz6P?|KUWD|LJ2di&>VEH%v=SKFiDg{>suqOH-3IWF|!1ML8j-@^0%{ z3v*7xZXOSEFGkjpTfEK3CPpMJ9OqrqO*o@g%g9w_jnuB860@mT-+a1Siidlm0C#Nq zrOR=frA_*u+$SZ1m0wFT4ZeNC`&D>dwQMlM!(*{#5m^(&wjki8711fUP`^j9x}1cp zc~FfF5Rk%07NS8&)}bPzP?q4+}}n~ufF9qdpz4KecVPyj&<)p_D-Y*_3Hb>WG=dzQiSu= zn7N+NtTYJ<4XOm3r&`c%fura)!uh<7M1ODjH+XUJ_2)0!4u4AfxQUy=$;RZKwFMcD zr2d5Xt*lP`M%W{#gq`h}(fJX*q9P?u?`fKadnuseL_te-a!Uw`nRF!IT_um_J)<0z7wvMP~nZ@@-*Q!xJM><*Uu^%t|89OATC zI}<@S>Z4-MU{I$elhYLwabkB9c@JM4rGL%4(Lsk5ch~eL=OH$alCa?NrLXrKevqDb zsAaP;vFz<^;1>(tyh0pBoxshwtExC(N?UY!!i@!*kO+ zcM*&fHMZ!VHd7?&m6q(|c`*E9)&k>?R9}{>O0=~ttS{KjAV|PmvAT0BkZbsFgk$V0 z{0#f?`x~?Qp+=I{o-KW}@<>Trw%gtI@y<@eFN0f<83@T8<3>nih;W7i zZ8%h|NU|@BPubIaV`6zxKKUywBoj78eXv1>QFUbqU_2Ddt@-;lx#mMZir?SzB&lG2 zX(CS)k^T=S45R4&%QC2k^l#xHUJY|=-!;YWFrK{j_6*cb;I&}w6swxgHqN*&gSVU( z!;(FAd>F74=LKPuFbi8Y7DYTXf8_4gNAoh$2`mUGR0Upj%4!ojTZXu%!QXuvKV5iEcJf>i5I<8%v!tO}FYt zh&!-C;K=H_j{>uR*qO5BoS25Ha%ZIOtcy)u+n19BqwkQi{T?Q4VJiuBjTEsB9h@B) zl^qpuO|=Z@So#5TXVO9b=5BsBmOkQou65#5cc$1*wZxLeD<4{2*N+$-tdKpTM($qE==P(A&LkAK%e-s1xPXv-h$SXL?v&?1ocF~FsU#e60#ttO%mlW?QW0jgK zcpvE~U?TY55xgnlB-j(^Key9Sq`d~{+G||=BOir+0*ORkYdFG#zj1q-Ad4$tnP--P zCrbuFMascI$o+&*D<B{am%}6IkwwMpZZRwXb#YEG}uJY;q#p^-rQDVlZ63 zrQjgCg7~el7Fo)tB$l`ut=zs%&jeCw8>A(tWvIwI6;G5zoOi2Zaa*PLHS6Ik_jN<5 zH%!a385xenVEbaQc+dd$iVlO#{^yDwY~Y?AUu?vzGq$vFQ`GaE93FKW9rYk!&41h`d}P`O<62U~)M}mhN}e;Y^nW_|;IcXcJT;Iz zMPO0iQ2>0JSa(KY`tleh#8)zZ=b zLf;2}{H|vZy@At;n=b#|K%TWcWJ2Ltg-e{cc#@O0OLqo+)U@?{v>L^UAc*cZfnzI2 zg_DA}Paqbe2HDk=u3xVdZO{;bN!mwOqd&`fV0zWkP1MuN9I{L}vS*4Daej3+Af}GS zKOebxgQHt3O!OBiTz;J~BvWv{i9X5MqN)y6i!ip7+j9u{!# zm;0Zsh)@W~P62IXxf18qv4M!S?fkO`Dmxd)bQ*wmirsTNI_gIja&A#9fP++Yt~Y zL_aLg2@4EOJQWB23BkokBDMNDl+3YinH^{T!9|ul!z?;lpZH@DXLhw;cXLZ@7lS+n ze1ld({{6`6zEM$#a*vi>iuy+9(wtT=N^rc|FNjI&$ulHHx{0u^+nV^w6dLyOzwl+n7j2($UeI9J?Ih%R& zlJ#~#@A(x+b1O!bo<^k}u4y(RVZUYx%G*2s&84aLi9316nWt)UC8!ZgQ@11}K^sI* z{||d_9uM{Z_Y02@LKCvDSt4uM31i8YB(jUCBwGl{lrdw;zJ;PFDr7g6eI5HQJK4<; zvSfxZ%;Nrhf7kCizu$Gvb)R!z=RS}7zOVcEgDH>C<1_Ei>+^oU_Sf@ys2G>@)p=k> z{Ou$K+4Kk~5lCU{h-~}blNF7~wz}=s_r4T|=qtqXhvyYju z_<+j;Q{(xn;wux`$v9H*O%nz)*4N{bFEgEtHx9G;Cv!-b3BHAJ6a2a1Em5LU3FFG_ zEK7}{fN%4XbXNrTo+ed^84uyA4NflF8?IT>X>DZtX1S|J@fzvdg+C%YYIH9>XV9)P zbRE;9uu#$^tT|>P#6HIC$EOAbx?RA(FBN$>m!E%UcgW^#+o4|J+ct*ieDUH2tJJ5a zmBl8X>q?c!{t&`dEO&aaxINZk`*$UD6<_c zA9d5^$@?`ro(;*m=jQTE>5e*lq%fdt_uU_#Nz^9|m8Ao~b!G&(ja)nDaIS*$M2#Rc zeseN$Vj(7=mSlouHiXvlJ{Xe~eOKmR9mLAHMtAC~6m~b>(i6L36FKrK*C9 z^i{f{37@SCD8VR%AWpDt5Zue*$ffuE=5K$uQ@Z^*1+Mc8RwzaCTP;Zf67#*X#2!|) zIpzNR6@9B|%5@rC)%iyWcSyZksfn!(+e=Z2UWf~D!?8-Tva{B^k3tpq^Awm&;xQrj zYU5Od1U$F!+@+AswO)Dh#&vL=x!1z##|<+!!MhoV{RQ8Hv|0?KVD!8O!5x3SRWZoy z(&u+y4y#HGurqjLhK*4iDi-Z7$CkuKau|=lDZnh{^5OwW4L--gz7c7&C2}{1tyUC8N$Q;omKOO>{|`~G96>jKW+uJ_5X} zbjM@=%}n&)Q=_1h&;Q#r!T;%G_+N1@FYj##P_4E)Nt;h{hzqn@5B&RSNVR`XL;mxu z&j0MceNZGoZLuR~k|;?fXBPP25WKO*gCC*4;p<1?}hT_=Tjwq-SSeU8hdYNk&Zy0s=qVl*^AE&F#ZVl+}DeY5e)iqQ@Pt7_^Ls5b# zcIo!Y^6TlJEAt!f@WHU}H5^?1mK#AfTw_016H-@xP7lIJK)n?euJuO(0IK=xCx!cS z^=ny~xld{G+>XZ9)_h^~%PdZL_wa4ls!VU0(gyvjLXIE$lfVA}kcTZF4mSBa-PV74 zPl3;Xzgjmp2Qh%^Uh4Sy1Fn0$QMGq&1o?K1kkmOZgZGqzrCQgZt<)@&2h;M$Z{5>5 zI92N*QR0b}v`T$%r@UbOd!XRnuaVdYiQM#SCBUjmOo3k?PSt@K2|0GSr`0}K=F>jH zzkj@Al`ko|+}k&2v<_hj55GfZ&*K@*SF*?~xBpzHpgwlb;_Q3rA1-?6)sLr~5pS}6QFURpq8WKTl+m3^-4 zH<2q(4W2Tc!S{X~@(wH#uB=q}!pt5@S_`}&|>W<(l>Uhfn%eIx5uibj` z@rt9+XLwc0SznYH%j@roF6SqbO%*BI?ZmhxaI})oi(?Q|g8~M1b{*<6mf91V8_zt* z-f+#R`LZhW#nGcD7n>g8Rl1U*wWM*QFsOAFSoje`UaXv~`zyXJzq3Q~YSYD8QHfIB zWhmM0yxncrpqMKOBUhd_Z>6SMZDeXJY23V6()c3pW@c=XoPtM>176)vi>d8iq2WxY zlH?^df4|(nnxKGlF13{O6)=q<0J4lnMr2IE%=)zDmRFz$)@WHTtOW9%e$9|B_M^V# zMeNw8&0&SG6yKCO*O?5n4rC;RQ_;3T6K_8+12K=jC(C8|BxM38t`vazfj2DoMUb#Pqd)@n}J(I$Ue5PjgWkhvoSMVs`==MPm^5{XY6w!);tViABom^ zF_!YmDGDzu&TNrjY*EaS`#vM+GjQ0=Lx65{kojb&zK^^8v-`52`M`BxRL z`f0^Q5Ik8juuBnbr3b;nb=72#KT1!CKIIOQDhu8Y>mFO@ZB~_vY6reMloL@>eT8>% ze{dC6IW#GUh{jgn$b83r0!Cnt$4@Nf-yqQb_p^$?>;G>oc^Rrz>!J@cv`2I(u2`z# z@@y+?f%`Z3A-~QsXm9e8{*!xG0wgRG?7at&P{pINf|U-;rO7-=B`Yk@?T9X43xdAI z{a!pAhDDlO^S)*=ZS=dTBlRaFf)^nOWji;thwVL{|oPJI^CLObh7z_O zJO;q0*lh*<jhyu9Il#t8ud)Qazd$WRMs@ueMVC-Os(hS^4CqQc*Nh;ZI{Kxv&{_#`|7O`6O zKqKx4Nn|jTVlViE9l0i(KWBCfdYv`-kLCYci5X5lM*g8S+Xn95r$;=6BGZl$3xO{z zpWgq+@^k+2PymeL(Qhy%q>#LaJaj|e1yY3{>aqhM6zd?T|FZZ$UjOBZoG$4LI<^1; zVuv)LjzKcu%`@mfoVS1HAA@>NUdi8g)c=03L=fC1-XXg5&HC4P|4&% z`;sER^;flzvV#n+)H})r`PBRcVI@HLn~D3E!K=TahW;=3Kb6zUgjDX7#E}5Dw)nl| z%Bkz%E1HJTb{hx%=c>{prOYBln{aUs!n5>M0kYf97{zUScWtt)F++nCc9xla&e(@$ zVZ+(_DAD8V$fnsvIY7xI1k-AX=KDOCJ0d1^e9bzwv?R9=zb1-q9N;8eLb&qZ*7H%^ zsyZ)2LaT42ntjuI9gP4oe(F^q^9U}1@z>s>3FttTRowXX@^QjL;76T0tce?8S zH5Nb~Py9b@yX!yyn#^d*@yz`-d{}MsE%)4K-QOt&XL6*qxkZ^$=t?bxmdEawP06Rk z`lWu;mz0`@V6vJnBZq)~_={Ue!j>Rz5XuEvqdEqqK1GwBHlp{O6#&I#pG}}acjJ*0 z@N!BRpw$+Et~7;?H58CpihNE0w30TMlvxSXNh^?TnsA_?`GJmNIt@(6Ohd-w2PvT* z$VDJ}5;_Li?F=5B0aR9K7ip05Q)3Q=dt-@?41(-57UM`~w24Pd$`!)sUnx*$F-ZA(!U> z{m7p~w(SS(1V=uFN^S?dmn{t}st8HoL-ymJ|BoI!>7Sn(a`Mps>)-Tbiidr`N^-%R zyLvprYWx+7rsw6%gM`>Gy(wquMTC-n-LdV@XOFyY8*3339>K=o# z*^WU-Q2_d^cyocEQ;MvU05J4~wc(GO6AOL7F8 z3{Daun|LS=xEcAcOj0bW9EmwyNm(~IVIE`P@pXTq7HW< zM`r--7TQM|j7MXf+@QNyK|WNX6^dbRx+w9o?~6cJVazK(xc6xkL~ zs2tf8D#5=r0j?VY=7JZZk%ZWN=)aoVzc~RR8lDsm!FI?IeKrvfXQf1e|9l%aqg-OFyrE7FxD3kBsHk;J77*U-Y)Tc zh`g~X2z;GL{m)V^(|G6@)t<+vuR0)8n}U2h6W$NNTmOtIlu$uSdG7p&ydf0ipWGqM z6MyK$TRiclf%xYn(m093PNLM482%)~agyQqKh2Mw#PBCE{7DS||A!dJ z6MyH#-#PJjPW+vdYG5qnKY3RE#G;>A^ph@xlOBze82uzhKZ(&#V)T>Vu#?WYlfJ@} zuFI32)RT_gzc)0WbnO0D{oW_O@`17@=R{egqy zAu_2s=M*sReDeE50w)qUk--0D2~ZyolS%&`P{QB-Rus3vQLoJC%Ux?;)6VxS)XcB; zyXN}%-;jF#bxQZ-^IwpQPUP_aZ8;#1NB+%Dr`~mKM}py2kgd6UO9hmx?rG|lGTsRo z@$o7Tq2Sk4<&Kc2or(U$&6m6@v**7sLqBE!lM^IOjHaxkPltM5Swg-}lKjg4lj=}~ z<~h*_OV=t>1Pg&kvDgEKukQr|cp#=6d*WMazNLuqg=$JFQ3c@q+()nsmpPXkeR$;D zTcMI(PV3dP`UT8dYjLUwKXwj01EwW85bW>*?xjtFgf=6*=)_{h*I$!Ejz+O(Rpi?@ zwFmt!dsd89YN{A~$+jB*js5zq(qd;&BksKUhxPB9x&;EclS|O{ZNkNaL@i#DAuhjN z?#H${k?F{8H{a2M$jsg@opPbu>erq7|bzYx_pS;6sJJ;P~ke}9_ zxA9S$<%wfdJ`{#h92;gO4@L0Y4;^a~c9@54K`r0H)zi;uj_CGjruO+oOv9^oC9Z0> zSQyV`kmK=59orEQu(%0SfY4`%+uQ+J%ji-V-MLhBSq-g4X3vl|0O;h$i0^3drq&sl z^M2^CaK=nm##|ce;eo~V9ftzW(dMYlL&nzDw=LW)s~LyX<^*#h7l95d9_x+t@JX-{ zWP_W)Iv$wHPvDGcC+u9Fl?k~IFYBI)?hAQy;Z;0{ig<|GW+nJ{Ok3J-M<_X%dYaY@ zrU$ra1YXuUdv;1#I-jnL8NrI^LGPpvMx*!H@jjH-J+{Xn+zmmlrck080fkw85ummW z=Pt2Z=sDGU=Gmw30m{`f+|-XES(_8J4(UF0c5P-JgDzCVO}La4j$nXW;^|uFOj=Zq-Va8t zB#PI@u8SHowx=vhT@s&UiT_hP851Fb@#O} zB?(FTcM17-z#x$bSw{r$wci$^HK2={)@Bjq9>OdA9)UN?m0nnX|BUro!$tHvGzYKW zdVh#T^ZPUjAex;R3D2-7Fd+)-wUDTFp>9qovM}Gysm-0aIq86*A}Z3w-QWT8N;UP) z>rOZ7BlC*-N_1uHe&>Ul>KgqNHRro)Yi*LfZ9-vc05zEk-luar96>eCc@5u{SfNR6 z?v=keIh{~h(X4VI-cS9@?R5jqjh!3{gU?b@wEAQVI8Na7(C>_u?KXG5qw{`WE-{aj zIIo%qJ}!6|;XwDZ@5*m~oa-hbyZjQ~HU3Zbj5ojcLc7ZCQ1+$F4-=EL*;_JxUsiln zYb_XaW1n@?70~U`<-x*rCZ9w%=zQ|9o{mgc=N752ek)e}gx}6Z_f|pIE#EtQf_#)S zj_y2B(K2u(=J3yecpqoF`Fa(Kxi&K}O)U9=WW>I5bG!Gn?>ioX$i6Uv+ZhL-ASvVN zLbPP<;CeWNT4(h*6V)5{r}ve5yBB*`)R?Xn7wSLcHOA;891(2`-zrgC<|tUE?qq{; zz(+TAChK$?7JCVI*@ww_Mvn!rly&TC-7hiNq>8*40?BV=a~LKW=h|&h1e&5OMwRH( zJcBZyao9eL>9iUid+h=MJ6B5WW@3*)TRMxVAH3Q)a7)Y+X1xyOA-7-rmA!-K<$l^{ zi#`h}V^o%Z4BAH-kbzn7M*v*}WxOsa`p*`)&BmS@<|pOuSmtW{;?vLETby^fI=}mG zXbWa$)4B{Tzb1J225TM4`bCwj}88Xlkuo3FcE_ z-F`eV@}jod^;SJ5P5G9d|L$e+JP=aHe!CqVCBQK$qalYyvb(Drlp{nuJ*E2AWPZD{ z)+Q?O$*b2#1xVJ=Nz?Pi?m}Y(RB>iWa0Wbea~1<3J zi)wIOB2(*+pNsHx7`QpjFqE9!D4^Y!Jo=dZP#<9B0syUEOj5XJmM#-~8VCim3&$Xr zL%8w&;O38pY&i{|xrDb%PZc`g3PjauM7PbpfD~?Z>+Vrg5_Y3F-9*!fT+rch9TT72 zlD<3P*g;x$5lVc)rY$%oQ&s(G-klN4XsSCOEQ3`joxYyneUIAr<#h(#H>*#56Ze|) zy9C3&t)(NxqPi%8sYau7!nF;i*2ATk6z%JxDGbWRsuotpJuO|=35NSO36cmF_$fky zN+jaa6x;~sUHjS5I>E<&h4)K%$l6nL~hent5k-ERZI`=KjW&vhMmQh`7U^{Js@RAdfKnq%dR)PZaY8=apaIqz5QX)08Q1FQ(APpKTvJal{xJQd4!K zl*u^=ae~cK{#I$>S~fq)%>M&i373`^n446dTh4WNw9kNY<8jxeUD^DkmDGE!UwT?4 zK7iN4TaauNz_p3BYZ8PNo8X0KLR=`{8zR>_c|#kn_1ma!$F2u>L5+u={Mt)+y3efM zIvzOS>o45K$VbDp4>TD0(hR`U;-_^uNBG$afWh0J?y>S82gm|PFcTbV+;wR z=Ej8d(6-ysBJR^WC+{q+p$V6KI*3X9h#_@wi>8TVFlY0$u=~z@cd=-4j8AKy*lhAO z?%bWyIc3=ik(HUg3=p{q7uA{cLhxL`=@B^NQZQS)N>5E@5VOJj$|E_kDuc}q!vd}Y zltE2R*a(AcS+4E92$TBSGX+KRb}@?}HPOkP1)`2UqDSj=ISw6%d?3j0F)1KE>OFR+ z1V67QS@hBEH#W_|jos1l>9=D4ntR&J&|8#wNCi5cA_P3!6^x_u+O(jc+t7ykv;!Mg zO*Z41wh-jl;k@U0HJTTYHf_Q3Bk?X|n3LF~O)Eu!5cY4<%g!NW`InwhgCOymZ$L z-Jj66_lG+7>qD6Zz;iUp#!RY6f`q_pyqpd>S?lZ?lFK!5yW{9-o=oMRVUtzq&s;7> zx$<{89!6jAGcaL$m*qzN$8eFe6Z~QU7!-E(d68k5NOI4CJv~}&?t zFp@#4aN0;jwG%cF9ON0`8SKft^{U6$=W(2T_QSWpe)1$c0)@HMt6YN!s^XBvaY^=N~6&{A#83+b6oH(xYOaukQ^$R^`NA6`SaF+Rs(3pP>zF;|EBn!V5H5N@+G|WrKJu>R z`EHr!)!4=lBr}31K)tluCcuUFPA=ZyxNnfv`<*h6UAe$U^RmIbUj4;=_HU}9)B^}z zO6$Uxi5M)B8LmnG47bAe9acI7y+VKa5^=jbR*Eq)Q6l#l^;PPhs|;TxRQcuML<$4` z7$p7+uR>|m&;bm(kp%Vrwml&mH`?}`jE>OI$IgkEOl)hQFSf2z<{1m&D^mj9j%`)@ zn8#`wX94#Cu%tYKF*0!8QdqI^=pJ#Db>`P3=aBw0g&Q;XS(<-~-rs7)++1o#wHgUMk@lQk2BRi+AO$_)x8IH*j3gc(= zzQfYJEPq6K{mma8e)2ZkG{HQN$EF@RrLyJ-jcdfFR+^c^ivIkNi9jr^slu zvsomh+$LN@WCf*iTj(`F)3FwNKbbq{#br8xM$lT~Lrh%MT;i!mr^D z_FK_uUXV&J__l)TRC)h*0I?hl_1NpB0GyE+CURu@zMp66(v}Y8GSZb?a!tlZKYPX( zy=hAnPp=lN+tujzIXu|JHf0cyL`#x3V6ql7DKgRP-?GoS)m1q0d~K9hnHKrlbnO+( z#Xuh;(%RBp9eswRI^Z=GLE-^QF~=CO#gNl;DyXLagb}!9>j&I>gK0`O`xV> zxdvzD9Ix{kP9MeJ%l{ZSzv)U;ysdxP5nmJqkyBmQht~tbiNMj3*-8OllKS3XSPFng z8!kcpCD+KUw4{{av-M(cRW0EY>|Vy5`>6D^rI`B!tvRQ zD@jXcEF06JP9hEraskCxr3%Jug^kvoY_se{BTf~FcK_IZXY;wb3mWG6VV00j;h>9} zF(`k1Hv=DFVqKl+YSG7QBr3@tA}Z^ry=b!v*{sINVGhEO_mM1cry;NWq;|v=oKs0g ze9~17ZxNG(dJ%|Z@bCy4g1m?RLA*&WTEr_5^4kwGlp`>OUSGW)7LAztK1J=Wm=1p) zkE7WXT|Nt1ram%4S! zPs5=Eb3Dh2ryvN{XNcz?sF4v>4QLCVDQbb zz~(IOFQXMiFql`(Th6uWn|>_TZDQ`nVfkWV_U5v9?^)eYXl6EVgDDP-ne6akB<$`h zPAKuh{(LF#+fP=1BPW#l^MPjwe~Kk$Y&n1~at3hQ=wh10NUAt1a@p%8Ilw(;-oi2a z(31afeCyAv3g@}L-ml`@tbw4 zh-#4+lj`G%z&7iULh&Jbp*XX+IW`kK330i$DMRL}{>0`kSKLjs&{m{=n(k8rF&1@n z@0qR6VJ9$A_%Hec3NQR2IqMNcf-q=3wn@^*D%Yo)Wqh}~_SL~Xi~e1wE?x%fB zg&;3A4-f#;6$bg1-<}kGb6K)1`H`^(h_(Z<2s6X_v>^peeZ3Z;?7y@`GKQ^IOqD5& zoiVl;lW51-Fda}rATp80@@ z7@6sCZVuFv84<=`c@ppLgY>TMSk@7io9a8PeQ+UUCMB1p;}I;gqv<@(`b_|uYeU(d z<6X)5*@am{Px|`|)cvnt`)O;qRwCG8t3V>f>#t;?_}`L+dRV2jk3w-P+3~Q)-HY+8>-Krk>vtc!-G%!|3a|Ko0bKy+RYz7HA ziC)Y?vrw2xLAdJng()qv*B>v{HknzT*ZWi94D>ZGSOm8(WbmEh6=j_z&%zIZK#zw| z0+0hkJRGksX2&W0dLjrSSeR;Eof5GVyB5^FR8v!bw6cRCyEsJziW@wH;qJ=wASAlV@Iv}>qsU1*J~PyqY1&C1~B z9FYfx+;>(%R$MH^2J3YQh_4WsuLi_Q3_mZgKUYNOcT;MhK>5n7$;96r9MrDR!4^y7) z>2Tce3>4OVL>W%Je7;viJYj0^s^^NyU+DW^qB;jSTk8T_P|@-vDSYn53s35BHK|*y z>|0XZX&Vx$yf-WgAhR^U9y9(An!D#on_uz7n}l6lzBXPUulN@$p>Abp)+56^Q$4*x z%k$HAwS14Xc8J#Das~;&Ee%V7(~Q7naFA9AyT)y#Y$Kj^!b_@ca=L0Y@oM+Is@jLI zn(e~s>rOkT_-7nl-De_aH^^YnQ@x|WxjxJrZtz_iI1}_NfWfisLZq)~G1I5nmo*8k zUvC*~I*X`%0N0Wn0VPEjA{Y5~DVujKOR!PQFd11eGh?#L6lP&)4GMp^#I;iQ0G}QH z6W4<;dA=1+cy?dmy^-U!M=zu%^nC04eq62EwAtPB!J+rh;Vw2}-$2~V;sc><8YW0_ zPIc49HKInpZJ*Q~18?D)*KZQPUFK4=FyqO_-S|qE)S!bm*?&4&wmPQ4jMY5r`SzZT zo8RyqJ)ac=_0ADPm4U>YpK5PopRJ1R{vH+gjE<&)kdS3EdCJp{;Npl6`O%ELZDU%&xmbh$+|kqfdv;z61r5I-O~ znmBP)UHLtl28TEL|vS4DZG+wlH2i*uqp3>yNmhdh8DDglWP^uC-sfPyaO zS(XjN`_09U&l+R`@N8{E>ud84_FBar+JR)xG&9!C?(DrrO59^QYj_rj%~~6wtx%plAV6n;8`q`CzC zm@nP$(w$5z^h8M!!!(v_7)l7E|HrOkwnkVd^YSHdw(VBz1#W zDI}6~8&B7EWKhieC}m>XvNF9uM>LFlH^F#>Sq)4tk2DwKAsRFYXx0S;$1J2Q42k&NuTbCT$IhtyRU?FQ>m|F77a|XWc zln%(*3iQ+w?*Yy<1(%?`COJQVWG?p2m{9Ltt|MPMG$DS6UyOWz+2r6WYaOD7}5g~l3 z-pl4z^|v%iXDjc6EX#!aSKk*!??ubDCx_H(Om1(T)ZPEi67>Iet(_fZu|=KZYVgaz#Uxw{0fG zMoak*-AcT4zX@mc>Itc?-?~L-o=X`*2#=7ga$!)w{S1}W05EPDXtHi*KlKrAs!Dm- z`C}RO`r0Eu(D$3qL*xxd_gdh31RqRN%#nUs9QM`g58vuz_cY9BrNT}1l(I3u|omo-DWmq`aMVbL5-QRfwfEA1$##nm+p_5l(v@=)LR}>fR`Z% zOKs0)UdH8bl0U$(wtK-ogAHRE#-%q59xJ)u><>ts4j#{jy8B?ox&|d3MW)vCHxkyYOJq&^+3l*^u;53tz9|Y zvwmlW=%y=dI-zGY9138!m+k;>@IJmGO9Qu(pT~c__xH+AIoZ;OG5Z?wH?v{vw6&dD zcc@a|A?aq|dgRYa3I3e}OER}xwV1v%r1Tk94}T_VR1333X1Q|Dg6U}ZQnaTzrKz^+ zCOep)&+Lcc0C{Li&QF`FZ(q;07!!N)H0v@`%(;Z{7kBE)wt8m;uC$(&_qlNId*QNNmGmC&v?9MWq14&vf!kUJy;d{s7XsrhIyVNX$B zC)c5($yQa~=<(N1u;ssP?_V*vtrVV1v%HPGD{wcI1{dWHYxWEk+j}ThVW^HgLnoPr0ahx%E zI7)WD+r?G$Cw(t=^IRSp#rzocg0*rIYo*mreZbWbW-!pL3ppf}%ub%{I7JBCFaNpj zjFKjJ)kxdQx$Yl_mTz#N=EG>Am&~BxVk^1rl8Sn%1QYZ3GNm%k?k;HtXOcU^{T0}n zK}@OE8_N)m1xL!TlUg<=FUxoZix9=uv6(-=#cUdU zlQ4&ZWwmLKh*pk}w1(-oAEaMgzWM#!p#xQtu>23{n^)f1471}`H-kJqgK}W2omG!> zIgSpZ%bNO=3oIQ>9@9vMJZ2C47NSlCy*ES$#lVGxp04slmqlUaA+No9|#63M#Ct|4AXA2n5uQgxTX3jBd+`O z8h%FXe{5nKm@z)Q@8u}0nS#~KJUds3R`N!%>6gh}lJOt$scm6u8vam4pyES9b0~2E zEVc_&U+U0f4J50Bh0}%OX|s3))Dvo>CE3s z9d$(w*Moiwp)}Y!T;T3sE5{MCaCU;JcV@+mXJNpldp>WkyMLvozxzC1Bs{CO)NICL z$)~xm@d_vWsFa*gB2xRZR}_SxccY%9Tq4;GYjOOr<*m2db9A=L+&uEmbgM6uU0J#i zMC;h*FYuO+RU0K=PX5qf{s~sDk6Y+Kfo5Uq3+<_KURr;kdPugL?KFu#mkiTvvn`(e z3b73e=n8>fDS)Nw%w?A9H}pJ0?iZ3dcjS+oA~6Bk zp0V+2^iTa2;;;SSgNG1F;lvW8^eIE;45*E{v==!vP~PL+|-(1UkOshc3tEIOr~sLn=+Sv>)byNzi| z@?x*dRe#-CG=cln#Um3|l8iYIqU$9WuM2X*^xEdH;`Sy1COHqEC-qxlZr4)u8$_aR z*I2vXE$8lz*wDIbHx_+Ud_z26LF3U}T3m#BAW zz(}Q;H-9!fOC?>AUwcUgsXh;xKzxf5(+o6Q9iSj^0z%#%4#8Q+p%k-86O7CHy)3-4|aOZNl0yk@&zK_2%5bEULi%Ozg1nh}i7-|W~XG2(~p{BkMHxw`HGoIR5`6vnX`EJQpAvX{o9}+z0 zB#u{El17}K__II@^U7M5&>y}@hhltgMS&_OM{(Y{@H@SuMOdP zRX)EB-MBB_3g*r+P~9E?JqNd;5_qy>;4aT_GhOqX=`Hh@swz3Z&S>@rKAxT7n0C69 zb0#-@Tw4>YMstR8k?~vy)uh^AZ4}MYNsZ8Irl8~smHy97Q4;~GijKlWyE-f{ zy;u$M;J(G`pYv2Zji@b=r3DtxtP7(JRJf+4_B1WQF>l4yhNRn(>?uZEWxB>32AIfZ zaZsXXBOwrvo$Q#G$;2$gk~D)Bx7qc<)6U--B{zKO*N1v*z={5$L#zyK&Oei;0)aF} z%U~3gX|)26itfm&E%HzSVo<0{-K)q~MREgtH}WMza9t>k$vlGItwzEYQFjf9LgA3l zvp)z|TP8R;CJw62E5vWr)Y^$?S$yfyKWkPs|JYUr^qhZt>q6cS%FtyveX6lo-j9TW z;16DQ52}c|`Kdp5-xaMQZ;v#kp?V40s#Mi#sy|~( zfr4mQD&;!`VkQXl-%rIkC2H8X=FQl*mDpDn)LNyq>Is|rdQ-7F^hQ0%@g8T~g_m}1 zIW|^S<|0f>jG3N^dq|k4_}!!{EL@{L#3N_e77gPnHUydBGUV49SA&;Ib#h{4mPa0y zd9~+-+~a#ZQ!=vpQlXrW?Ik_rK2S<{(x)l1&->-eSKO=sIKk&hzPjeO)-OMdFSw@A z5>;9F%la8~BGoiE_D)IhO2~@VN6xFRW^Q*EP@{?cr_GkE3#fPd_C-en=3FTZ8z0sv zoY|Xo1N(2Ws?sjs7i8BRh(9BCXFOl}LsNiO&FW4$I!aHI93BylFUGh`G-40u=NT2# z=m+L9{5HpvI*3f+3N*5QT0z5krl)?Cx5sxCN-|KaG=WnJ-Wpv@;w5bC zJK)nJwRmTY{fn-rjVk4R(Fz!7Y;5-%IDhZiLy6oiZn=aI)*nC_zC3@6b8+w#iFE=& z>jtG4$BTFH`w(@b}lNC-Ar$-Ys|_~q;m7>pVHS^JJES(%F@!l4vLLE{xl## zH7mi?{Dfvtbt?lz;sbVt6%-atxV@B?w0JenO#>>E_*UzcY~-3fR;MYd=+G07HVGmURNtyWflqxqaaxDkZl1eAw7&^e_&ib9ByZmK)Z$1KV~4Qi^s#jZYf zD2xxYb&_v^57X!j^`2(%qx!zc57fk-05!2xY!vyqJM>5%iw;=(K%X1NZ6x1WW>|1K zP4!-!mZitrjS#w>LZCORIs3ogHo0*y9SfqkHi+T%+OuhoK|bJEG}|9R##va^jLrnf z>|(yBXR`EDR%nC8O?u`~-a6Ln$?>byE4&MA@BknyLbznO#e&fUH&eK4N8p;d25BSp z*{ajiJ(q)u>toK-n59#(X}kup4nRlv^=K;_i@rHrds%$tK8tSHK#w^8pue_Y=mO3{ z5onC|N{Uzukt|F71fl~`6+2FDm~90EEN=C9LrMn`px8X`xdyaf;B6=j6$oL}=3Bgh z*$ebTt;=z?&lJb&Zw(o%+s}MS2z8gM|O?{xWcTQl$Gs3*tE0ee6vz&_}z$>)UMca1aHY8eaeC#=`5 zAzo1tLIKhLB)H-gS{LY=E(2{fX*Tv|kxFFlQM1^e_K->wCbsffgM{#}R;@9hw&jqU zZAqy>Au(w>ZE63G8K38LQuwxpF{bgZ`&TDNyZ4d3D{>W?E>G_|x3~OkzRz%JXO+h0 z&ujdt&TUJaV2t9NmWbx~eC9JIbJYAZ(TH)&pcHLXzaT+uwRs^no8Fo5p0mS@(GeD` z^f8CgszLvhFOahs0?E8@y6*~^2P|yTW_b>5cfXEY_4^p)A!DYWdL7r98n8c+9O|Yw zGZW`2Wwdc`c~f_n`fuR}fpx}AQ4>y{x8h~4roNF79{W5JTe3_`F2&Hu3k z1w0Dt8mR%sIqcuYQF6TepJdOl-T~-^Q+N%y0u79=nLiSLdI)8%zI&^mei%4}35dZg zB_Ng3F~GK;iL;5s;te7^`Lp*+ksfVZjw&ikO|Ava`)G zbn)S@yCQwf?Hud#6cFEd5r_ip9}Eg{8iv)BYdr?hp zcYbjwNlJdnUq$nMh2|8v5!s@Te&}-^R*yl2P(M4at(KgZqs zCO*k&)7tA;A;W#!e!za6BwRDnu4kHE$5((8Eyk=7n3?GsQ z<&4IcucsOp_OR;nNPlaSD$%*iC~)7vw#PhD^5#uHq_0S2Rjj+La15*6*T&Y6fyB(D z+lfe2nYK#0QsX1!RHv{lnyzv(?n?6!&cjoUjnzPfvyrlUz-h2WO}h#k*JQ@ zEAJ;;x{G)8G#e80_{zqZ=%~Uk2!%sOG;L}sD(mU1Z{@Zsp9S>y1V8fqDx* z5K8z-bRo$T(n|=|tF6k<^Zs&f%Vo-XhB%b&C zo0)sh%$eJtId{%H{PDsAd7qS>z1LoQtxqZBAg3U%?nCRd2=AgX4MYSU2H@d;9*@O= z{HPt?`dmx z2NsCe+)w8r6Jao`o@c2I>Mtd5;n;V{&!}Cdxcoc5$YY`t>g1D+!taeg!&JqQEJVPU z!-Qu?Yq6P$Y&INnmz1bwPhSfszMQ_fsOP~w%-EsbNBdqiK~oC@zfZ1APP%`S{51J1Xf9s z-n~OoCzn4d;`gaE0DQPhV8&7Hl3^tAB9PAg%cO7h#-|^fi}g!&{q1ESPHC=?U9tTe zZJ{q~W9A-=rVV};6%Pr%g8y=+1+;28f|-W$@#)KlTWIbQ?E=N99B0-9`hYL$_ZI=Q z?dRdWcd;i5bP7+2*Nki(o!m9HLh0MA-s%;?VR{RmySE{UFfh~}8BY=gmcesZ2wjVY zqe?^Up6n&rU6%Vz+*ji`zCOt8=1CUcoRhv*e`751`$mi7ku|ZoO(+_v~}x9al^u_UgcO9nb9IvQby6s9hc@Oy4&_i8USLIYa<-dedaz|@Kg~9hfsDQ^ppgBty;Xk`j01oXP zz>bSvsbtW@C=D~Z)B~9OY3`Qq6_IgT&xRTe(}hHn9*i*TUA$?H#vQ8xz}TIqXD&Kt z`lk37Zpe1?-(QanbQr}r-0ob}EejALi|{JT%#`Y6xH-*0wQqEVEYz5jDh$|w#hjq^ z{<|t?Mg-SU zPT;F?#O<)2fhjYG`i7*xL~f;~UZ_DC^Lqu?xrE#x$dF;DUV`~&)c{;iw@XmcI0-I7 zJb-qT1M40IZIJI8p0&mBp^D)(n5&UbSM(`r#TgphR8m{gM3y?W3qlWs)N51T{RPY6 zjzS6|ot9!yHA}m!?9snGX3)z3U^tq@Z)FwLY@y^dcFR^2lbcBhs&7y2~Ok(jz?&`8V*Y0-C7=|b*a)TgU7c+XV}e4 zHAtz~v?yDubOzEa)i>TdaGn@8+|LHu;dUNLFVH>c8H;XY^=7C~><;27X84kej@yf? zk5~uw1>b3-d<#=sb{j=NgqdH&veKz3W52-ecTDXYavR?Z{BQ1gPMnQU9qOJ55Q4D! zJ3~PCODnu35S#Y13(e?J$*d&oYIkPmb?upjT#}Deca{)om-AIU%^L|7ebOalpl3<& zw6)wgKoTXUms$O2dL!wyT&Us8ORO2ZINoPBd%~+Z4jWOct&`+*J~903RVdhU$<%6Fv1tBg!9ikoUMX*3<|U5HZc(?n7t9`T zc>as29$A2>0s8*?F-Sq{>) zb~Cy&ayK*Gmy?ixHC+8oU-w1x1>{A|ZOHKB;29_xLZ`qS9|&Y8JC&HhVaO6?m;$$U z_2;@XHGbYy4R?{3EN?{mv^Z^(g33rVkBjG=q1C94K_!n%bMVW&X2R)rYfTMkFBRA3 znY~qWr_f@SW;UenN(;D(cC!6c9k;%3(~u}3%n@(_v_oa{y&p*TPACCn(F>ep?>=56 zt0M+MPc^5kg>e(pn{cbplwY{(*3wvTN1JS*%9n)G`(kR+i=+jT?%_}*T?@H0_7FHv z?W`95RF%b^2F1+l7R{*3j>fM)_LC6LwtGb{bb))B)OjU8X%U0c~uk(kqunxZ!z5E;06 z-oB=pIwQQ35n6b!LOA!`c(1Pv^`Uxh(g3Z*i(MhR;LBwA<1;#h0MUKtizjIF-lb5 zSES$+s00CtV(*a)E;&A!g76T#*{vYUOVy5M92=k&#a{XBP4#MQ;ZC|BB?U)@WdXS1 z^CRmTU*K4mM9f>R%l5%(#nEaebjr&q^l<{-xr#(9pl*?!jK7L%^PoZ4;v0QeU`&gP zvz7#vocmMd#-po-nJe6=n+Bdj?)U9_q%MEx38>t%21H-}59J%qAN!v=HkDz)90)b+ z(oXjWI11y~Z7Y{(f3?K({T4egr{@;TEh%a8yFU7V98FDn7v^>AZ1!y?wd2E@o@$U9 zSqMMV$l+~>I9zN6=rl+p3~F)U@Yd-fgw)7#it;q$33QI#-~#&~Lo?y?HwH~b#(-yZ z!(lL-LGTEQQ`zYcNNUz{<_9*fbEnS$a6cA(iSmY7kKTq*BLYVdH!-NBCdsb-*i(+F z7FuKxuhh12=ttX6O*))oW=v}QjJGN14!%<)@ReZ-gTLiT<0fliw8~VhJBsGSin_k? z;9*x+Pi}H%vHMIj{S=9%&u=RyWSlXEw3X3?PZ$UYISD4-0TRNCAo7uiR2vnQ@xdH; z#6~~GArjAG#`=v=~zF0o`AWFWX zA8IW!T8o|~U+Il%(GbdP6`8KHnMzSMfXMZX+B?lF7?dj4(;qUWXi8fLe%ON11alIe z;rfaBG8`pih(}&cNy(-2<~rU~L+LV{8h*xfZMKlk^biwhp{k*^M{ZTXtGQ-~oQm|8 zuF#{e&BC|0Jv-@=Zt-w|PP{Q}S=i(rDyu0~PCSIfbapp#XV*8g*5VOd%UavqC)+BLE{BA+0Ogs&e^74l`el+Xt}m$PXSVE zzlc_I&NFmSyQ#K8Rx5X_w>U_Ye0Bf($w%K^kzN#2<+9wgswj?kqSb=8TQK3G;c!r* zpd`k+@k|{qVg!U?kFG`LkyYkx_J_Sf)DPtQ-J~8cLX2l>9TQU16sK!}#DvWmS=SKe zMuKe38fUrr2d+-kCXM=V?253^{!-)A6qbtsfAe~!?#c z_!UVE@dO9%fX1G(F2%0e;oUx%)taz-TF4B!1#?!}*-?gN%!DSOazMMX>x;rTtLR-Q zPO|+4duN{`8XFz)n{u*F?}4^m)XaMHMFp0xMmD`EYMfW96%1VW8t$gv(=Nz}q88+j z(_Ts9m=b*`Alhcc0SY9$4Ja`<;I8Ik=Z>h-*DTv{8 zYh$=|wypj4HJvxVkBT~7l1id#i0U1@FkpHvvU>@a-jh6F{?AKz;9r)oRy2C0h*+ci z=sk3pJ)%X$&8DP!(5GL4$lDzhr~!J4&(X^1;^Uj}5ewX^?U$t%3bC?Cxf=b|pm|lK zkPaKj=eo)pcog^ruY&Ehnn2vbNy4hef9Z-fchyC%B#ERxLOIhEW<9&`;t=KwObj~G z@C#)qg(_t?4%RihSLk1)Tg*K*Hc=1;QN6uABFK7Ci&E1p`g*`d>sRbzqsIyWfeaR^ zjk0i(wJ>!rjPs?H68gbwG^q<}*4GFd4m0kS=#&nccb+KtHtGq5yG= zsHY--RWui`Je3#wZ zbdfvftkJ^##Ikdu=DBp_@C2#>s8c7aDlBSY_gAAsRHa-;?k#du*SCy&e6&eiXLFNk z$Pg^O;P(k(cLA=<65{!Lm}Tn^sCK$5nd(O?xmjfkCx-j?l&qf4a+;X4`xD6LhUBkC z*Jp&}A-=IZ48gPZ#s=k5U<$qnATVkJfl=V^z{m%jl%Yr`@T{QvM2j=D-IQ>f#4;-0 zpKLXm$!Y)DGqh=J(ngEFO2PjL)uWN8#N_4Up5&Q+hGc9B#<(b!N6NxNJa>$fld*5_ zK!-W3>n{ffch_w$AO%Zbh3(Q~^ts^MhnR~+D}O2fwY)1~#y?eQ8|Ca?xfWTk;|z*~ z*ac8X*UNiP!jbP(nGr5{%JRVygcSbqdkHrTvSwDx(U+#&^TIdXLFSJy(#)#64s>LH zS=r!|@3&0V(ramr@2`!0^l)2?Qu_Q-CPgOBP&*%cZ&UCj4z#YyMR+h0{MS}?<|oWx z;HwImxszW9$D%P;(?CUe;z?(PN>P{Ef3{U>d=vy&9X^s#s`TRsR?MWwr>!5F$G6hV zcrw$ppoqo#ikq*Pr@c}DPM)>6K zKOnB3DAkhE!KW|G`r8klUIeGH0>CCma(+-L8nz8%LkJhUwlZOV6>nH1a|`yD)IPs; zCno-?od~5A$8$|B9Ozc%{da9b%!&`GhEW`?qIeWWE4D@Yqt#3-f!u4Xy?nlxd{E_0 zMDT{e!#A2boG1Es6f>y0Td%;UwTUIlx){l5H0|Bl;@BdWDwn2jWox?hwM=tSv^2GG z@A@z~Jic>bnk#5OM(Xx8F1tq>6gHxdY&CwAg2;bRg5Z@sXdc8h?6(}-0>s1dZ*_cC z_JJ0ZpIkRyR~+qK3%yq+`RcP!QVU<*I*KZocL^Tdn?f*~Y>~*L8L18Qug7l0eG+2M z|ICrMx1!K?>3aOZ@?jF%M>;561N5Aa8GyYHbHSs#7CA>F+gh->9p;OiPjB|vJ?{#5 z7K}Z4@WM}pT4Zer1gH-r?;NipV^uHXOJte=IL1QHrxHuqma*OdH{XGnSs_IAGWndR z;H`_TSP)%`;y9n^yK8p0*BRFPzJcHWk|8@lwTB!ioxq4isX3sgffSqnkq;KY|ApZO zGXNdu;Md~sY*u)sCf9`7pJ!`#C4ZAX{yzN{A8!)hNrFuL-omXr*VkCKzHp1AMU5DD zUnV;_#%8q{;KPruP48@~3vb5SO#Hqlz0N{9v?!=R&Uolfw+b!J5&UomZhC?b zqBTYZZ4XK*YLt{HKY9`W|3wvSWIq1}w4bcyz9x^V(dsq7TJUW%)Yx zck%cEc%a&6;ljwzsj{jkS${wyTp_B#s#K4s$WI#ZW5^B~)Zrd-zZf{ioUxOb2BAQl z2X=_-e?Y+20K)$yF&YBZfA2q_AbI>KaCEtb90tg&tvLkO9oWCz27q&F{;`9sR^rtj zlY)kczm`_D2>c)ZfK)e0ryjVL+LnrT)n=^mrVZ2*HMx5YA9% zTj-wF?ti%j*meH#ZbDQQfq$_=H|PI=mh?z6=WqFZl+P!|`SLrD80T^Md@DKMu+MXh z^R(#i^2K@bd7dwypGwXTQs>9j^Aq-Y9pk+8a$YSuugU#Wv*x^7bY3m`53CkR0VQ$5 z+`pT8uDI!UGS4W}Ue-Vi&1}E0o&KqvO+WL{kL$Fs+Zs}P)~G!B`plV6aH2RaIU-OhlX{6AU_Dgtq$`4`Tw z^D+20bAUmldkKeo3J4w`kIa+6zir;({0RtC}{oFS__rk(z$P=ng^du`(CrhvQX_caBhUCzQIFRb(dlQ4!bfhj>;fR z$pX#1#isbOsv9LC-yAAzr`MHRYC$d{`(qMNtt=%dj0ImtwoCs;vIDX_6K^L7CKxn> z3%jKFPBnG0W4(~Y1+6qx%hNh|50%Sr{Tat2qx8q&nEhRK9Bt3+QRS-cPikKl7x#!4 zUUOC!HnVdB)OK^1-6N8v=>`XuQ*T3xUYe|{e!@z&HJoOmSCDh#D?19;sF!us}6H^q0Vkjw`xBhwz zrU)(2DNLM~?BjD#53Q>gW6%*I-$zAM;6__CR@-p_*G;szTF}@8V!FpGAGDPI_2Js| zjb}4&4&O{SOm84}63e%bNjHdPf~~E2m}dy9QZ09(j5l}U#W&b)=*(Wc2x?ntVL1{f z=I%kA(II~3$306NKc{DfBZt_%4di^+jrDjD9tnca-zikZ#$NhB5IL> z-P9phd{*ZRWZbW;QD5Vi=JyBeOyAvn&HBLaCeN-)TX%p-C$C(m7CstBet&FtY-(zu zwmV6ohgB~;NlJ0IOg)8#6P}Kj3-)=aKEZ*9%6%^pq+FKo_4lky1e>Xkm1A3$ZM|Tg$XF|cxg4>_g zZ(23N7BnVp9*^mOc8QbQy>nMXRMY_TM~V7YZrATo?{S=wt^11{y$Ixk&zt!bad5uWgm~P`U9#*pNVTxoqoCv z0Su%6)pk)C)p`T*1m7v&!-CDh+|&&$^SMu-xhve2=y8;FvggfJ+5_idR$E)6Im<*t z|A02mpu``u!KO#}A3xc|imL{72yNRVC=~k%}7uf^Ww~K@tQp^4)1! z>?YJF+~QNZ0c=aVe;v-*$QdVi-QrK}YiB(c4@Uk04ILp-r{p74KPWF7NMHNQK>=ik zasio(9ZX}gfH3W}G-pLn1H4JmOgR{j5E-JTo6`N-@$#{iqs@Lp@9PK6f`486Ir`~M zNIw105}aiP+AA3f;gbP}z_{z$apY>aas$=a71nwSi=2SB^i8y{#-BgYCKrAcw(Xyq z@en*RAa)pyVNc-JB{ra)$XDDvcCy(rY>O4n?1!_QpK*ywEww!5vdh!w$v&&;fH1W} z5NdHN9%T5O`#2uIi0eZo*6Mx&A`0NO8nQ!i~)>Bh|Ni9VRu#zq~oGOTy9lBaI}HEn$ryO7jq zkI7juA6qh=$XJ{j&{@^ZiAZ4cixX&{7p3_0@HyEK$o-M2;!<3TB<92Qw>Z8}jzy|P zHyY|Jnq!vbS>!J$L?krNuRS#pnQeFvdfNzP!ymSh$ZC=w;ih8%F)4=GulL?HB~BQ} z>?lS$o2dy3hQcTSrdNyk(r&?-$`FP=$m18U!IGhaD|dl-(YL-eP&y085=G!LU_~ML zwhE!Of}(>%A1L<5`v$5ow?7#74D#fyy%HnBy>4&0%=a6Gfj|H`63TZCfGs9l?uwWb zwVBZP$Cc*0wtaLq4m3f@GGx<_V9W$7VwJKA_ITB!^A9L0TeB>GaZ8anw9wk@FTBk@ zmi@gnGNn9Jz@|HtIjN90F(}ry(&4U^{P+VNe7 zR`hG}ii**9nP#Q}0-}S`(oaBCLV#;1GieIKPPnv^JQKa5i;%$;=az1p`@X!tsWb63 z4-OqGFzD^`>YwIZ{U}PlcTKDZ?WkK8KRk1Y$=KZ)Qa*&+YNqK3*?ilZdc@{n+b?vW z^V7l{*@Fx}t#3t$6~Q;vTc*NKkT%f<^=X@5zsUvYUJ066*bo!NG$OEDh1zrH9=T+H6s zHVorvzoHhv6ZuewJbhv%L{_vMkj-}!b%xCqA3=7OM0TMrhwOa^a{K?}ShTV5#`euh z<1il3A_3{IA5{RWYAQ|2W!8LJCw$iEdQ9E?uhO+XPn9<)O{k2 z*PZF49&?#Reg<1bFBi65{}yJNfGfv) z6kaz|Z}UZJ1P1`=qHM0M6)|^z-?w?|H*M$&?HaS?y|K@aE^}CY(`oBp`Eq#F_T|es z(m_9_Vb-@Rv*E^;L9JKUIxn}^LHRqz;HVJQ)VHl7SVW!xV19_A(R;euN1!TOfCt|5 zz-w3La*8h`SR^D%i+(Mw>!I zM@{h1L%~7nWBs~!hIioK=0k7u5?MF1me_+Y12jN=DwkmCQ8(Ew9?d31Q(}`LM#>8$w4^Z>3WcV<$o^i1yBJ#f<6 ztdy^57Ba4HK4<9MDWnon0H)oKN(Jn6k7Pu!^kj}2Haf9Z&*QP>##Ha4`bHFulztXZ zLKnj(uoOqgG%f?n@&DaLAssYs+T=k+ZuGO9|7BLXR}93N;>I8Ds>_2kCzRc+*&=qbInm9sgxF zDhnX`)r}JN6K=X>>3>LO+{!Y~{kpeD7i{S8mo1%l-xHpC)7;O3vuJxLv)%{N(xT#& zujHk{Wp4e;a6REspR)ICV~?MDh`+gkPRploW@b+*R)xLtx~Tti^}A`mU)H`Gw`C7a zc@bHv+BUf@zflv2@P(N42ejr*tU&L0PNILgt@{I3InWcCL|Re@R~G4|Y;ZWhM$JLi zWD!tVyMli}mbFMCbKR+q#$_O3I=}v_kAcZ+yl&@g&+FACuAe4TB&*EwqpMST+lY3$ z`;`Y_X3rm}K9QdNA{7TZgvP7#O)2YPZWcTF0TwWgXGXO-@}-E>KF%b|_sfCP#eiW5 zD`u9xomWsFHwg&G`NlU)F6tnDnr@yXJP5eHMm#8qSf&K|xAU$3eJ7wX1@Q7(6+YoI zGyG0uHb-fKVKeBfmOwIN zx^R=j(SWsTJ!%$*j#^@aIzeV0>wd;R#r3vX3*(cEzt4X=^@y#N4Kmlw39$TZ5t({? zh2o7z3ge9e?T0;j+`9ORGXv7NH3h!&1DwSG@4t+U+EPyGuMnB36TcG>b%$EvNyP4N zy#kUAAW8gOl?^Ysj7&f``~md@>c6h8i7!mUn8?B=?~eZe?`l0nEdLWkUy+uVhZ6V^~B z_`jBy*tcN7rrZl*bIIoYoflnZwpC`2l&$?VYgy}gheKUjREH)-jpo4(Z>?Vai+HpN zu}+l|uzVfauk^7JpucOqq~>-#v%Y9GMKs&qAazg3){}GhHlL#}CP?Y1g!s7yhJnS# z%%Z|nnZM6f>UxpEeL!?n&Rd?d6hd9JX#6^u1X4v|GuAZ}of zO4usOVfeVY9TWCs(6uTx28&x-z5_LMts8Sy7eV_Zd3B)9)o`?kSHamxE;y5J)i+{` z-Lr48ATKnl z>M$<(Xce!sT&86)#El`Pdalf6s=t)_ww+m`B2_KKppuhRcPb3ahGO;gqmV21jB#ZpXI^4U=0Q9?+oXVrM#_bLkv`glXs3Rvir}n=RO==kC=lH7=mW7- z^(KCGmGUh0W&~V_H8s7LHq-rMIU#aJv+G4 z*1JUK-JG>Z=1wvXeI!3`xnBsm0FR^4nS7T$KJiS2^d0c!`2+IAo*CjJF;I4AI1H}782hFb0b7xVW*B!2-}rT>>Q>i+LzpF_IqUp>&~Uz%qzAh z0}Vn#E?p_y9s@Vm(e)prR}-mLO4^f?NE}#zJ`pc7Ob!Eryz}f)SH z4dIRG6*WOWm*idh0)iU8a~V(!=5C>1u244$?@(-c0qj4KlFJ9amzL)f*mOPKWIf_E zh;LcBANhnDkfcC6Co_=-8lY6{rg++AUXxB${t@?QA8RXK%xQccU*^aqlUl5wLXHR)=^Sm5K-@nMozh<-g^#U$y%}1~Cow!cmC+|nNAZEH< zl6lLIYA1c$dpk3kx06~z(33K;De(}UU267{_fwhuMKwqXjM{itl?5*@xqG#$Nss-3 zncJSNlVg-d{q50y?vTE1FHX}7)b07;Sg3n^qJHlp6fb`!ryeAp6uY`$Y*vT55~m|x zlo}J$z(A8SCdZ)pTeFvSICutG?cWBY4j#%A#+igEi&I6*BIzye939IEzfHRzQj&U{I;Bu6af7}q&zJ!B=(Zd3A zknZKSd>SN=GHfLt`Tgt&j#tMhZOJsF;^gwUN8Kgf3MtO&I9})rNr2?OyZv2IlvAJ1 zb<2Y-xE)nfUx;nmY$NetE9x^!)m#Q`%X+9M1e-O8^JA8F$|wcOjHEFs5RdIJNq{um zs)&Dm#*YyCgng%11lZ~Z7XRSNjjUKX0Ur%2DU4Y--^~wj2Rrr%-L42tyM~WHYQ&GP zz%E+iAswg-1d9^Dh_NDjf;&f2Ql~c0=({`RTnd+rQ1!mi;}F`_Y|w!?^dSzss{@=7 z#wR~2Ylyk13S;B$Z+Ugqj`hE-dTp7WHumL%=+K)Q<`>!9tqSnHdoJRmMV12|dBhXEfoTOYDDT~Ff0+fp9IV*=cl(?s0`CM1h< z_uO*Zm@nisNTuBrrOF1idn7dV0R32srRf&&4{UrABO+NC!U%bmbZjL0C6#6T)wbW) zkhFk~WTz}X@Jt3ljW=^tn$a!#d1@3tZ;@LqJZN$6>TBvBJ;sg{{uN+u(lC@!66(&) zdv)+$a~8%w#y1|!4YL$)Nwj4sDxUk6$dgXC9((mAZ{Gv$)CV`MF=xSW`EHUFcDPep z`AWf(Yua~sN5=-{x?54B(-ZA=`s-oBHZhktZ9}A|LKG@zvyRm9+4xI(W%LPUbGNzm{0JO;?f4VJtUh+(l4^9 z;F;Me^oo*47_6o^6)BhX6nBRtDr>f^WG5%rv0#1un0}K@u<dM1c|^vpLHp$%ugEB_J(q zWxrdE-^cjQbLA$!ToYKh`aR)j(+}ijbBBzdk@H20)+>TKu>_E$HN&zM00K6>?UHOi zNHRF#vaMt3?J1KlD6lUa}$1Y{Fq;L2JHp&z&ikq+z4XiGv9q~_!#V> zen82fTiMS~Qtv+XmdwjYYg_e!hIwGOsPNwzv@4U9U7=~^dHb_-QhWHeE?z$w1B zk;px0agw>bVA&IiYfqgiq)*s==xOO*b9RT|fKSA_8sL?%$IIXtAH5f=nkd~y*E8eV zvGS*9x8n4#ynfiXXnptQQysU@;E@m;S0oR?BOlp*kYzlJR z4G-8JM1_Fy!lk>!9LcO(y)l(N`b5u$$^yW9^+dKZ7Is7p&>QcVAAkYT-mn(?Wk`ts zDLBbTcY(@)wA7l_*V>RatC{#V|Dh~ zfkpj1Z1ta%4+qvh{yo`~t0+dOdXQ|$^~?3pBMCe}hSYubGmls=kcRx8yj;$7Z$vpe zIx53XJ?f56SqP{8`)#ptRS`H#&uKJ)<<4-MA0UC{j=0~%7J6xHwf6OXMAgDV^9!zCAjO8@54%v+^6g9XaR1I1nIFowvXi5&_r6VibKY|sqgS-cD2iol zqVQskC#y{Xj!-_ZTCE*_EO`gIMm|xZHkqzVc8K=g=Ly>>VTW{s%)J@qsZ7@>r8fhH z`;QL@y)+8F38axl!AXP}5aGJ9QC}*)P0LR+ISpR_=&zhemuI}8G849z-ml{N%i<_= z%-1UL2!W4(`@_%MMKSUy{Jpw|n;W;D%vcK@t6(CTgM%`NOz79&C&4%reXD>A6sw1% z29hppeVO-)h5dNTf`l07;;Bj&!%)-0!+~bC{sFAw2$j-uGA95F!9IkNN0}8x;KmIf1&#Nl;~<#2 zjHYgTH-`52&NxkqOg$R1#tdCbwZXrfPG8O~5H#?iXIk(wj7WH^1cGT~UnbYM10x{| zsKjp9-&s7`TcAGNEtG#4O;=wnPE-AiN zm8mQ8X#(?Elj3go0`rQdvgH2jPLRn632oZ#?P2dh3AC1+KPSnU3D4DheGYP*y)^gv zFyW2Kn#nHvDikQLFM^{n{*l2TeENM{QFnK}Y>@W@3sD)Agi!#^qP(ybnE}Jss6JzU z^3#4yh2pUmb|d56wu$EH!N3^x)z2ypX$l=J(ZrQYfYh%0zd>C5KSAM8JDe4N@>x3J zsDLgbPsxoZmVBlpfDKZRzLEe2ssk>dUvd8|1$5rC{;$`w{{QlwDxlbrz1`8ox<8;9 z7X|2KT>8HO`h-9M3#9+CRSUs)f_qqSj+h&|@nsSBd5?Yyx5RrKIi5&(a*JsyD5vNw zKjAvd1Z)ES{NL!&LUB3Pnrl&mQ3D_5FJ~CumFHKa4xs|{z zwtqmv;{+$tH=rc(XTR;zj6z6|^ULXBSBdhnz^DR&VptvkoI20b7D!9DhLlK=*|US_6Eer*z1*CS)SW zN9O!h&qw%tHl44W^C)p15zqIW^ZoTampIRx&U3``Ed2b$aelh`FMX1{=n;SU^^SfC z;Oqp@_vkdQW%Q%79%m1nuI@tZ=IZLN1^uSLialbXp#&iTmF$02=y)4PI3>PCI-z2M znefs6%flC~inBb5-l^K<-G8ieJ^gO=AJCxD^UG8q5Lusy%6sx(Y;gntMjrJJzxxNI zE>F?{O1QXeDCUEz5^@AEvPv^X5*$LHPedzw0T)^q{iLN|(#W+86kwcK8A!}U|NTl5 z_-Nn`F28C}MM(~GS>ZkgrSNcy%)cfQ+0hvjcKSQ4~ zR+RQlBiC4=1OUW%=Q7#S=sD=_9C-LY&H=K{roHLkP!AEHQeC!8W;jfUuG2pQxQ6*g zvmij_`(LfezgNy|I!U)77QH(NxV4d{dvM}%Pg!%1M$ zD4J>)A&H6CM#!(A=sP1HeT@|i+VndcSC{tE85D%TzAVG&q4t%$@XB8QU>=;_yC%7u zw*zdXwd61J!i_)Hrzvjq(Rc-3Ap3@(!h=5|Y^MnQ@-uM1{@yvsoBb+?_ciS)$?aMq zRjLaZ6azsH+pz0d@Z%4NZr^%ifDGQigvt|@pnQWVOtR_fqr%xCN)qzo%W$i!J+aGH zVAAE1cD+pFDQ&lXN%3K(q~yyFZb`AAr0#Ft0`2ufR7r!#1FIHkz~UNxs-*%7ZUZoz zWLVUW5l*XjiP;w#KLTGh%QDsno`ulfQ97z-*gB4!* z5}imedxsC+*deDgkNm|}t$m4~)yztVlfurwG8DOtI`(uvMnWn#hqK&Pe)eQ%RoUtKTvGThtm}HjVQ53IF3;=n6vTq+db1jj}uw{OFdiQxxS7?;5;p&=(k>!&$ z^4}MFd#o8JWj*tqv)2Wyo;9$uN!8wD>t`yfaWi;8)}!Z*s*!I)MI+gfPL;C&3MT!D?@vAQ?&Gj& zz0ZDfsy@g57A0wncl!Aazc+`HnSB;Y00o!h^?UhfaeH~Q4WYl!z>G73&zygqz5iHf z(&JZi_ba*Wrw!CfB^O?)t?BucvfIb+_~5SkQPl!()^ZJW6QA)z{RVTw5+1p{&p==V z8uzgaaCS^x{%DKR%J3y({ic3I>a&`6iBWfN(S|eeGvs`yBabV>mb7zG?b@C_Bv{~v zd$=D}{y;JKvS5f{oyNf#jdh*tUiU^FUD*!&Q`*ApgpxC!lbu>&keqmEFIT{13;RdY zpGW+=(%&t+eHouQQm#0FR_k3Z`(aw?)ON=D&oMf~Q46+@MAqL)S5Weo9lP8m8AHiGNCnO#T zx~RdA7-|NuMt4-w66}G`7l#jQSCwVaE9l3F?nzV>WMJ5JDW(~!`fwQ@*IUQh7#}DI zz6TLCu}cXE;ZPSqZgRNxl=Oq8jAgs136>hK>S~XRK{iq8IeIdRSh49TkVdoMi z=DqfH_1tXR_LS|BS6!#tNR6eB`UY6a9!$Cp4(h2+S^PX1o1IuSoQBqn-#R9-+w%kIf#&lYa7X${!w_(jvN+BgW0K^tHrB$rD$Wi2 z60ekBH6gxNMD`*#RY{J@JHbTfpV%~V|Cel<;9%0^Qgunk-~OeOtwIQ)k!7q_Jc1H5 zDMM7hcX8I^mjRbXYD)OkwS}l>RnNQ=!q;K-kYtHg(coSrMKHe$O+D^|j!53joNYc}?8d<;zl;qGx4v}7cBLXUy{ly&CQ*eEW7UYUznGv6cc1n=P2~0yJavJ z+Q+)i3=bz=8bDoK!~60j2V8A9BIlU;mg-B)_c-$O+}vG+1|B~912Rd@a$giBKbnT& z*rN$zMS~Ui!pC)^zkblfN_bq=g)=tN6ue6*O-KXvue~u=De2M&x$p_`(O2OlUVB&8 zLY=q}CMjQ*(nA&<)5>A)Q&b8))GA^JEIoWs9HdJhFYMEl74ML?X*4cT{%gh|$N#l} z`S#80*Vh{z4&pBKX7Z;44ZPl+$unt!0Nw)FD&b4F$LL7nt>afR_e~o!bv|yU?Fk3O zUhg)%3H80rLH?nipKo=gG9FEXpnc!_4nE!ULM}Au23_XyY2=sk@`{m&jLE7%9-HPI z?w{fhzVT}5eNNgNTc=3Xqm30z<*<{gPNNnqpkl>`VjF`m5`=J&1ZCMXKe(JDuD+pc z{%fnir&D!`R*8Q$A=x>q26 zlzQ1thb@=8hN?j)B8|OwfcwgWWgV-pPUJseycKhJ!LCbBD|#_G=_LyR=H&%kN>m3q zZsW1kcdAft_V}N8-Tp|9@v8*<7-@erm2^qQ81wz5E2`aulOCdD^QYXf(>zpvg)PeD z@2>uO>m(MibL0L4y80WZPU?2^05Tni%Ao~U(jyiKF~Ix}XzN6j{5O>)z8jm|$u5^^ zErhF&EwY5Y)%D<=Dsm$qfde8wdZI+(So-}{XC;KhSY(=U!W57}8}V_6Gdz3CC?Gc;_QE#>wX+6P?E z6_m9~0hpl?_SB`OuG3lGr2a-t?pOXegVB#cjH;8oFUdw|Kxg$W-+r#@x6eTsQPhk0 zpLfD7!RtAkvSRj)%q$^Mf>*I!dF-5(7l>biMd0BWlby*GUfYfauzl%v+T2#ke9ygp zP}8$}KRPo!{<1O(rPQHgxLrO>8{CersY7>&!7jK4U#@GJ*t`D7{dbh*_SM~SNk#p% zzPl0C`AVRU6iqMCK3>%C!-pSB)md@A^n81+-0$~fMylHdH3Xu*MDR$ZXlhYqA9>Id zV0akDg9fgzfy+wt2@)}F9NXcrSlX7uS~dpy@b6^a-t_ZgxHW{3KD&od#cg(=@gnXY zumE21YBPJxZ*jI&{V=h<2Ex=^k|2%zD{>aD+;CzF;O|vKOdyfzoXP87{eYR6lqITn zZXCI}?`B(RFUnSCUp0HR?$>%FKC+FbJ<5yhl&kz>!AO)bCxsWa>_$^`PjdLIW|OjH zDSq=#%@i%bugrN|pR@{B@(a-8cl(=PiSRfRvEE){L_cv;>@>w$lGx$T>bNgXO-|_A zJGy#bi)L8A6Y!iqv|3J-f>8tUU2BKQ6cQA=F^;v=_*-F%%jMT}c48x78V|*0(YBH^fy;|o*z`8k<$WKuGJLbq7^XSh zpe0ksPR*MvL7j{3U`!^mrHCx}$EFO~*eq#gktyfsRSRBMq8T%RHRKXSCtHw4V#F5N z)YbxeVEk`xDrpUM2X!T-3G{PcGzNt+s8JPML<{^{3WZ*L2W&P`E@-e2*81flXMBg4 zU$%2y@=n=vyO@Bt%5Fbo$KjvQ9qwu_a4faHV6x z)@VwF_ustXYZm!Z_XhIwI%RV);22eMe5kAVaj3FZ2|RA`0!5nh^z(&s`zETJ(rzLn z%hMajyIJ{is51>b_ljc=m}Z+#1}|N-j;67IcAsMhQ=74!wA zfJe~(c)JW^OBOi4e|o!I8>cHuPKJ3Q{bJX`g6sY_mUGs>u$;C3!E$agv>mw_o8zn+ zeG#pwA1Q;HjXqk%)raEl2$tcY7?sFc#o%-?`5Ehjtw(vz4;i_y+w}=sQ!!WurY4-Z6!jjmvP4Z%whKLs5Wfa7;wPe|Y(V29mCN@lYHfX}+ zQ6W=q`KUC=^ZOx;f;4I_zcX4u@UO?AcbQhiQ0&(8GeU}z;Tq))qqE5`WwzWd?k?I2 zrrDF3rm4F9H}>8;sL8i&8w^sFCP;4qK|zX0M~DfiG?8YZLlmS75otjJL6F`=6cmIg zML>i|uM&C@={+crL`8an2!VvS_wU{B+nJr&nfG~iciwk*cK^ua9};r6>pZXHJW5Ii zh9N=45@ue}!l^v)#|5HOZQO;(#_ONzrgs-UK`qX6W+!tdJu|Ib_<3h1L(sp zss06B?KA*Cbf()-y}U}-#6B&h)Fmnh=f(Pj>n|;Qu$|?6-K~+cAQ8*1(4!-%U8~-ulPLCOPGkT*jA)_L{n44f8d?;_ zT{vK@F7SD%DK(Y!&bx*hyR33;#k^*+&HSB2E-KLV@-{N_*gd}xyR)|qN&~My zjxBg`P(oHEjK-mT(ReTk+zU_5K1Xy?J+aRH#ydB1aiy*?X8yZNAAjgG8|zKkfc;SE zXB>mmP+p|)@(Nu5PT1}^Qr0a)z}O84^4b!pGe3J*LOw~(u`IA!z~r2MMGTH=`*yEIqW8{RQTfUOF<^cQp&&4O(0 zYRT!tP_JupZc3@Zd&4yFM=di>AuWxeS^P#n-yb9z(uSS}&9s#ggcG6XqI!f2WT_fh zlomxey-gBYa%n#y1kQ^W|EcDDNBPwI%@a$~%A9Q6ttUR;nqB*YvLS*&47?N%#j-h) zoY61Fo2<$0yJcQh32_qaz*uTva&^ztT8qgsSanb6p0G3rq2_72$eccgCeT5|QP+DI z`eYK)ma*MQ$g;@n$#Qah^#M9^{=_quQ7OsUnBq<8lYR?0cN`%4>;?v`Fq{ zc~v26obsW5tv+|scbEf_UG`E&Hi~xvxuP-0N~UWXO_wnO9ES%HCO~(fRuV<%w<+Mh zcWRMg2+{PJ#YJV7#zvNN=XS7X?BZ^pxL_XXZvbXmm*S?NaAX^Q-ag8&uu;MQN@{Pk zJjKW2*~q=u2gW2sHkbufQaeD(C2a@ek}x%}{%r}m5zx=Z{IQavYg6;p*MyS_ubvN) z)Umm-Ei1ytDXaNR;)fkq^PQpiBW2*4sAYE(;hQNGQb&~bG+ppL@)RMb+dQkh`>2_B z9`Q$sJV4oum3zW1_K@@Z*piy(3H?oYXS-?+-JGMH==rKl1VO@c^>c*mF1$|`3(ZnY()`fFZv zD+&^M8f|}!S@j0Tcko^ZwUy|i0)Z{Vq6?H7DLsVK9y9lct1|pFp6U#SI*1Zot&~LX z=%p8JA1WM6e~HgpZOHgwB4D>E;7O1UeJ1#_Y*9A!wx_{7{#uNQ9M21i>ZS|1mUtO3 zoPpOhc7D>KMf5E3x_VdD<#)>ACIbYsZ#{X6Mhbqnz94`8wfA(RY~Ngpph ztWyWZ3>5(Oo3MBU%GAAnx*o#F!E}T*avQ+G(2OC5K1CJ4IK0f4 zp5T|F+YAZ$h)!1t#&Bl6xNr>T;8p&sDZWlCndCP$n=C>iBSEf_WHu=q1QoEO1P{S!SMiT1MMM*%{}7&NdO zJM*KiML2mwlNM27uUZrlmV7w>E&JRvBVMULo|jEo!`QxaNrHK!0Me{T9fHRY5G$kD zDRJ99O1h{dXs-wET5GP%#}QQ~8QDj&S*-;#~hsb2-+yVcYBH*joozJYX&vE7f`?|hFJJ6&le zeaZgK8RA?Q0Nv5KeRt~4YzT0kd;zS#PPHE%pKDn{ChfhZT43phQ>Syim#5ocCB2cQ zz7Lwe@wa(b_q~>I_qos1%*^bPupv*z_Ayv$s0k$AQ^=4uzG5oLfiX4#(ff{=tW0PXDlu%}f+4%6LLE#Q}lcZqupgL?Di7)>>#bJbZ5= zrl36X#o#vRnI*dxnFmYWCbD42=7|ij2_fHxenUkihbHG>8NxkODOVhs=arW!vS_^2 z`aAe1?lzEUEKl_>^%a6KLsW|4jzfq@1Yt3Iq9;Xo2?ZT8~0FK4fIsxwS?WUEBNRGR0|o_8M??q5$?bVEz->< zD~;ZGD|i2zz1h_c$8)bW2%VMI@VB*j&E6-5GxmPNI}FEzJJ5A(%`^^QDN*W^St%WA zW`x%Bee2awA0_NzaDYs*lEN+Bcxj!FHwWCCs<3sbBl<) zy-V$`Ujmj|)0#$*>#o(nWwoKJQ6(wnfZZe6;*&XnOYHVyEOEU%OwLyCiIs=`fn#Fx z&MVK`N-KSFw;BYtL}ecOnRLSIc>9u|{HsgH+JwoR@JZBN_mX$%9-kVuADwA7MHM}q zY_fVzcr7oUabm;LYRv^$Hsr~%9q5yl=;ge24Jx_}X=PVKQE<7Tusgp?OT>Dr57fL( zW;c|L=Z?N@D#te^MkDRHPiXlw?}byt3jG>xuC}D5CwK_<)_?Km8XL$24ax;DD=~Mn z98>zH@P_w`Rca7*X&Jp)=|}ki;jW|(b>X)Y*q(2`b=3M!c4h}FNq4Y+zpSGk&lURY zR`QW%{p)Xz-4ETpua{w~`y|7-2nlX_Vg<<{Kt%XYgtiYjc1hTclATFdA5Fxa2oaZI zX|c@oKLvP7t!I}&S$oj?}HUcL<7-dXe>A}qdkJB`;+>Npz0oSGe zsFRfGCqj?5FvV|jtZSPaZkevcGyOI#aburUrX|y#?70e1tAU7aQe!WI`x{cW>#)BY z9{=cEM4Bi1@#X;GSM zFrIouB2<|2?M1P6BVL(pYujo^Hd#Bai1`Ogf2x{w_gyYi4~dmMOg}ss*5K8pni_SM z>hJ;S)TJG`IPZASFfkm+^3YJz)BnoQ4VAD#1&$il14Br<-=B8iOajU)M9L&k56h|F zcJ+UdzZdX2-JTk*D;zBSeHv=rwD2RR19pCcGLaGE%1w#Nq6N?&QasDwE;;KuX-28{}<$r&VZc^S+@cj4SlLV<@`R)vwK1v)QA#}* z?qK}W`j(9CS4qE-N`^FYn~0_hC@t?Np&XWna4hqZsyD^_vy0-Uq^9q0H(CXql-_tv-^NlRgZZdJGXuO9b&)Dgid40p)It(}2lEOTDLFcsop>@_oc&&sIo-ih zeO`IE?<+hOT#M<6uu5sf)ah(;6O~j?BIiGhri65aJJ>8hklDj_)bb>K#Ry>UC>{(0Hv zCm+{XGBwW0Ued4{qgxShr}{%wVH1M+J+#iZ$>4pZht^9yP-`R(krV9LAX zp(?9A7f71%T*^>dC!_+s81-lWY3ogSIWEuWOFYTTf2@)*T$;M|7yn>|f9>)8swoy8 z4lHx(X?QjuQXG^M_@kB8P4J^7kDl1qX_OPeh0_qodn|RHxnH*_>}IyNp7N4VKQH}R z=k>-Zzhf7IV)4;>CzNu?rD=3wQY+guiVq9a-5O*3fdhOF>uZizA0L2OmPsX`(C2otZvg{{j)JpfIzySf zl(_rN?n=F#4|Y>kreH=ZDXclX$kXUVTZ0vIRxhd>7!h`o!_pa-kShd-T7NS`q_lg;Gt@Q^;)5yN{r`ch5u!NT6PnLk+ms9`iu{SqJit8ea$y-b!y zVY?e`zlUW;9L?m;yvlfJru8cu2h2F($W$tHsITQrUmNyEKJ?PH`K_K;eMW0K2?2CbQ40c*kD?I(*WGw|+q0ZLWUhqdz7kdo{YU^c3q+TfOM z9L4Smn8#f$tgVl9HzhOjpw(Eo;k!A^Cip3+p&x2tWL+Rew@uTLiG6fArZwbIja8{d*h>$n&Eh8wRmDokAIN^m(h z*_W$yz8N+jyrU)Vrg~n7WvB8lD4aICczN^s>!)hDm|N6xa`%;#n;1t*RK{JY6h%YS ztXN!|xsUSWWRr_wQcfRjWt@O}`t2K@o#d52a3QMQJf%dxii77ef!bgU7~b9l?6EO&1&^_12V{V)o- zb9-#_Hr|hBKHycK=H{~H>OG%%8LN`j@ttBI86rr*bfXv8I~Vn-e@K`oL2G(iRXGGw z>ATFbcO|FLeiq1NM_of1$rlTE_-gr)eSRCCK??81-j%-tAi6$mrR$Zu?%wfcRp;gG z3zLK^^qWt-36mo>MOu-ByDkYuaQb`o9ZEjIf3O+#%v!S@*ZOj!>eu%~&FB`roe#|{ z4;7fle|;FXa~$7QQIAA8<6{IxuVZFJ&eo&!lf319| zv4>OMIZ{*HD;|;KvdYD=oub5wIb?U2oXexeF?I?Wg-D?SsM?H@>c`vGoBB{ zHbFT$N3Vf?m0@+J7;5y>nP+m z>?JAqnW@9ec*`;H}|P%Va{E69q;njXvAhNLfk= z&fqZ*);Qs7x3}jnJ)$1tb)y>+=;E1Q?t|}rBK9sbMDWjxiQ_T_><(_3;zhFMiXD5> z#OU+mu0dt=yHt}RY662#(h0cFvRj+v=j5j=dn|241w zbafkJw=lYc^y`fhrW9niEXDpwxfEs@s@SM4%=fj+KIE9`tLK>ouS$PK$&;6oNG&?{>73Ac{`17C!A8uAh}vd|(_dz|?GX*8y47t(RR<=Isa!fkMhx^`Y~A zf<1^!jW5*;f_;07zFq`ec~+f2)19f->m_=~Nh6P0dD{5`>=#9Ya&96Y6~K#N17PdPtYO3yAxUhQ-z34A1t4_35T3)P3!$* z+HO{WyKUsF|1kOtA42`Z<~ReiL27iwTWlg^REeew5ITPUV}&c_^{N%(9^m)v=&cv! z1+IG;n6T~Tke6=KLOwF4rzb>JQk{~J22BS#1A2j3F>D%5poo0IxH4p3cdm9(>93X% ztd~dGT3&BS+ho7M>{)59g6sD?G2YS*V-I+>9ut<5(Of9_`{B1pt6n^`zTC-cxV{G0 z*jxY!`6glcMQuxgxvj8?1FA$RsK}K#%I)9#DtL{%{6duzLRtvJ!69cXNDr|pr^g*lA zs~b~`=}0N$raSktp3x$`AJD5j;~5-jjPu8pq1wg z%bm=FB?9ZZYJ^86fhKbIn&QQG$&2y@gc;qU8zUQ+-Jxxyx3isQx#GVitZEkZ|(#d$(n!cAOD+P z%*n7FscG0|0^_lrZziyP+5=`|_hGvyS^?Yl|M$h_fIg&!F$`}48xQoZjN_loDqxJG zh^F!~o+vQm{^u#iet{pO>UYEHc{d{|5f0R5&jQ&ohTxz_;V(y{wIt;?;WJkC*T=vw z<_q`@A9Tb6qYbsg5ICk?h3Z2YT88&837{TQc?lJ0sbAoeZyOkkV^W8gpj%2qMoN;y zZIe+O<*Z3JpP7H;Oe-{URfHJiE z9rE(b0iyBR@Fy(dYzm)dG=K8@OX>pe;44e)NUOc#qij?pMcA0?LmpWo!k<6_D@Ut* zX4Te(@b2>#Vf(~RX!4t;C>>>6I?>3@Eo8)5Dpb)wN*|>z}*4PbBZWo?G-6pRPJTtt5 zw|Rff@J6!arhPft;_PzyLu*mGD5{?8+7E+1(eB7hnhyO8F!Tmzz`%4L;;_<&o=#f5 z3v|U4Ga3W0Hv+Z%7Qc5xqd!AUH}lqOG7ry%b`6v?b!Zuv5xP{*zVW9NPtCDupUg3% zuqRY)*?w7CkUf{~6#o4PZ~FG6M}{~XbAX35SP6Nmhr$2x=p*1#RJs>T4J5*te@V%t zuUw;8l=P*X&wDqJ+FbX_{qnf1^6P4G13av4eih10=cNWxBKDe(inPEYE{KeX#qdJJ$+LSq)O@m2m!q?70UJd%hKOq&KZwxI z8IEC-a>`~E^?WV}xjV;op7kDM6wSssm9daTihl|LmvE{cPdl+<8oLn~{s{F`Am>Yi zOH)!KecoB(H{8@4eLPBfp0Nm(p^gTX+$a0*$7WrNr;SE$eG(Sh5oybme+{*t+foot zZdhuSeDuGK~ zg+o`i-iE>sI1t9}VX1b&?8~G>63yVp06y>TH~Acnve)_pb7hZ3?d>U?7JYWr)!F?b zM=2Oc+!Cw&bUe3VK zl+T}UhGvUXYUEKT7jVnngZk?`+ zP(hzVFTm2@$*PHh-uu?-vo4YHgCHkZQat7iZmWwJmZ1(x8-_WqyFX#bQqvBLu8GADA8ej(JPucHA%%>Pq&1w3`%3Z`M zs^zWrTAYINya zU1P!N<{5wNoBEx5JooBiz>8ZUtbsWP_T_zHAYQ|x)^;)Ke(yTD5#B)lv)uzQ#LaI- zU{86Ap@nsRUA=tKmKFACPROP@{D9Jx?eTjw?orrA1VsthQloB?p?ioYi%{~=00IKR zNB9K!+V7Z%02F9th$`BY30d;JV z+m|iXxJLA4oLSL9p0a^lEGEZ~%JX&^jF;WZrV_|Q{b)X@7@mG^oFN8y%6Na8$}JYl zevA4%0QgI1Y;*>mg51$2v&|?Uhn)i*Y5&i<_TTu<{9CxAorgTSd4&c?xz0qLcPLHV zg~dH|@NN;xd{k(zAQh6ec~g`9N2u46r_L8-aC6rID?ISLmpRplvbIkQ>q;3w=@2Bt zENJH|2u(7l@lHK%<1Uw(SrW@S@a5@nk*`MKIVn+MK;adp?Dv8t+L@`oOZ_eEJqrG~ z#ya-)46D}Uo$+h^K?u#IRz*X%Ti>5;93%O_O{lPUj0s?1QyNZzla#M?6iCE$Z|HeF zQ9)$Wfh@{5 zs*EhHyHx!ybiMpWEonQkMW@>;9&jP&r^}Gi0sud;<_!CWNvY3%IGn&b)a=AtyJ#jn z$NlQtfW@f?`UxCNSE4kMF1I2DR7gL+l1v`b3n25SS#TvOFdiD-K^nE67nS@je-DFR6>VKMeHG05_C0qV= zpe0q7%uYb}Veu<_sZ$L3RrtaTLt@6CwetCdZr$$y#(0a5><#9Y4^B5j4YJn+RhnK|{ z-k^k1qgD_iy?Wt1Q9|B(36^9=+MWF|v$C(9oF%SS-O-CQt2+_xoiBol^V!TBaLh5XbWDE7*batl0NesNhs;(P*7t+_Fq zBqBI1cls_ao|a=#}s!YvVbvkO z>Qf@TrDN=`T-j^YPkSaReS_KmLAnY16!TDJgCB1e>iwJ z(z8ZSNe!P_hETa}z$D-$4`>0FYS3vYb{Z$LaXCAsu_p>bxf-7N^L(o8xp?$((D4gB;%|!bjr9yYI?r z_0LnwrT?&`dozDWQ_u3!b1eaW$X!YEzaW#`9l$`YkpV&9$I`?gW`KC}+wF{)MaiR& zFpo>8Y^gHK+i}?iu(6vrnw!T%CM%uhu1J3BW=?r2_vHz8(L>zQob~*Le(ce!c2-m{ z*(8Pmou}WL^#2PYoThBOE&GNa@`G?A&?-|JY}wM-mzh7%FFg4^vIH=*gCux2lxT`{ zPQVxvaQ*Cq@j*0rleOR{OLidmSjey+i*Rk%Q~X276LHR;5Vf$NR(9r`qr512Y6ftq z8yb|Ho8#xDeA*@b4rlTT>z3MRdLJ8iEGIZ*&gK_Or}A4Oha-n+!gM`KP8VXns1w86 zF6{#xtRv%DBk||La57qTq`m1CESz~z+*9Y~PIcdda28&nqt`G`v;ZE(OYJ1*k)&SY zHK7p7+1eoX9LTBWnolt-@r3KDrz_YmcRI(jMtO98XC{H;4lx_$fn>J@Q|c>nM4!}o z3Ovc*;%7Bch)mD111_NIE4i7*%CqF;VOoR5+i0$iY`5S);AD(Qa0XTWJcbbI@kde? zK7}#@#}>7;$2Km12J8iLe@w{>Psn?j#o6~I?u5=smrw4*j1H?L3>fFJLJgw8?IsJUu`(8y z#fLbG<3Giwqn~%Kb}xjC*!V=^*yTOpG!MwA$UloB z{41_JpSssnTNfT?u6vO0p>Q>E{>F#TAd&v2s2=c+&k`4KO$2ap0SxY`CT27YgBq}8#XVl{ZoNma-|ysJjC2E zcidw*ca&(5nXMDTsw0;YUJQ?&ZdfLM;LL z_sUi!9zfl>P#gS1t%^?i`ALgZpSSi=K%3C$GTUhd5~?NZ^dNZhLo}a$8M9wIX{M-Q z@WCp(>SMdX@)qZ~0Jn>32e=NlQAA zvEEA}gHDJxy#HP`0dJx5krBiK9HrC|P)vI(DQLn@xy+@3jmG@FoXU19bx(dD)bnul za$5@C6n{COJV12r!%-gE3zqCE4CFKM@*^9(AC$ancz4-r}v~HB{qsGwq)o zp-)qO2h|%uWvOVw_Iw^;w(9`O! zE)K-rDRAlS0~&ED9wX7tN{<9o>Z;KUu|~?)W6B9}#O5Z`cG2V(v90b#_1A@AovqJL z?O$A!YXvpoZ!u4zm}$zZXpZ(vu5?SPO*n5nUAAf!=P`}aVUcq(N2WR|l^NX06ciD^ ztu{mwPsB<9bEzQoKs_aa@mZ+)FQ{-C>+m6U*Zc!#ACQN8E2?h z`o+E;{>%@Os`h_(>G>==>f8+Knc!tCh`H|MyxV)GX4)n6;Cj^L(v~;w@cJ|fDCAcE z)aRsj9v_$p%C4}gkr)wa>+TU4CfhQXaxS(qYV+HWe-ZiT&_Y;Tm%(69A_|*r8`0BWeXu}0r%didr6#+Qcei*!l4K?c> zchqcm%ePnclUpx8CFQPfiWb;Izdt+OThDn60N>^K_T;-^!i?!fDCJLfISIo8)S@P>-F&{X(~Hu6hJA(khRxnJZU6< zJFDgi`QOVqn>GgL+xq4OQYoY_mU6e{ACa zUWxVzoKx$Anv=f!qxp_awvXY>KFR2=Wm~I%y}jOz|2H~!jg{(uzY_-?31GFTgKv>I zfGl&IewVNWaz%Slvu;daRdx^dUw(ETJf(ck^x|c+j=;}Zt@WtY{}lD|v1g1e9S5F6 zEDy4XfZ3o>#GK&5XwOeygg?%^E?jzL#&^SI_T<$M?e2;-<;u6ySwT?`sa6I@nY-X8 zLKtQs@+84Ap$jKb-@0H%ykv4c^GV~%1E-3}%PI?#q2hjSd}+;!I(B2Y#`JiI28{;^ zH?TU1OD27s3id9XzULA4DaqFA{Eanx?eG0xj(?kl%#5aJ!4_isz{l%(qYKoiD(}{3 z5336Bw-(PtKAhd)mo@njSznpAUuYeu0IEP7#Bi}zU?kBV*rds*o&p{!)3UcoGTnr* zo`11a31a(eh%fV}!oHQ*#$}o{zet%xC`c42^TM&gh(uT&ZexBK$9LSxe7fpSS->+< zv9sBd4JE|$O?PTQHe^FZQx-3+6-?vk2Eaj^B*5bU)!O9p$P3|U$Q7NWO-)#%_nY~l zJwIAR+D6~$lGowxT$V)6V-Nh+t?i47cg~pu#ch}MmOroFama?;sX*3(f*Z=~Qq^h$=?$}F%1V{6g( zaS1^E#L~2olfY9S9So#@TA*#y2mSrq=HC3~vik^>)V5a>T3xc>SMh<5M_5rVH1}Wa z+T)s-d6RBD|AGUx)oNC=)ybm`>QU#Is)jE-p6b+SToWzv_M-FVr~uO>0}rm*?&3`V zGpV*K(q+a}+D;Ki6DAzttNVv;M+G+E4zz3$ABZ8l~ves~Qh>lTFWL__>c7m=w-ah}9evu;iVwfhL=tet> z@ts+Qp3M}yF6Z@xt6%b>0Ml>YO(Tj%ca+>w^#)WNff|Q;?vxyqD&CdOaV|yL!k1%X&0!XVH zoC!C%wx6WmmX@ex+L7tL`=S+`5TyZRBuG>D%|B7XqU*NfrH)sdKopF zNXuAZr28H#6%1_jSeM$fA}IW-Jk*Cy15~66LgjdlUtr9_uZ@WF5+%d#qqK9B;jp9qzY>Wku_ zYJ9j3(ILT4tm2J&Uk{%i$bQtG54x`3<>K&QRP>%o+Cok!2m(){-=;l68IiU6I~OwI z^w9D3k)Cs__=MzYb(#k6zK06=av zp{i0cS=>4ax4@n1K~KMWd7vTm@K{}yoO`3>wb^-BxuXJxEK2VvA9W6xw|G-DeDwAo zpH1G9eTZt(N~@G>)ZW)e-fqgi%!rOIjyC5lfNE3i_C6j}p|swiz?6he8gIf(CVb(W zD>8Y-pn~G@j~5ktqsX9bSioq4s|Lj+sa=Tr+3Lg+>I$$&a#~2D z8xl5kR0A)D7;>!#aA-VXieKzjdRn6HUHNr{vQJ00=u z#+p%DWjdkmZXW1`4PzPRJ1G*Eu*Rn}xvRNu)ySH9`WdVl^y4P~9|9i^_%#qkFTpmLTAx{8J@$2k6> za392g0fxLo)3lbP6gWxD$$8agSotqVv%tYTL`D9^`O))kf@VF}d(xw4D~2d&LdmC< z*xduVAwqIAgQVeY)37{C`CzW*-6?@eFa6nXt@qlRz1YTiG_}oxfJs;s#Zhunv6I)w zgX{R6zYd6I%P9-+pVSDEexURnJs+_*Ob+YMBu1**PTg108W#Hty73G@Gc7L_kmTvk zw^;r3!Vu#7FQ>yt?W|{S=}3qxxb)VE~ttWdlvD0wwAkEW8c0u(K__Gjv;h}2l?|F=+HHpJeqR2 zn+r>VN;>wF{*djbLgwE1X;(XSR~;)o0_pR96<+_zgRxeily0g3DPc23mA@D|lo#0< zk%Xn1qn6|QN$xs{_o5Y9&l(4_R*Sf?%L_1VVmyK9%mM&1L?u7EBH*P~SPqRo0bscC zXrXoq)IBnVV%Mz_wbUc#HgGrmjE6B}>+!0?yfAmcR3fD*Z~$xZ&f z-2YgM*Ta|p@JxZN&k98GE}afOc(wz7;skkx8+dU z;w`m`0op?y!5eJGkCH)u6qQDxqI4I4*pqby%|}DzDnRz3lWCl-P1lukB`J{#0#V@N)@)h&LFo8@%l5ToSPZx}Zb1 zWTQJB>aSx5WIw6=7jaep2Pf99BZ}kb6I2A*zn3@#6?F1z38DrNIYokYk!_?<-~q{FSJOBa>}rQ_LIO{Jp^1((Md zczDflUwPRtGTn3nzj7_z$LpYY=d)9}52QRaJm}3_+rxVxbhb_{TX#O zQGa^0;jWGAAkc5({~`O(EsCi+odp`T(|;--E;?11d$WAltSG7f$r`0%lT)QWXv?gArR{%A+3hwei&M;)hzl&wF4 zrO`b+ik6UJ8#v%AiT1-pQRx$p%`)?u?i-6d9_$SPTq_wWC|&9x0WjW*cRa!+YW6y0e@`QxKQXcNb^M#l~4zr1pTB0L0t=w| zH1B(Y^Es~T7zM)|`8QSQe`DnaMY0+pw?5GfP~v1LxpKK2%emW0L*9b}BNHEl5sO@6Rw$ zV2T>@{|ka-F%ASTp=u*&{!yQEBYaZAD6$qeD^3rxL1vglo^qAApO)B!t%nB2j>Sag z0iFSZQIb^7%M)}f;%3yW<%CviUWC3dBfQusv%wr1JV}O~m zFD%+|Dd9ZCrI`A8X5F}c`t-Q#CcnGz-jij`2NmIHC}6+z5VL7lLY*ZF$3{sytR5z$ zg_!PEMA{y5Yfpc-yQ}hJ1XUK_(XUds{F>*YCeyr)s}>237hcA8M><)LwS^JaRZdw? zY@Ytq6#u=Zrc?hQu!fGi=((&I5Kt;2SWd>0?7BwNT8u`T$+bP9i_&$qPmL}nJTl|_ z_&MqHiyK$JGT|Kzbah9#L66b|fA^=t{8Z;hJEYXmAS$juH!82H=kv3UpT_o*hsryJ znl=&=nX=`8@a<7!O1lga+~F#a*AKlgtD;$XOu)o%zT_BxP2_N@Mq9h_ShLEU7-+*A zI}A$%7HR-SZOl%|i=j?!sQN#B1N6K-s?EFYg#2 zw#QP7C>TOmr~la{Cu$K<$^Y2Q=9dTI3k7ER5A2Nue0QKrv7=uYgO2`LMZJ8 zMhGHx4iH1~P>*f9bafoP+3{OC^>u1ow)P1MP(GjE3&v_wGB(3$+DP`^8l-SY=UQS;8Y_^;?@$OgsbCMJTu*tZyRdTx=mAzAn7m+BXL}AU&BlMMb}x-8OKDhJ>n5O6Dz%1&29jEWzgb`nIUBWs?|tSMrto?}J+omi56w%d4&?t!G$aQd4NK(lCC z*o^s4Yp)%^PpjGA%v;{hknK5itt^{!^vl&^l$@!rd~XT)#{Z>H_LIBx-Rz|S7BJw$ zvG+%Cxt;f@3>ZrM1@+*qVo{U^P}|YGy2z;tIO7_4QBk^8WWsZtqwt<;3n1CFOtwVHY6*& z)3Q}}rHRux9YIo3k)n0uJZQnEw)y)ICkH{dPFzp@#p!mThsRQZc>y~7KNdTI-FSr& zccUATAih#4~?;pL@r4 z?Y+<1>z+OCIOB}<2V?k!iAm=CzPCQ_^XQ)M7L27NBaWk;(eqMUJSsB5y;UtY4C}u! zjJ!6EBuC(eua3x9z-;9S!VDr?r1ycLUW7FP|N znA+lDKrqBw)#FwBK91|PrG~6P#G4e$ zz$0W2L&;q|^MlWntUZC#&`0BBS9@#M0!O)x>R7d2_Kf=}u!|B}Pw7+8uO(EH$V%Uw zag_|M*i&THVX8DKx>LE{CsgSZKCx@Z`n-Rxo#l`7FH7F`J>smCSfCMO;RhXDlIjC6 zg7zd!VzgErk)a*NKD{w+q*uPsX8JTJqKBhn@#&3gVs83x1wECPf5mTA97{PG@f+p& z5e4p>`#YL$-7OqeIDdzb;(2EZZ~3Ezk2WMGPBoh?|5Hf}+G%_GunhyEa!W>O$mToN zVSB*)W0Vp+K9KDxxXqXM3JYV;TS+LQRT9XAv2RO)_MCZ0rWbY!iK)51j zVEasMq~(JzeC@?zhKSchVEhSQK7$?<`weQ=(mWTX@Mh={=GqAzxB`v}^z2Ti;tvTj z!0E7*7qYt8Jg{!sUw@aCS@RpLBA+27?Ii2UeT94t#&kJ9 z|6ce#Whqau6_cQ2`99weJ0FCQ$NC8)AluhDhN9F*y_&#X&C2D&C1iE+MwRj5qcG%h zLqtDwp;OGTTZ|iJ(uJ)lCMM8ti)%L-DCm0}D?2tb11u=ML^?f72K|&fGT)7S5_POW zJ;8HaA7_JJmrlLn(>ZR}{2oFqScFGGgsf1egzZRIFE%De!k}h9Sv*Zd*A}Pp=FKj9 zLco*Un>GxBKnBJy^<9Nhfwt}0?}xUR5W%`U_378eL_Lno-MFlv@~&BrZWYiP-8N7U zyk0DNPynE88lgmn1+KhsujnS_2{U45Rnew*D#uNsxeAGU52an`6EHukAN_j3Fzsoa zNH!#G5#kq;i!vuVWjW0q==AN(9hTXoM_jt_ui8pPFD(OzVm>~Ns5SBlA}8Qf&!SY> zqML;`jHV9~)a7M53$`C@943Ub#L3GS((v(}rI{MiLYWhiI;|u}u`b*35bja0=-wh8 z?Fzz($z-p)==5~u!S$6ZZG8p`qR&^Z&H4iI#z8XIZhREhxn!!M6(X0>Vs(*7&9p9> z(OH3=;^PxZx_VKuCzjBs+ffJo&Yknzgw-&6Vpdau*`$N=x9CXOk7*oJ)-JB- zE19(}u{pxeK>z`^=Og1lvI2k(u?JX&)HX;qi-WN)`~Id}Z_V)LzB!NhY&ZrI(rAMc zA9^!L#yO#*z?G{lh!FlpeufV0UzuKJ8x!|vqIMbnhdK@sBC6aOxBqk_-2zVS{`ZL!=qJDh(FIN!%H>;fg%UPwN z|2qq+B_Yk=-=MEIAeYdSc|bk+hcg^GF$lR#c}v#Bqho%lbnB^im`#L{uYYVsIC=@E zJ_qQ1yi-Of?_9S<9cJqlP(@+%d4gvcK}WduX`pPOVNA@x7f}1CAi}z(sgu-4WNHUQ z9NQp4rOqigH7XNw2Flu4N4|=BTS=|N)64@Ls_@d(EfwpJZ(I?cB|Oe45Tmz)2PPsO z_Cmwv+(+bB@e+NzmAeDZ+i-56y}~rn=wAf*^LDE%MB|`5;={u9u{=N{qRecaL)KIb zE;D6W?jv71V9qPuu?J;C+iizoDh}JR!btnpbhi&BKZM2o`fn`V%rg1ub5!b!uy<{; z5q}zc%Sj^P4YBT5I*3>j3j_@A*1&wy!RTHrr*3(Y_DG}TiIot>S3y^T#B`Hg&_&+r zy{jG^I+7C2DJ0Ghj)2T7J(bJJc>Icm<=VW!aQe98dE>;BiCy9zoUiV3Td@wyUsXDy zS(*bccsCI0j~AhBFm}{!BETEW#r;dWz5nJ`zH_eSy@l_1<(gPj!Oc|5-TIRjnK%Wu z2_d?LU~@+dGuU8aUsoL$X6ES~JOgXEG6}~{iB_YlI%^KUHe1=OxF@|00g5PJ$*YuD zfPF85?D88_HrR>&byraB65Bvpe$y@RNN1YO|3gQ?>b1FxHy;bKgsZD&z-Gw3T<9 zIE< zUPBvW1LJKs?=L?X?t+dG2hU1J8dZ+r)RA=I1P+t6NXf9vAoI372*-=E?c0-wOORATxIxtV{A!1hPkL@_)2#88we*)Y z2W<}A7XXr`8D99Dic~kqcoKdFY71u~b$1>92C+9P;W6xF2$8ko>@%lSVF4_+&4W9E zk5m|mUq9Z0C5t^MR9l|7o)gdqXY@peq}Q}e?tQu3s9JJV?gVpJ(Dn88k~8wF%Y!Pw znz%v^F4clS7wFp13m;U@z>cNoVJuO^{?9-9Vi?Z_)yD+0q2ioHyPJ6kt5LuuSuEil z=`Q?ljf82le$v;*+>KO&--{M8>XCjoLm$V~vrK1JA5Z37THKBxtGg~D_{lej*htSU zxp!(V#=jb$cUv!CxnvzNb7Wg>r#TQC@r9`-7=mdeDJ@3hVpG7vw)Ktt)XewIp@Pgr zO3aLo+}nn8@qN20$}}7&x(Tc4pd)a%FgpPofnE3@%v3SqS>Sx^uhOD0JENs|m#8fd zzZbR0IFoGb+W73gwg~@XvwSa6L7&n@*32QDx;3;PjbS1w;?}tN$7WMf9K}0FI^DfM zaVmWa)>lAB#@+$?aO6CSpU}D8)yozyKWB9seSZjkBklx}+U1(_X29}Q_jlin3{WKP z63L_g3F;{&3kWmcFWnxma&FRYr-~7l8XVK-qdg8ZZ;$G9M3lRFJw>6u{%9|jOM1oj zmdZoE0GMdo<2$5~C3J_7=rzL6%o3#Q-rS9m+s)}qTPn|bN=}UAy!vP%BKLEUZrwSW z5~e}3Q1BZxDFLqfQC(gWS>)B8uYN1{XOzWkOW18u$*1n`OFq&7^Lpmw%akmXDq+NU zOvBI%f7+uWC)?9w-6*YbZ zc6go$x&^-W8T89|O9n%Jah(brs7puN_bI8(0+Rs);RjDroR4u#0**zEQ!5ekK{D-@ zOd1`LHaNkeI=0lH%BR8}U)aK-m8Y3tv-$r9SpbXz-%baJuccOK0Zyicr2k=1!QA1+VPq0)JlB-P zPw`7tL)V)M+AF!qXP1LQ;JE?J<&=_Q4INhm(P+D+WYIL`X%oLRf6j|Aeo6Kt84sxr zwl~J!D*&ySN7O9AXLF3u-3F(zro6o@QcGkAv5ET5qjy)!cf~v-wLv^5=p0R=seDXS z?z0@Bc10RIw5B9fgEJ@Ki`mw!viJ18W;SllAM!Kr+YWg2E!tj18`MtTQNMz1gXP+i z?~?qN2G0N(qQ96t!_&LN5zUe$`}P}BK)Jy=$yL_%nTL2Et;}$%Nb(eY*p=`=6PA!! z2Xq-c*rm#qb*ZyifoL4N2V>C_{ACweynfN^W2>|%yTYxqbE#j~c)dkf8P7NI;14%( z3)|6I=Z71gd%42ftYnAIIXhcAUvD@T`ceJFWTwZ|QtV}!q_eC?FpVk*a8p4*rt9&U ze2-M*jK|PH+c1&6Hm)j>qfQX7bfoc$wlXPCW{A5h@gZbmg1~`i zGeI_r%$JRFCb84|s62Q2b^1ruXmiW1V!xL(`nluzFYPmZ%%9`mx=J+fk+hFhkl;l{ z>bbGbgFHO$46@2V)7*ICr`Fcis9gT7&BD@O`RdTFAqHE9h zvi($?hKHXarBF9Bj*mU34+t}DNow7U>aYv3Q^=C{8kWbZBtM|B&$ow&OZSPWDmhIG z(F>j*Ue~<5a;|E0OD6C5#?1Wfp*l~yBXbs}`&UKS+ej4$3WWPjOIwI;fHf*Ss_`^5 zD)Z@hk%s1H|EP%v2`|2!KXUuJl~k7;_I$F{PK<7$vc4zRk3IwL2(R;qg8)U$2B%yo4@1K0LznxnyE zK%V3aI|*Qbw!uUJkJQ4;_&{Zt%7^4RU@0ZyA{5`eD^I#Jj;PAt8gZTLlk0t6PTjxT zW6_qu$;!BJk!ug}5^5*Qf)wZmU_n0ND4k;SwRu&)ndq%kTXLBf5wdNOKi`Lcoh{F7 zhI@9sBsx};7HF2C5pa3DNV?W!SNNLZGq;{TwyPyQI8g6-Ag`jV4-qK_H zl3~=-NE-kSz};v~Ooi1qCYcc{xw#KZzY;S{Mqu$@Biai1>l3kl9@VgGYQ`C)!2|c* zj3tIBh~P%WT7WF6$;Q^kOLNgwMtXdBcABB)PJ+16l?uVi*xiNL4^f~Mv@Ja%;>BPG zblQx-4F@|F*}hMq3NhT4ad!I3Cbillo*cN4$fjW|c?_O%oIg*BJidW!=tv>XBJ&&F zKRI$#w&+N|(0hO0wC&W>h^XxrUZd%&3pSfLfK5Z2RR_l0+m?THvV!8cpq zMC|tMx=aqm(g~OdaCyl~vdxkokb=(h6JY??Ei&9R>Z+1%{ApaY+$XVXmjqrSS00|1 z{F->dNa(o(-4bqkb@8wb9*!&9?(&yiWXM}H3W5}1_+|Gj)kcMdVdLfl4>RSHHSGkss`Kl9bRK0J4)4sd8(2AXk zUj+OV!KceKK*9-HgAK35&^V++$~+dKCo?GL+szt)R79+{9$9=jo4xDkEsF#QeCg-l zSE=FVe>i&lF&KlXzy${=S~yIP7x{c5p*yk!eOwjHhui|XKTStoP2komE{u;)v*h*a zlJ_uQ^x52991Kfu&c_#&oN>;&LR2nus(!d$zrQ{G(_qX5o51{8TTmaNUj5l?T@Qpx zIA%f8wHp5hmEMCo54qYRIL$V~4ZI_D=g!vE1+m23zTdN?1;0nD`n^%0!`PSyL(q0s zb+hQ%;3|BK`Kg5mML>cc(d?Cs{zI{7(pN`Z`C+>s+tfE|pFuy`Z&B1$sU}0COyHY% zt@As{Hwn^yOUX6ob_mKbAp_SSFDn4P?K{v}_=*Wc^`D?dIUSXvR4I87 zK8JMUmk?kKR2$Ioa^~dQ*6>CklteGOT+$Yg4ApoYbK${PbJZ`3DFo`|4WVn~2lWK8 zX5Qly0DfWOcFh_y~m zTH%Sgy-TsP7w&bPL_PqOxKUn1pU!dOVVh-(r!uetq=xnv=m^)JaV0s9>6ip-q}J$= zoP%xDP4Jtj0+HBc3sz&X^yA(xXoXKZf(OY7%v20r_S9o&Nl(Fib}V(=InRDMo%g07 zP*E@y@Rc2!+{LPj(XAq8mUM>1a zROfZ4k%jCMU@9TfO3;dqcLvTslGw}iyS{61#_5Rd==sOm5D@L$VE!uEIKrHjGWN3r z(IK&A5IosgX%HE5iLoPCubn=P`L9%uUmLIj4mdug?Ktil+k#vcw5;rtK0H3t%lm~z|t(1*K+vW^4UR(gEz?0hNRrPM`0W=#A$J}ne^Vb|5 z)>fZ$4eCyR=?az;bx<8tD)S)BbY^ihQsVy_?oF96C@mZH8hH_Yv93G|->I_3Ui0QN zue4j)=Y){NTG5j6Lu{FS@Vl`pzj#{GL}xSI(EicWBcDHtwMKo2LMJTK2$Y;&EB9i~*c|&}D)#szQT_<4WwC=Y}v5;gk2}k}%)WNjPmv z6q(a?33`fro2VVDVT#}8MBXL3ZEh>Y>Rf*Ixt5!1`tp1StZ#<4K}Ea=U+Hxf>b9G2 znJIj@SX#RK^!RGljHv>nk5HDiPTTfO>?H4idnf1_uBMv|giVa-&0M6%k%b|GdB%i< zp)#{CO`*&=EG$Xd?ffQ$Jr5NA!sQ6(h6w#yfcd%C9%wV!zctRYAoGwyiK*>skc?wp z+?xBkFfUOk1^cn)<*frkg^8%IqX#8h((Xj&A^l834GXKsIN1`wcY}!ti(!=AudISa z7>|Q{BlnB%phZFR5O;9}iW&qS)XCbmg zLyLv?Q9^dZl&i<3IF`EUjEI4G^4lZHa1{FzHe~)J5l<3mrwWkv34Wca(^>i62Wr9X z1#j|r>tdHXURm0_Y(evr*U5?P2nHLjU_Y#~L;S`s3`i z=M}B)yH{Q(>MER^?RGyZy{*!YnKEvhXCUs^bb{HD1szsU;_VECcrnHYFSSyCOs_zj zc*-;GYO?8QvbRW5(g;u>c>llI?D$WcD*x;V{IetQ&yK)9I|Bdg2>i1n@XwCG|2Z9j zj}xbZf7FLK)!gXnG5=~`>qvitw_0e2kCpa>MP|o>jna{-4bXGtBA=vW$~K9I6Z>joTqId zyj)X!@zPAM`eRLTJm*W3{=O1!mQbIthp%TDu9hg(Qi7?%LXf34HTzhyO)iP`r*R_u zl)2NfXJ)x-Lbjnpx?$*>#2&Hbx2qc87&6IL4TMa9Qxm8Q&fsWLA^cgtK>`0-sJ3Q% zC6pt*^3w16<<5%3NY&!_1?7 zO|NOGwRlbx4FMdfgbw?aeSAL^-1en$PI&ywlJ2wjTtQEgtyFCI&x78XvegyA7ytxy zcQ8=#-^Euad-+~2BC@P?KZx>)EqYJw3(XGiF^skJk>WU>+j;TW{Tgt?zQWm1JUId5 z=*qaG-J)$A-y;H>;DPKUiAn<)d!YR@RM*)*`1F zv%V@v$8PCV{Oz14EM33`;FsGkWG|Kw{RsYqN4bv{CI@t%_+2r-OCrNf2quGAC zO|R-jFIImz`61VQveNx#^Ax>%=lMprd)KcCJa`A7tpgz^M-Wt%b!-h!m%rcQ0U!-O z-5pj`>90JcG;>sx-VMu|)XmU}Q-zq$^1O<1m5EJiy6OEvxK3bp#*^8>m|fx3hj?r(M9#$0O1B8PE(8(yVT#-5k(cQr&XGZ&ngfg7!(OIHJ*`WxRPl*BX|FM^q!D|9pu!#!i~UC+Hcf-ZU# zi}#5XU+sJJjPWYXHISb@FjHec7%Hql(r0M3)YURaoHFU=Fum!}#UwqCqdfX{2P^;e zWm4^`yJWytA0xvT{_x1+&E_j)nPIduTWB@ERbzTkC7~orL;cg;U+9C76Q|q=0qe-& zd!T%kH{n-j0y`WmNo;t4|4W^$OKKpTLE7L$u~AmTIIC~lvMIPnrL`3i9j@=|XiX9< zpPQ++Ev*PVtAkW6%1$w~)q`UzVL||sAnOt-)PK>(Q4!gwoi8)j6jLa-dR|Qb=W7Kais>5F_dcezr9flPVro z+ZTJqj4SJ^AYe=Yrf_g5gr|=1plSk853V%N&$2HYKXd!Cjls|L{Md|Jeh)leKCCq! zPhLSY1e_*GEg(X|gUM`X0Vqi#-n_RUqwe0JPpZdKe{JpSB~*MR?X#005J+n5|2Mks z|0U#Y_-`R^M;Lr}=XM-~iTGeB0F2_RC(b7yAXlhE;U?8(aHVLs=&vctF2yth&pfuS zf$!>F3kX-|;@$_U*Z<2%bV)K9 zL2U2L;-?e@2%%1uRhB1RCwix$eSI_%z9KkuP4W`x?~AdAN{G+$D_07}i zyV&Tph;FjMvhK2E@rRD#Mkc6~o#r0F?uP#3y4k#y4gHN~72?M#w&#hT!8PnpxLs zIIxk8G6;m1Kj)2Jey!&b#N=1x!n-|nbh+B#Sh53-SDX|uf2tOK>SKzmaR=G~cCI)k zD4=(FYhcr(Apm(?d-REJ?B5GkeOiAJTAlE5ZoMX_sOhx<>>Qn#n3HJ)T?uWVtiBa& z0UJHfL&6X~k&Q|I3YKfe+=q!XnjL!V_LYvJk9#H`t8{6*#3i=}d_nV~^oUCfcHu;o z&VzC97a#Lv@H$*kIzqMFha5K=j7-HJ-_4f1?7J2Rpx|TwpKr>w)4hoSo&ZLuQD*5h zqH{dOCMX-UR)t~^jVTe>kF>cNV0kU^&CfKre?o_GCo@e6+Y|MD-2TQ1O58gIH*e+ zB*a?~eF-O{mt;gA5JluKj%;mPr>uI4R7S6S;H;xf>>ULGL!VhtH@5fyE|LzxqXWX1 zYN=)%-3ZH_?aaBA(D(FpV%gm<(hdg7TF9v=5m?aGSi95=f3uWi>2-%aau3J+u>0-R_YZb5}CJH+=M7`Z0K_pI}xHS z!Z#)D#(igoUZ@xt`-@?}CLF5eTl11Ci&7(YN0W6Hpr@Q^k1Yu?(cJki=h78iewsew zN#FvD!)EQ)w4j;w38cGWYt_|9D3*yTn)?_#9^qU^%59c4(Qx5b8#{<650S6ukzhs>h0<01CJLq|06Njf=bT4v`>89wpw@wJ?je>+0rmky&>b|c;*j!U z9(}QZ?jW90zXxVUIRS*w=x~Utxh1FW_|HPJJKKSu&dN=MJA&`V$czFx+k$0=1D1c} zK(~J*Y6${H-5{VlVT>lo-4?UJF;5v-|FLFI+?|(M_lj6BVxQyW&dcw7xZ1OHvO(=1 z5{&z|XZ61*9`}za(*4^v_~*a>@4I*ScrpNOJ_fMJ*@n|O=w;2%=ovN&ZSgmT-y?tP zoB{RRf`jwFL5bModjO7Z%i}jFm&b4fg(8Iz1pq=ap+@NOMbw1}aYB;i+Ex)M#G)EB zH#1ovFU>vG)$M&Ml|l5y!on4e!A&hd0}g$V2Ed1>fdQnum-8X&klbTvz!*Q)uz-qpg(WrN3x|Lwj6a6JJB8G!Oz~p%+SKt*g3IoJAPX^B*wY%x?5;a7+ z#qqlQ)*NReE744Z+m_j?yKaVG@Y%oyldTrP6Y4dzS>eQ8JX;3>O!fvKTAR6q4;HCy zdOms8HTcudC+U0;PYQcV&Fuw*CNEnE#%_cJwvf8J5mCatlv+3PMFP*l*`$qW`-7tD z0=&@U85l73W8=ch+j~EvUuyg7CkqbR)fA~8U?GeF1%q{pWjOfIH2l*GOuPO|0O<$M z8sN^dRC@XTPVrN%<&KWqOmwk@s!yU?cmp7?D0G>9gofx+5RucI6hzpt7+n}vDv)I# zsHyQT@~F-2Z)#<}W?AUHz)G7r6LdcVk0s|kTnMKGdSNJ%jQ~B`7V%~~xaKr26hcd; zP4xG>M;)skL)|PCaNH`~ly(iF<8vJ1%k@7pJQaM+VyDwz>d*fnn=Mn8O5?P8uvtYy z1Hd`mPMkwVn@oTbrY7&;z3jA5KwpKg$r<;|29Koi=2BAZ5{%aoUIT6Y4Wc!1_IX7@ z8xDEtAf%1#xDVyq5uW-hF+zo3_3wy1yH+AWnppzo;!#YLCJhxL4DV9DX2m&%0!w5H zXN#mDQ`Rnvd&fi=&c0G!K5=zEhnNVEZ$>F-y93B>R+4}_)0|*4c$^P*FRZ#jvXh|U7f)o?#c!SG^0}3iY&3m)fo3CN+#yS$3?T4# z0n?K<77KZJ&};#mQ}eybB}$F2ZhC#PdUu_L_CeuYj7+mF1M?B@v1YuM3b2 zfI0mz;6XJzC$j4&$to$^cGm9a^jpqj@^=ev|wvksdE*|Y(g7^fQ% zM8pA{geBmas!kA$YP`I7f-vT+lE?f>Vtn&HjJky;nhK3UqsI1x_=7nTpnbl0Ii9ew z^u!6!z&#A!6d86?bW}W!ynH!pK#m3G))OYNJN77e;DfuyX8W1b@$@KvQsfeB>fzEQ zWZTj)6c24Xw)DB%+*l5|ztJG&%VMnW($zJaWBW|yqAlpmN%Uq3p?%4|19JiiU4nLC zV$_!N{XgmnH7D3gX}gJJF}Z$7U!00eU*_fv_fmgKM;|Da4AeRX`5oZp@muUf8~o`I zhhKx8=9Xwnevj3@%F2{+pHMM++G13gi;#IkgMBAda?7C<%ubeq}r z@quvD2E^&^h1@2~t9>|I9;jn1YKYr=yH(U%A&R&}l>Lq-E6VVC2rmndMJ z5bZkUT|KV8FG7{VCUT;$?nU$_DhETJ$ScgQ;dYWVA5uGjrsQz#C%Qfe?;1u3qDi&^ zbegA{FF8?p2|i&Ox*8W0k6;J!{nEW$O^rKl+y_4*bQtB$5!T)L0et|CZSn$q`V?Me zb8rdAGmpdrinH7=s?+kFZZ+XMuf+Go>F`#l{8|Ak2y%%%mZODEVyU{YEd`(?7lQuL zn{FQ{s)Jt00dqZh<)%%D4Fqrtl=$<3SD`P}d^&%FW*|!)Xg)-M0f2<=gDZom-NNPrKq1&WblTSCi?L<_4 zP`A}}-_a5)3v~=;;Zr$t=X+xXAWE~cFgV$(OXS(yIKHll3orJRG5EQv-Fz$r=n&+F zw%e(^Y)KUA4AKaMJkAL-qWegRB|n9TP{)8miiTaf<$OGAzHow=^2E>^HzRfdX_%P( zYx!(b+N_|!FB{juekiBWs?!1sxOk;l3B0cQh7jGOEGGPuL@jBfa(T5Y%kri4?B6C< z0i5(ps89eDb&`PVfO3*Oh?iqOTTSj`>9d$mZ|bX0s&$Nq#RZ9;gkiw&(AcMBvUU0kQ)~ zJD;J)00FSmt+IyZx@zW~IV)J0-Br(PpU(*8EO6+)nAprEhICC|P9kQsf_baHpzizz zokXmHr;|W#Zi@RlZ~54aT&p)*%4Bw3#TFjH33{?e=d8{F&bs0e02LQ0hAl9JejpBZ z!e`i8AKp11-@!Aqwz>TDLPOXu4)?7FMVby0gXJiI6qGr+w8REr_4t<-L&r;$T8B`Y zA2yW7AFNdmscW$}MNXM5;vAm~IEn-_vT*sK!vH$w)GZ}IG^L?E;4-o6T`nngu{iJ3 zLJ66Gb7o?H=4AX)OXIamNBNc1A-< z?AsT2-<*EW`br5ShrD2uC;mqUJd!k}yqt+Vdu$7K84FN$vJg=j2m==TFl;Uww`9*% zh4~@3zz;U=G_=Al+6Qo{Hl_qfC_;(+LJt$3u{l_4 zsjj2BoE`o;&F4gn{3or4Z!3&HQro~>8UoPGH6H@2d1MjFXHiP%x$8~6tZ9e&3wopB zzTVftCp`@sY5=VjI;9POpHuHQ=tvR4gr4Zi10u1J2MkKbiqvm_WDbyt;UdkpJC9e1 ztoeixgE7yeVv=bIzBJZKq(6Q??kGm?i>Tw1`#LQoYe3cuv&Ft(4v!O7B6Z@S4C~a3 zu_cw6ve#?5e_3Ii&q3`>f$fwU^O1KYIYNY6Cdndob7p~ult;iGyd(Lr6wm{}YP%*! zmiWYk{`1te8No%k;rZZG=f;D;>NG*2zWZ^_{s0HKkyIi#)x91^>bQ&T5CY4Du_p$qH<=U3n z!b~j%1h#_#k`JL_@pdG^4?mVw@SdziaIJUfeA`Hti<9t3{{h?-z!=55b zmXLT_)Kp$J>f?1WJnQK$w^jGetEMzZcUyRb^;2vDnZ%YSS}+6{j{Evzd?1dy z(H$SHvw3c~(X!Y{rMi`+qOEjlx>wrotZgPp6!cun|A^3ymkn>0L9%I+Vi&bCJ&T@; zD$=ftva}jbu`Yk~?SjS&1MKUV;)Qj)%WDTbrK4MY?Y0)mBGZ zE40@=rM(u&aH%~^RD?qzY7?Nm^1}nl7{_`W*h(JYh%#CM8n438>NriO_InRywsJ?W z1?<}x+DYy3oDRmS7jPXh5eKJrJE6jW9WfsaPj3z-pyA(8;A~RZ_zi^1itAlRi{-xY zp#nwap{_i6Vhc{d_i2mXYCD#zn1B*`|lMiqXE&&`6>m1z>O?`o^ww3hyiZL#%~ zdbJsI(YvRILLw8s={l}>&udB;6sZXsp9LwI*7BW6d9rbU~#1rHEY z*<&Jb0BNxeU^>^Kw18mMO*m0ztBTU!neyTNwOuou@60nZ!`M__i*5Q+3#r?=dfvG5 zpW(zYv@0S4%np%9N^G|m;rmaRN`x`lj=X>D(x;4bd0_2ma2A*#R=FJ8vUW{Kb~E8O zNHX$R_cv&QQbLW*il3pQ-d{QZ-nF!7f*My<>}-;Rx8KY>HXQpEKoQfNij^j z<|eL-bF_`j!s#u{Z{tK_?(nEv1h;216#iLZQaG3xszhQP;ZuQHHngX<*=H#q*e~o> zc;(smEqz~qv6^4Jy3`%qXj>RCgQIIS&H15nSTCI9>RnSqFzXf1EwctViNEHGN{R|Z z3~1R?=d;*}1R}H*8*9Z(SiwbQlMqEZK4=>o0o~_AIx>aXpGE9usBvFSXwR04*(lQd z5_hzBw028LIU`=lzUt!s=@y^rRsBHu`%eA7t8G}Y4iVZ??-RT74d6u;((VUFas!Ik zrnb+s-yI4O^A8D~Z~lCJ+baCeGBr@h1&fpINDA@LPU@MEv7cLK0wnIVkf*dqeE_bZ z5DmUU(voSs_R!LvqpkJMjm)X|9U&b&_Wp+G-EW?*r?#Mz2kz7U#Wxk3&6fOX4AAY_ z*({ioNQK)3`Ozr(F_Qsp4`kV5klyWN1MYMQ=)N1rSaCwrS2PN{jdoAoAvgq_~k z(H+hPHM0C#t&uD$Q715vQKSlLyASx61V3m4U0H(u!3mU4`~ez1;xzQn<;43y`M-JT zZO@ruThMP1WCZgi z^x1c>72VSYg(~{4FLGW?on*}t6}k}$4|^CD5SugjDCr=RAoirEI`cC+?;{E8#Q>9g zzj{bU^*FyH)^sBPu)47xXNQ{G`x^Nh74a8Eye2zed#tiP|A^HnxqRLPrMCG@q*@(& zs&3|p{{S5Xlup_`cnv_}=NPDUvuMSV=cMbt;+Ph5$!h}7n(q6vl&2rNU^R;# z$`^!vnh>+`$F@`LS4=0EtR0q-D6g}WoE7`xl*4*d!lomBQA~LxR)QciEcH>Z`PrPp z@p}4p$Cj;Zl5(<71K+|-7SQJ^_WshBZuD|Bc)S$T$2>}8BctOF8FrQSs^mCyDpBHs zf&x*gkNW*ApW}CA)r2l>IEp!M~jky@!T zUo-I6=egYA3tfhT_j?|F>$q!Nhip76veY}4W@II2ejJ_QpqSA=9#wp#ykH%68|rVy zolc(Uc^Y!JzU{4GRqZPEH^^9g5p%@6f!S}9rO;29FJejisTANyN;w$1cKT>5S!g`6 zo49RG9c}{ZlX`=ceFQKn~`weI-g#i)?18Q99C+rk(3>QHW$ISq;S9JXMclm zM!!L1?>fwxSDvS@#2+2iP`&2IwB{tV?o8|Pfv1hJdrWlzElv?|i?{w}={)r4dcvQd zYv|AAc+**$zFB0PuJg<4)VmEh3AqqN+D+x3K%WP->{jNQRrvo?5}MehmM5?8W`)0%W?QVnNX-`Df#Q zef%`%Vhby+)tkANu3~X!v!1^})no5I11CofedtsS7z`02b@!nMirjqTQEIuz22h;) zJKxOgZI|{z@-~Im4ZU4}rmC9`lTz+fUGS_Pb|bG1W%}04gl;6Q8tYqxZj7OtC{VS>reuLnc`xWWGmUJ3+77Y5UioUe$ zQT9Cz2T3JVt$uB4Z50S|`pG(;@u%ppTlLp>VP?3i)<9Sw^BK@9{KX`?I?>AJVw>iG#X_mus=8}9awPonl9c6q(`;SC@`}2@Fi*~|^#P{e4 zGQ4NP^4j0s%>RWecvl}bf|7D(iPSMH&@A}O504Zd+k7(W>X}h?c50?}CZts~VO1Dk zZ+slPr}UL1K-Ds5%fbAT`$T7aeD{zVqyjzsm%se*$ExVX1v=3`Uio{!&zo}6^YPJq zTb*R~g;x9AF3Zy^uWETGw@0i8+(&IRtEA2YOy~3HiK-fExF*PXNb4*}|3COK`XdZ5 zSA4)Ao=Ylq+JSf5=2|A@Stu3OU7`8@S?U$kzek{Du-Nap7xqt2zuZUzRGA>sK)^ZI zO^pm|;`ncQ#u@L)Hs<99y}7gq-} z{7>sgfJb3>%mM&kJ8>I0gPs0BYv1#rWq_AID$z;igV`B_^k@z94tOOxm@Hdp;b2yr zksn>)n$f0BCp!?-#fJIO)^xtTJEcDPrCMxrn(qpX|87R}0B=E{wy=){KY5b`l+(mj zOaUDdG(|G5d76MBM@_3xe1Z7qM^#Dq7g7kjIhR%P?IP~2ybSFP4c*pS>C5n!9R{%U z_j(%?pqTVpPaYS>y*q z9ot0hFiqXOkY)ci2s)lr!)rG7h0+4hR_N1d3h!ZYu`v^2fGS?tBRp`!e)a>vow{9!TwU;bxA z#;Z5Fr}2kpK~YXZuX!BiU>Z|2-zwCY;sQ%MAAEI*?>A`U;oiUcXG|4k^^)YIS-iU0 zY=0e5zbY*oa0aur1E;a>n4>LyFmLu%HN(|a1dcjm6i_yh3-Q<9;lfDHgRhed;*j|t zhc)k-`m6BPOm2-BHY9Ls4HuWTu^1>|%SM$du3yJh)=O)%{{BxqJBs;bY;n^50mNzgbSaPm}n=CjPMi3?zf5 zSMi`NPo{FRSIyPMqCxnvtD&sMqj96sdA*!@-k(CE76Q+oQ-CPw8xZ9n*$%t+i_snM z9nc7LA+|pxX6!9EX#~e~0-JY|g-&Dw%uNRSmnfNtL6zW(D81 z{syIVRighq)!hUbIanOn7iqwl`lXE@|DbjH_tq;lwWC}s|7LDA@RaVyZ1gwJ%35I5 z{@vRGF;c<^@VH7Gv2Xsnyu}P9z+Y{^c_UV0yEjNzdAvrFZsc%Gn!j_bgMC&zXU;YG zCQDWGvvyIc`9lq3K)mq#2Y^w;tkSdO=)@FYO?ZJ@zzcEcIZLnof>8lTQNaY z_W-`>e6ua70B~%1i#UoWg2M8E0nn<%&&MwGs~)TCM z@@dEZVsOG9erKaXEhe^rgK^%Ng6#&fn>>cz|@`b1I?dJwQ5d)R=BcM* z;GBQ`l2{zQ1#9Ec*mnpNz-@5;2DzE}P-DUxVF$hm5Akv7d8!Fd_4m8bxjs)wNt`UM z%uL0m7jXBtK78p#3eXr|>clT&G*uLw^4Q^(&vV+s;mgSvv>wE2H2}q9m~HVaeEO!}$t16p zfuIQ)H&))K_J?C)zxw+tWi-+a$4589{>lTYdI3Aqof7&RWCg%B?r!d*k4_Hmf%^=P z;n6+)*2B*lWQ;zu` z?7ewBlx_PzK5ms3OhmFyMO2a^OSTC~QdDAOo20TOvSbODB5M*tsVqr$BWuqG9ey8WTyBqH3xu50x{k>kl-{~ZJriw#lXhCi*F z{|&#IuPkxk>WL2t(IXngqDO*ucG@(48DXRuvWlv&bBoggNuZUl{J)}= z9m@`ciJjvY%LVJ;Xl`P&h<6)WZU?sp(ZAq3F#9I0(D75e8zyW=JrSYIB4<~{D_ZT8?O(QaFFg*X2@v851iPWqwE|vxJB2gEC7VhIV4Grnj3!#tE z^~p>WzqR9YjZJ}ci2mP^4%X`{mD!e!x5LL|ZuET4de?f#^VV2wpW{I828XeB7DaKl zh+SBPam`8L{zd%!*U|NslWj9Fp)nZ+DDQ7*k#E$1hSU8$dA!ELlWNlUrltD{XNO#C zK4eDF20d8p%eH_8(yJH~p*^;Owg0Ea_}?(siw5DwJ&opKZ#%M2h1R-Pa5D=X-6Abd z{O0h=X-}igxmrJ(>Qd~&l+7wEzVQ?f5hUX{HvGG*PE+knq5^lO>G0HL|eP z0A<0r+V&@s-dfr9Z+pZtrY_dlu&ell`&>*AN%=T!Wpyxa>$UuXJXQW&jqU7OY$}gq zk{An9lHYjki|OngRxzp!OMUuL4jh}PBH4YV4M6sL-6crM;HVVO4m6ji=)~?&5nl@0 zGO;4*v#`1=Pe<>;VU|7eQOs}ME27EMd#)D*n?Ber+hnSH3I@M+_}oEQvUz?%;l-%q zM^fJ&V~zJRZF;=KDa8T1B4KD)FXW)ft`K@7LE_<+m=7XJc%8p9+jrL6*vg-ouM9h( zZom8Dx!6lFzVE;RQ)|pqQe)*3I^WMnK|5W@DmdjG3V`@^>mJCKke{&=|qWHUa~Uiltl+Az48nEkOioawuYC6}Bf19+&*TU;Cdo zA(>El+2JO4>$px*Yy}tkcB{Da^5sdFCD%L61@-&Wm?c7M`4vUu;(iMRLiW%8WlQxf zZ`NLg0U)XAk#)xz;9BiX0aP+1d1o9e?N7h5E;Cm69)bi|gECPy*5NpRNzL+qAuHkV>Ozz@jA zNu_BV?b&p{s+7{4e8q%~1IA{5FPTw<8)F&wTgl8v1Yi$0uCR1NqtC6IMZZ$Ogwr8+ zUW)8L=ele!H60o{83*AqYG^*9Q)H6&;(LI`|9AlWj{*dFCq`0k73PpZ?11~ENORo)6O|l3zaFoCA;Tz4&*pvV;MfcYBb(D53 zz6=#6J7&-k^E3UkfIW_3Nyt_XwQXG(Sbo%Gr1^ISXlx7WT@sHz#&YgR$&Rl zX=Iktj@SW~-Q&}}XA{DHY`d5p!TM0`La@&EuXSmG+xQL)%1bJ?9jH z!!aSws%%U1Te~SU8|L$aURU{5*kFcx4)hX;6)|ox1fr4>wgS~Qn&ur&Q%#^1hfuc{f9?k? z!Kb6(QiH27mDDO~>}C6&;Zf&^krw0b?r(LA0XwF`6wJF7-B?_+p_<3og%;${f^1E} zmy+y5^GIcPP3O&J-ibm0G0f)#KQ~Z4xg@~DYKW{8fRLJ zZI>T$Try_b39xa7TC2V}ZV;|{L8Dh(L}<$F0Q>m@07Fim0*Y_kKlp!#i=?;C*J;~5 zeQ9Z6_QzIXJ*h=Z6z`gq2E9D=GLtrKO1B9)B!KL-vL&T80SDUI6C0nW#(`3<S86sDm7W70Xr4r&k+DHbI2bhLr~H)|^>_lAD@ zNEr8<+~QEumHeX6-P^twZ+&~+rN&@kJM3V3FobSa5PgGo39pQ5wk$GvxvSB^LEb0u z;M1Ego)ci4_o^3hm$M7nZr|8tXnZEDa-zBU<4w~^+@$2Ep37p`i|@}Z(e zDzJ}Z3Qnyq_w++PA~;U*pAUd`NetcD-guiM6t^km(~S2R+Q0w((I=+<$wdy>qDfk2 z!!fqXp$~ErD<*a5yvZl6i&|~b;=8TokWrydw{oHy_kP-IWY*Wvh@U~5k)me0sJvuU zB|1CZKbCEV^U02}aE593Z`yNP%6fig9}4=H3nJczSfyNn z{Wg2wZLxYj1|BE!_8+oLW@=zsT%L5OlN^Tc@-Zm8Kl%}S) zPdn-B%V5@~N&`7kr zMp6*%6puO8y$h22imdYV1(V|(lG1xgjkD?Q-)dnUqJtYM-AdDs*Xkl}ol@X?zokfS z*a^MN@*U*$kyPf6J$@FW_pV)k)w3cYz6mIf{b=%XEYOdcWrgtde;3GtQU7hzvZ^f6 z)Gx@;O5f407mP<6-wIC%n2dvJ>1<7eTky%N|7A;G(7 z*a6y^qVXYjU9hS=B2=h~5)5-{`xzzv4FMjvUQ;MAdX3WFffDJrUFN*46b6k}XJ&hC z|NV$(2{64+sUy_;hR~=(N9he9Ec;d($&2~Gi5$sr`$GvHd{_L|qUEeg>O&1h82<%# z(1aDPPs+ikX$J@#nB9_bEoPsxN8(jWoV~u9nb|QG>(c9;<}coe3YB>hS*vt^OXgr8 zs&NjZne(RHTq&ykvAc~S3#;B;C7MwFGyUkruLB%0o%m*$vM~3Xa0^|D)03&GpVS)t zA;NQXUA&8VxqU;Ddb)A8mE5bX83!VsT`WmAo9BG0DVeI0{`sDtIV?L6MvN965$Cqt zMLFcbSn0_*yCv;JL1#zQJMDMlT=V0Xtk}l6E_F3z8JA!tG;F*65lQg?g4WcKc5 zMVa92Oc3Q%*@4owT!_VDUe!Fw;G0o>(#rwRfx^1(s92}`!#5h|+1}J9Gi!a>lRX7n zbQ7&qU~*2ifY?hk2;U~S5cv*4tT zfMOXqUl5ZH);jy0~V~NMn7f@fYoX}7% zMIM;-V1N+>!hRqvvQVE~6fwZcH{HPai%w%ZB{-E>ijJIovM7zN%%Ytsqr9hE~DQnnYJ@ zMdY^G9p#fXFssPZu`73;wG#84P(o+#R3KDQ_= zU7?1Ya1dyuDddWim;A`6o9S%S*RqYWh%YGaTc@&{>ZWYr#0&)6Q`_*);;!qm}^=zV8aZF$WWNp_OR3{%$uDMJ0rPF zy9qX+iY9G@$k~P4oY<{u?VkO$>!PsaL$?QG#+?#HemYM-vJs{7SuzdzzQ01td8!`( zu1PIE9XymM1b9urdw974`E`o1_`9*cIliQ0VBng?C*3b+9{8d@!Z=Fn<+Eofon6(? z96Sg30JTVvM4a%ORlrCGC28)fS8OjtUbg=5?sa?Jv)hvZDmwl*HE1sYCh#Za$Jk@t zpY$q)3p*!giCvhIndije)&rxz6c0UYX}p>KuTwm7^Z@08>N?Hm`G<-*xh@H}cB?q3 z7OF%nbRej|718-w^g7L{eO>l!BKfD6#eiAj9kq@i$WQ|F5gJIM8R`l+sfwoapQV?n6x8#J}WoFjN7?$Ir9v^FYi}dmr2&A3@%1 z90y-*|A!dc&$0b~^7Y?4>HcpX*>6u$%&7_Qdd61rJyN;%cgqEjqmaJVUpyJDE{&s^ z6!!XToz9(-?$avqXL~Mwd#9nX%C}LoXc-m7G@R&{1;{6crG3G=!2+_dj4}vkDv`UD z;*oTqKGg?)-U0KV%QuWAI;>fQzUps{z98L` zSZLF8oVWY=s}aB8wm}=`7|GeM;r*qoV%LA1blPTn^GRp<*M`$kTsX z`18(Zh2FcDbTob(9SqN@U)JKI6qkji+rTTx>V9I=Hjai+PL;Eh!~KK3jYR{~deP?f zVP{6J+sSf&gk95XfZdHH1dbz!2%uuiq9~$%1a-d|k)m1U3)jcRqAQc{Wi1DO+Iqxg z+%0;sXoCTZV?mS7hy%4O58@IJaT18nwpOjes7$uVnM6>?Ix!l@s!C%6m;{-%7wXfb zaX{tAfs?P?)wnVSl^2l@GCl=0YS%#+?p+YO1UMXz+#Oc5#&X6wzxsxoG`N^_+B3gp z{~-7v#9*ak1ElihE4F(A@tX^v82dR**k~Oc??@eqmy~s6*;yUKg{+9*4P@5|lQ6hP6+F zW8?-(6yoi#biDXoPA9BwJak&BdAu#7UQyx6nk2AN900>`xm01X{u8jGr-Fr0xJ@>&Ix|(A}G-kY~jI|Y270mUt3ir#)o@ap12Yi0(f6r7Uz*B~XTQNHoS67OCx(Vd`LgpVKXfoptO+v$Rs||c{xc|#>aBAx+ z%<;bN%(}(ZTYNqqZT!NX!}a*WAW6;Qx2?59JipZlZcyIq24?I5(7koaJ7nE8QD`-o ztMDUy@eiESFi%f^*AXLM6HhG)>gtB8t_RdkZFoE&*6XZ+6W|Ck+)XqxG*-Vw?wNJH z`*8c<57{5N>&A9@96QGgeB^wH%$=ERRZ_j zrLN9W<3DGe*AMF-%>;U3cs^gq%+IC| z{F}Cv{sbD)c#0ntqF7CV>|iK4|(cllrDyQ+}xZ#$MQzAH~_ED zelVc7$~Y)Bbm!B4fsj-F?h;u9czLOaW(hccankqYB2lH;K>c0k1K#fZt!VjM?RDH{ zy#&*XK({!1sJha7^Fk7pmjE{ z!bZsO6>{vLE`;YMEKwP+fubwxHOEa7KSSQ~3A9zkQ4aowTkHU41la`jwXMQPtOCq%ckI-EaR4RN!&U6BQ4s0gUecFR8s1JBs+hNPNfhX~%;?>%ItyK=Yz>(pF z5pAiI!4#(7$^yzd935`jZ6+ed`)CH3_$ACMC6Q6GF18xsx0q$1#wU5UtNF@i+gDEf z6QkBMEH3U>jo6Y-Z3Sq^a5NcV^NOh?;4k=Vprpx4n3!m zs>%BNvjDZqL4d2|kta8)vs#1ghC`CvA=z7%;0P=xJK3>TDL}c&6@HoY-Njrh>}#*= zRSfI^$8YM93-iJ+(+s~`@|3+6!@d+w;M3rHQS;0!;)ckck-cUG1{6uT$Hz$I$nPm0 zb3xRrRO|Bjw8wR$DOQ6AJ(^xYr|+XC>l|#YXAcY&MeNhg`YI#8nY;Vd4cvDD#aYPj zggiqU8m+$Q!t~Y5RM34g^;lY%tJExdvd7lidG@HNVJ*AMP~IyW&_WV6!Zcj#exDQ!_$GBH=M4b?lrf0$>G}Veu#b>YCkO23VB8%6$fu)CloHJ_+Ts# zCfQ6mJA6A~yT`@Wgs^Me&4Y4lTYq6!sgKQElupELw%X~#iTN3ckohQf?CGq{ZpZ~Y zXvFm?c$a!HX`i)|(@lmLJO?|1#Z-BKjt3_C-y=D$FUQ_G7!0&5`k4 zp|0C_u_JFl)89*S`OjTtKqS!zf)Adk%S|VaL$ZcjPcf@^Bk)(CtP|SHvwk!Z0&gUx zB<$9vw3W8uW=`fyD(nb(;pv=apu>PYH(Em7!pO_Op-KE)eo)fr>HUf!lHzI>Jh z^E#c_BW`f3A<6-K&$flBD~dkMo9qS;D#G*J?$Y)X6}Um`4D^5^l715+=y#SCh9amW z0DOYmh@~m09N13!J>n#A_d&`8kEZoKACr%#=H}e1Xv<5PN%U8{>lWhY$k{RV%#r{* z?|vK6OGfFg+sc`4NRCpiN%SY}rUu+X@PTxg;kRwQa_7ou3l9eP(He5B&o@tUb4mBj zHG6pGWnf6Qr!;q>w7vKT-E;N*BeVQP?0WkZ+*p(6w?bGyX zL4DoEPg^$z;PJ)6p(bhg$|4qT%*19*{{;WRe83Fo?`V{tk~St{kZXFMfw6w%3wkEj ztq-^f<`sQJ7H=9ALqdtCkQ|hat(w8(Fb)shM)WDQHwKVA&F?V|dOi;S2}` zVHE~At81D)738()_366%U1XBEpKggDyco+^96cuXZlnQ?&T#6ZmnUn#QZQ(Rn@I;- z`#%-D3=#A2jKKY0Qqg(z0}YwVK+vE#U;|QUQYiZM1%54^M0-iN>!yz}RTOz~HL;u8 zE?Z#-4dhI65`ra~ct??KHK3LbwUW^L-dzWefn)rfwiwI7PKz~0V(|t%XkT{)$9fgV zQ3=x!+EvVwdLHaH2yZM_Ve?cbDph+nfquP!*wd}e8shMO-<}OQ@|&IIp_N@A!njRq zW>71!24hHibI4gcs>IvGT8i;b`lwRRio93i9(xWpO=k!HgXP|gp_}9r5Z!~PV%J&E zbwl7_Shv0Gl@4{B@Ky>48f6t9Xq-Vct>84)%}u6el>l$kZV#?Iv9JcpDcYp{F^^=% z0{Wfy|GHqWf6|%`M~`6IYbI4796A&$_fvkU3~q~R{?i(&Xzx_A~y41Zd_5n zP~5t7d*zCR+6O1zwoHf__nfk&EupQ9)kS$LTaRU@P1$SLor1?qzU|E~di1`j=ENPJ z`>0IRi*bMtD)6R_2k=}-P5WWy*}kLLmV@D| z3;QBSdJM=^Htv2~TXZwMgGb|47Tbwvn_^A-uA=KQg~v}e!g*7wJ7=VjJY!^Zkitp4 zYg6ce1%{_{=-Axj`~bc;Y}Ao3;V`utLgh%31{S}W@RR^B5!S(0v!IFsY5Q*)mfRNY z86$ccDEpnub`6CUcZ4FpPFqGddbtZ?NOy4M2Kh0@p6A!oP5mSrM zP&3T(4X|8XVm1=h*nDk7EJu8v_@i zu~9WO<+>wVPDUHbH|B2L!Tq*JGNW@SM5^2OQU22;3-aMY_%@ea6-wJQv6DMoH&?$I z&`e8v@uRr7&}om3R8>Wt=p(+pxAv)ThYf)I$siKX*@hfNdi)K3MZx<+;Jw;T0}rFC z+6<)vE>4-wb3WXuKijq-HRUwo2?J^5*xk^#JAHEo1z{+&P<#F6U) z(&fGZlMn~nS?7;>m=vS)^R?eXYx@vim8Lvw(GnD(xx-osGqfTZn(skI(ZLySCwXYI zk%_X0Qwrkb4VQ0sw)^jpe0A@l#07VLbSjz)JEYQqohY9NqXnmA;RY2TWpys@qsg>5 zf&N>f-Ha%z+f^A#mEWqzq56r!uele)aU4l@t z6jZyE=~)B^`LSQkAV^;%JzK;Z5XpmYN*}3#6s6#Vwtvjga6EBwXSb0_Z&3Ta&!1Cr z_g!jA`RZU&h#wW4m3bhHa}0Ky)0Fqml-@Nra%)K87*n#;JbLO7NXr+&sSY1#mS`1s z0PlwFvIJEg7=Zpb=B==5ztD}eM6rlWCpv+TR$(tFjg{X81Y(r^YLpG^BVvk>K}mJD zEeu`mEv+EWfR-u^eZaF`sR&8pjK{d#TM$1b*U;te9R@&Jz1;J=&Nsgz8tR7(>n&J6 zD~XZ!*{s`S=SDb-d7FtVDQWa8SEEfWOO&n0Vl(ff2-MNGTTEArWor)j#VumAerH$x zkP2DeroDiu)y2^`&S%BW>{_OovpVjmsOvb2Jx=57uE$f)u!uK?7?m);PeDkK;a2sT za{y5EqZNwJaqdCAEoE`wOy_zr?8(zCCECblm93f4Gc$*;+w9uEg1Wy!P^t+K#c#ou zJOZDS8si+j%bVm)uDOWYV91PTx;*^jCyo*M6UBkA zxr91LbvFFc%({%XOq1{Ro!PRxpG;?_OhWZNfx!NqP$EU!Vge_V zLfzmE@|0m_^^E9V2Xv*l{NkR>9L5t84)^gcY(NSBocKtU{Fh(ZkQ?EwBgs@aH@>Ys zQ{aLf+hxZM0gTu%Q&N+jK)jOF{S*b{E; z0lNSyQcPW1{~U2G*dZ=}2^*OKXhaRY;SL)y-9P7;Hpo;UZ0^F23g?E*iO*8SpYrt! znM)|INtrq8fMSdFO&B)JGT^-`(TeBwKC%K75enOmiOHde~dy$j&T^KLl9t3YSgdd|~I zGqA*yql_j$xzu#8tN<7Nm6cDWuZV#`s~!KNQ}lnL7;InjxQe$?0&?xssCr_1go@6e0142_w`^rr<}TjGloX zfe5Bm#1h&3Xe$Xcb_^@jbOy09?IklYfbPrz1xRqJYYz$s5h7e1Xb+D8%*AQ8W=bYy ztSGYhjbcQ(R*Lz(Z91>mWid!;RCznP2MV1eOyryTL46@VLPD*-E1iUDhpks(*XrR^ z_Pjx-2`o=$#x%v)MlAi|6&{?J9C+9N`?r5HURuADGmjwVHym6Aj- zmXYY~3&GOUPJw!5AD-Qj{ru7bw;V{2mC&npf3YYaQoW3(mhNI47$fl_NN9~ zeLvit{g&0U2Y8+xR_iY{5=!9971?V?$RO! zBpUaA$tQ@%@OMvLdM?B$G7(7SvE^`a%7t-V@(E`jWEE9nh#=%CqKAS@Op|Vx2aSEk zp5*of;M>#8LSPN{18Z=>Nq`45ltQ*c8LoYP3Yu}k85vu!{&##8+3ww*sbBVVQsz9KYqYXkRE=JIYzRx-CavJaNj_sqVq*w_YB24Ii-4eg;0DShop0i(xHhw7E zj_k8{*)Q$>_Cko$ja!x4Yw|AU>7Xbt;P(t#?ZFm7n?!em!pM4Gh^(%MEz4 z3cOPod7qkkMd^@R$XcFOz6V89Ze4{H>pK3u=sJ4=3Dz@RM7M395N0Vg^U(kVKj4LG z+sAa5c6g28{~UZ7i*A}e_|H_-MBKpPZ&0n<2#twSr3&22KMH!U?p6{ekZ?!&`pP+G zMw~w+i@R&%2M5JYR>!k8+a42fE_J}?dnVnHA29JJb`Ix(j$;PUM~HYh4@7e)Kjf$& zoVsLGjTx|o=Obl z@ozDbmt*+497kf@#09%Z(q+56i^LT0&pbVzIq<~n~doNd^D`r#?=4fVYk zi6$^j!??U3az1-cWDWk9m_r%UJ=`f?q8<~Scp!f1eMbfz5YavDiKT+vYg2=h?ucA$ zK#5XkLqY~c2re$-ymG4uqBArj3A37Q(^c*q)a?R?=?TsAKqz7>8kB+xF8WMi{bOI| zvp8XY-N&iP?&V($;x8VScOzd=H)$5-CoMxcZ#oJ+0eM+)+dBMxP7jrxr#xk#)-!SN zb8oWf4I-T}jAJ2-GX64%A#_g$>swcwv4H;dwYZjr3rHQDZI|{3%?JE>$u@qn^^K34 zm@fPzf&YV7rq?7!nv(lfP{{@Ad3Le59dPaQLJp2p#V+!G01Z2i;okgFqsp|A7ED$F zWGaO_$cZ(sDOzS-&@DO!=q03CaE(k`O7Qu`ThGdb+|euQP;JIL2sa8YZHwl^eUHt( zVdc(idT6EI)60YT^2jo(+KHsjCBBFo4Vh?U>7yK|=|I0AGEFvurWaaWw9-r@k!J>L z4J|XfIMY6819ozTD`s+_%JazBHNGu@LkOrl&rN2^dcRU$-2NCdUpvo45atP*WO9P7 zvKc)~d#`jKfgpXco1wkNem2k97&~_x0tF<$<_B9?`;kBlCnVfTD+JKR=$Q3~0w9C; z0zkb)6K-_boDwt&`69!R@J49J0c=U2l!@;wNPaYA&5HMObkFa`ul$duiN%1BssNR= z_h^#7%-uy!zhCBy0n|0b5@IGg0p-NbLaj!EO0X2vCYrDtv22sG4ph4d!g0`43K2|T zSP;;YnRei5jBj_3G9u=y0GU|x(<*EZ(T<&Xg{;M;AO^%$LC@x!G{F{7z&x%n7!LfH zA$gkxkU&+*%q~>arm5!l0Y?WAd_`l{NG`=|%JUcn;GzPI)LE)^>?{|>>qH}YIiKd# zZO;C^{D^&*cA>=JEJ)%UjLm{ZF4zEv`Sjn&%ylIdf7BR^P*9)* z>L0ZYy4ouW*nJKl*vS+#pl2P&?I8Xn`qr+G&NjCG!8QNhw{3l0R;K!|=5|d?#y@xB z1z=`Xwcf4OYQ?dxE7rPc@3{qx<@YuL6Td2%=$XaV9qY@s{wNgrSupa)>xRlGkjeBZ zs6Z0gy=g0)pGaJlZX8@DhjyV)KU#S4+P+6O-shpBe;nN8ztd2d7oI@r($`bX(EIT7 zew9Vqb(1Dtvls z7Nc`DtQj{?q_YXeNMCn@?P0g>Pm#YMxyYd2+R zfw=CGqfv52!|BK&MF&umOS_&y=hfhjn*r}9*dE_N*{4Nwz8|Pt5R#&=?d$26poj2> z20ub&$PCF}#_he-%r27a%Pq6oX3=ptWb^26gO*g`J6{~Zj z(a&D-zqM`Fr}stZzr+1

|#Msb0x!SqN+V<`NWP&)cpK(<70esG_2uVnhgYmJ&?B z`<&W^D5bTZzV*`5@x*Hpj+f(ZHk=RWydZTt59sQPzw$U1))l4d`O*E*UZzZhFgqLh zYVe|;MZxo3#l2pN(*Hc(9{p!7Fo@DYv;$0MZDBH~E!hpM)gKXiI-G-)P?3S}QAS<# z3Eu&6!kozK3N>oSKq-BHs~AI@X{|4EU=?-)==F83%zvd`8tE$6j6GZv(kScd$zF37RVrM(nh+Xafd#nGQ!2cDC|Ec>bIW9)@d8CD*y|kU3x|V!f z?jOC7_2`?;WVqGhpdn7=TJ@v~2SeXa$NOhXgAK0?uc@JL0V(R)-&k=aptQr_&g%oX zKuW#-%R6~;kKepvV3uWk^04&GdSBUsF^^W)0;sBA5o=U4^+l~HaG(K*!v01w#j3$B z{s$kUxjkDu?lrl1t7=+Y-h112+xB?n^6<{GZcGqlCBD3eC{8AT+Nh=rcW<2|xy9n4 z(D5OyZ=LX!W=ivH@O`H0(Kmk8yX&@ZLQi%vK7uhnKcme#xG)*ZhaIxd1JvM~ezX7s z3L~60hU!4B+@*#COgi)lf+B@ntQbT4L`q+#$Ti=H?3FR&dvQQnWC_#9ISgp1PI-By zkv`V~1bYrt-%IAkD~fiM?S$%NL05@av1F-<2COCL*PsL?a2mbj*fY>74mEl*tPF%n zcWQ`8?%JLYp+MKEIST6 zQGz8OupV+z12p=Y=GeC`PqFx;ChsXrq|PmN?-ekzJ3G*Z@t8JXc3;pa(2rI2pPDGl|-7f{22v)1?uxyAu4 zp%Q!KSB(9>U$2RQ9<1ehwJ}NnLXn|5Bgy=ZW=eH6ksespYR=CGmWkVA}YYwkj`(V28!SB{EgwQA0OFNM_@p6%PDTOYVC!JGEDWL$=Ge%3{wo z5T>KvNrgw|`p^gq!5&O`j&i^7=#yXo;-EdEKyfRf?$jPb{O|`0tH?K+qRhx-Dd-c!9bD`Ywm2 zELZ%U=T1NQsL@Ps|A8YX8cs|UY`5801JJJ{xVHOOWQt6HAG9qd8hgWUOXnShyalpe zq94UC6Y7V)T#@-zXx-@ho&5U4On{qU9Q}2T%HF@TOwh^wtz{=`A7HuxYOk$P%!jjW zHvZ!v-`|v6<-QN7w_QUGeRMO(0o@pb=11UbH2yRqXos^-<*nh-xf#b?94_wvy7RQy z-ya)8UnjJEMYC1iZ*ps{5BqxvED(M%^GVIQmWGHzX`n~bw86T2cV^H06RbKD$PL@; zcH{$uo*KPShtQ3X(=%{X_8j7$PiGwK4&@eN1!khZQ{_}7pXU@*NCezmdO>yA09VIt z|I6zkuC$(L&M;uG+{#oBdU9J+=ckOd*BggQP#Pv+}`Jp6Qqh(Qs%%=zk zGEK-d{75!=-n;v~nB_KE2#COZb7@-P8VqHGd4L_#%_3`nzFjRsrn)0l(w?*$Tx+XAn&XR zt;~XM1e~=V-qmO#$J&)R=-x8mnvMK^eN+9{+fxFEW>Mf&QP4kmw2!<8^A;$c!RkR; zA$Cq`E~>KFcBe>HAA`EneU2zGL07H!c%2b`b+@IJF2o+{6OYcfn_O*TyH~C>-m=sj zQ2zp~ESeG^jb2=t74Is?^#>2~O^jNQxUp{JO$)HE%3 z<+lj3xy;xwF3?@5xG0+xj$gk0Txib7|B`>Lj_mOpPe?r`87t>iQ!6sYRCPaS`W=#P z3CKY>Na1{5&&f5@$W*wz#9Q2M)1>BBVA6@=pWw^;q+=I8uL1?3u4!C_rFiFRlNWtU zqWTvw0YW5En*)!C7rXXg?1FdyK$M0ntVB;{q?e$mu?ZuWykEM!Ux|5$y;u)U?NOy1 zr-YJr-vx2oFFM#_$P3DL~bf7X!Gd2)WJQI{QGJg3=Q~0f*QPzkJ~YPSnw@Q zfw*I|M5rCeH9a*;y$g&1VmW#iVs$@T+Ckel9-3sEI5y=|3P}unxUrr{hXfp zD4&d1q9IE_#Mz!|E)(s-rDw5o;vwWA(PzI1WhX++faRZ5c&)`!c5$$73iMCQa5pHb z&LWIoJaqPrLPd-hu0w5QqKPtKMSP1bxu3^5BBxFfE~wk*ahW+DG567#tcU$2bH>M_ z!-{0BTI|Au5Geq4xz{xTGPuS!UA5wG4m4*N`nJW1Jtw%wE~X`|Fg?bO9e63w;gyEE%zQ^DiPK}*&R3cO|m^Ka{NbiClX7dix1((dcL zpY_jrUY5kb>q>2~*1?-g!JjB4ssE0F=9xIeez zCs`SfnA-@EeCzaIGk++|iHO>_2UdJu7Q#$y5){C$hht~1JgnbvdU}1SrK!VngTb;5+;W;Qg3t8FkCEqFf z^w<3l(!FKGyeWidxzy7hMuLjIfskKsz2wO2ddvM9CQ2k;pHwnYL@?HZoqlh=r097L zwovrn-GRQeRpggM8^s1px|@hJJCED(^{It)6J<{@OgR3>ljcNRYdIxVN2j@QS+kMA zI0PhORe{fO@HaGFV-}}~q26upXCDx%0z!1Wjv`}`4OeOjz4Jif1LdHd$MPX^WvFs5ISD`dwC-8RT(Uc9HUG` z=f$qu6t#^@1lnyn*3T-peSv3weeb3?R@^#^Pw$r;rgAW-{s*7Bx6hvY{8{`l7H;~a z_k@O!&1Vi+pX|6it}S9-yb8T6hGYftT>>WWf0#k$mY|Smf{RK3HEAWCCNHQyypvBy>d=4la0|~(1LyzgR*+phM0s?E z%hHYK2X95&d?Slxk%LuHug!H?KphIAwoVz_`kg2}o^i9sa@+DsHE@oz6eK|)V(dWs z0D|VY=m~e3wMY!OGStnEI=upZERZiIC5hfmE7W9>KQA{`{_<>=X2^o|H>zk+rm8l$ zJrU?54a&s1RTziI`D51IDWi_`@+1nc2fuQxNdvANb{$T*Nh7dCxw}dn-Z}22mymd-R^X)4+*&fR^{1lJu$Pn2@!c zI48)7%faa$V`BtVzAgg#7nqPyFoB_>QQdjZo@K8H>A@^1f#Shm;bbOs3W&!l%pWiLB};$(VS|T==C{64J0cQDZh4_C`1tCgUVcZ>0wMp=eij?+*MVF}1?pYHg>{|C zyrf2dZ)U&N(EjJYfe?ycp7x3<&?UN^Z+~TVnwVCb{%6ygG4K|5<5|&xO_C(PkC!YB zoc$s#&9b8|YfKBbwd_9FdboTx$ile7Rm=Uy4TjIKkML{l9aF@#wk+8qiyLeRjqR_o z8Ll-SowK2b2>8nj3^ouR(c*Ir+8(6;k*Vp4&%F~2h5c~u`au@CWhqg_WQOhOVh^NF z&b&hV!!fo40!Gi#5{Irwp1f|KdQNUy*XHv+QqZ6%s8|^*a^zuoI9{}fb&J09!dl38 zj|qM-%-k|@`|$q2t%(t}dIv71KCImS!<{pac8sDCh7(YE7sU5rY%#AOv>UNl66M{u zXDX1G;fZFk%=fPB6|pwF@3B?mp7_zl!<`JgLSX?2Dth{R#B;W}F!-h>rigKYL!{UJ zqQRmUXZsela^lBP9nl_@ows*=*EW@+)Z5%17LLK+*)Lh^KglC}s%tT86AO!BRm@2o z#XOlLiSNF*`>w_Dk52qLm!B9uZy6YNugvr9;NB}C>z^JbfBu1LHO#&thke@u_Axo~ z#r?%?L|cY2XFDr4si9lEJ=cG5!j zAY@4tdziVm3 znv>Yo)m!CoqjAx2LOv=?DpWB1|KaY>qoMxa@NrzJR6-?tOl2!Vl9X+VkhV$obzY^c zA!J`h3$jkg5-|zMI%Ho*_9go=vK#w8V;y5=`aZl?dT*cKIlps0-}C*WBgTy9@_62l z`@XLGzAmvNetdVn-q5Ys&wjHXU!GuOI089B3>MMXBy0q`N99o*W!!8{pHiQ4dIY6} zj1=hJN{P7Jd2>b8iz7-X=uILih=%3YX*NA*zbB~2<1K-i_ZdJ)j*^e15<_LsC+Izx zS{TxAHB-Yl7wq@(M~1X?*1EOP2Nd`T9?=?*+o5H;MRo5xzd^LWjMP0>Da4~*{gLHR1bhp7Hh5{sRIa;G3dCl z?@t;A7$u98HcCeh145NW4FE7QDune(vJ+u<-!g^UToW9qCQ>y^@aX$Jqxb_z?#%G@ zHxW|(f^mK&8$rE+{*%tsc|g|*S_+I zZ^N2t>CPrOj4?&?n|g4gr1s!hl)mS8#PaP=pyR0oIVlXX;$S;EVocgu_T_4NuruBV zGX*5eb^GqM5`=G)X|g%Q>PFG~($-u)gmdtPD%nTnYQ!RYxvA>AMz~2xPGjxShC=y{ ztdFx-Y2)3Pg*lblBYg) zI@!xzH4Wm!$shDNR8S+=kA>P6P+~s5z{{It|LjO;h9-1^a&L+K;}F}Cjx9ujGy@wI zsM%4=9)h%vz6!bw3K|+}-WaER3hp$H`&p)6VT_rbd?$NF{Zz8#5Yb7SAWcbVYLlV> zEG$V`9NUMvTjg#GrK>|6Mu@eo-C6;F(Az2m)*d`rJ(0VgmqL|p_uG`zx-I(2l~#xa zGIQ~L&GRw&hN^AaqA%?`mMZrc0A$cGt)<;>EM0JbANPp9PVuc>omvnlN#)&>O8#D& zX6_yY;YQT6seojh03J-uxsLj*>SfH}HJ!0Gd$=4Ct3lv0Bi#(U(YvpGgz{Me6*FTL z@0RZTh*Ou|a3&c3Rgq+~u?02MjOvJpX?23aGOq1L zR0M3(Es&=?DkWab^Jb|5)c&4Iss(155k=`s`4eVonRTu64|@HrR4&(QZoRN|H8q+Q zjq}c~bVfK2Z6AMUUoN7?$VY8kH$OR1?SF*4=O_7kAj>>OAiMubBYuBqC+;HoB#rpN z(&cVR@{PNqiwCoS_bZ+g$lN+C^p;NYVsCn}Sf9ZfN%$@G2LVo@$81VL z0M`9a7|smynwvxY4cEl)cqtun!jk8lJWm&`vzyt4d|2Pu70Ov!&n>Fb9Mi@l5ZW1@ zJD%+AAEUY?FSATaMPvM0UOeKu)))Wi+m?nYRIsBUmttI_&`49xQ{Q`a%fJ(L3pCur z;|J9#L-Og1)0v7p_?PSS-?yRTyQI%b(GR-G_Z$MWtY|G6q97K5Q&VyTL$TclS~N|5 zuqKlY-w%u)01x(;khnwlc${E2`GIM>lQY6v>{A`dz*LYfIJ^h(6qsat50hh%WZ{pa z(*OC+d}IVzBsAY$8&LmC)Vp<{t>~Y(0&S~*;WIvYnAW&)QO^z4QH{1Q&jgt{jhN&v z{s=sNSf*F&g>{62JQ~+#minC^v}0ZTf#~hN#=8z+4|aDmZrzSItLNTsez0WN#h9E- z@dA`c`(%2_*OjTo)fyUy+`lx{|IOEb>BZOnmv;TXox4D%>VIM@RuH>@o$|Nk9A67* zsV;!XkoXY^*oivmZgN7vg7C?l#C!Ien~LB7<+Cl>+rx6}MgsK30*A7K#WCi2*%Xx( z$z|&tAfUfMPc20OGnr&`6)^*m*Gd*%QCUO!BB{(Nh+ue_q&yeMd2ZkfDBe#ItFGFh z_Yt)W>3|M^_n2d2TldcJoFQ6zSf_rCIxldAeZP0p6~V5RSF4#*>#>EF8R|C*pY3s; z8hJkZGBq4(8u93y+!EVHZ{S@UtGJQ0q+}7L4MUz`Jm<85`_bi>3tLkC2`#SUIbbW{ zsT(y40Py&!1!ihlJo-5C(sd&D1-hheEHwaa+KCxw0)9D(kb3PI7dicrvvph9(8IW= z(!n-<(Q{X+CK%9zN=6hy-)9W|!P};b9R-~R#Gf?J5EA;ufB1$NDe>Ps`K^Y@ z9?jG(g!CrYH~2E#0Y#c|BhSq+aXL3$!&?T#dOE(8@iROMERTa>g9plu@%*b&P;PXl z74kjHK;+XPYSX!Ah@-+di9Wou|52Ex3I=76e0XoqS_|t!MLp!nJqv$nziZ%%g z(;G$j1iK>$iFU3ivaeNn_V!NLqs^asRI=%j8~g8L0v2ZO57vaYrv$EY>)$WHWK(KA zRoSkx&vKE4WdK|rg17^cmoLy&UQKoCHsh8V0@C6%y$t}fnNmEigP6qG#2|HUa9ujP z>)p98b(U0L1Ta;x9rHHsorO8MNambUV|}Kg{>r~M69)F8*y@R!APx?>*rs~GV_ zzY`d~;SgRwJ0;zmwR&j#^XAgt{v+H8IBWiiZARhYzb6!Ecj7z$j23juvjRt$vDv>o zU;5JjBYwUaGIg1C&!(Aij*zO*A$vbme zy{^=@PuH?J68uq=p@&3OKF+-zbQ z`7ZfkSmX`G1rSpZ<0xko^_vrhbb6XryRt9080h;x4{|iCeqf>HxN$z9G3_pA5v}Zq zNr7b`(H4wStyKP`Sw9SL>-Iy2vHx$Vt9S32mOHj=?G_*hwZ>o#4=CRoJ#}kYFj#pR zx6)qP`T=p29bW4t{}G!mk|)PXs$Gz(kv-1K=xce7j_%8qC3pm!`MHQxm623euHhL8 zmZXBP?Axro?q{dM!n&d<=gWFzHJ(n zm1b|H415uvI(b1(1$rxa>-wOD_T^~7b#LF^%F-=Z^jGWpJCva)bYyaPi;uQeBC7Ti zb_g}=vG$Z2AqK5#{~SA09OATOlDma`|NV7!)|j`6&DQiDitzmK%ML$z)zH;<3EwBzZxa}4jjPbCIy z3BqV^)a#omdv-#gr`FNYLdBl5xH%=B#cjo{^HrCyZeP>o&;MYSP`TVMIWV`Am zcnE*l^-fRP^HoF?-!JG{I{O+h^^G-;m*!>RpbYno)eVxga!?|rEv^~{OV<8Vih#KM zv$f*Ru9?K~H4%9Al+K&{S9ve)Fy*Pt&G`dqGh)*7_j{cY&$ez1P%_2I@LRRpgOkau zWsRP`8SMrdB(S$zjk=(!FhWOAtK0D2dvhF=c}Q_eX>>S}&)x}QQ&}*non1FSEGDH` z_`fQ*PmJX|MR9f(87vJ$po#-CICb}oN68loVrUm^S6VaA()P^7bSKBj-RkBI3D>v2 z#z(v2l#1G2tW4%vfpdG1C-TvX*P5;`;p%32xDWqUvdMlyIX8*!v;Vc?EV$BA4oY=^ zS^hZ)l=@4#`Hi5QjSCxHxF+SB|BC)GB~t>~;U%-M)97v_x6HUk@op|{)eoLK?vzaE zBGRI4M@juZYP$am74LS}*ncVq_Y!wC=mCYDG-6`wr$sz<0Iq~wEo`~R#4Vh@W3~Pt z6p;TLCjNC`NCTEmL>nV&WO_j;0a!6m*>OOeK>M6ma7MU};ytnK)XVE1nRNfxnxmQf z?}!7)9Cw*U>j3tB7u?|J=aOnTSNx(2(92mY`FVY^-I0z)?qK)$$e6!FQ|GUN)c>xn z)BgW)5{fm4h!|@Q!he>4VqP9||OrqOOtRS+Q=ahcRqGZug*(tT+fF?F09ifjmSoW#yWibvq zdh$pQTeX-^oERZEd{mf0raD4Rb+jo8E$Wl}S@ny)g~Yc*H>1rSmE`d@z4Pm``HfHG z>67c|r=PrJeXvTz^~n2WOM&hZBE`k;A3zz5YUjPk-wfvNod~Jm*P(4;x;14&G$-6@ z!m;_uI*p=Pr)~2~t;YNuMD|Xu<61g~_%-u_o%3%0_UPhf*s+Pg243d3Y>;c93KDz` z)`wq@lRY%pfdbLadt9xWY|h5{JU~F@d+X(RXAawt+%q3mO7fhZ1AX}a34hSI@^9x9!@iz>=eQT zHePAWkAKPZ>`7XFK^mlLLAkEkaM?mDkeAQ8?+nU*50>NbA(GZ)*{BS%Q$AI0uy2sKLScOz(~uu5)%z`9I%h(ko85^ndemLTUZmyKDDujpxNZ|W zbUx#IA=#JZMY=_#^iY~py3u~-p_?7$-@FIv1Zm@h6WD?b@=lWnkc&t%bAV`LUo2+^ znWx=j89we&vUPs0&%JMdtFP9>^=Tf}=<_ExmEIj_5%13KDF>}+*NMoCxAt4rph7S= zir*455L`~b)qzRCo{1mCC%W$8^PW*@II&j^%{w?@`4uPOT3 zh6Z=ML|pDPrG+~>8yg@P$Xs8s501&`>K@wnc!9yER^k=1v78y5y^v@QHGSFIzRC@IEI5=$v&^UJ5x?jon(O7ln|=cJlzs|R3xYAaAq~X*TE6S zyKZ$ph>0oc9*X5e>kkZ7Y4A3+2u0z7s}W79Mu-te4@Rv%?N$<}L`;bp6@i<-bQWiO zDJ4zmR)dSdw5XVs6^GX7dgRQvtK>9qAO(b`1w=qG( zm@iLQ*S+#eOkjM{&Escef+5Taf{BdivWun7h_N2nM(3uadC<(t#m4-|`i7mlKiLR? zg8gN2%0z&XUGx#yE@(>`sKWl1?NIj#?|QE@ZC7mBOf_Zs95OK^md-$~dA^gqn0h`G ziCFsvjx*D@%9nB&U^;if1jZG(pL??e(@QunuPT6T8c3M6%%Q>{L8>Ao0MJ%l!IL9% zt=5li&YMoOLL76eS|jAsr|Q(~<6>F92O+2mBGlO)wkjs#I>;th7x8_ zZ;HzPlP0wgvd|2kTmNqeu1Uw+6zsUG7N~9iaPIZw0qhDbb(0!y50K^G)gro~Oo+M{ zpxOXv5C^5cga%K`M`T1Dx_tkP)Y9?;SD_v;0w*2I(3|a|LpL>6bmu?q!rY0vo|s1HkY^acu-eDf>(V-8t`$fA|K55uNuL;{cRkr%n;4d;N3gUN2v22!I9OX}1eF%94&-NN z^Q7yianh_pS%?J+c&)b%=%e@AuRRVv7fPQKru;Qr>xsXz>P@Z6XSQLucbv#N7|o0m zLM|^N49B3yC%1p1WMpu2Yu(WP@2D ziu{H4BetCH$BgTIIcwXhhe|c6IbE&U-s*GWmHgnRn041A7)VwwLNcX@Qx&ZdzMYPG;IU=qqV!$rx^2Vx|`rl5n?l|Ng@Hh z)ySBiG-lAvqk!B(gIL0b4i73eY#zO>%lut8UPG8Ng3Z}`WeuI^DQ)Sf7x^-j{z=%g z62IW-AIf_RXFc;qSnVqNHeObMsfnjt^ur5w;-{SE-~qPh#eyswq3tgRsn7T|&I!Va(4Byl!d{(O!Gtre?6`hn0$&>svSLr*zsM2~OznpK^=fpem%f z+hIjM6E45{qR*R^cXz_zM9U9a;}8;d4?(c53dhhKoUjhb@_5f95@!=r^7cndSDjH* zN{jfV(X~{z9P7(Q5t5_nEEbsiHGPCYlNgM)o6D?+FkXdvGYQ!LAqK$4;hkBcM-2Hy zR4Wx6?I&9E%&C3N6TfuJn=Y#V0IntZ%M=a4$QeX!SM8ULXmyZFudnzcn!T<|*A|bK{^B5hg+5*9Kg}Y8*Vd2ur@! zm%j1mhTeg0ZaU1LAwTmR|?HF{Yv;=GJpt6Q8FcP z-75ANx*vSQDy{H?e%AEWS>fdmwWUvD(h!%l2jeNnLPz--xgw>=TQQq5Ijfa_>LS*H zwB;X@x%1@`fPu++{m^u90*u!_pbi0-RHqu!9oNjfBGb+IK&C;d4OpHhW<4^V+yOw} zpET93NFPjh>l%E{P|n51x70YHI?P9RQ@DdozPuZJSfveUad8}@k}*WBLilI`3PLtN zv6$#0iT*J4G&NzG!O3Av!ns6Mx+v@^aRnq88bg>{3jh?AYDbg+e9RM07CtY!ye9Qs z^PSug4Y2ofAm$Pg4RgY^nGFfky^j(Wj+uTyuslHrJdd|LV-*vrNKGj@u)dLw;w0wW zid5-P(|VfNJ|%vEDwbrcw1w_mle6$&F~Y>JDyerJ(&uRWa5RXV_5olk!$8ftjB#+P@{Uu+h#^wXU+Rcl?{= zJ=?``Y+weqAP>qp-TLsNH;NxI1@bX(D=;`C^6TIXu~j9wLiqgVfU5$y*m}TKZ^w~b zlHX?}M*O-aP1+&i6TkBkWm+`VN-NAc9cMw&IgzswY72i8W85PDSk*bKjYq~ww6_L6 zpSj%n!sf#En!?OCNPUrEk)+fdgLj5I(6=(UQ(ac;xEgOLhi1H}(=$Aw!$# z$(SL>(~$=`tv!xwQ-~ zriO{b+AguMUFTCIM_=nIJy*QKf?_zYLjfOZKXZ+z089<-xC21R`Jd+o07yVs4Xl|G z=$0u+i=zW}UZ0|j+qKe)bF5$Yqv&(MnF3YSe|%}yZtbG-1b3c_ou?onE_h`+@jS4` zSiM1=pfKe+6|mJ5TR_-2VT*I4++Cndv{~^LBayd8X->#jkYz zX#D%SxTn?AT013o4DJ|-m;z|NF7B9RnapS_(9O*iPV=*Vl+(33gANKC5D(QX2o-L6 z^iWhcLXhs^BbI85Y@|7WPO5Gb9qi+sA%^(%kW3iljqJ4o^#Mz}Q!XtD4r~#krCFNS z-UuE1v`?sx^VAiNM1ojVv!Y=ZeqkZ5cwPP+zGBcamHkGFaH)~TjIFh@^<0zhv!xF<=m$ zN!@KRe+_aIzAf8ON3lnYTaKan#Hcr7#J0Bm9X^?$vq2U3*Ptz-b%70x8r_YgY>Sn* zOK3L|oVJo+imOltK&Kn4km2!NWB^NtK1o3I*HPc88d+^ix<$MK^CqIuQ0>J3_G0rS zWLW?{ik$Xnxeah4SO4wV~W{jQ?mbKt-iiMLIxZknM<8a0X$Tubun5qqt5>~A_)P^XBEzy{z z!?P`PQ(ms~bFF*HL49rsza+`E&0u4#d`@((lWOI#q^IxX4*^w0`rxW9_8*OsVrgzN zBE!auNXBfo?+NjWsPEb@$aaL!B_mBTE~^~J&si_>E1O&!r@4562@)ka3qMk-n@nsO zlq?gv`sl^?kcB9ra&`3K@8U%4cT8?R#!^1sbhKJ|WyGY|WX$>p{aN4;Tr?AGIkz(J z8f-Az5suD*iy)hw70=qYvhtM4^y6R+a#Yh&BC}x;LG1OVGHd)Ex$p4OI-26+f!7v7 zC4zsTABRzu2taSo+dT%aU#G{t81R`W3X$4U=sCpQaG$s0?S}&`s#T2Lt3C1dvW{AB zc}Mx)So7o!hm7&6A4ztbM$cKnAORPP+UjlvAcZEa#1;qp@x7zxTgfo7y3`tST4$-$ z-?LFJRxz(|4A<#2Z~BpyTq8lAGS&m%08>}wV;j#xwC0z z%Ot6BtbT! z3mt}ycVz=rj3+=9<1ulU{Vke%CoEH=?L$mjCI@E!q&eaNpR{vjBfl2)^!;6+voPQg z&pDg%z|ruNua=^xFX_L}%s3L3zs{(P*mIjf()xZ~yHyd1bFE1Uv?KFUp(F+e$$$QD z4xcAQjjW>*ge_3ZTmf^CmDF9t!n;5;#M^N&l zp;++TKfqMiojK6HGJQp17s-_Qa@V&X8Mt@yN%2XDdeK&DVI8=i{+PdeOHTTLV*~;J zC<3d4SUo&9_jZ?F+#k&)S{}7Y`}4XSt<#MRe)AGvIceW)_TT~R(ZqpJE9?J?S9H|_ zol&ptf1LknUf=-R;F`7!|EK5WKArAgF>x0_@k`IkFTSXM;vpwNKJj)v1?WdWR}YtH z?3~%5`Ps3XB0;kg+X)3#nEk6k@?#Hji+Uu3>=0kgT@6E*ib@N|ChohQJs%(Q!J208 zO}Za*FkXy1jX*`P?8R)Mbi>wE_yCag9S| znCC~=hCPdPlxpY0-{lwxRTA=1Q6IQ)%Ac5out53B3;Dyn5xG*Y*TnW1%!TA7;w`cZ zv)(`6ym>pZ_uD;04~ynWZpxh0`a2XrgBh4dFkaMyv#q0V#PXt6n;dlcLi?NSkKf2S zXL7MNDJRdb9_J;${s|-LDL0K^104iRI_@U;;Ti_Z>XrUsMaKAK%ZrDeiXQ)M*lu5M zcBb-W^p&ex`;{H2qVAH!(Gi4qx$j=-sAzL;dN}6nTxoy3xnc>4r;ks4Pe5N)^bDaEL*pEmcC~6Wd@hS$y_|Yp? zi|F&G^5cGpw)e5JF)Kk@z7Prq?UdQ=kxB0MTlyEtT9l{-PwAOHAWraOu-Du>_vRFQ zeugh%#_3Y~LQ@0R)jhA63QKJb3u@8(2A#Rqg$NEoSU*YF@w|CXkG4q74=}#dQ+Ygh z9t+|4``=z2k29lPK<%5;HsZsZyut~O-ze%nFJhgf^X(vybAWgmwN#A^jJt%zJY)y*(xi8N z+HV#!=4G2oE?-cc5+OUTTzpsCF3?kd@?my|mprj^QB-tZBr|iNufye&N+|2;Xu&7@ z;(1?c;BVqGuoXZ71QjGNFue}lCsGcv{Ayl|d9x;?Z|cHL_8WhH`mkvDy~jwmCofu* zEs?-ib}0arV;%I_-J+8!Wa%mBc2}WK1a_}W_ioU4bqSm0kMPh_o{MdVI2yA~-KU#T zVj&DK2#ziB3^+Gcp!QJNEKC)8WF7dH2WBnelk&3j4VT$p@J>C^zB)oK+{apvhxbWT zqSz6pOM7NWSZ~ksmVM>1Uo0FW3T$WX--~vr49u=?2>WWt=o(#ocKzfB@F&*(qT2p% zeSK_+Sw4)S(r#0E*z_PL>UGds1XT4;U*o3yuVea!O6)UkAiCng-1?Q`tAU84Yx3Q|b91Ye zbi~jTc;K#^TO-6edpDDD=k?2_&C<@0zrvCo7G>KQ+=t;HKxIWU1(x&Vx)nWf3avYy z&4qnFZo4`wJmqI>mgtOHIxS>7dj(yaGfgESaFNr9d-&H&-r!Ojje>lNO%y#e+0xP; zxwqNydW#)*)9OOc7d0j4&@k7Wxzqbuo=^-(kab%?j3k3QjOGRNWyebka556kD@Ny4 zp--dir@8i)(T}ivh%D#tI^#ck`$7Ke#Xzr{E6m6w7evr2bf821xV^B7$FrE2G^>zE z!C`;WN$^px8*guYC#LU1jJh5DNmIleh0dDS)n9jbl*~JC&s2S<0$UfpCK~>^Wthf@ zX!Jd0+yKINY=ZD?41(0lfa4K|ZvjpwN6WKpw5cn)m)l|?CoWx1CC?b(X>u-<^E`tf) z@A68aZP+UMeYv!`HmlL$U~`+PB8cUir*B_g)t)R}6iDx3O?Z`~8i(EUP$TECTH1+D z+xFSGDYqvGr&Qn_*cU|hv5Y}|RpouV$HZE-q5}O{Y)AFh;{DlR&|L;U(Ksu0OqlXv z-;v9ij913`%Duj}<;lbf%RzAvNRr@t6C4VWK%p^v{lvLS3ri2mQ3Zp4YFSF^!zSI zm*+6*$-g2?t{RD~@998{ixV+xY3+#mdXTJBUl}Mzj`sI-=Leb0pTwN@aof_YFTLsp z@Cw@G_3j~;sb{T&B=uxG`hg|c4*WD=49P=uyG{s?^j@QT3%3wvddEV4ZskL5k@fHe z$rOAY1t5}3&D&h%CUQ(5Mh*?ii6goVZjx2?`Rw8Kq9vIhwx5ncPcEY8vRQI}x}l11 z%wiUmxQ>R1*0yAd=JMx0dF&zT6y6aZ`7pI|%XHJi0$foHhgQVJQx_>P)lWrxMq|e6u{O}( z!`!YgW+-^*-_oDWy3Tg1b4~{()3THoosO*WbnlZ3klR!u)dA%7y8_=z9-cDQca7Rxy!Df2hzE14j%)p1Y1%^E)Hcku zBLeRBgB)<4f*VwStpl&IK4j?!EabmJp-HZ1Ja9hm3%zmb0yyzh`*y<~${4(94 zr1Zo+sUr?N-Qh5QBVQ%RXeVfH9+*ZgAMvH`2lte-8_ki`Ce}Qd%3v{1I}JP$wQYVm zQ!Imew7jppG?HI=sT|{n#meCnslM!MZxOs2h@G2N52xk`In~C}F zfcB=*NnXTpfJ|a5O+Zu4$|ZM6>L;5G-{s7RQcTqlSGltZuU?74h3Tr!8AGPsu6D9g z6+xi>tPK&}2`J8{HLi&<8f9i%Bqbq$0ui zF9t_dE_k|j7HtyB-uRIBFYG>53{Wg_!_Y=`>ZAtW6PbQYT ziV2>ishW2-sjD%FZ}h6XZ|XMKxYIuHEQr+r@1edbKy3hm!TYKG@4a}H;-IDBdVlNI z%s66<NSu6GBGokg^`xc11fJfI13Z;;o#*uZ<_9LvI` z`*gv?s_O*DkLAon?@u~l(>bK#tH>23m!zp5qu(2D^R(mXyf$NY6PI{z>)|vOEzv?F z0eJ~xj=9Z9io2l4Xb|Hpj;?z#Pxa}K#N3xk5w~yh<`#V@Fu>Fg-AE){br4aPKD;QkjE<29D|)AC#cJ$d&*6jKLid}S z!ehhauo=W04l-#OM|8jsD`1~P>4}xgmM>g>^EbZn_!#q%#ks;0B3&y!EbZ>DT2Lm_ zQ^#3xdp&&SwKgi5HU&8aDfVGMX_U)6ZDnPJOsq7t4g9Erizl-4%pQw}bLbWaxqonw z)n5-&;`H%81(8s?BG`M;cr$wQL&=F_Vaq9eO!k34X@p2_>)}D(b-okxO?C-bA!BGy zIUMbc^APWrhlfJ-Mx@+7`?i867}LJty@>H3OXxZ?;Fdac(q0f70B$7Zg_-9DK|3M3 z1n#w*O}vf1<~eI|JZr5<|17j#n}Lw3|PVeHDWRdcwNIpgv5&b zJ~*je7lHgqqbdzU%0s(WlX(v3zGSvSJVu`nX0MmJHAerEW-B3-NpBstQ$_dNmJ*1@ z3r~O@-!x|RdG{(1&I8yDU$GcztC0Grpuv_o&OuFes3B=4zNSV;TBA@eFr~KTL4c?A z_f73}@2J}Zh6;A^{u{}J^TR!*b(K7Ps7m=Wxdet>VO##>_pj{~P~senH>bRQpcRN4 zXTxwxxEDy#v)XgZQzUF4_S@uI(n$x^npvO%eZaG<%Gpe#(X8VOn&Y}aX>=FGs>5bj z6EEqLeCD-f$;~H1S7{0vUTk9JM&Bn$4;3BoMC8==uJ<8o-{Ywh$wb`*W$nG(mzz~E zi=nARW2$nQ6Sa(&s<{<(xJ4~ll%TTgD0Vnrx{a!vQ;2W4QUqB<=*5-j?tQNPNs~W15(Rhc{=_O ziPkX4mIpv?NCI_)YLZtAeKyg=GOJ%7Q_=F(41H$R((et4HPd&yKj=Ad^}&kc!JKqA zX!`0vNYAYO0X13)Izttsqw_RM8~+@x)YtSKU5SNGPZpfJPdvIUk9k^7Uz zYkTsstLM`AK;rvuqv-pu3>Ey(W(2GMGyR>gBdMNgd1j@#wT!MxBz(g<-71sSW-VaV>NE|mij#9 z5gf5=oUdD3cpJO5Zl(SJXK0dBO@>Oi)74Tqy8p935%zkeL)Tx^$lJ1emu7#T@S~Y< z2MPF@4BrkIg}u8+TW}|X_d6uI4U1j`3tZ9fUCw_bS zjw+xnxQ4z!No)^B#b1G(jd%tt(x_X`Y#lWUvbP^$;CbGac}hD7|307 z6R3p~*=z;EUoxly!!tybu4j|c+0sG`s`BoMM}0jU)>_wNC*F4*b*a6&iWa8CC#}ez zX^g51w>eI6F`?!`=YP^z$_@KD>@O&ZXpU%$Z9jQjwgy;#OA+GEWsp5wY~-p2qQ+j+ zg++Y3hIo&5V)1+TRnInZMSn&9_MR8s^5-(n3LCL)&NvZ22)x2^8&`4q&lYpn{m2a0 zRb0^C)PH$l1lfVPl^b+-4SCoZ@gO`f)Ig_lv3mc&cYjdhatL-{MCVp{zmY_m>t2v$ zF1hA43innhq%ou|i#1+Z=qN7aEGH&CDv&sH}I(PKt4A#-M}fxia5@sK(7LC3E>ExM@C0 zHSEv(zDyDow<*I&sMNui_qT&q_EI3A=SmLF=Dj&}*ele-+kI`!aPP7lMAbNL@QQiD zNLIyWi7gOL8I2*BUX`90t+uM2SS4xYa(_T2kK2#^E}T-Ud%x_OhozGHSbA8VucEi`Js?AMH$HvdSCD3kze_2kn z$P6?77y&Yh>tA-2Tn67gxywL3@gjlRSJKcb%ovWsbDTsPo=n#^m z4aH9_4I!}lND8}9e*9>EcaPe=9LIa(<5{gX78OLq=v;%|ylD3Gs0h!%tJSaiij2vy z_tQ;8==#h;G@ASh;N27dfA}Sv_G*x*|>mn*sr#(5#-DtM$*2PCI0LXE6 zyFob?G$_xJc#!|K9B#Aa_ON67TiSiMyLXWcFiB)$pJ|{JxP;kA$wbO|zS}eB&@hKt zKAKMg3B~(g;^M>_n$akUQYhWs|CsSM(~^hq?-ZO9L4Kus8hyDVx9kCBQdhDhi?{Yd z)rYB`e9X=;sK4!rHIb%TY4BB=w9`lL%dY5JbQhr-J*CjqA!4@{ct8I-FQhRhRPXr5 zJ}cgiRrKb5uoq!LgZJt&?YWm7WC_pb{fd+#XNe<->TNATw_Oh44S)7_>?{X&fd_xg zjqN;2(9OUGKApQ1_?ktM&(iE709^#UM@uHL4ilV%BmL(Me2hHrk+pSyxk`Vub#6;_ zlNg5(!?%IH$j++Hi+{JWyNglMXj6BOh{l}usLyg{v>AUjc<=a1>knP#>4q-&QeS9WgzRKR zuR(rx{=$Stc*?k6BWB=}YvWi6@_^?i1-*|n+5 zcHe3CXV1EIc0)x`foy@#o}aT|Ir1CVOu&6hiGCW+2<(r|#n; zt60i=-0g9es?X;P*>&$l9yKsL?HBmHk$11i4}QZJyulY$hSrY0-2Xn4Rq}Y@o)3Ep zU$imzctNj{o$y=$wPd93B@2|{CnDFU-jpAoJvxik(%N(^)D_rL>UuWvB)^FEMd-)3 z5?)T_g4CX2CV77H+45HNf@yOBUPuz-XZRgYZijM9b$<NIwzTS$@IJCfX0Oh zyX0(1(lIr~eXKJR{k7Qws+C>|f+-arnf6Cg?}FNM{tlF6jea9Jys69|x%Cz?Gl|g# zE1(tG$3|5c&Jm*o>cLm?LG#&Rrfnr9kn}&~f>4fN`*ZKpRY7yh^3nJWoKqFNPoeWJ zO0%7OR`jVu)I!+z2(H_OKlXjQKB4~I&fbg5OfXKCdA)>^D9`N<7~Dj#dn6dnJyT5o zVPH5+uR4;KciMpO-pUyhT`BHViOTF&77`2LQ4?;p&pP&tJ9pP5q9nKcRCD3it^@&k z;^FWM?;}NyaSR*|zDe`;B2%*-whD?wfVLQh83gNdE8%aRhw+o~JjrrNL9Km@HjAfP zDmaw80vDf7?K>bid8LAG@ZDbt{r{z}EFoZOt6m5}WCD${Mqk18-wy){*q~7Lrg2*R zcL$|mf0vl~wuhG~9Px;vxwNqjA|D=Ee^*QmE}Ak8u)etpq19w zOUya(;F>728Q)$zB^t+7ZShe!2hq4TUmfoJh$8l&M}2?g!-VLeCp1yXKG(h~o!ZA8 zDwKcl%C`f>t~CErv6LPE)jF#E@9Jj4CJ4H$<-VV~0_w0w)zvK{~}GEJ3dN{8=HRUmZl{l1prN%NOH5MdW{Z8$CoLQ zBmwLQLXqeZ-C3=sG3r)%4A!&Uo2Gu7^OHhVy1x#^=@>4jil44>g5ta-$LR++x&VpHft<-R95*tVykG)7IxY0~0OD83>JlbE*3= z*SnS%rP^3*CO?&yR4<<&;_ps^6>r#QkvQ0~1#eJf!4CLPxK#r%yWnMH{h#`7b4qh;wL9=AbezRTPZ) zHiW;3`(pxoX>HzA6g|uja^#3PC|Z|Mu>sn)z=EPme}$wAi<{xoS5c zPAsx`LWzc1unXNNG^o@2Am8Q%)HU}$bv^VTyMUmobFZd6JJ;< z;=zVDW;}{$I!6a)$EJZ!z_m!!I#`p}SXPk#=@o1!0Pavk?Qxd@HR_5!*=8H!Y(db= zy5S(loyV$42R;+>KBE{~-?8j1Xt;1a46wfHRS3eI7@(ipv}-Ul6Xx=K&iSg0dbce< z!HQy4(K(ld-&mz5x(u9oFCI4MKQ?M`C-uJim!l7c2!D4H4J~v zHI}fE1Cd1V6ia?HGm^n~*CG4A5e&q;Pc%f4^ZP3~HR_FVN3Mxe1;95#SuvX&jrq80Z8`(*kYm() zF_?E8#h908u7G^A#@f81DDC+uW_uC<2U&VX-4G^3pDvX=1`KDCsZhihN8A|2sR(p$ zUkS~#216G>k>Te+VKVBRLa!2$c)yVEcr|kP1?a{y)hg>x$;N5)3^-Wr&{=M>f0iW` z4%_~SA%}UGei*SG&{#>`s0`upMSNxv_c zc;Dh;UrmEP!%u=s(zEF6{q4?i+{AdbYzC@~xw-4j=Wxu5)4klcZ%y9;boUqfhtfn{ zs+v{n*8nb|%;W_dvXc-iSrEzVy-RD+r0OS$61t_o9i9gctiRyD0&@^@Y9Kqgq83%jq`80S2aSr+ z12^#eKUiEPumjJj0&Smh7*ZH|6!NvwWTyIWkgm8MzJCbNV7Px9>R-r<|Gl;lgdVHW zY=j@9uK71FzX8KcZ1~G-| z`0vM~!c*&?IojSU0-tT00lAe>SpKQ)eco2=hNyL9=UqA(Va+w&vnWHqcmTW3K}QM> zzC#>5gKieP`A^9Sqg($Dv>C_t|K+x5_`4(dZ!Ap$E$Xjn$p7$@;or0k|1q?KNx$)z zrv^hCnhh~kLY4#3ChBMSB3;- zNoE$ieN^MqPu*2Tj!TgvIV>G{%d>9dI-b~JeOY?07ROEpe-qEsh-$cWd={j{@t6gC zGs_!cr*j5ejtsh$g8BG&!u)Du+Rck`LbE(a4(8=Lc=w~@J1e3s^=^B9o7v);DNilm ziAhF3n)^ny(s~i|%D#o+d(!|_sa76U0)3PzGM;>vF|{(#OL~q;4@~43HQiM0Ziy8b zbv|a*i-*NKxgv&O>CiWDDI}S{`Ah4%ViS=OL zV%a}v`rnmPx2SR9h!UY<;=*40Fv~*_?KOaKEx-n%5~*x~*0CSWhq*nPRNzHBSHXVT zhETi!kn={tS7he`q9q)$STaUBR3Lr;Q8xlZ&bi=Uh}C^W2td9d&%vDlW^>7q#`2E8 zuggKH4HQ30hxkP-882k=*pm6+51O$Huc`4flDJnTsBqAx2KD--jZy9xQCJ9UIEES* z)8i1dqFYxZnKDaZ`Hq-;w-;K!_7bj*lIDay2=()1tABK2WHV|Oo4Q}R!&M&GRE+mM#{_zTs$X|L!cpB#VdK-kGlw)y(D`5?>X zNtDcrZExO9F*l@KNQS^_!SicWZh*po-p5bP;g)U|$$%A$QnCyN!+1eE<4fPJ-C z;t`8xHRsb32lL%G8azX&(B_?h?WEQ}efji^Am$+nyB#lfiz@c5vVk@Dm!0`Lsj)S&+lVT5 zs=w`DG%yEa4Q|=TM>Id*Iy03P?=Hjpiaz$j1BLzbx7nd%)O&@yW% zOs_E4l{*2XTNWD1s?=l%Dwc86$7xi+@SIt|(cVTZEF2S^!N z|K!|%-6heRaTG_cr^TaH-3(fO*D8wL9Wy|C$?dP<$v+evWt09m-~DA^zxgSB`9;Fd z%YVJ+Pe~u9P{b4z>=}#xF%tL2?1;J=i0u4d)FEP!3%qM6Y{WH|jkd$`m zcfd|lzaiF6H~*l4fr%$En8hw))s(~Fp34ZD;FG*nK`}i{&;mFb-20UvfHOj`;FiIp zN)QCkn@t4AI^QI?tqPqHa~PCr+_`D<53k)8gE=7uUY)2PM9R@;v6L0{Sb_7RrEgL!n_xM zqoTIN~gJH5jDciZ=m}OCCdYu?BURVn_W$h;~G>M&Swr0{ylcealNxsE(*9 zwlkI?xrCUq>I0t1kQU5a&-*E6EcI3!C{gedu}HA!g^mbbi#LA<8HyGo`5FTH5c)ZB zqgVdgm9$i!8-X(GV zW4FfU!$#0f3@Jd5L)tLs&2`e<7yFwjpSMe$dyBZxdhcNe39#JQNdc=a*Y;-V$}5V= zLEh2Der4~Ios+8xcT)(HkpAlDjm`-@k0*EoDV@zrp+az6YD-Y-6P?2D7N_1?~SuEV95hXo^j8Q2K@iSUFIrq>tDIK#&g zm8=BRGb+987ra!{htpq6YX%s4x#CgpT!5;knYS5YDSq6s-Ai8lxrayMjBm_ItaUT@ zeA8va;0WChZV`Ed63-ex?u%FEMpnf>W392wAxEOaK*$wJ3826ctE{u;ClD>IL?mTT zZJ41;=}yMPgPbFj<1Z(-(bg(#dBbL`yFObTS5Bc?b zO=!b*wAL16bAIlh2#(300ng~CL6PSV8n3(pAb{u1u&b3 z6{$DLn7P%V%Ie@3-);fo%sfc3FePJcdfE!%FHA-NlHcnYvQphYAG`s?zZ> zu>al|a-jT|TG3%!@o=c~`wL zyXW%2UYESCJ0+x|Kk9Tbo=ac9cKY49^pXodQG6L@I;t}D6u4>MgG%=wI>MEtKT!N> zH`q4w08oO1g8rz%@7&~tykAaQG7vUIa}&6HFg$|S-0dIqDnR=BGnnlAH=M#;pVCA+ zv3u;Gm-hdl(qK6D%j&2PMCsGZ`qb};IW~8H=#SptQG{LK1I_m~8Coub!hr z|A)^^*XVcoT2f<35d5$!-EB`4tmH$r^4C^ddKXwLBTJBgDNImBUR+Mho8WMqsIP|4 z2`UxaZ+r8Au%w(vx#PZr`yfJZsO+lsyetO(r=HRl8cm=>+zMRk`@Lg=u%@Xl^+`d6 z1rO*xhS{GOz)RMy*S%UEyd9q$WX0sUuOiBJf=nt5Q%oLbX1JZ)=n=bZShh8&z479B zam(c}(G`K9JL>5IK5Xuy@2)8+`(KeA3rIWDTWyX@{G2AxZ>I5fAxR;qv4_?lbL2|Y zd7j2KwbDm(7Azquc!zL#D>rA+!i}xcTlbQ+3*u7_hCSvew!e8v_wtpSuhg4VC&OM} zQ{=qu$$+(Zo;7M?0(V!|-tj*DsDQV}?n=~(DZN-~vCULu(W-^mWaqJYsvw?PjD>~{ zdW7-f>;u2mp9~5KDNSy*b2dJsk(o3x^R@d{T(>w?y_|(@giky~JDf&=9ZTs8iJ+eI z%$pDkzT&&bmKmxMy?H^tKy%@%ZSL8Gvyeqz|LaA<8Nm)0ub!qwxuFzsu$XGGi7+lT z+zRsD_pC@k1BVh3U5y}1Gr@d2A6xpzY8==2o^)Ig`u2zpG-?9N>!E}r0n{C~&qVeH z?_=y1=a0g!1#P@absjNz6Zlji6s1jV=|;?cW!>3CV}S;lh?Gv;Vx!-KeC=|RRMne0 zG{M}o%~!9SOeFJLfHL8ABy?*8f{XWbp*{tG^qAX!>4kQT774N%pt^t3^^93%fj0e{1W^9*yIs8ju)PV^D!Xw0<0>6)~u%lw%UZh*_zKJmg%u_6A{ zo#*!sw2(0CR8#T;sx`P>@|LmMxQwd@IJfmFk`#HezQy1PIUBgS>Wsp{TTu@2#(lMdL*kDbvZ=zxuwH> z(Y-BB^jgO!g2fr_=%79|aH|vJ4`c-ksH}aNei_(@b_YEZLrN*>hQ1ZIUad#j*f3@H zXdiS|siIqgZ|kn*rSX14zZVq|0lfm~vf(pxEBldJZ~LL%*^`&Xy5B7b?oCTj!<$}& zZc%T-l4bP*UOcG$K_lTeeFFbBw>O~DfnfC-&Q|lZd4eqxBOYpbu_N=jKf=9%Zrdnz zQ8dn$y<1gTNcTLw>#3s8g{xST>{L9vs+}?YP3c&)Pmtb+Pdi*QlMJWA*Q?rmKls+j zH#I4i4j96g93!aKemCWZ6`(}DJjvofC?Pi}g$HB z^r!n-#P1{sRy6m4!?^A-Q*)izmTw)0H9pIPPHq&UHud`?mc?fqa&COoaynyTH|mta z=@QdXJ^N|&8th>8UFW&1Pv1G-@;U6UKH5{@0sRmY3k{QOVI?VhORkC%%O`MTqiU)C^OG==8gg{GS)L3s??BCu3s)d~589#QP;sV~Z3uN=93H@wljcWRI5$rz{-voA9a0}l9$x-(UXL(~)M~9S;k?lKcliNMBXF3kK1ItQD zP}t2tizQZ}M<3CdR=D?``7G{IMR^y#)j;J@ z6z@r~IkbFyeUFDve_C?3b&gquzBS7w0bszp>U%!1(&80Z*t8_{{g8cw_VT40Bry3E zL`szjX{tF4+C+Z`Cksb#?$!3P4yEr*5^ENF99)-Wxiqq(I}WvE0`t+EM5DC&9P4G- zqqgkA&?mlr@tslLjVHZs`3_vtAZx!y)Tzz*)}E~3Uim?DuM$Pntd;@eWE_aq%DP1D z9;1dkzH!|4H=`m(0KE)7APXaJgXrdmAq&BwgBvmztCYJqfr8SnkUF#F$M;P&CZ7&* z++-i!%oghQZud_ME065?u>sj;lCT*lGCtL!HHF|#Oj5Xa@F>Qyqn^Ii(0h4rmDF#4 z9HBOvRN0SSt@OvJp|odKB5hE{81w}KBP1qv{EOO?7Kb}vN|p66Az@Qt97z$49> zw)2F7A}F2Adq`7Q2T{La6r!zWH1<|W+z!9)t!5sC35QGDWzpRX;pZaOt~u&{^GVcm3@OM!mC6demP)QxZrw4J5WDsoR@{IlryYFMNivkw*ZxuA9Exds@z6 zLSSY&KyNPXERJL`oPZ9bqhe}I@H&avj8R9AKZ_g#mAf$J+s>^duv%P;&+JYsAf}T@ zlyiCbhTA`RY+SHi_fb?K5n11P$3Sww#xOLz61n@UM4+s_@6As|s9zmMuK_FCLNow& zkc~09$6rmvuoMIy91J`aMBOm#mw%8k^eM_ZSBayHWxu@TXqiNfkd`c*exOcgYI(%Ly@wdDANChw-zZ%P7AR9aY4cD;;LpD z?NjXX>+Fcxwa>oo-XQ<*zr7C*fnST_y6YzH|HpB&%RTwVKE$qs^=j5al3z~YSQ*kW zagUtsjUg7NL?mYcV)ud~AK2R2qF{`V^iK8nSA2itvfha~MvMLrh}9oxl$jOzwBZ=x zST7XE2R4km>kWo*Pu{xCEh&q0BEFt!7Dd7czN!0sha%Ht1S(c+0mu<3J4Agd^@E0Z z2vPT;niha|zoy=6j`+*45|sIF{(px(K4LGOFgvcK-xbB#y>#Y;o!eXP)c6mxtQ;1I3mgU z0ydgrw-w>Ge$ePv&s+p5VV${~h0Ug*Es`>ZgHY%VO#dd8+cxIn)k*Il z^a@brSy+_CiD*{ZVT)!LC#iw#pq4nL#bCQvP0&G9jPQ$^Zo2pGG2>u zh&AkZKPjptVP){e<0H4uX-a$w2&Q`~%d@dv&T$9ftMkEUehX2LtRPw^BR=%kOX5aN zSKUgh3(&9TAdH0sg-f54R^AlC+AR;ix(hp3%Q+QQ_Vn=W^?**z&%~OOuaoN&j~g8s zu6N41dZr{_V^%?&mTs`3YM%koA%CB=$lC$h6?jGG8e0MWR6w?;bPS2^D(RBZD}BiM z2R~?xK77#FCSTurJ4L{zP@JUN&ue&@#wudS*U@3+TOvJAUPZTDu9MFxsKj9Z!6BYK}BSTEv}&UaKS+1bMr(?HK1@YWnOH|DW z-k>Pof7j6Ja2vbm(97i;;@Xb!!*Hz^t&_`Mp-rMwu$`$%k`XIR=w;P}AXI7_`R0C6 z8A~`I0Py+7-_2BpE=_v9qhybZ4%SsBdEW68XpKT20!*3RO^h^yAlUgOKZ(bCf+-lCn9s%D z_A>60h+buKj2CV-aCvG+)s3+$Dpi0NWQ(xfcdZ}{=rOf+lIcho0~i7@qo7=Eh$yke zHxr#$Q^fK^-SW|&LI8?houecpH!z6mR!*S@Ap&;_Y`J^Yxa<$SaVkdEM0$T6Sa;3Y z51-MYbL~gX(n5m8f?n@Pge-3%eWBA81Te>q_Y1HYe3%yI<9*%Z@cEI;Y9fQF_)ASv zqYStRTUq$UP9gL9;{z_`OdpQ;>9Kge8yoDg9cAs3Ob)y~U1Mu(tS@)fPo48iqi0dm z=RVh?TLp;t&fEl}n>TJ(h`D{qTD1J`0Eav-b3rb1O@}&g~2N;LX0>U!^Wm5h8g^8jIgLc!=h?T3a zJvB$Qqy%o%LaG%6c;9d47l_p*EayFz+R8|r(s3CZtO5c#NP-MQCP4kZ{K<9C$hxp& zP{Rnpg_wd2?FS9Wp;?JWCsCeyqjKNUsRVLB{?wLY-x6CrF)+YS;TghYh!e+M0mJ*L z35`P!m}))k{A8xa_FZW)!||>mwAa|DRAmQ|MnPAeJDue@Mgf z;{tpAEcnPhgJlRT3jn)bM*-x59)c^_rPMu~9`Y%G6S%FVgpg17j}Z=|K`InLWdl5P zosrV2KMRtfRZkxMDXWv)5R9Z2=KKtFvy6nh&?fc%TnJ|AEIWWk5C zeBggLl!ZRxJ5Tnz%GuVi8^kOpwoWUJ*lW=ZWp^1m$P3s5-g%S~Vfx;9htrl~yx-rg zTGPZ|_+wQ2*XNHn*d&=WZ(b0U=eL)j1x8Gp?gBhs%{P|~35fkeim8kzm55dz8Xb!G z&4R@bY9Y@@z{Jbf9(()a6Wx1b1Fhc8##Gz*<_Kfn_%(LCC?K~}_o8-hvww>vLLb*y z#qkV*1(RiwQA$KIWmceQGe71aQF8>a&cVCg86u=a7g2=1!POC6=`f?td1J-h;AK1O zpWx-~9VEjLqqXj4)Yks7H5>^35!lrOreByEz)ip;tE1+S<*oe}pZ4#8n3TV;%yiW6 z7vIL>Xu0{lK4^)Uj=Qs!%+Bm!joxlzo5Vdo$jdv5Qgbcn-|nnJ-3IF6G7SQZ_!4#! z)5T(?7?Q)b==K`|Iw!5;qk;QgQTZ|2T!;J)Bz)?MXw5n3Z!eVV3L#P}$%s@pgJ90l z6R9r=@oEmd9g>~zQ$vBKzdRUaD7DoD z0n!z+(hY0$3rO6l+@LW17T3?C#ux3temKfS4FRVF(a2K!r~Wx0DNgeoz(o}k?=^P3 zfr~?)$k45pm%71z+sOK4=3Pv@8w{uaaU}j@+rw07fZ)RJlK?DGtm zWmnRRhIO50OiiB8eyNU>r$|p!F8t>GVnU$#bJ&MgORvUUN~f~-q?Sm&?s>Z0{aHmM zqeS@LZj3Hb1?%&rtTv-yKsnFm5^Z}UHT~(#jG=Up#++Ee;g;S0O zPv)+Uhi?ykvJe5sxdGOcHMmx~#|>6OyvN^@zmiXA9Z_Xg`Z}Uy(iS&(_1uw?)7Od= zF~)2y%siMoEFswH!*h^BwWo=Sjn}-L&5do3A%?F&8E-Yg5GH|S^mLLCVm8%5 zHXh0JFe>v_87*+i7~(i#ec#shwt%A3>ANxB$BZF9`)-~4bQEq{-9y&O^K`)0#zvqI z7osv3nRx3jr#{FTRdbT;Edcp)!tF^P&-J#b(>FNjlZ~+X{fFQ;h^aXXk$DpurHoqf zMU9*Jqb;NIb|-G;B3K^XktnZat4KDCljmq}LVDSVSUO9uXM}I6 z-toH=r-Q#8@fFXF)PWYhCG@Xd$#Na?h@^cR@Vdk^aIsi|j#~CG@ZI5Y&o+dG$!b z;SlOh0qd^p|JbgM_G^%n>?%#5zn;FR{96eK*;W6Zql5mcY-$~esRhBm9a%_bzoZ!ZAv!)t#1Di(=my&F9$p}}YO^%%{V^n%9r=r$-Mc*TDk;AtrrQpB}(qOUp}3!E&FGYj7qEJ^AsWLW3{4qRfG4ho}6@j zxJ@(oDeczL9?A7;4Nd*}jzdOvmzF9U{Hn|K0P41y!Jude?d!Keac^VU*&I|@MN*< z$&egQ5j|L(+ra$EV^ZvU{H&{ zJLGd?p_`j{vK?;fVnZI6M$WX7KZ`w!v(T(uSHzG%bT$JFi^1ZpTDMA}pfR9K;tD*8 zBOn%s5i)ExUaueNvEO4&Nyg7fvCXb^PHHHWHWmvGH*~SzSU9N}SJ1+%5MyaV8hmm* z@k!*s8|c;$CXD0J>NirZ!B%=gCpqRCkqhCFLtK*)%|6u3U)`eR$V&5sWON$OGHp+G zDm3E_JG?SISpVjRBUE#)mT~{sglZzw+Yfwsj+qx)CfMzfLC!o97m*x`+1&fvDgDDt zOkM@iLYQ2=ZC75(*0$KXqJ5N*l>tA;wR96RTdCj$T<6&dN0vSV#%Hft#BB*;pLqDco$0z`EQZzJ^zqA= z_yERu9Kn!YZRvfz5yn*WO3@)qm*xVP-l;9+j^mTaUT+ZHm3_Z2B^TbPsX2Vl5 zSGFy;A#@YbED~#^8`=34v*}%M72|4^H=LPgNMhB9h|!$H1`?l9r#MVWQ~fr4*l{5D zK~KrL5b{G~FGxng0c)k@g_YUJOFcX4!zpIzF=AhEINs6gE9*8I^J#Ds-OK4ADvj>X{p{jnj>?hn9!EX&n8`*rcn`Q7}5Bj}_Z=CS< z*f1d7`x+Rpbx_r2Kia?vUOvUyJc77^DVv|D=e@*Vr!_Li+Qj(1f3mc03l#L?d<9p3 z2PYRIDl2AnF57Xd#DrBVO16eb;ezD0M^=QGPcQmMDnPnvUs*3vy=zEi44|oE0wXx0UM}@jY z9SpL2@7baJJy>3$ba(f~ENqhOQ#apGBiP3vKfA`Bm?D{5M%66@{vFLdemLsY5IH^=v#jD&C~H=+%F-KBy`)L!D)Qkc8sS_zCQ&S3s332%hs2$ySuwP zd}nr-P!uQX&S#efrIz2{wjBs16hL-Gbha)|^RrI3LWEOb%DY1I|7z1QorU2l9vYHp zfqA)b$iKTDtiu;`k6kUrJ9DxYFWo3`E_L-I^#Ef41s>-AJc$VHC2M0h`|VCTU{85@ z$)SfNL3iL+*S`O9ZnkxXNy(n>L#&BE8>^r|`n@gui*U@%9xktook6M$y;X7_YtU^U z+E+QYE3f(d?_FslbIS895b#r=Tg!;LN|vMl8w*FX`(DcUwaZF&HD_<`ejEQwkVXq` zpx1RVhw=W zB`LeskO+1Szttz$DwG;5Bs)eYRgXBL_?GBl25R$BkkZ0HFJtauO-!BhHkz&a-1jD7 z06J}dpd%SQ^9(;5(r_Y0>}l7?^_-L9+~EXNh)b&rUa07fh?k!rPR@5VTu1u#^_SzngWA&4XjpJVPdtHs^~e(nO*58aCibj1TCl))*0I=#GsDw_qj{UJ8@>B zzEbqkWG9jl#)6jyCQ$uT8mx0NCP)%}?2#d_a{12n1Q&@oQNne#_tPJytms)7B-h9{ zRShom-QDXky%2*QXS#K6^8JRO1|Q0uU@GBd)u_v2ZH-&4qY6fgCj~TG+bprNaz079 z@11M0f61$p1Fzs@AVToZpj*;J92E|vmJ%EqKW3H_6|gV)xh!Pt*h0(>+nrJn`T{q= zs+8hbgW-2XwYXZ^ivp}h&hs_9leg%y$C}UzosKa>x^Z%Px)KLe6lDN5rvyypNK@Uyl%L+DWhoxE+OHi@!=Wn1hv!k%J^)TlNS7AL=sTPb5^K zc9>%*JVI3ph-EjUEf5bgi@#vcp_Z29e$eE&Aqa1gi0_Ez4WO%~J0pZeaFMT96Q7I^ zoT({}?#f3k8^g;)K5DRilIhIu3g1bur$Rv|U@;wW5%%(sdVT%Y`=0*WkE=|9=!5i=#U!l==mURP^o$9p@J@}VzHA)=9M<^)7GK(RBledT`BYmJ!9 zAt(Am=D}St*8|fbgYRts-C0M>-K!KA+gj2E`><_lVr=`^ICODrna8ej!wz#&9(mqpbp<5yA z1$g8U!12Mh%$)_-@EM{SN!njnTX6u0=s;Els85AZg~0l6L)r;UHSG0@8Uw?|F{CZT ztVODYC4hHMrCL&J#bW)MSa{7^WROTL-`pZ?Hu7Z=kCfJOh}Vtzz3 zfR|Of{e#9EfS)QD>%ad`0sJgr>i%bzW~A)#?;Q4oPDu`jCtQ}8)l6s7e`9)i>n&Mt zLx>HhPxx)-S?3#i>a+_1N@A00gyHe&2n-q4(o-U~z2A(&N)1wRh&~-mJ*gYpZzu}D zk5lW^wf?@l#1N`H!_FkB4@2aPK1wVCL`~QlO8{l?D_)EgiWr4{Ep30dY*Ua93QB{8a6sjj6+2NYRCrs4SxML=U zso3eH+e=x!aXw+C5<%e-U*A{X8)SDJ)kn&_qH*@WKv#-mh@8d;r^1usV>S)T)6Hv< zxdYp2HlJNX<6=awi@iElq{Uq%W59pXL$GX`=Hv+5C(o@W4~ye5nmrP7wC@(Frz&@N zkTz>x7mwevevss~r$YU180W~-A_1Foq6@72J`<~oR5xKBahR-fX923i3 zND6e;b#RcZTMh4*^UMs&e$gCuzwK>NhQO#;P}TSknq!`y$y#rTyv+%R_$jH?9to8} z{;`iy1)1)u4bI0OoelfMdf01Uzp9bUgDb4MAW@BV6u1%xxBJEm5lh*mLAuCrlJge5aS(1WVw}|6s7=?9^dRR|&Oowin*JHl z9;T(`wFD-|kMeXdsRjuFy~tZ^lquiNy%Y0YjXDN==@opX?)}YQFD$$l2y&26W#~xS z9K^EXmcwOYpKP@jLQf3g-E%0p8J-qOPnc`d?tGE~XYHjz-RvtJLcwnL??qqus?Ixn zYi&eV;*e>Ja{~4$lPgnW%!2Kj=fs1S@f}lq1%*wX$y~iJdcJ(CX>5vrq!)v6u^29! zFFcv)(jz>}M}1fuL`got!{6`d{rwIo@aGn11xaF?_NP3PjTmx^e$WhdVY2EAi3@gH z6i%#DzTQ{ctO37AYhsX=2l&oVj>vn|X~gO)4us~EZ6-c4*n!>l`GwCuiwV+O@lh^r z(-R{L?c3-!5>w^n`X1N_EOsM1^?bypo%_v{dN9#9X7m_ZU`y>N4ke4_t7n|$csbb{ z!fbZ-2~a<=hVShs#0a!fTLzQ2tlHLd)?Ff*KaND_NCTB+1f5j^1|$m+qHm<*xHLv| z&BXmx{DR^!x`&l+4(O4~$uUir$lM8x><2_Fc-M)ydu`M~5gggn?c?MlRfHk1wF>gvk!w5}4 zZ*~Rsl**6zQctv{)|LPpPC?17EEilWg(^fwycuW?|9%}!t)b?$KuQikaD`ymZ!#h% zdrYdZ^Vi0hTIP5UD^x~pdLDhDXIyiYHoPcjjdgWU)li85M{|3Jp?XC zlG3=Hws{7M3jilrPcz7-kZRx{q^|~zA|#0>-yBM{KrBW8DTu*t^wa?W+}J@0w6_x& z_KkbjP1JsX-65}V^=`;Qng7;~H?xSCK(T5~UdBy%bd_aiU;q5P`mPqw*vL$k?ytXs zae1C~fNrK&`wkUJPRvqT{;=elosr(Y>G`M;uBWxC5E&y7eBLtaGNw3ts$j&`Pi@3) z&8iW}?U|@dtbgt4qEm5rp7hlCK*q%ktM-qc!@<0cU;U+_1-@1=CbH+VdOW6X{oEJz zB@U~>1!eul(rH~r&Z|SC(W6;Jr=T2D$hlkLaZ_oeMLX6kKh6c87?Q3y*3^|<^ z0e__-?S34e4`5aP_rs5l3c8k;cTTGuGjf_t+a>f+&%2-; z0nMQikXh4d{VTmNl=PD(N2kbAdMUejC_{m(M2Kl{JU8<{ud<7W!^0;lkb6a+#?uf7pM9v7>zxvx z^4(vsg0hTl9Nhdi{TsCWyU6P+aMk|MVA%KHs61?-{`TSmOzhtu|C^$N(LY2VXuSFV zdJeqxcO8ekf5t>+xoWl7Gz2WP|_u>vP=|vURsCa#!erHwjLeiR7 zlYhkRUhO+H>2Q9N+I0uLiVcp!Ma`!SMdk4sdgG1_OwqdDAtvuX)&iV`cA!c|dNBr< zo51i2+z`gdVXNTjQerKF*fQ^1jupD}&i`!4EqX5*g+}dU;&& z=xr108|Fk_bqD#jc9){5XG?=R2+sx@)+}FrH(Qa+Y|dZ~7owqR9=OM=I!Tit+qRu9 z;)bf)q^GTU%)8C`1w5UU*A`H^RhC^~Tcv^UF}L$&eNxNUgj=Y*FCJBxWFG!rAvAQ4 zP!a~(I}h8KL6W#Ix#Wkqk&TihSoi_M;_9-ew2>>uhYy^JetGIXHzvUrx!{T-t?N>l z7uO$;2IPbmu#e#?c5;{}CQTWFQ()S@cJI#3n-soz12o?%Z#9iB49CUM958#D6A=3? z*1*2ds=6X!ZJPWg$36Btm|D2_Mcz9pZA+)7r+)y&yVf|YC2tl`d6{qDc#3mQK>JP9 z866F0g(`^J{vPMeSj`<60sYo6!uCYo_vUQLsqF4UU!Sci-ezhHWlV8yhMIJwXN%iD zG|O!=oG~i$Q|owT|8icUc^xcgU2PlXmE?iv^;_%1Rc-Ih;%~I+9J}UEo96Fr)ELZ;%z3nAMnSh%s^m2%Y?X3R21Xb`b_8KAznS?GIbQKuP=G zLLYshm(2-z$b0(u!{)9GNvYsaXvyUdYD)G{!eN-!OUo>M&J!Zm2b#tt25NROl)^CbY_AZWqOcDm!so%yt@%y&+ z7$^B*Qt<&ncEy*+8o}I1u z4p=i{y+$gmV7TJRnCSfG$+B?Z>h+%PyH^9I2l-10>1C}f z;)_QYH3>nDk|eCwLpJn*SF#l7>CJxlt47~Z6amplhmQI)D9xMvL-@#C=jUdd2cez1ujF_0(ySyMeSGKU87*dd}Va(2m#e;Gu4 zQ4XMfG(*fORkpHjvw}Y2C(s8d)ajvWazi%iz)~r(H8NxJ6l{{U$mcdp)+NBH4eD7q zLKZkh{`yD0!vL?~&Vwj1s?s4ujYXJ+OE$O7abO}7PO$NW1jYUYV}#mP&{M2nlnt|w z>NX!gV+31c=tnFDAY}Vd%}1G_KY0Q>yJDvI(ro=ZHS&7QcKB|w@}*FGD&*y&WQ(Em34R@D(n)U-!y|UhdY1$>FI6ra zyQ$*Fl9_RP7ztpVhoNMb9p2}nIaJurpDX$Gu zA>kI;MBqO1=C^^a#k}7nEwZw0q}UTLHJ&psB7E9vv^}5Mz2aQ={DoPfvUB*kyh}g% z%_sg{#OVJ;^Tz)F=bQMMeCLFlY1|}AabiuK;oitFw^lUsyb<%hC`(=?g9DI9uF?t?{z-v>GC4vciL zJy&3tW|Fl`%RoM9sWu6MUwbr>>zpfhHg$P%dd7Ft2P-^xciNQE=vAY6b^$^cDw;EP z>!yldi{E8%xBbha^MCls_Jugs494ahuvI8_;-X=IO4pLV`0()MzP%p}=;GR%gADeD zjnkPw@bjSdC6v`FQgQ${3>_%DATs*I#VgmAJ;9mg+pvZu!z*xVDARsW5%fIWB7cI0 z1Z04g9d_i%ckugxGnf2>z0~*ENcZBaPS$*%WfX5zM4vw0EPXv?kCVUc+yZ&|2hGzc z>*FUF>Q2&4jK(;eM1_(MpsdRc*<;`3cV(|-pZeZWH^hOu{ucFZ?tunYV{)4BM*WH`m7F@SurFhL%K+bhPy3)dZ`Drv&}QHKg}c=Ye{#xibLG<_6AqCbA4( zJ15kQS^Th*V_EW4)R~9hx!Q-O%iElS?ja5mjW(Bo!K6m5FkFGikLRJzw0?@APdkM< z_2Gy@$6H>j@UM-$5lnT)8K)v+sv&}w%od9{-pxn~vQ+Gt3HdZGC(wb)xu`$1TGZoa zcUmH$I{1S6g(YFgv|hc_UVpy(0aqY4e~EnkS3eC^tH3Fn3%CVPAPME&f|*KJGN4kcKq z$?0;gT3~*6qn4n}qZCh|N$iLK7|0^*ph+fH6{RI3hX zD3jPE?4D68gcH1SK9Lbr1tgCA1+wrED_NbCF`;`bOf77~N!Clx_rs*Km7HnY#&^*-`*y?`wU1}7L=nGCyeETHowqy`x^x=6bw*MpG$2+UDCkM2bG^!Z z2D&+q)6i*-<+tau$cyXHlJ$LQ`ZCaMkb@GQ6YN`W<8KuEv$Edzvh5t4B#;b{%Sbj5 zP{HaycLJ%9>t<-zpfXpac=}U?PkC&V!zW|O`Zssl)(YHL8$Y0jxaxN-uN_2mPbzel z*El6*=T7iuj|8OO@e)V0wjlsLa>aX$902;l{-7zYRNZUbu{gMJ*T}J0w)U!v1fJ^=$hIdYmF!RL;9fYNdJ~0_*1=5v z947yZ{P$lubvh)Wp=}p>@*+_>j7Hj~t7=OXmu1tqgzJpnEpzmPCH! zUlGvK6k>G({<%CoxE|jlGiRI9Ll(+is6A1Q-y29f-g6)N$a&%Nn;U9hb#!Is7cIVk zO1%7pVk)*ZRPEJ+98)su^sER>7Dzo7@11f_Q>?n=h`bXUMra6?MSRK|!M@^;Ngs2W`nV=)tjI4m&9$RE zVk2b|91Xy&@83n5x@LrvKYT$}^{~y7mdZu&elu19_h)Z2-5X}%>VSx@qhkh#tSTRp`^uEk`!#9Y(&i)$><_xQ{;}=M_se)McN| z9ABw;Lp+OG$KjaE%W@J2smrt|h~lMy@Y2xWfkT7ZjJphm(mYt*DER$vGTyA1efaReKsF zXu#J*vwETFt0uD-3ULF~9oz?3<)xo{Nt5Q+11T@{v&DQzGgd3-nZeKES&H6SX!|-@ z@h=Ls^J-!(%cA`EhDmqmTpFz)N}b5<&DAdGqb>sQZeYUQK#=f%vG?ZTP`_>exKe3j z+H6^-BFdU%XQn7ZlET<0Nn%2heV7okP6$~tNeD5?z7E+U`&wk*_nEQHFf*U;<$gZ* zeb?u?@8|hG$M1U_zvFlOzJK_mqcPsgbzSG{e4VfJb$;jzrI_b1L@|+*$X&TBIm)i4 zx<%V&QN}+rb;e-kZ=*{{4Wf%rd(oO33K|jJn5Pj16CdeK@kTv`uJv8{~*sn9Rp z*%LS>Op`0{QX0|#b%|D_+d~}r_}?bay-gVOM=sAW~f_)Fss`r$u>U+orf?*w=t#9>%=9n=xeh4k*ro1=TiQ-N^oQ_c7?2 zZ-p1s=A-YF8Y&z>Ik0>h!~vf_$7clcxw(<`c|w!OXOZ^pJIC3Q`e~0eI4$peKFn%S z5lc0euIyRIvhy79clUZY*uUh2n=|X*C=Mp%$pU+9bd@8rP@|v!b?$0|Nh>a>Lh{1` zVl2G!z3Ir^Z&s%sAK}T(U|VVX zXK`^PHJzEoV63t%bn>6G$>(MC72@=&LFmSFl6Y&?n;COyq@#bE!H;+ZT!BlT zj+*0XMb^WKujg9O@Ia2$yWY7lkM(+Y$yz8?0m+W3{D7IA!l&JQTkNCw`MJ=D2yhFS zk(bB69@9LlY73-h1nCv@we(?$X!G-TXu&S5R`u!)kB9jOWchviL3laDmN~AgsQeI9 z2cMb-yT(i6A4$k=$U4ZVD&TLt zGg;p^zQJI1%q@iUc7d>Ynq|B935I$ASR<7RlXLAHI(1uCa`Ac@G{T42WI>Y5qQz4f zpOSNg!^9HBQ1V3~_QVC>`yI00D^$<0EiZ)Q-#rEO(K;aG9U;Q!%w6se@#MGGcI(xV z@$j#sG~sN@8y8mO17CvpK1^3C7k#9w{LmZlY6fnhRebYGx6OQTTA?+-ROx6Re7G3Y ztz8EGZoHQpK;(S_YJp@r)QqH48Qzoka97c4v+HR@L5#DA+QBXTlPt972zp=c{^DhUA4vvTRRrQ;b8Q4kp-pS^nz(C2WAu`$Q0SafQWC$Seu4yk z+{yd#ohv`DD}zTpGH0$ELYUZbZ^Kgvk*OdFBhlg88Pu*EQr!i@mrU|SnK#b$c|VrT z_kVGN#5)SK`^d9v-|9o(5eh#Lt*X4F6#nymGAw2g_(1+gy%C@xD$^hcsPGsY;{{1I z^hL8zJ`XJ7Ke~S~Qst~23>+q)!Z2DbBsTnbU?5cb@{d(0csmOQ0C%^VMLJB z!et+*fnPU%%g_In+z3ZuyXVX{{ai37@PT5g?ikN8qE9~;Lv{nN>(5Y51(bso6`}OY zbg9>ikK4zs;6-QzfRZZ=oklQ^vIA2fq&*8?6~Zl!8t_cU8koVU(djz*p)CQ9q*ngr;aN9fY(5R zTGO8u5s271)eG=|Yaj12&uXYDoopu>_ETy}LfU2?2NT4eOk@o74Qo8c43I4#Xse?0 zMHmRpV~nG`yNZ_*@#ULpSxM})sEY2*1j-MUTJE!^9%m|wgpN(c-+S!JaoS8hn`@wp z!DlvW!P1Xbd5?;29EThc@7X;0^94U}D!euoiWg(Qb>^VoIp2eM4KLXCJ!FQ~nckV$ z^i7_4kyy(RjJW9O{r2aa*jr4ImGS+t=*jyOq75`#heazAgnhy9rv01HAXJ z<&J|wn^dbKC9e871$((u9EXS1oD8qSDLGgtGBRkXl~hnNkSltc|JTbP!|KAea&GzfOZO6lj9g}-NR z_ThC$7?6s?t3)ISnZt8s*RUbhFHCKVpL}n1gMWYY!zOw9M{1IY>1%mGxG|Ie&a_bvwVR(5qZ2^X?ViJ0Bt*@^e486R3#WZ8`3mI$lP8uy_)y%;+IO>AT?M2XJBrEN%RftL~z1*>&*o*V{hq zUHY5nnw-v*h`8}X_!i<)W=Bk+vBl35o0_;_R!)>&yFwr#`THn>tU}73O`8oFcmRiu zaC7n31_Aj^>Of`UC|~P(ahRmVm;R+EGS8o-llVDXU(ARvv#?LNF}<*N9+n2FUZE$U zdg(#qRZ_Mt*g7_jG9U8_q7@tbWA4gHF$;dco1RUGxESr-xms~Y>&`dP29f3VRpoYM zf_Avu?UEQ(h?59DGNcMbN!Qp)^LD8-H-w%>q zlW&T)E1zk>ffAwq=)?^u2%&V*ZlA+QV_|T*a~oIIq2#-l?~}hh+rF@GskH6l<@QYW z4UZH31!)-SrX*cNW181gvrIk$yi-Y#i_aao1h>v^HSC(H3%b$G7&aDn!z$bx^p}7U zXJpp;DnYnRFB*CKsKJR5!`&`0P~}RG%m{XY9vkGKFGiY_h4CAlyxgcFFD5#_ax8l0 zvdW1+m=wVa!w_S1X8^vCCS#Ah1Jv%x;YeUpf0Q^<&y%B$3|Ca%G(_5z7(d#mofdVd zuW4+N=T=4xE{#Kuz4sI9&{vpE-w?yqvXeOL!Kog)a2KSur2821DST1YfaJQDN3ZOY zRPJ^)qsLP&yy~E=JniDtaP-^WxDEZ0Su`F@VMmk@^_@m+3HPx}PJE0fO>O-X%n5~K zgLFH15R5N`PqH=Gg!Qw>YkR%X~9@S-^On>%y79={c)x|Hv=ti=V-k1qqV>NGUbvNFLUq&`{cYfF%*-s zUKF2ANwSr*D`T6goup?;VdJ{lu{K$hm)uS%Gn;^7mz3J_7VRj&Q<;b9(49#D@jei= z8i_Ns0*SXxNNPI!rm}*p`cCHd+w1J=m-o;zD!na4`3~Mmy=mIKL4_t;Q*}(%@?^$P zck(f?#tI4sK~f#Oe*E&V-1@e@)u#=77!8s&?}SDZdX{qbY+`JCbfc%wF$eGB+2UoR zFRWl#|55WsGNWa{7GLlHKG5ruQ)k_OY6xZBU7~>p;SFc5Q8Z4)tS&KbY-{L}Fd+`_ zP2Gq@Teh3mdhXu!I*L3?dtBbL43nbxcnv~6{}eMcy|-$pQ*EXIy)b&8gtUAJ9PzQU zBh6L1l&$6wUbOs)$&ZFi+t7kihI-+6R)o=sk=3g+#H}#PM9piMZ>2sOg)}*kfbrg; zDnb)iTX&umo`lu22+myK7m+=_kGRoUj31bQca=hRBV_`T^3Mbb8OO2E!R2bLxBh85 zjnoZ_+OUT*NBA2|NNgwr zv@kOUFon6sXYLNKE42tpI`=V;^7ZzcbL~q?l6TNJvFsPAFW#u6 z@DS9>0jdW1+kuFr##)M{qb$q7#0<_1q2q*i&)qMW$T&Ju@_y9~>i8x7x0-sW9ylHU zL${YIm#VB~N7hW1;>a0SdmK&=YPyS=*hwoqCI52q6nK9ls71TO=`Avb#55p^ zhgv--T1lBBrAyP7%m%dIJxco}Y4{6bVxQwbFcvp}55{~h#8Dz$0#+F@P+q#I-6{aI z9b#$^dGy7M;C5j%lUoMmpKnHg+^z;`McJWW5`~>Y{r7}FnCy^}U#C}OFvs7pu`BN5 zJkUOn$}sw9E@wbGM|hSSP_jI?@u6`If2|6@;62c*AI1NpB4X4_nIJ1Z29XVyEg>0< zzPg%d-LWgi6Re6Q%3tX&1dgGKAS`RrvK?zjNTZk$cvsn6zk}JwRnf5%mTfr4x^ciI zP*PH#YrcIn35+IsY*bXTCBKw`N$q;Q8xMKN_TUD&8kZR{BY&3@#*G3Pyng|@VhW;c z*)4`xNH${yVyNv5aO|KKBZh+i`wh7N|#<}u>9n`$P~JcoYwR%JW>y3+-7JZGt`mDx#Q zVo5K|c`Y~b*YDB&hceM6ragedwpdHSInA7$&Kvi)3L-K(jGQ-0+Sjs89^{RuSnfwc z9C$qGxgftw0ZM0g$Y$eucEX?f)&Lq}796W?78%UrtX6ePj>nqbKBBIIkgb3%@eaC)CFhjNK)i_`k0I8-(1om+!5{+@ z(1zjO-d_O!dB!i|#Rbu~V+xSp;m=>d)57dqHnm@GTjTc)v*W1~)hQ{XUFmj9F&Ve9 zyQr^RBjNM%xAq;Ht*`HQvBN*i&;O>C{H1||O?un09z-#5%;Y3;%na@)gFL3;GVbe` zLaiwbHYRjgndjbp{+X;W;NI@JQMUS*P6k%w|Kpo@`75Yv16Irb1pq7t(MaUCf#)Z7 zEn$8~ontQM2w;`?x+tJv(rM#EyXBHax7ohJh74Kwj~%_C(}d3n&v;!D^sTK)V~9F# zXN7PngB3C|XY!YKFw_N~n z9~5qZa@q7c_&^^%*r;9n)OhJKJON+w8p+HYSuHc_IVmTKdfzH0BtP8hDtD_|^HlpZ zgui!frBR+A5_s}p)x~~}*2a|L*b?|jOw(baBi$84KWIm5d_k`R`-wI$ z3=oS=R;<~wKK4KtYt_xPYLt9}=o6hyVX3PUtQ>YIXnn_;6Trf)z_*$}bb8I*eq@Zl ziT29DJJ2}scA8(fX8CrTqMI{nT{$=89R9?XuNz;A*KXYz3leen^j3deS?G5Y@4VLd zz36%@Q`u@PejnZqly$U4AzZrlE<#O`IV~4*5BwbSk8IS9(VVVKh9(#d=u~C2IK{|` zE#Yxv#_@(%&+zfep7uso!#Oe9;JeMswbLR? zja)+=qmYIQVT-ect=C>pCdLN~(^t)D2v*?+Q`^p_RhJhB?)ngIeuW>D2hv(BKXd+$js?+dluzCE!SrjOahSBaVi z!b79p3b(ZIGmNhNr$c6Sd9HHFvYTJFZ@Ds7dJ2-}>4+UlrL2#+-b^=F_F;W|h0cmD zHaAmC?00zTQ8{cY85;!IOPpp^PZo|Io_vmXD}Pc>J32ziuj42y4pNp%ufa-d7e^^v zc=DuW|JUk9UC-^0^Hl=mR~XARVTOg(8h-Yq< zxxeubKj?1akbk}$#4p`P~BRmi_PiLvyf&_~>B&t?Mtly0~eL5_< z;1k+r!QUd|KwnU=0MGnSEOVWixbBijw4SRLr#GkSbI*(yG;rtyL!EX|6sY}xPMGoQ z>k|SGjg-CuhIw9Hal8uMrL2>yOIr5cKYtvS5H@PM0=p%;&u8T(4=%?OFnVI)A~WOf zhncl)119e{pC6o}eGpu4*R@;!GI>7A|IPig#s$5jC! z+T@)FSe^to#`D5;jbG#(p{FoQ-|-|s=FdC%fkwWMEOc;v0&aHz5Bx}ob!;hhhX=-_ zIU@&#g&1{%^Z+uo(x(?{=6SDD8{|7LpYF_;7ypCl0%Y5sF1p|Z=ln{WU?xhE&ZtO( z5cKJ}WKc0{~DMNSYlydqk(Y6dP93E|0hDj&K=_2GE*ivUg?xON%%UCi*-8%%We#9Kn_GX!Cq_ zYsL>g*Z5pYktx@tH*1lWtoi!@AN|tK1^^ekXyV!baO2(l9@_PiV7+*CYNp2YCgdbAy{S)jBN-U3m z4sA;WY^=h9h07B)+mq#N8!NO0lN_(>sfnzdi*^@trt#y>Z;b*wWIq^v^6BMnZ(7Ti zGa-bq`Gq_^EYU-)>VX%=X{5_Uh&oj=DybC{9m~g(*_Qe%m{dt!Z zhDhmG8rP+)htdJWLsgaAv8FcG)GOxQ7*&?eanD>ge}G(MU0mAbxb~q@2%h2D`JWh1 zL&86mqIUo2!surlm;EBS5is9;1O*c%TQvqbld)8<9ny2XC}-RWuT=|H)QIFj>e_m} zr&4o=ZNq)0&l-`*Ne9P|>m@5SRNktxnknY(xhHb=o3j}^x8}{^0M`EvF4@m`c;;63 zOf*lDkW)9^GIs#awmjLAmuGiDH*YdoxR-X@rqj(G;~@F2am&m`m?a`cHZc*`BIytl8PZK0X0`j& z3>NFqbQ+_lM6iIri)$_qnNj}JPHIwAu^Z@wMp$)jz6%B9OcBjMP^rJR;^HJGt+4}} z<_tRkF1v~=@4y&+A*g2C%!)ENPzn2&-7PsPhu?o3;Qse%&55SJUff?m%6|fiH~)80 zAyYr-*!Zn;b*FrI7jb5MxpMs565gpH&p))>*KX)1SF;pIUFmiZBe(DYOT*g#HdH3xl1r zv3-jVPEPasgQ)?CqXVvyRvAIMv`kQhA(d@yg;3imK@`auKo$As&%wt}O<Q;}q!z29iBm5Ok4zltS=BD8=H5=GC!4x%x4?{9LDyx0?HjZZrKs7N?iH3KRt#*Zz zqN}ZFmq+_cl_ugvfn|hwWJLl5&)w|X)YRk5} zXOx&9-5$DF+gG(-(szlEW8!A#O5q_dr9|?*d_UXBvn3Dv#qvatwCO)l^}9UT?}xrf ze+i_M*ysELNKSXoFM0zDQu~Y?awQ}r`S)6escpmy9!+v8z5M8yv>&P)ZBC3Ho53m@ zjA2hw@-?*SsG5A0hM31)T7d(`vaw2fy~tO$nFE=^!cH<#LMIL@a{6IL-x3b{oOWN9 zf7eks0%?E$-D1d|EWU1Z|IMp?4oBl3T)T}4xv3QCxG|V`qnB}>(azu|PP`vh)a5iW zAN+_^Cfa@lLmuUk@%PT^MLgcGzECi+^as-m!r_X&BvkEfirTnr=)UszEzQ&1HG9oH zl^p^sJa>=3%i69zZA(1M0}8aVOd^^ zy3Q5L6c>Ny&8u@wZ{6gTV@rUR#ktmV2| z?MOc_c-(k-hU$0WuH&mbp=T<~Atc9LW_Bj|c6|I#}NQU8edM%_cr`VG0W>nmgmz z*D>_N`UVu^A}A$tU850s6f=qL-)YQnr0?(7CQ3528iOzk`5-|6uHoalU!X@|X?XWm z2YZe>ZiylXjNC(R0UlYvDOZ|lR1>J8{CpsFlL`Bbv|iUcUs`@35{ z+`Fha+<=&R-%Um*LO%d?fe!`47jN?KrIvF8|Nj4`CVVb=v)!S|Eig2s;el%e4tH3$ zdf(CHf$J$&1xt5t?Y`iJFivxK=&n`VhfAN!GHl*+o`3r1&wQz-tpw8(SG%58hn(QK z4q@hf#oh(%mA~?To9}Jn6+1C?dw4wuBbxkb^)V_Rzx5dl7AWa%#VyR- zJciEhh-8V^V4Sc4^CCwk+xtH&0DAu{JILr#K^akC8VnyJD8f+(*GK$Lj%>iOIq#%a z%(dhSgv*vs*FZydtqE0xl;Z-Us4!RulDoAB1wterp|6}=7QvP9#JqpngPXp;r@HTX zED00Yx_~3|^uQO!kzhEP`M54sWGjOn2kA4O~~*H?w5;#cTh!#i!I*bY##$uZZ4J_Q7QRu=!>^H?uA}MDqEpRx*E* zhS({|3x@+Y0E@u`6+=j{)0lWF(FU zQ`$&-3_cb)i90OgltL+#k-Kw7&BmuCmHV#vS^1_-_9a~Y)!N13v&G^~Uc!27$Nexd z=)o~6zY^7$?La{#)Zg?vWEM37fAC{t-gBsEZ2PMv+D(Z&5X0KrTWc^-*NN5sMha18 zD1FKBt=Dmw&?P2fGFvIk4DDii)jWJl!fv}*-P-)-u}0?Uh2?aDl*^%DNSK2|!=6C; zNW?Sh`~E%?$#AreT%#c|tT?{Q%DD(C0h8MAw7N91S-N1ca_GUQ)YhAt3o?$|UJNd( zk19v;V&%9EaE3)?y6sFB1v2Zs-aJIVc_ZaRX-y7O2LkjX z8ME#4qBD{MrVOpL5xUn-jfQ)7V7amLXK#_A6w%aOE zKF695%w>K9d%CNVcE7td)p19LEj>YQ)_J@*2w0oWo$t>0v!rPG6YQjVXfGz_(cEp} zjYs0i>^CuNh3)ZHxVl4PH@Lau_IRDcr)AN)gKz?FzZVPZcK#k~BmaL6A`94p!|A?;tdL(LE-p(*pm`=rc zH{P1AR~5|>vB|dk4?~jCQ-AiA^%XYcUFy@cFL(ZAdHR%dI~Pr~k)}fusr_Pct_!W? zVk1bl6FT!*lqs`HCuyde8*eJrcsn3N&N^cpJ+1w!Kf6`ym7_Sai9aJ96qMXP zKTx;b{^x&XC1lV%zGJ{_`VVzV3Z5rNqnrY^Kc~lmUA-SZoGqu}kOpRwI8j4#(!k3A zP9j5jUxy`6;dp8GE~6^P zYK;ntv@;S$FB>%7Viu5fUm^ESNL@$^a1s8jZtSoxU05popcXm0dszts4!?itF;dY! z=JS55HEf{ngMin97xxd1(>FftJLPq_|D7G;=oQ;@kwy8bdpG5C=UtAczb|@77U_J2 zYEB%IB^;0M-ha}X;ZD^c%W(LDs|DwUOXefY!yqZzZml?!)Jf8!e9PjNFE|5j4yjs&wwPFkD>1mi9Sg;A&l z4Baow0?Y}SLhRIN_?hV>x?K!N0sd5p=I89FL|OmVa+d#Rmz`)GVIF(~C8c_;S&zcq{AuF9H{4UES!#zyE*9qES= z-TJ$H$Ke!`?$#_k@?#Kp&!w?Hm~M>0V^9<%t z_v2F=zTEM*K#}1J3^|I18Zo`_ZW(*=k9Xw(!`Fsa;vh}m@?8ad>)pQ%m|ugw5tE3d z5u|Rjv>mYUt`_thuS%q+S=J>(OSmV^&#!TjztYHd$dI=Xe;DT}mnd)qYTm~;=9{HU z9u0Qtydr$R{5|Hz+}uIHFaJ5gDJlo8g5NtKtwjx zqggfwU`e!X(=1REK=+5g_|kmDLHj+jDRusW$8!+_z4W!pt91Ryn$BzN)BJOU4EKm_ zosL2bq&l*&Dt`l`@cP)ohMPqB{!@pe+pbKci(;y~sIopH_Jh)?HvkIobl(T(9WOb# zh!&Yozy!xc9Gbfiao$LDrUkfJifV!hS;CH`gw|w3t02EL2?LUrICX)YdUlh4Vf&o4 zWLV)#2=qtY#7}KL?QON;R<``Hxr)9pxuiDzNZgwPZ6d{sS=UKEMH==ppe+2z1+oO{ z0-&{+?p(sKMOt&qyqp7vx`W2Fbb4BSD|`yMrKSLRak6%phVqy>RP^nZ-_iWC`cn&0 zJ`SaGNEh45(1j=P8KAXNq^*;JXfZ>&j4vS3NzGmp`KL{oMaCihC-@hRRDvNO?_Qj+ zIxKWh3PdT)JJ=qQMxBaRBkV{S){LBw!aVHUAANV%cqwbLrpi9}$%%F02SlGhu1WFw zS^nBhVQNCFpP9Ss7O(eqxncWuUK>2eMrEocHqxkkdOzRLj#=12;B~8{!3UQ{Cf}2a z`cRssRFs1OaedeJO^8^=HL_kvn3d=t<7W|0)~9d|^siQK!5O~DBfuaf3v7S0XGH6~ z5&;JD8K3aA*dPO`J<2)V7Hb1wx8#0cpABK~< zkpgHJfJK&4y~W^bPWYTPUFrdlWbX#hQ3oAP0(I?(1m&m;1_qx5qn|Oonv6De(%269 zgXyG#`!39v0>VNM)`yru^9RN3TrvpZ21_P?_9j>8evyc>)e(HeZDM=<)Z`IUKUOb` z4_CTv1j_n*EMeSG=GS;19G=QcpjE?{iq~>8z-hZPQsAEju8b~j)0;uIZTsLyP-Z~P zq>wkl40?OhknCvi&ZA>Ui0y$OMuf?>6ka^a=~3IBoKBxY67IMD2C(8>h=d6`BzyF0 zp?`}b@Grzod+PvTE0CTDVstk~%A6w=9K@W^WjS)LD;O%LBgd#rXWT1ZM^agI8B5+O zx&gSo8WY&e>MQ?YE|wb(AdEFlgW+p730;@~XFW!7_P#vnoN%wkv|mM*61sEE2p0^< zXhx7Z=#~0}u9=p@kPr017l!^|dh?5!7z+0bSXg)qWI~aiwe;_VHD zaU}1YyLVcw)()HrMdsa z^vf~SKUK&uaxoxWn>aJXG}9O>aQ&z2&A?HP2jrYZm@V$ytl!37kl>IOBS`p`bXKm$ zKfT2I&CRgCgzQ}Z=a-lhxBcgVxtKXCefE=o`}vCdhi&*!$sjkUVRRXD43{)uF-e53 zU>D5v-}ef2&X1o;d!gs`@XE&)iA%3%m>o9mU~l-Fc>Y-y&BKEi$_&;96T{%oFm5+j z-b09dV{IpE{v&1_xqbxnvtVNE!DIM)$UgqhI6wFgVZVVICF;zM&Zq6S&~5$1=P#n3 zoi^>ApNpr#s{eOxWI3LH8%2HYT?o^;^z!!)yRNHd`Dv;!@#&yfN5dgTz^EfzH<2htQuei-9;&z9#kVX`lsV` z6YEZf44BUwk`5qmevUYhD-`dR+JF~+bt*jR2wtR!E9?kF=UtU62^q9f#hx29jys6j z0=dCQTThpzz%&djjde}&&7`3NVo6s(O$_Fi6v%Onda&%3Fu(seBxt}e!^tX^_Q7bgP`TUO-+qFx)emT ziFl7KR6F29%!xr|X0Xv|(O~9Kzo{-=IJOiM5*s;4<0NVpU*Wx{q4yO51Pbnw@SYbCzhp_Aqoja@x9fWwXdo(wf;&+JI*==r_RMxRo%}8;x|`eR z@|8XOS6DuLblg`bK*!OpPQWXZ$FMzu0@Q<~Y*3?nM7fS+npEga_wfjs8GZ1oAqsGJ`Rj7 zj=vNhQS$Bh?FBa|s&_>~>KVBT_H5MW)f;VRnacio{fDbh&WkIUV0>p|EiU#{phz0E zCc}8%3oUU)Qup>+n`g=^)JpTNE?DTABYZKFuHeX`KL@5>Px?g>FLFd)-bpOCXIyV* z3CvK>&d@Xwkm_rqsITjTAe~Sm(9hx#fp*O@1EJfJWSCn7k(kUTp|z2O;lOVu!K;5HOLE`*;52PymbrBo zI;UR+rEKdFIs9|XZS~YHZMt!27P;I@F_1U_7nMZ2tr&p;Trv*}bdU3?TZ^5}udq+e z6m?j_c9mGi#||wu4!I$eH!oK^0o0IdXEb*^O-q01eG;(Iida;?b6EXMg{V3yDr0r( z6^_auoJN_@^B)(3T?n=`P0YOY^}-F*%sJx=(#aQ{ni?y;?jP4+S<-Ru;(cH%wGs-L zg40M5BVZ?64D4j-`w`u%n5IrZLG*hN*21&9nx&*)Spnj1Y7yGHYDSWeX4X1}J5#)^ zVnD{8B}lh36wE5%$E*_uRHIGn9n{X(AebMY+KOQHjwia(T~zqRuM5XY9-TUP)0G-v z%DvHw)**dkXu(ndbOZX>PHm*sPIrH#wl+x6tpOr3tV#{EgD;6)^+c}Q%+Nek;osfQ zuTRm&+`TUIgy)6Ct_$*y6-TE_7qC>$Us_Qr={_D=O^^_*h@F8ul!Z{1Z}M95&O3pC2JBnt~^%0I<3U+4kOCZ$_)mkWhci zmYBM^DJwFn;?d803|)JyT3)5`rOZXYm-m9(yr1Y&`ePPgFydxlAzewem;YKB;UVqq z2jiQY-!FPOTli)a?d@3i1q3JBmsE@m=07R0{Sq^~YT@ef`eq=%%<0N5ny%yX&vr%? z_dn=YCOF>QC|z$2;_RdR!IT_Kt9LX{au4JMNs*B(ltc#s*c=!L4ON1N?g!066l#xA zQ>RGKI?z6Bobu0z4xq@sUc$l;!t31rIyh+*J_7BlxjlYE~#eTKVpn0n?Kq{*^Tet#v0vh~{ za}kcT4%uFy&9`N>xE$~m94py)cdUR*A1G_QqgB{Scf4aV|K=TY=epor$ML2b#)Hpj zZALC68>6)*+Ps#i{D@Zp^Qk2#Q@xX8}ZYeVG~XE zrVl#N8rRW;5OsXqkGTmCN7{ zrrSt{{s-n+5kT`&n(xLTQM_^!fp>yrM-M^yK?nYbo8W>{R^-@;k6UpzPE z(Nn2bo)rGxg_vni!WZ3r^IL_}1@MAP-*Pnlhxw%StGc76dR19sl{DX}N84plUyl^3 zE)Ui-IVbO_EF@6ENz(P17bW6RGJ#PxVtaIi5*@mJ7A?OHU%qkExwc{6si^X^fQhwL#m;QgwvSP{lc&jJiOzvP zHO1fj>{rXM-AK5JL#5f0sD?#Of>J`-_7O4%Q^hppnGQ$kFnP$;5~%0l?C?eL4A6v^ z-c}MWTwTRSshmJ^pm0ZH)|Ie_=E(r~V{w8PzMJM$99#KuNQmliHQCnV5IQWR{IbFp ztq+T&X^xzqrz3kXx-8MF(OJRqB&sCjZd`DD); z>WzizfwQ*t0h0*5t9YKTCfBrU#Ga?zOL=>C_tfgzH7CY#qHrMkgrs0c&g8b?yXH*< z0&(2J{O5^1!ii3g$y)QzUWS<3?N)szpZn#Rd^_cu3uFs2f*H|ffB~ZdwVTnP6*fK3 zg`@SJzLF8T{_b-QWfv74FY~_^ZTI160-PWK^HqQ$-f@2$etZw?JpW?cW*1-!cKTXD zI=H_cU_us0_(~eM2lyBl0i*>>GxRmWpLRcAq<|*n%3%BU9Og?Eic(33A{9hS#F{h6{%Jea9Rv*mp2leSM2OQ9QHvBL2o^|@TWY_iLRxB32AlNeBs{~ z?_C(Qvhr^%M5hu>x=<^@TdW_CRj`q-yESgU61%AheTFZn036X7UEt=!2WnRDOsZ7r zILsuTf?03|j>+{X%_kq%-mu6A#kAldVf%27`T=W*Hy+A zl9*c+LN{W#DQ_APd&zU%Qhh5tJnmwIVgBA3#dKfyON04z{8OjHZyGms3%!zyff+^< zVaIqWSJdLNvfzED-Bz!UkE9|JL&K-YDnE-d#FgAIDRn<6KC9c3I`BbGvw<&w^O)p) znwzLcY@PHpNMyBJ^xIWq*)WFFGKEw3+TzydSQyVlb9yn85bHHv`tB6(d?0HMpg)D( zpvm~0R^2~^Oqf6mzXyFMyIqE~M1!NwCU}$ZTc1_yF8ZWGV?2?+ zm4*XGkyKeIJEp&Gr1>P}@O_2+lZ5pphQ^|?wFYmjKXR1;Pf!lM(vdSbx&r#Y1g(?I z)pl9PzS4M(_4#~~w6*0{9s9D0C{Xvs`j}{wvk=$)eb-Y(FuUJkJ*gp8aYu+w`xj0c1*+Vffd;spYz(}#00}f<3AR*6wQ5a|YM&O~r57Tys~ak`Viw(| z_}Fq*-z}3*0U)f)1NasZ0498 zJlrsw^AX-nI;A;3kL7#9%#+GNB`cmwagJm;G$<4{oIIg5eQ|Rmi`z5TS}$^0G$A*; znm|@^$1uLg;b6bQaUXUWW<0w|tucEW~z}VwvXi<`Vv-!7!^?ZF+sA;|~+wFYz^G^3FZ-<@Q&)k{`8?<*a9a+RMH$(PBAkTOW;} z`tNozk3%kv4^D@ECPyiI*1WN~7p70Wa~P|}vm1W2-m|~vP^b}F$1^VBW!D5s7G@i1 zwI}%aq1aJi!p%OR)e>~>@na4C=XV*e2F9zc%n#;e#Vjs+##bSDE%?*d2Ze@gP!O` zR&UNRaZ}xX_`?Nrgw|_alrs6mvZD=dKJG`{hGwq*)V-2rz*nFz zl1)o>{5k`Fh*A*qIXh(f&G~HiA$K+^o6(e~)0{;R5oMm=FJ?t7%%QJW4!zok;JU)r z(VSs<*fg8%Eh3505eI!pdE?{8ah_K@x}7*_HMmEQ$Z62EhT$AGo8Nr-{Y&4mQy!B> zukvEgoZvWR!;}OrLjx*eT&b;cRtPtY-VhZI_u)xtNqlvhBWL*Q7IfSv9>9eC>iI6t zS!TMQe@>Z&zuzsdj1VA7>y%dz1(Z#~Hah8_9fumKB~&xy*YWC5S0J7?r|j4#8c#?lt3ec^^fB9dRXAUedMxh`qr2qw4jo_mwS9!YcM^W^ zh2wy>xwXQ+1%p%At@++H{#3vxhhX}-32S%p31h|ROQ6~)uuPejZM*|g3jq6WV`&1L zxHYaaS|bH8OI(l67Lb7DZc5a$@wfN#ZURrHpU5*@Huj<|Nyxe_=nG}L6RtJGCc)T=B83d&95W zu>c=vBHq7|*4nWB&3a&|y0s!h*{Z92iCL{Ov3i)JRI##SwP4U@u*A~&7%!V5LV8;V zN-0!4vR(yU0v#K_>Pp}I zwk`@-u()=VkW6{~uC0CiV0a(%MsnA3!3yh~kiEKBUDP7&^UKR;WNWxqkOuCK3>E7? zm@r!hsfinfufvb47BdHyL}Eu@19lH`tlZi~1+MQ&JH$iU3q|Zq%)tZs;?R)054TmH zM2g|(?zSFF-$a*Dr@=J0G(1(Yka7tcWOOt8qv$`r6n^C_fi5JT^mcBLV-t zxOp<3v%SNPl5v(4M=v<;{hK) z?(gl#kuTusPeZBV9`P0z-_HYvxyz8;v6;VE6fOp1{UWlQ=-9eyG9xvPVMU5&MvoIy5)+4IqNwC6qp`qoF#a^rFL%u>Z&HRkmlJ$Aipt9zNj{`-K_m24U z(J-=}qbI5U8>q;jx)t`Tumv?~Xx7hj(6%z^8&AhVYgHkd3c_R!2xLls-unAiq1B42 zlvD4V>q>E)^IW)Fty_~-bk49`o4l9`J?qV8kfeGr^BcQG|NQV;K3L2EqioYs| zspuBEA?*{sT`Nl9INz~&_gHUDrl!*_AE%>zgz|v) zBZcv^0$L8Os{*gT;`mth46#0x)q!*cQf=Kh;%K+F@Psw&{!=M1{$%r)@h9b}pOrmv z#{mQQUjaz=tGq$~V-V^ZoZg;+D)Vdil3CtH-9B8$SK_f{)%R~Lg?7^AUw_!+=%w#4 z9jYS?Mh2xX?Jmee#K#C*`@r@aDSXQ&TJ4z|$JZZH@$Ds@G)-UCY>34H37vq zu?w}00ul<`EPQ1i6Am3Tfz$pb-NxLZ*#Qg=;|GxX5p-w#AC{tyRsjk6clZiO4g;7> z7z3SuFwJa1wzqH-8XJ88Rlz$jBR!zxnGK$7dk|rkhlRe4B&ITs;vkp~S;+PVVeB$g zk*3!pc9zma+;Y#o=ROfxe6cJBndEY66w!tI_0OZGk-=NA7)Mgj82uwhq{iVc*)@Y* zrZapRo?B|aedUS5M7kom_nzOz&}(_+7XL7)XR-Z<+2B39P0JlGUed!!9oPdg>n8km z8FXPBur1sxTJke=_sTZf)Lfd2HjkTH+(M4vAV8!zZ!X2_uC&# z;^43!2CVvTznQlAyg)P&;U{|N>4Sy?VXmjP^zzpyMM=1E!byVfVdJt4SjgPJu#|jB`0EkU@&0 z(w$7-8Zzy;1DzT%IA+@ksa4t6#<(tcFsf8DgjZVk{Oph-!(bE&?dxV7KO~oT|Dn!v zE#=P#{XSop4SRe+s^0q$F<-sQ0me6br``5Ml(K|z{1xHNHU&ICJw;i1L!n^EwfV>C z{J{+yQQz$@;~x8k4t`XeVk$a|Rs0o3(34jxWNKU_OIJfAbjB6EwGoNASBoYru+~Ku z!@nFoNLn8*cjkLCSYzy7)>XWRYb^ykGE!td+%P zzHZKFqk#8H+PYbJcJ7*1Zi~+)PG=v=x9`apOgOk_7Fw+OjkRMnR}T12@q*5eUzc>n zYd?XJh{;t7-`jU;$And5AcPCkbu`e)S63y%K+mkL#fwq?97TZ3Ty~_kGTG zrTy|aNzFv*-{#*9BV}y5ts6}-NZ{t97ScSUs6IR`7TDsJU(oOZ(t~#5=WKardzdHq zJA4m%>Ct!Mv!SJFCc09w^C=NuW@ncN8Ss>+OLE#2T!a>6bx-p;u{sI&iH6ja17 zezq!7Q*xd%$BXal$vPhp8^aGn1@@%2EskEcb~h}qY=wdGg!s|mZAp{sM-KJ14~uz7 z{pKNseotKPo^DXZo%3MJZ8ow>9FdtGUsad=!BJx_3>K1W8rI61lK12Jb`(p1?-` zeD~AuXV!DV#c0#JJL^w^Rj|(tvzyPjvlW$Z1o6q}-Lktyk-|P}5m|opX`@Gnf41P+ z0m2)#;HUC&K{WHe1p2Yw0ANz)00LWiXFF~^W(ljfPyQ}WF7f4<1{8?3>lFmk^Xku! zzgjo&j?KdR0Fum;+z&=`Tv*|_pa4a9wI!!8T$O=h z__>^J#jGc)?YCP_Yf+8YvtYDPA|?8G?yCix_Oxs~+CZ#L!G^etXWa+-2;BAMKnuFJg}&WfHYr1+iJeqd1ozx4JsR zSuqW9RXc|P3it9zX{V>!#lBg@u@CQwpdG}u(hkG0is2*u^>q%2ireX5C-~pK-87c% zCA6}T8~lO`n$}5SY(RK|c)BJ(WdT^>fL;x!%DZnJq#?F=lAXFSmOYN4>aFlBLP3dIP+ z_|@}z>U&_de-yk;fBdomWjLLzH=q^be{M{y+G~%I!_8+@obK%xrc4eqaF1Db)WZn` z-3oN)!%1?lzazG_adjUl7e6E6Qx)}pX;@tU+&rQ-i#W&#urZ_Z$dY=0?pQij{cr)t zb=DE7wcb82i?o6I*8`KsAuaV6(Nk;VutVLFC8J-XOI1|kWm>Ovf%y3 zuC2v9j+SSw<7HHC=kBrj9>2b)ysSj|zGA}D;j&Dd=cg9IvI$`oxhF5`&igxYwUDgXTyZye9^2)g>3pFFY^F9=$xpjU# zz4y`*v>;oEj_{m_!$KLd|s~JSCx>-Q=QEky*vc;?W zCOd3UZ|_I@*00Mq$u-2(1eMerQpWoif_JJwE}xqz7f#H!nZ7&_7+3g8jdx}7g7C8k z$4u-v`pX1%ws!^_X3^tqn&t+SJjs?k`&)L&^ifM~I_QZxwV9!<)&+9!&5Vd8 zTsE)W>vbP;FmlFDPWF4%2|WD)=)4K(^$|v;GuOTCuS&=8Zcq-L5WjvvK^U1j3aWz$ z7Sjx-pHv3?eYY=txl6n$dtYRN_ssKPZbv$Bgl0^}ZH~p;n*;{zy4g%)F+z~F(e7%L zzaIuV!a2WJSoBnHK(qX1)8Pr~gx*i68OT9g#!Ev1Q=sv@^vNQ^pNW{ilsngb%Dq6z z1VTZd5d94SCp>fAd#cqnti4b?u+h``C!TYG+!~SahrAD4p%cOMZvTGcYX5=Fiq%`# zS^BQ}GCHPFbQAB8H(W%03+k|8hIGhmp^Zi4bBNboX>NWpFkM> z0Kbiu_E)gOC>p7h_SfY}ESfLr&yPFb=n+*z>W&hn(|JP#a(~ca2gJ5FO?59{K3nA> z=eYiKrBsl?7$W(kh(%pP@o^^!rY6S#a!byz$m}r=t{XAymS!QWf1KZ^2;AleFZ;P+ zXX%}UyazL!A~j_L6nP^9aLKvD!3?XxbRMu53XTpD^mIfp@F?g-A_u`Mu1=+3RftyJ z={Zzv^*Y|y&X&zu<*Kv4FEt*}+(LST>&D9=b+YL^q8|rmnJoILa|bHSc1D9EQeLd~ zS$A@PFAWm+ip&$f2FOeH!nPeqNtQ@p;eD`daufnotH?54)BERT-Qvj&mP(vyOS0CR z7>-zGfE<+%i;;g`>?YJ_{!e4xNkfh6x%2s;~My67jp5Fi2Vg4lY3(;7D>|g6bnWLx(*NlpU)2;*mb`|ytGL$jHRxVv&!?$ ziHKYp0EZx1)#k zFQ}ZnXhH@<1L$?Zu)vFJSEpZRW-xVs%%1eRv8Yh3+1IH59bWzX)tT1YFR8qAbnboq z!Ac=VsjhSd3{Yqi_=@QBBmARzz|qLgQTKVh<&CKma1#=<_hW+`9R}gi)J4!aQh5AS64ma1mXj%o1yD6&0pTzsJBNEvXk=Cl7cdFQN)%%ei-n10K{ka?cvSY4CqYL7A`Zbnd|n(GN$G9eiZlknWf_8peL z93sC+=n>HBTDEi*jN*28Q6t*gHHLMxy1r*gYb|&dt7?_F#C>Y40+K7);!CfWpktwj z{@y_yBGG{{N3snGX1Sk~;T(~)T|w-!3rC!U)VTEB{eR5IQ+a)e;M=eLxFT?za&d&Qk- zpvut$mcfLYgoD2A$48TO32O%5Pfp*VY<*#hUiw(`M44a$I(T_&bHJFlD}p z@rK_CNKq3A{7PDvULRhXnBtapfCbe`RrIa>%bIZCBnDL~_Z&NNnf5t6$4y3yzz1uilbBUhAcj z|M?Xo<3d;F9w;%g*3*d}!dtfayfy^qf>+(>$wrx!)~BhvFz=9mvtdjB__g3T<=!Z>q(@R(mF~M^A}%oB0;} zh?lL&j+0SZ@@a9}KfwuRu32VY9j8-NUMbvG-aFUXpUw>kE|d9lt#gs6hyD-xbCqIl zj5^N0=6-q7wQZnQ;GcRhSG)(&ujPOox9^(#G|VJ=XMZ*CNs{lc-KeH+DW9|`=7Sj1 z?sH$JtzK~Zm@BOxY@4F;UrJlFUR`mw`vR{9eG6E&7YEvlCu*MD&|mU2x#WX$TT)!x z8sd_Ir4uHK!Q|ZTc6Q>(1pMVa`U(|xvC6c#kfnyp!riNL-@2b^8T2{{X^pxTvlKHf5R@lsbq}lqIH)3v@Gb~qbRmI zQtvSbZ3z>(a@Quqc2FT{zhsh$CPdKwLMbi7{1)Xzk(XZE_A%SO7XClhDVPa^;3zf@X7qneU8_y}+O%o_&i`uuVi%PqX~yu_o2_6x#VBUQ`81}M0f z1=0p3cTr8}A$C*Nt*hrZD##{f0c9^N+frx190xC%HZvF(@ifp?fjnO@E{;W-9Qp8g zqczhjh@oMZ7**Es^5(*?06o<<5aVh+cs>KB@M>#BIC3nLYKiSclx;kee%ebMyyNvR ze^Fcj?-*nLXD4v@{=m`~0Q7-R#&_NvQ1&Tpv@1P$jV>8cdSIawD zT8t}0i)7`w?wXV@0Gn%2GWsZJ{Sj<4EL~#*+6*0E&bF-#MnVS=7-#oJkTQETfb9Ay zd7f8oVHzEUmMp50o2*(HgC^@p6RL<^7?dwGT6rCvka_PdP<9o!^Eg+5{-e4VxU@5x zz|rIkn19u_TlAE>B8W;%CXI?OePZCf$n?Ah;nw3XZ2sIR!~P*dJQKf&x?6O|G|qef z{!7hLH&>3MVAeEShu^2PI4I37CzldFuAzwn`6zJgtiefx$}9Z2oKkMrw~@_~Mbc=8 zi~HGwt=Z9(9w@zGTG_7CXLg47aON#Rm7eLrv$yBBE6cs%|3jIZkM`KH4e^one}JN! ze^YL_??2^S_zx+{{}e<$Ez^+ll{Xo>>hlZfF=6<@igOd)JIG%Cx4_Z=^=PU7|E~z= z{<}o1e-<47-S5H(5At2m2nN0ZF2jT9Vt{z~5jc8sz)dRD|4i`rH^agTYEEdUmvj;k zFbyrQ{7zxn15b^HJ{axi^0l%yuGD#H#u;UPlWO#9AwBM4oh(>ft|7K-pmY2?%cPTL z;6W=?exX0W*IR&2+rJo0CS;N?A?Xl(${d8PH`41%9AVl#K)eTGNApd2yh3c3fL#z(o*tJFQ>e9?_!iE$4VoB72vvojaJt_X@y zlkeGrX4xT}35QCjT>riI4~&kmON{Kxqt4q^bowxf8=7!Z25GL0Im38}Cu=Ic<4QA3 zOxjBxq}fh(y*llXO?4o7>Vu%3LetLu znLP15lpFVORxYUd@&VJv@YOtw;_0O(eKS;GYQ?M*RVLMRsU`MmAa#&q(IP59l`}bF z_JCI9Av%tre$1A#K4yrY+2@WN;N4kD^gJ`&*Ys&KB8_DG>tDj5>RJtt2Q@nti6^CZ z67(@Vj8tFD3)jfpeR_||gEx2LHPYYA?lLk2zBEcaX8UAY`E1a8a<0y4&Qk$4*ICc3 zb6+6uB~#}qwKrVGo6F$@l0a&|AG>8ET4FObL&JXQlhV+IYo7VI1evUb zp|ZV!woIbL&hSyFf)IuzqON4uG0bO)1LV z{eCsVXSrR}4~$xxJI9zg%P>0@*}r4sAcGg@Ag#;mQxaNJ06uAj4u z|Bi)3dlpnRdO2}YX9!$R@rwgaFadr$DY`aF$Ll-T>F(XE&R#rd`|*3ro1zn)b3&n!@S~1<9@y zEWAG1L0Z8}23*E|a#9xU?k|TCl*>(vxY(s(?Um8+cDZRxt;YFMU@{f&F+F{AOP}u3sFT5o;&+9J0eR0u_eQ0U)^g>Aa-L& zoO2~*Z$F>8EzFH)C-5LkuU(k^mciM|dz5XU73P6GTY3sI*al*O~xlyW>$RYo9K^aQZ2}}YY4np z!HTcOA+Cpuw35q#!$AjtI8y8D*fDQ{kQS zDQT)W>F7hzgzBe&_AmE)JRdFRDhl`j5#@(4L`^vsl2L5fNKhJE^YbRf%XjkZi3ig3 zUoYXCa7Gq5DQJ9Aa~EK{cjGvj+Rw9>0iR=(=;N+?-Y8zfw{36QE86#ozIUkJa!&Tj z%>CSdi+Ox@43*ghutWE^%2UxCU#lKN6Omtw;TjnMWfQ(vpD2*F2oXT>qxY<=Vw7m& zijZqtUfrAfH1EpvVN}>spJ0tjE%;s>~yqVH<>7&vQ^UMO;y@9%;!_X18BK1aQ1zkq%o>j$? z=MgbtC)1br3TLjY;Vp=#D!-|JwDY%oo=n*>iM+_Y(^#zF$NvmFmLEO!lo+Lc#sb#C9&S;RRMILLiX`^**th?^O}bMw#NFPNyB_UT$237=_q&?nSiqMx>K^VFSD4HSV~^eCfTiM(czNO ztF(;=G|Kw!PPWHUhXo@N-D(Q*Y=J+G8KikOWPcpXjUukdTk1f7tX9E~A_5l_)%zEP zs&yrt26W~$Iq(PHfk)|jy~w|9s94lHcCLOQ1+gpA(k0644puZ)jVJGSdqYF&H0F+i zE+hr;(Rxdl`{*N%WDtRVs>k*o#NV87BrI9{WuX&JwG4Vp4kP0MBoLw}BS4+fP!N7V ze-@Vp;vOdE{-PKJwvvHkyS2o{z4F;vi_6!6euGD39~lGTS@ZU7Yn`E-zWXZbB)8`D z>`tvymPPD_EHYc`z1zh?Cdky)lgEUaO}*$?)QmZ_6~(vL3Vhf8JdJ5Nsw~HD#lfjLi5~2G_GwOGFH!e2=fYiKf^Zg8 zX&ElL*x0zMxv~79?u{XOfc@vL4*n1e*H z;MN&pQuV&ML-xs?vp`xn2jPWlecNe!TjYM8X>LeZcS$Dcdh>Gz7cKTr9I0A!NM;*v zR8Wz-d}oNma&71>cF`HIy=5)+&93qyojbDE>LPdd(@G5@ zrwxESE<(lY2tfJu7*KB*Fn;Fldvx^;OsUx#Dos$%lYZ}}H&H|8E_)6~zcuVm+0Xpq zmy0w6^SU|k3n={g=bkKNB_sN*G`!f?a;#P`NTHVtm%S*0MU?78Xt2=W*rr2%d;sSg zB8Y6yi7@SwQt2GVC9Xcxr6%&4vI0si9C|dxu@mR%N`5B_zAkrVD-r`JR{n+PRQ|Xh zh0}jd13><9zsEM-Ica(8m7BSfygwqxhosOcmdmX#sdD}w04&wtnb`VEb=yBUuzQ$Z z!ud(W$|t-2e}QPKC0HnG*U%62Mn^%TsjtOX5_X%Y14Lk=ST*k)(<3kFcg zd&oSV$~1L0Og;){jk-K#e?#0=w)eS074=>rigHy{q!Dp-EsqhegI2KI^$NYBOvUu0 z(MnMquh&E}=^PWmyYu1HiM^v&r9|_10PABoba^^uIZkbj@a@}?KJ0=93o9cbcAba$P1w?8Z6E#D*c zovs*p0*9Bsy9`5IW~}Br9doryDiC@e+J)!K)AixC>~p=>es=#9S@b}39T#rj)bahe zO!u-iWw53`H7;$omk~g=W%^1;Doz)H>BV&+2IYH6ybOC4QT%8RZq$7zmpbnYk7QFOrde`VcMc|3ED9_32Dk;FCOh%dQ#3RzDv>yZ_nrsZELGb zAjkNs5G1t9;c5<14fZp#x=O1vA(+(G+jNu1*S$G9KwB7M_CExcBCc)K#y*z2Wffs!$FlkyCmfB-{vF|Mh znMPl13Eplt)k6)-T7mx@_DZ-YG4=jePRnNtW^W{cd_xSZ&6>;P`Xb;fVctSmfq(*= zLk}sxGf7NIOE$T2JS^~UIgnJ>A@-^D08%z+1>~E6ORO9Pl1fxQTb;Wc+xo?OPvBTy z93!K3BcA%b_1#=YHA9l#?NUep4pbP|=@Q7Z1(>f5 zLY~2zJ4?ywINxLqy9-TfI`LXE7buPmAjZ#+6z$ew?>$^*;1a`*IkDX~f*-f$Q#SXJ zD%2@Ykp2Y^<49vgevh!F!pzNP;)0Ot^+j@#ogP>ryxp2>TIpgZeNc?ybF=a@X^J>i zuo`4`EUeOt97$I$cEd9W^755$w;3Yu*pIqF4i54`Q)r}^by{Pz3AO#X@_eJ=&d4yR?w26P;D@|N@KU-ko zge?CkG|;oPTxp7ESKbDby(;RY!Bqp`6pD1p^_F04r{MS0p7!R?sRcPZGeqxBGPtxD z_CI-eCmcXnSXi`6cOmwqA~uy^37L-Lq|FfaCh@c-fdRkMKL=Yj3uVw`SF4N`JbzJ+>q%lI-y^`Oc#mIyATx?m7jzgsTH6SQR#?wCB&31HtIcD z7JDNYo7-f{G+2C95&D=wof0HzKJ75kI%a*@3Dfm{h9({mZw2P+-A=J>k=$517&Q{#Ji7n~yEX?i9|_SyjIf6kv3kGx z=Elu){?`_oJF1PL7x!gtYTHepp)uMC^Ye(68DIWH#LCZ|UtHJ0D^0+!+eH+YjQOfE z_0r1Xz$g%W@-aWIe(K~gT+TUfL`VZl?|8>LHV#I|9Y{c$y9k_pCxPrps!H)ms=$GH|FMs?9{2f)*h~U)W;(+SA3EeI+=6DNN z$++u|iI&_~g&8D{3Jn`xQ{^5Dr8rm>pAN_PF-jPQQs$0a4(Ja0N0tupHdTPwEX-P{7;&tO?*rMYWLo3TJR?Lv=~R-z}UxJQ(=1I4I%ZDn10mh>4`ZvEmO*vTOr@_30%`}RA z&+_E$3%{l=2^G5?mdDzl9a6u&@_lvQ4Rw`xIm=UEuDc@b%Q5n@ZGkwGzSyZ*cZD~n zAxcC$zC(|O(r>7dthz)q-Ovl?SYK?}?sG)ed`|f4d8N(HEk=!PNUK~^Z7%o#X+a&} z|C3F*eq9}wxM%R$LRUeJncsB%Bj^fM>K&(D1L=3iSi&t=&DT+nE8Ju(<7glOqWSKU zb?;b0a?C!27dmgaD%J2LTx{cND!mE4=T$JpW3}%H8-2pef0%nk)K>4+#9tKokp>L% z#EcM*Z0+ouzsA05Jmgo;Q}}l5eE((-?-SPvu8r{owVP=0Al8JRifej`n;1+2e*u>mlJ1qd7x< zf@5~$lRtS=i`flm&Q?q9()qtNjn^{ZFIzoI;h3B`{>*sUEEzpgSn(`hr2Yo#h8Ih% z>Lzze=_R!z$*r`~uJF%tRoVmUOwULP?Y;ZRGsOjs(?UaVhN2tgIr*w{j~pXq+|17~ zZKlvExh{Hs>Thl;Jbbi;p{2#j@R!}%<{5x2-~i_1IRcmu=}K&QHB;K0qeW$$aqS}P zXC^k1n`GyKVGMCgRhjF_e+~v0SC!LVV0RglX>mrCk_&Ria8kQ0ogV|ciQ!bNtX|r* zGoQZ4&LPxN883V_b;J+1iqBImxfOJ5L#Gt_2;*7|aeMDBrA%IPRI; zi^}%ujey(DZ3>_hj-E`u_rJe*DvWj_%fvrC^0pT;6v8}-ww~|4v^3Z3oc(h#>Tb$g z6Q&%$Z)2fu)sBmf6%`d5&aeKWpjLQ&wt_n??ykIzPp!!TY)~;_t}Bb^hGlwv9GB*uyNGuHv58Xxo>^MMxS%&$gltwCM9r7HXUW3G;4RT~yo z!f2Rh#(>sk+g)W30aN1(Dz9GH>;0HtlMK{AD1Z!`s4*9Bv3LmRarfpsZ8;I6?&B`3 zq+Kf8X*-3fp_)7vcE;;j5bhw_H@_f8zav(bkn%Wr=To`OzSl8}&Ca>wx4VBZy}JW7 z+$77}Ue%kH=${_2jC(RA%^*v4RAe*4WlnN0aIKVK-jMq*1?2a@Y#I8MNf{mdQ%Vd2z!< znMLsjGPaHp*_H>lA0*y?J$JoosH^tn}~8qWzL$M-+L!Bw^;JQ)H_zv zaI4m`vD;o~xyz)EsY76)M|m=z}H~XJh1@nchR#zeSpZpzTU3;qVoHb;wdpYFamHj zsWNn5zTljJA)akRjZ^v|Az#w#TYm#FkF2urDMk8*-0NSO>Ra~Z4jxwweY*>+;Z<$9 z41075m?^*Y`8c2;-7m`Fe)xqd;!PH4CH7S#9nzZdr|NEN9+kvJsI#)BP<}oIkD3#? zuv&v`yGFc!?vwqNWwrZlViuL-?28viN66UTl1ND~^Rm0!cV-Z(JV#q0o*7YabBf<( zxZu}>*5}|H;-k-=WLew^p6Ix1Hr0<+sik)h81^G(h`t(R0z&TC8Ux8SbH>*GsQfU= z-AH}-`iOebr1HC1Zy#886HK}LB8Hk_O78$0e^BuZ)Jve*e~B1r%Bg1d4R){z!YLjZ zr5&Gpi8>%~>`E3!UX{Jff#AsKc4kfsbyK(7yzcCAIiN25!!mJvzExx>S#;_M3XGG_hU_UY9l z@0<*lX$_}y5B1-u&o6;M9NVjM^cQvxhn1INe%h*>E{P~M&`Ri8>)J2)i{jua;-C|s z)RliU&C~Zt>NkyOOq^aSMa^CB`)yhiA*|K%J9-U4I8b|vL1SCr-q)Ky1^n(Pbk^yZ z!5$KWSZVz-+5%x?hI6ioE<6BOsrxyOyP3?cb#AO+?ZP`MY%MuS(&zCn4n94Mb+)$t zAIzZ}qDg127f*u*$9VXHQ~!T;Ti<21nx${_%7|z776!t>3tB zlQGR^(S3W_R#Uju8#&v5=J5UTTP71c&*oF|7!Vk~&if-xr;PHif&X900I1{%++Y(* z0|@`OP0PUv;_!6lpTX!ZUnG@GuQL~-$j`(}%2(=w47C?aB6zD(vjvJ{W zWH{jRflZRZFO(wV|MiUxwzzCdOU71M@V7Q6W}XSrQC%rcG6OdE2a8)QX3PJ4*WI|l4*vBIcgsKe`{ORc(ixdwB{@8t zUb}QSs3a%Kx%tAKGo0WL6J22z&Y<6vbnrbNP_l}BSHR!g#6YEalj6D58~V-z3`(*0 z)oss0A;&>=Y(21297bO3e*T^4$3V(|61P1rRoMez?4j17W&+541qfoM{CA|vcQ@4W zUND)2#OdG#_PMbx!T))Ou931L!Iwu1259mZ*s&FA5;EMcYXu(4X$K_0N8q{cNKi)4 z6$w0oa16dL4TyE(AeO?mqHKeky-hz=lsJ?Hz)6{~`55U=esA?Upl2+sac!8>Y}|~6erYcOkBzB{JD z(X;ue7jY%+ZWKq2=V9;7zkbzCmFzD+;utpYyETk{S|qd(OBi6Wdo@xiX2MC^S5tMq zomcafZ~j?bU0r_k$yayR&KrIU@Ah@LVA8Un)K;udt9y3)ezC4z^pnsj>;q>s#>DXE z$LqY4Yx$WJl-v{)Y!rq3?p&RG>`j?ZKcDp-?p3cVQO2p|uH99}9RNl<(wT>Ute1GA zZlFF?^@w;jqBKgxXzGB}!7)#sWF=hG4&aE@tT>t# zy+@lr3!{cMHe1S01*qn`(1h5ZKhk32!P)xZ?yMk|=%TGWW$6FhP`=Lhn}09$%%7`h?g z@Ov^$LoQ2VQ#ovE-ybFqyT3;);K~b1AGY%ueipOqI6j3)Jn(ySJ$ZKXbGG{>EoV}q ze1iL_Px%#R6xHU7;BR-jEf67oENXXdjt5q8)d1nPEV395m9W{lebwUa)mRs!F$Y%| zI}zd@)z-iqT_c=prPcQ+ZuGO(Ok?FQzu@F1w}kIqF)&+jxE?C{%ndd=5boHrU<|Gs z)j#rjaRRkFNW~_4-bJN36AimiXNELrap1|5oQb|N@Al|Lo%WmbJehmV&-a;D&byhO z{lK?L>79irY#16sEIcI%w823`k6W`8ba*N`fZm`78J4Ob_mMpxC?a@#K0T^Qh1&%x zWoBxU?GkJKsGWA+k3Pob=4iMwb0}`5CU6KFo}r_Von?<(+LJ@))dg5SF?Zfzu0JQV}v&?0wSV?wko=)Now~MASrKan*bC>)OZA|pSM)&$%@-&o%N4;;Zkt8OM2+8M# zVkUXZ9u#^)uOCcK5ARa{==AYA*k%;O3chvw6+)Z-2EMSGbQHxzR4v)DD#=+n(seP0 zWJeT+Qy70Ubv;0y!d>Q$V(={H^wS-1RAa|rSbx)U5iLz#@yuf0Ew#h}h`c~&{y;l+ zn2$x-Sp#wyTRlVWtB8*$(+qy5c_JkpKhK}e>ib#iVJFtIJReT^E!AA<v(j?dKWF2&}ym<*?R?(C-Vu#6tRZ;`b2*c|eq>>7%v-!G=N{*&#+u8V1)e%KuunaOg~}n$uO``@do%JO#eSWcK+yX+ zQ=VF%T|OH0WSl4DW3AxDd&{3nQuGRmCvnj}SZadDX1w#bum46Zyh5n9d(SlmZ;#|A z74;G|;d0Q`X$0xK-3*5Ig}}o10*xto#42ouK74ladog1uS8#FB&0N3QL687nN~9eu zJFXKr*XeM%r2IS2_ge~0ib~$p;X-8&@*2{EyUb^O9wcUJ`7Z}b-Fn-b!T-#+^k})?8*Bj?OGGZ+YvEeM z7K9oT*cS)yHEmp?odZXk2nCw;W=Ykc$IZu--NZKw9+-wt2x#fL$;ofuPF(BwxfRqX5o^j40S7EN>IaQ6xBrGhk^v0s7$MA0mFHvtCOB zk$H1$Ftj&yQ$2?vmNDVviJKNSw7fl%V zq2s?1_~+7AjHL>TVWG#fZCui#K8?9co#dHK{!MwOd`LnCf_~i)l2~BltRz`9#5<1= zaGa|OpZjfA z7CSSpLcWs(CBXqsP1k$v9F}(IQ^aT^lbf3wB@pDY2AjBu)^NS6`LNe0yBjs>YdL3R zBy_(?PLVWRC?we(QCAGlX0m>-pF+82i+a0XFIFdHZa8oWptjv^#B!z&U}-@VPFOl( zOllM;*NWIr18k*>UgaZ)+Y*cg96uUuME%^O$qxCN;BlOc`XvV=9)ns1&9~DgvhO`1B4X_4A!ayi7s4$@JC*& zQv#3w24i{Nd6MuY5_9g8PUr2c-$`CpV{`M91ZigI3CGzwoWQJ_!{E;|Uu>`v*8W#M zOk}w{K$j}7X$PCra}6$3Zow{TBuMR{rUEnq3WWplYr|PkDGdfSb=lUD0Zrb z>SKqkGOdd?{naau*~yP;RLMo>3qb~u%H{3uTkF(K_yu> z`17rAdpi9cEF#ptjx9r@$W-}AnCW_DMc8S(g2Yceq7Tkc)J5J4|8Ux+`6iDu^b}kT zvwg)=d3Uc(#2(+0gDfymKU5Mp5vJZ@_CRAs{G`sg>S}vO&dZ^*+1UEHndWtpb*rIr zJp(c@JEm8;GWgE0BgV_%wXy9DBOin7Rv#|YvgmJy`B;lHfD_%?iP?tvASz_PR>0xm zEz}$*_SDFFkubi8L+Ievs|NSxi)~AAw3y=_wXAmeH_H9ED|$?%w0Q*noDrx-qR90W z)}B;k$ZuR>1_*@9=e>D7bKeOPpB1@ZUW#B@Toa}99?i5fbP7xVR{8!whUsx2 zM|GQ0Tanqiw^qp3=4^RGT}9P9yBE4DZ*JdaB^J`3)JsjA(=l|-70z#~{4l>5jC@$f zK7hh_R-!c^>pG-lxQ3Y}i@HMTmo~E?@<9} zV=4}O`_L@~Yo=|}nM_X~!vO_J#foUfxVRFRRv8CfNx|3eUR5#8x*BX+V^TzjkGNKW zzrViQv#))gL9NHIl@;Z6z~K4aoiF6syd{6#8X+g9k2r@aP1bCv8r%|F8gIu!2BWLb zJoR6VHo1|TzBaIWwPsSc>rm9NW$7Vg%`y6e_1dKFRFuEF2#e>Qe!t*ysQkyIC0#EM zezusoh82q*Po&2nnZ?d^C4o#;J@h4#>(sZ`s+}(1z3MKP9RP{>l3ujdh5TRgJex~{ z{TyZ1L=$zIouil~M}mx|5AJ{1;^)l#bpxva=+f?_Ftd8Z7jY!fvmv-l_5CcZ}AUJc`v7X%MGc>{N~BT6qJl%#Bw2t#th~GM`9xj) z@Q|Mws;RtA3hQJhj`GKrJR^$aP*o%}Y z`gzVgM4W$kxgp2qn~ zU55nfHk?70N(PQ>L1=x8wy)b+56tqO^pqJx3Hz!%oCsVb)ZHU*AE!+E($72xtdkbw z(tD`2*6}4L#0Iwr`F&7Bb_J*BYZpz5DOjqvat6`{N>l7qmU>T}hdee0kAJzq^COG6 zcDGGaU`^|MFJ_|mor+%7(8$Z~fMT+_Q{20}9^;glJ8O0uo)%)ESznEG=q&bx6+3RM zcQSToVy++rl6!DkuVi^r(O+2&nWL73m^ULyv-h5b4qbQ2~(}=}L>J^b+Zvfb`ypfV9v- zAfYB9$-Zl!@BGf$d*6G`H_jN}U*EWYFkrA4i^bwy>wTU%=QF4JsZgnTf%F>m-BFq- zzY$Iy`{rHdiB6#ciIa_s5Bhgq4!DlT>+{Tsfa-RM@kuBqK~5ENY};gp=O4)SCw%B- zE=Ykuq3i@ zh2u0KzYwdTCW+b2d`0a%i;vt;$L>AL&mD7*K$}pO4=bBzb-qnCE_3W2VY;YxQAC#& zv&c(}VlrndS4P89a6v=bLJ8e6`8T<~^B*FCWYP#fMk|6_H2jVWTF=*wp>O1-Nfa_J zl8n;Xzga{vJ^x8D8tiU&1m3C)Munky634ZEk2uczzhiqPrYNI&qY70}r!2J~7a>D+ zG$XX<6z1vqo@%UU0!oP_a8NA$dMC%v;8EzV&OV4t0}9;{Cs;d;Ee9Aw;QsRQ*EiD%z@)j_G;D|>H4HLth{Un`yK)$+gvw1Ww+xQT?7b50@#lbu`0fNd z{<6EFf#O^kW7GXLO}Q=?vusAATX**QBwldVsb+rpTsC>z@#h>`J*SFGULF z*yBgeOiyoWMng>jbXNFFua}}q#oGz#7$fEr`O^ve7P~C(HkF8*8kP{&9!FlY=-2K6 zk)%Ec*;j-yYdKhl*wW-7$pNC%WD*(kVlLTRsL^nEEjV`iAc5LJGZ`hqs1N6MSe9+T z2no!$N&S4AjSuZ^u3<|E5v3>IBK%=K!Mv3OtBDa^mjcT~&u^@p^Y^(=f^?|5XJxi7 z1)v5=NeZG8K0|ou)I3tG)vjq1wUo4pIsteAgwq;dV)nJr4AlJ36Hn7xp;Jc&zv8O{ zWv};nae8fLijC9GnG-Yh&udIrgx-1RQiLz#o{_xmXTU}#v8K0y2lZkD&ZasWeR zvCwoRjDV#?2mme5ckek`gfofp(A8T}qmm1JcaT-^XMX3d%zkdntp)wL7SuX$0MYAv~=xhFUOQ?jJTIw&2$X%l{n4U)JZv^sT4 za&5mGiftWrcu{}ziH$k2)O)zzFvQ)Dv>O!WIV%+2c1Ye%yEPT9V0~ZV;=b};B-sY0 zLv#FbL*c$KB{wQ~Kq7(=)poA))RrXO5!mTHz-K^T{AVoUJj>fvL-6J8s{lYzld`2_ zo#KfnDF@=3!ihVD#MK{}OyS1W<6j|iahe~qp8S07MY$L40lnGk*EQ0!F_t{+34XX3 zUFH`}3p=Bv(81X>V=zms?x>Rtqui*TDvi7PPHWM&us(U_J^tNKg?y+z& z`gWI@i+ry)e;uTIR=JLFIv;P-YdS6&zbv8m_4^ry%TN7NgpCUzg=|45E6=3hDy1vr?mOs#_SUCv*i? z2ekShUV#s~5&JwOBu+D0W{_@uDu_z$v1Urw$metBa_`U7e+;%b8* zH4Xp(1{43yaYD;1-^q|-CYGsM1a)cCtR{slwjUt3qrzta>L zb-E}2gTqi6xM5jY)S_hTPrK?EHf>(dFgm#b#n^)bVR$bALetf6#gNzL$XuUUzWhP= zpBxE9ocbvT&}QjY|NgKz+6yJE?w8Ja04q!P5Ju$;GCNjV^v^|r|$+SN#!Ir>Nw@SaqWd}L`>HT6I>co>~^ zSPVT6r0oJ1j%`4Y&*Uis&6+?d0z*A-Dmedq9udX z25IWYNFbbL&;1yq5m+<#PBrfzM9Tn(y%og{3#2oN+@4>|6>zcvD^f+#A5yWZpR(F&P7zHnbJndAKIWv!I4N!^V;+uqplsbYNy%dLEj5b9x0zO&5P!lLD1v zz4cnWO74dI%A-XY;GSMozNBov5EXOzk(qdj{fo)x-p#yR>^*P&892Uv{W)nyOfj_- z^V@3eq`xM8=4{L6Yq9TWM}{O(Gxe~vx^n%MITGql8sTK&isbEYwbr}GA?BQ;Jiw|kBfhbd-dM4TDV@BZY)dDwg5hf$KJ*sW zd*u?;xj(_kzLoc&eDssZQDOWb%*(l_`bopB4SuzJLzBAOgHHl^f%3+$G&B9#AX!?n?&Bz7i6l?3 z$wyg!vKk^N8~n+}aJ?Ig-S#$EI~h$^qa;vwOp1Us$IiOUBW2H`-c}=E9>tEsBd~+tlJS&IicQ-I)$ml3<-*2S;_7{E zI(t~xeZ~)-KN^qPJVk>Cc>+q8M`kp!{e!z4yUSr)18G5Sd6QMWQRZxB)7-6D%(;w1 zk(;<%CIn-cF3EOoc5>C-6xy4DbN`NT^?>Ya9iiS0KFgyxjf+&DRr$6%*Uj2|qiaCt zgUX$e2FG@tLR`!<GSysfV)gag5I&9P3(>K^=a z?`4$onwGFeid61yjz!#^$Q*99f&>JDa%aCrxy;z3q?n35)9!m@Ik%&P2)WZK*H$c` zV5|tLcz(MJyteBj)3SNF+P5yFy?v}$FoN$y-{aw$52phFrHc>d40!458tPy~pL*$_ z8H(uK>ZHCxtT>Cv9mq@W=14@lZ$*iQLjfT?tvbJFH#@OZIXc*Pcf>&)8}E@0 zvC&V}XHK`>>Ra?(O2e%9Q}kN*O9!O#YSRU!lk2<1z?4Ml*?6p0#M@rxGnSu4Kc^Bm z=XW0h8aRutH_am@XHGt6IQ)28v*~OBhE*By%@5mNCr(S^8=f{xG?fx9WMA19I_0p) zwHcz>a?R3|zQ1WDdL8kOk*&|B^FeNU+Vd#HE{zTAE{| zFhr6N=PW0lnrLv1H$r*j!RyYRXh|DP=M&0_3Acl^`WX@9S2dNTRmg6Q7}ZCziecT0 zW2rsV&)ae=%od12B_|N=_$T<#Vo{Zim;;b_JS&aL&=)}Q0`KMxh>8++3c2XQ zqJ?)zxLzCl!FqQ6woZf?+sK_afs66lo7U+n$lNW05vFctl!Vz+o|Wp}3^J0;?)mjq zz%;q!Fy?T}0S?0te}%pe=H4?_l+AYj#`-}`r?86xTfoOM`;-%Ny(>zq z%!|TA<5+^8MsrY@lc`+7d8;))XH?T6l^dYlGo=*R!OKfy-|;Y#9T_yJkS)UO@(|b+ zmoP@X@E6Gej-i!6n|Jl@eJ80KH|~7M9cdCLYD9>%S#t1KK&`?zr*%OySnLai zko^nFD~!KdUP0^I^I3YJJeNAx50dpb8qC#MpcvqoBrj2{h;C@7bs?W4yT_tsulZcY zgpEFAhZ}4fZUA(G*9Za&ATuL(EcGwn7%<8f7d8b;5BKYfhvq-XyvEbYVz4=)15H(# zYf}t*ou<=Xp(YmxI46<;j%Rc}-Dj7gh26>tdv0{T^?dV1*@2KO@`PMVhBJAy4g9pB zHa=+gzGwtrmFyj(Xw>1xYD;%Cl+*hpH+9z<4b^5q=+xDwwhf-O7PHX{e&+{|G~+Rk zRr!jP&@1^${e-paJ-e6oEGVIXQ6Am8c(Gs+sSQN?=+6^FB z?QER_Sm$N%@2f1dH^Sb~bCZ?}+wY9ustzy;WWAT!G(zC>!Fb6&(o`?)B7nhSu_S@0 zpq%v|Ov)DReE`#Bv}GenaBPp67Z>f>(H`Bpp1ExY+QI#LW8&#&m1OxfBl@X2zZ9OeXW8=gDs1w8>|(K!{Ew5g zW&!`DN$5f8k>op<-!kBH>lR7vNj1saJ!jKz!78NV3i%vGxK%5Sn#ndFXE{s7ozBW5 z0Ikj5aAU#XYauDP3jSB$ZK!U!9n$a8*z8}q9_CxM`RYx*yWqQhPrts;=NQr7lSBMx zy=w4RMH711GTol#FPf)Ii0A8wy;YV!E!+$DpPc&ZF%FbH%ms#qGQBXoGBHv6Gs|ml z5X*rvzX-t1!F#}n9cEu)jlm}J#^)3N__%Nb_>n;4e@}NS(4?tNsvv$^BKRO5DYFum`!!aXkf-#m3QO-TXCjFf_ZG_MAS81DO@%C($em1~4uPz1t3X^OX%~D(qHHG z3jJE!>sF)Dl=JvqF)d#8$jf5~dJhyts;g5MeA{9C4TcX!b?n;89E{C|6Lv2Dm>>eH zSy%*@OAQ0SBdYWyHVH0|FDYDf|zP(&SNyKr;&rkBC%$N)sM~vo7|YVWSPqD)m}FXx6HNOpGWA+A1(fvp<|? zeBTfP6BJ7?%gqJQYWoA!DflDM1qm753`g||>B1O@9&BB&eLU@_{>jU?rMEr_N6^NZ z+Uy2=aBAzl*Ew6Hl!8j^B@0sW_)Sj0-~~BAq&4*#A}bO zf>An4(m{~6uLhwxVKtOrK?QdxB2zUUptgw3dBhtd@>-w@N7r93o;FE)z zvBV$b0E6zH#sDCQKVqN34Dt8BqkqC1$-#e(2x58a&$f36Tz6ei9zzZuONjy`T7iI{ z!sp14HWXp=@-c8Tuy+>v?b_AujlLxYx?j``IjrdjHrO&ChCm2vmKvaYKjjXNjU7u5 zi#$LKig<8)L~bD>uM{X^#3*4lKrs^+ytpwpQbW}_ZfCJ`a-ZTH(;7xBZb+raHi}wW zTedAe$j}H$>R#P4IQD7x5Z=M4jU^^z^CILnv^3~czsl_T0+vH!wClhbeKBMc{ve*I z-1m4CK@dP|z<0OGm!iK5 z)?Ppy@%a)juLS*RfiI|<|7YwpO{30Y55}+G2FAxbKYUSZ$EN^tdMpnJkpG}__CwNM zqexP~Jsk#d!qo`~AgsC?&=v^-A}nodR5Rw>n!rtzc{TR;#I8K&>3e;>Z*|`W8;ZRn ze#1h*7iGdT8JOG0DxQW9T-_4T&BGfTLVwA1!-)YzhM2K^LMbN)hX&m7q2@K{K_Nx z9BEU1Rwn91RBh8R=V_0G8jsYKL3eyMGVx}i7`j^^{LG>#r01$(1zJoOvNQ#Klf5Y4 zfZ~FEUsbiHj;!W*h@D&M64m+M!gmq&Tc691U8ey&CFS=-)63U~$DH%%)Rd`k8GWY^ z*#jz0tm?m|tY0vxdA8cj)!u;HWtq-ycY6nE;XkVhIY<3ARUTlLc?s92suFn4z}XH%w-&8I#NXH;LoiN zTt#Pli9~fvSs+XD9FiT;VF0f$N zZLW&?dnzMlgn~`Rp=Z0_<(d>fUp^AEB-FNq^)==wHu0-mEv^PHpqYttAjqrT1hx_byXYd&j`-Os|SUJ z(6@J-3@tdT8=IbMQ(AV*_a4dK|sEua<^Oqr3j|_k(#0!$_S5DK&aHL^e z!TWX)`+5GyxFcL~4%*SB$+EC8BAI_4Q3O%)WsRlZjD%t+rRfD`Z7h zBM?{XA*FxTo3cxZTYKtMvm&lSO}tFg6D%g)2evE_&^VU>BR07uCyr@gSMcUnfC@2k z3J9IDo^Pyn&^8e=d2U-&wARJN+STD@eD;>^oo?lL*kIE0wGI?uq>@4s=(}BPn+yFK z4f(YM45S(9n1Wwkapc@(|1DoUcowRb9EjGI036}!-P=Pfu#6!8YUE9X~8Yq<<18=Uv!_;hXe#A*GH4Z$*7(9=7T zFDbe>r($u{u~psL3r)72m5=btWq$JVTu11v?$UVlC!<1f-jp>sU{4$OhxY}AqS0R93=K* zAIRkaO9J~0=BqNVzJusaPSlT?OgTKq%A$?;)4{iz#MSfpoPBv>W zuUYnJcU=QScdWr4KQk9ZX6WGT2vdz^|$dig=Ze2qY-=a^wRgWR- zBaP;`gf0{5!+<~N=JiluUCBhhI+LX)VPI)hvjx|oOoIn8yCT4ywQ|AiGgw300yixO z888#FTFWdKqo5oFWz4;iS5-mHimK3upA~niV06=jf22-^NDXLVnsrA za>+3+w;hfj@PF-vr`x4cPZ`7LPt7vzYA#-g^lx`kO#&@mhnLQOK`fZ~EqFB_U+o;< zl$%eCMzQTD*~CiFll?CsK4#;C#z!a$gzO^O+2`snr?lnXluP_1>5Is3Ss_B;50yo5 zFwn)1dd=ky*KbcVoqq6fCM32giF6XC!T!TIVb~I%2F}h6%dqf|;S#HrO)1Rg1?f_N zIy-&~|7XfHVjW(oJbM*;jhb^hR&x77x~oZW|0+0P&mx3|RthG}jG|RwZwTR$+i+dN z<>p|K%2L^WAzA&CF4syC66B{BJl-Uq^i*;v*^sMC4IVikcqm>Sr6@N^X;+Ric!SA( zjL;$L)TE7hg#dAxzlQN+x5b|rZ)%#nD<~q%2nC7YKnGL$)L>0Y3_KqnPSaV&QkxKK z8cpHY1MtO96^&C8cl-swBSr}!ojRnz1)pMJblXIIEWJ?m^#w!0D~~Ed(@!fqG%C}y zZQZMBZoA477oz<5e#~_PvdOHaYYH|(|F`g-U;LVtx5_EcqHRi=?J zYBZ6H7qx&o%{_lT+jiWrl(2GUo|!%~6q`wm+%rg+H;r(U$(gp#kR4pPIaAR?H#vql zC`8Jmj`a}a19B#c8IU18YXKfWpo}70E_{oUYy~rwz2YpwME13ucP-(G)hZiYwutXQJ#QM|@MY zzL}dETt5*VO&{%0x92=e>s@YYh#TEjp5<@VbEHtvDO1Lf^<3GM?I1KBpG@>xNVbcA zRd`Jscx#?DZ8#4YqZEe&$U4_}haZzEK{}WY|J(hA>MPwGi&Kb7jrzoUH7{A?;s!+B zPHxa{LYd7%@w9{1d1}cWipv)C^HW3!i#iDuy)w6Bf@qiQ&n#E>{71q3*IJun%zqX7 z-~PlUdhFlb&^rH8)f!dk_8ab4k6p;krM@n%EE!%YsxRAH-QI7f^RDwhxQs=2kqSL}L3*82k zrDD?6Mcn&^1g^Fr)@n{$X7_vMdhXMM>-yOVnk@x{r zrh%BIy(^W09{4~HTEx&zrLiWzyHL^HL z7NBk^Zs-0NIWk(O>b`ML$vLr*K=e$)>V^|rNuWP=~nr@O&!tkr_ z!R^5Gu>$w*$bQJ0UJuK1Dp|IZrV(UBt8JmgI8Qq%nk2QQ`1rWDC{NA_c@)$?U)p|3 zub{kV?xO=~BWG$zU1Hpq!7ZCHzWcUb9b|btSb*bOu!9x!$W|C0$BjT%vF&YUG(nhi z2whJ&;Z8odI5^}f&4eHqDScMifKYBYWc%Uak%PLa@@vykcfi{Y#AWF9^)c&Aw`8N# z21Tk(BU!Z4y#7W@_3jJ8bJOFK{UL8|J08$YzHQHB<3rk!<=TGun%#{ zbC)c(Hk}1#vUeIJ^9G&;p7xtrq+4G6WGg@Ds$Bw?Wy^ND(c(PQ{mY>+|D;~*$WAW)H)q{ zO8(x;?-x0D6ea50wf6;yU*g%n+jG9B0;;8pSB4*RCf**uc9hQq-yut_2Hg14V7HDm zCW>5ZCx9RN%G2kr{sIQH5(r=`w~NG86U+;EOgxF#*L5)NxN)-ic-kx4Bi(Bvn`q$W z=0sppPhC-u1k0p&@iEuk&7)}3K?Do+ujd!d&32^4wC(nWl7JUEVXHr2yPh2REpyKe zuQ{N3CeuZ9%jK~stt^4cTmYDHvLl)M@kYm|^)&9}Gf9eciLCxAeytdAj${+rO8;?? zWv~a`SLE%;Ev)QTb{JEleN)W;)xOGuk%hGR!>kv`pEa2hTw+J3C!96fm@8%Nk^ zqo-`Hdn#^AxqwLkdvY2s$+2Lks0~n@l8eN4kO~F-$)&-6p+5Y$)Y;TEIO;CX{@kg7 z!s7=Og?nPoDf4zzj5Gro`@zrF4KVlAMse?M&%5=g%d9ss3Bf%c$O{?(!kDR~#buLQ z>yl(Y)KZY##o@X(-9~Q7zC&v0U0D-oZsF3;Q@`PES(prKv$+1C7x5OX^H zNRuCarO+b_{-AEcg+l-)xpni2ys)kiJL`ut@?_S>V+S4hqxKA4=yNG55>lI#A}?{W zKQ1Btxo-*EHfeQC!YxbuY>huJ-7e>&WlyDx?m_CW92DAC76-+8*xi=1V{fIr%QuXs zrn{U=#Gzo5xuhifxSXVASEl$*{=1K1jl|rmDV+fc37O)}fz#+`0q65ZW0fQ;kuL`s^hMf21$DQ9XOh>O2U5shvc%Y z=_F9t&*0FYk~X;Mj~{mXKP}b;an<^dPeR8C$spqj0A~M{YlHjq`l0;7Ui)Zbs`8ibwnL zQa_xVz1)HLUGv1?q(vLHwBIz*;gY+Tdm5`=W`?|s^n%{}WKty&U4J7ueC?9}e*=d7 z!fqj{B1mNs6cN86XLA!T`16~So1@Q3+Ou07@<%~4jkL-UT3@F4gk%M*`1JPuTb%3< zcR!=i&juvE=!(ibCepWlwFH4s0^Ff??2D%Du(`8ha2wZ>!pX6(O z5^Ax}$*+U9fjG0NEnPIX=F7EOoj8d`HeDd5>YsaIBQaqY=6b%uG~VC7YMV5pIKQ|< zhD_KYn8C^-|hNzwK@9V)$v7C--^uZ=`w5KK5PYR~pAejR;a4@McEwab^s^gQsuO zDr}`rK_$aS~^rIff1t3@=w$Im;ANCo(WOLEi9|%3t|idxD_kQy!U8K(Shk zq#|+;@_9Q1cva<9IgXwGpo4J(e9#=sb$#WaW3G140@;DsJizIH6hg*xUMLGJ~R96@13bElj365eOz zPW^`0!d5UL5Ce6BX3$s`&oj>gvqi7!ds1VmM=Vok!zp41HY;QXl28&}!PYWGj`n+Id!MGiJl=_uj%~Uog_{ ztK9HX`5QyI1YLHqfGG{PEdUn0q6eZi2?6E4*e2?4!*g&Yf}43gfw2{J4hC6JZ=bHO zxgRz#+nfM>3cr!*)NtzSr*BSoa?KY#L_oDri-R)02g&y4#o1h#W>VAbmYz(_chS|M zE2ufUFL*ghWEDvJPe{ew3iX#j^17bm^4*L7ZV$|G4Y2Q+MPMtdY2hlczZcfn|Kw8} zz}f`gHuU?8gE$bqI*@lMV4fY@W34SxlV+r4_VrOAQ-^q#@A+kgy-yz)Z!plQgg`8- zN6?1Nxb>_L4}*k6eQ1@dAo|&E+8Ma2i}a~Vz{?1s!}^9IDsD#ydZ5VqbLjNyYh_2t zXz9(Ru_8oSWh2|xKK1cn-<~Pcw}*-Tp3mT2*^y?;5k2`gHfv3)Z-nC{0q#ODXMN~} zEK|Ye6m`R?&b}y~>serUtae6DqF7?&jcdt}t7b_L3Yad6eZre!iLHd|x>o+@Qg3-K z=1xt_YPut)c5PVTc2!~{l_a=K&J+u~Uq-RZ*8WxDpAGdt3tyak8LuORI6%}pm_F6) z(Nbo=CjiOw%fjZR_1L65l3#TJ?r|({CFTR?7Q4u3MaJUgj%t5_M8Tr-8r*$p)_POFxhi;J-8?n=8 z4=~Z>NcSCI6cj$)m`!ZDv9gTef->qQldG#vYe#%j5&Tbl6>ul zD59Q3#dP}bU2yq)>iI_Bojv;$1(HX!UHWRYcM>l1Vh{vqN?CVQdG+&WoXIk?sWR3I zXnR;%p}vIMb==H0=&mleByZbDxfaXE?8ygVhjn!Sn+IYAqKX1E2oJEd$LchWR-&K@ zLCDwU1}7Pm&TE#GM!haGwQ@!m$W_XXKKjf8`+w(FfDyW5$9BZN1mKwa3$dX2cVgwW z;y-N>vD9)LhI|>Z>k7uBGibsaJijoP;iqHMBYC@lDc?(dAq-x}2n9r(nq>{rWpnseuW-?4IJMf_J5iAC0SW(4ig9F_&@ z*X7kM4-`&dJk=XX*Q+h1QU(_e*U5pfe?C9tzi0y$SpAlv<$bx8JS_EA&f$VU#!I zS*uJf0K5JyxzVCwYK=6_hR6f#%f}!%3Liq#0W=Yq&5Rvdk|tCg@fV7u4{>hW`%JJX zF%o7eS@F2rw@&hok*n(y)JMF#ANs=pV%B{QhlnHVGb3v4FwGdLjhSom(`JN|(VxrI zVHHQ*LOkOoAgys_4H;{MYefAkpL)!Q9k2UiC9l*G+86;ur3ddi0RQWp%wF=wVCXmM zdI0S`F$$0fQxb#RXmy%E0~}wEcx@vv;yq(z5@a*L+g25M?XAiA%1NXkSZ7{URNF%i zc@QCyb*x;pUZ>J>NZSgGJs+U|o8*;eX}#&FUv~WLd5HTSr%FSU>X8+8=@VS=)?~g| ziBoXyzN?poG0{Wk!eBh;*lCtFTvb{f_G)wHYhUbv-GIUrp(}uF4_lgH0UV|gR5*YT z0u6QfV)i!LV?J-OX*9SSG=)>5EGJ~3I~Zkl-$k2yM{%nE`r-FoDmtSzdYoWv*ol7V zxbS+nfnKw)9v8aELnNB$5IDpp&}40h96wvMiJx~$ofW2`(1su3aSEi#oy!XgeNO;1@wj2*7N*FFY!v@Q0I zRkQRVH*Hv;_pCMZHSugy*UOlk9rvPXwtJsGX?A0WAcF_38EQ$xnB9j5el`WEZNQ5i z>_|Svard+A&UFP-15}Wbv8@H|lz`j+L1rh(-|s}5`E>w0 za}0>uo4qji)|aAzeN@c?BoMt!4%uef_u7W%FznJQ`zG>|?`Kue2W!+a>(^z`1y@Sj z!^l#P9!d6YAl(U7g+k!daN(RE&7J=7i`*jJB3WB+{n2MKf2!NVR8^h&!c}|jfx57c zxEN>L&JWc~Bh1gPRJ>gnf79*b|LL*jOP6L4)VXy$D^8mKoWp}cA+Ie3MxV(vWp7YS zz^5dAnC$Tba3Z1)@2RNzBR*=@L^Rlo%BE0I83d?_szC?Cf6#p@s8rrs?GG)GS!RT=T>Hp zT--S;mk@s&-myq1nZz8L{hRD{*$}#Q06Mw~gfu-G8!Ap++TeQnMNHuvSiO@2Db!^9 zu%Ax%6_dT;*I>`P?EME_D>)5N{yVbyA+=t=ih2~TTI+~dk0jJ3H1n`EkAOvi~?Qk^|ZCI9h`{m?1l~CFG@{Hek z^~n3N8h6N8Rh}Zd2St|SQ*-{)v8@AU4a&F*qzuiYG$mL4@WpXS-Jj0xkeP|RM2ianqj7JQV!rEM+zF~tb=m_bhs^;yv#^@3!AHC zA)PV{;%TxW%H?f(5MPK_eLDQc#Kp4u<@xz$cXd~aB{5+U%9%Oap#gF&m_4?q0JrD#EZ_3efaTU<%uW>0>^+vjkXJ#p7##o! zu4=lvF$M7DtxW5Z(k*QIZcNHwo(ra^Wqi%lcP1t`F~tbOAA*h$rok->FRk27Y{_~w zt;6$*8d2@w^TLn0+4_)qW7kVoO@Y{I5fAU`?Z=O$u=Y8M$GC>k$F;QyP$J^m^zs)= z?+9Py>=q`C56Bm^6hsxjB$huJbD_!V2V6Z?8Ey=-5a0Nv_~q{Q`rb`Jm2z{#!d{KL zYq^X7-;lu-4OYGGzOilsh4&y%A(s3?s6dhOm93)b^q?@dEzavsUxJ$Z?||pk*Biq1 z$J{DvEmj}b#80@?usvMngE!+LjFo^;iYY3D>ro4s z&?(vCVk>Iquv!Py4&3M95q(7@Z@)xsab~;NWByngGuqV6g^3Q|OMZkBwYiNMXxQE~ zm!dFN43EfnRn?F_J_>7bT6*@e2|o~k9bdc?8>?EWUYDCqH6X2+Cu_1NEadeZCc%_m z>B+1FRh}+6!IkkTz5Wh=Pm*AlXDHnH@WXV37gR z3@-AYn4Q+#`|@_$x%)8GdcK@*OHi;FK1Ub{>_zf)3LjKY3LOinIS@p6QR%T_HaCkl z>*h2bqW##@^Af}U2U;fjyTvSjne_5c`C*M{ z^MXqaHFbZX7H}0eWT$N1$ywmOd8F0VamSkREl}K`-`#rPhX`2svo@j09-@cJCikA@ z_clP5f+aIYmpUOK>k7dy4^qWG`~}%J_)8sEGTNg^{x6RL4h&_$i|GS(l^SS?So4|y z|Gh@c_+sOn60rT*TLt|4luNM>QpSEJ1qJraxxrBUe+W%shy?EnjDxqrX>TCTz>WQI znEwaeOTd)6{vPO$vHyR&9ilnQiHOC)GnKO$_|-NaP7984Adk#^Bj6RbqkEgHTXedc zW4ZKp%B^1Zkng4Ea583{6~p%go@G-MYlj+GN{E44SP^8Kz!l52EkC?S1THtSr>Y-s;-f%sQNOa61U#SlgYO5<&-B3{Wd6 z7o7i-sPSeRrlH0l3fbzb#Bzfu`PmQ5w9|DDxgm$3G02xVT=fu97bnez_%V)M?Cf6x zQnA8+Mj=G^!IQZ(2k@~MfMbtuV^vS!i)R}pBUZ`=^6%c2#+lQ%sxP4ud-=j3aUGCBz75X!NQmWW zGRAAmNa(;^BstMo7is0R!w(h7dY3gDO@yyxKD1?I=zF4p*vS-Iv1Xz@P}HOb19C$C zAm8JAU*%pW=5<+iX&>@h0cL8~*1)@LKw-FffY`sWx$2k6*;zd!xvPqOT~Pac-NzG} zM%5qEzXU9lq-tuA8f9~S?mcE#6$a?Jj^zO3svXI{9K)facvR6Yi{f@BHKmziH@71@ zmD+mQ;}yl?Dl%Erequ2lL$0AbYe1gjVuM&A*P>|75po0hv_cP6Zbx`)_W_02Zw(}Bu_yx!zDUGJC z08E40hT9~Xrd+v0UU#EpHkmhR#2(^gH_GI}PxONHk7dPzLabl%aX+Oy?1J>e8-X6Z z%W#iYuMIR(JY>Fp#E)QJ+AL)-sw?{sIwrFN8Vgluuq5;5J-q!Gw?fc^nbPUyNWfV0 zWARBqr(c#+Nn=pf-@j@Fd86UuHcZo;v~}7O-T>yO#+N&PAX!=l2N8a=b2khCCr?vV z44pospM}Sp74bFs&7`z^%9$du&4}m2`#-snDVH-U1DTt&JBwe@u9*rj;G^Pp6Y}VS z+Imr|HlYPx`AtKS6kEUVUw>z7Ljw{J0eJ(QQwv+Jg0;~fhHpO$328YfMsVE4L>z=a zI9Ivca?>?tV$Ko3i`i1_5l6T^eZ4wI(G?+i+p9E{@m3k z^ZZ-$6q`E|JCZuJK8H_?L=H+MfO6dG#KvM9{wW*zJWXm;-YU3qIceI4^Xp^l9a5$Z zTQR{TYdVGSaAM?^9IMW+Po03v}i#$%tm2*e>#zer*b?%H`@H8^>36RbrWcCN$WZ{a# zgxmFDORd`4RI_-q%o~%v8gDr7xu0!5ZvBkT!o1M-_(nA43tGx3r){)>i{UNS=|(0+ zuSR~S%p;{*DatlR+LhxN3NY9F0pe8%yJ+QyG#p!1Qa{KSq>F2MahYLHrq#gal;VJB zZyf^YEc%Xle!dvc(JzMvVY)V?nJi<#W3~M_qD9I#VmMmqZNs#+*`^@ zMso^SgU!~+&tS(-QBJZrStSekHw;-~ycdGs)k|_9x=_VHHTLaKsoq;^#sz7Zt;%JR zijS&n91d^OwB~O5V?BKA;ZOAFci{i}Gz8HLOsjmwzjOf;%?ddvfc<-*D#I|`NFru| z>UB@oLlg!(QJ&m_?6Pfs6YG`kQ(_t^Mz}eim}om2rG7NG=XhXpOZn=}$71G67`f*d zA*2I<$?W{tAs%!>4v+TiuE1!Sb~%j@4@8G?*GC^Jq>*&9WWZsIxI2h1YW53 z%MXKu7r`8UUfS`^8$8jhyp~;k`3+K^Ta>>nheLIUM+>b*fGR2C;<5$)Lr^SC5-0Q) z$`2bTZV0M+mcIIxm25H?hS`1cN|H9D6t%y6Fkgz*H2&=@^!9_Wwa;p z54t;0iTgzx5P|az!;ZC%<&~+7k!-=n4tHGzH&eF;2|8Ror3hzEDnhMy{`bj(<;3f# z7D(E1^+sT4B)8`8nS~*(Qy=NKgNJXTSqR?L_Ku8&UL)yuQ{oEeI?R4U_JmBI);5d0 z=-#G7I}kh=554gDjNnFK$0^5#5mRW&`z*(dk~Cy$!>?Vp8R-5tC~F&EEVBdXguw;0 z+;ztvLH8wTynb*o*B8>qe%*^8p=TbQId?4oI%2g4n%v#n0-fJqw)OQlq~DJ;XQzmb z{#BxLETIQMoz=0%VFXEt2t2L#PMguz&ES`tKv2y-^30!cHDdn_;{Sh4{r>mZ>_zZ` z($F7twH|JDz_HTnx1$#ewBa{2Do zq{z^X{sLZ*8%EUsdBxCwSMC$__MzLq5OXPjQ2&2&)XM^R{m1}w4}riHf963iph=f2 z=Ez6LKk56vO;wrUG_9s5MJOFm7fQn3_5rCb5Q2c=Tx5Ox zKIkRe5Wu!71!v*nqo1k7#Ek;tkUe0 zlouWhhcPG-v!4tn02kWjrVl_X?=ky#pf7DGDnoE^ zkv^{e^vBr^h+6-J`P`jsHrXLM!4*C)V0L!cgk zOmdg+fDZIF_tBa24bnonHw!Dqn^uC*k@6df%Q}~q?S+LXO!Qn8**wI9KcLSNey@8iTD#(ux8k=QxVhwS1wFFy1sf zrO3{g(d*e@7kx|RmM&`;_iXF43V75yNca^z*hSf1{ zB8vkXq)y#(k;ljpLBgHhR(A_VW4D0^;kstuE8`=z%Xg#gF8q4cDRT-=NyWqY-_vyQ7Vk<>PPitz)mb&-mFcxLSp2TDz$1vn z7kot199S$1w{YH6Hc3l`n1t}sZy#(n}L@CmVsB{qlA@rzp z5D@7#6b0$hrAC@`BE1uO?=AG+YeEeHvc9?Av){efTKk;+{W;_N@%>i_HkRCv^r}Ybf!L^?C~FjSm+CiCEW6J183Di%6v+ zzoaq8Y+PukuJw5n8^y$%3_kxHYHa`Ob$o4Hb|2h@9>_-5xiLNtFDX*d8J$g749nHj zEfjF$F4}HAh(*3oZO^I-sQ`GB_I9_q@wK9kCwaxj^0_-6f$6tixeEr#`=2%>*dgKfAC6cwcwI%;3{4?b9D-YAXL5{=F$cbv)F*xXV*4{GDEETw< z7~`wnA3HOe(;s2ub%QBC7qRmqMm#@nX-#ZoAxZkMP`|d*nJnT{?zZyjt%700;2Euv z=-QG4_K7y*VuS`@$~eF7BU@dyg69PC!|veZmy z>J0gpox?nnXszD}UF>P5NM;3*$kB19vdD}phj(If_9Z9QE zb%+={I-mzmg!q#F*Z$RNr1Z=upKak{$I${vq}6Ds(esI|+84YMgBSediUMo1=vI5_ zK@Oe8Z&5$uo^^8YYzn!HXY4;ApB?6zrkr|q+}y#p{KPFs$2!A+2T@~i42JZUo;**< zt@5#W2wT*&Kam2Yx7P7Vh)Cyt#aBPnY@Dj(nXVHGexHk=&@e_`h%mug6eDr}YW0p} z;gU12N5C)ccVp{60-2(-ULu${V8ao;u>6_OTh>@pWl%^KAFdk)TUo_nP6o4M*GcHQ zAbE3mKK$(^42v3XqS-4>mZ%U2eh0Z=>Os55)aRw1oet(H>tb0~7(7H|$`WiV(+2MC zCQ=NN|DH>`ps+soTv;!;m9t!M&yJ46MfGefuRHp6XN4!IY4Dz&zSYw$*Y&c_RYWg6 zJLz{x3Y81#=mNb4PZt2`gHbQVR@>k1AI3+^h;>2_Ij{lvW)uv^S#UI{K5?QIKGWN=oc5FItIBv!U(Mb@59K)&Hg5F?=LhJ7> z^UVdj$}|+f`QN|b+--Wtgd@S|?;ju*h1_Hv>THX$00)_e?h@iBh{Wb)^Sqwat2yTS z{R-J5ktR*&bb~m@@)e)TwfLawd06)t;>sJCb7d76$E|=_rvHqSt3*F3hsjEH!3Soa z-tDf-F~%Zppqx+2u~+uRGi6z-o+j0?=k;Yw%B<6hnpu%hl@)2DWz{-xh-6jp@o4Z0 zGRd;OeL&VEEMQX#G6DNoFI3PwQHIRXIzpU$5I=u^ep1OGU5D^g1*>_rE-hTf?(yM~ z6eox*q$T+LE`ZG-CXvUqhy*wv}A7Vy)TC5f1yRiZO$ zTM!ZoyXmuDQ|jtfMtO7@HP@Cub$Z)fus_BAo5c~*BK1ZX{P-u(xc)vWGQcn$!l$ZQ zB_tN4(pRuL>norxNCa@g2DwIV)b;GnY{{w;E75b$mrDEDWet!5#xcd)6OKNEc;=B9 zSNP4oWqVeuQprogM?@YeMz7t+Y}eyO_t1flggU`!MYJ-AZ&}Ak+>Fvb zTN&A?X?Z627p&Jg4si9t{E7qx@@vF<`CdJM9aL2~L;q6DQRDJ1(ay(^2SA!uW9ZdP znfbvmVq%Io~Ryx<*}*D>ak@f&a28`U_JVYcVY2e|@K=jqlsI3Ckz|00t(Y zw>O@nW$Tvq*bURbcisH#a;hg*pB<$tYzy)mxxdTzyZe^%C9!sd$g@wvy3o6g3~tB{ zNywK%kpn6{dG4S`AL?rv(h|ASZF2{2aI#sYhwP6riZ!`D))xlXOwHCu>~C;Z=JR@p{$EEh)&9d??OGhr@&4}uA&b<^ z^vb8=6qIM0fJ{%F_c~A#!vZo&doa=F_~3sq9AKxg3(w<6ZZ>%NJu!R(aIFr>%Pp&8 zPeQMN{mT095axe#54+nK+Bft;qN6Pw`3zj7UG{HXY560GA)+e>^)0>A02UB@7T`9x z3wo})j3(6BA22oiH_C$UB7kxILs9Bw?1(=K1h${c$T2sQt6paO&2zZ=W-Fd6jSYM} zFnT_b;TH%Xkvcv>jJ^5Tuk4zACKZdH-PX%Bl=X;rYrsFKkB{^fM=uBc08#=DASDnj zx9MRMi#ZB5Cjs4%^E+4+Lggw6pvj#X17v%^Rjq%g5?g@YX+6~4Nt1GXGWLbO1G8qz z`pv6)5P&KT6Lv>#c_~b;dVjar*?_E9d)?Sv$821@R`bw<*RBYA=oT}=%aAp6^0*wv z$par*-m2E~+w_3hH%li6Ods0|TwCAg0VoF;;ck9toW-tvEb|f{%W@2`<+vmJOB+a7 zY~&bSz9`egTkaHk$3NGo6HEU5vCQ|fs%m6E?9&S}_V)_9hd$5JpK;u4MifZDB`&KZ!uzD- zox3h|(Xs_A4K$;#N6Z81ZTD|v4d}fs?l8Qb-zklQOgY>!ELh6QWbFRvW7h0G+Pv~} zLrQM*paPjyA))6bnvV@yG*8a3#1x{mQ@{a}5&d7J*V&B}w-oIKvAcOL5N$Fx=#@NTK3e7o1jM6SbxK zh#pb_8dqmOKwpZXfen6!f?@z3M=n@$l=N@#^SImO zZ3GRo>GRH=nxNwX2ZbEP+StfM3qQG2vsn0=_C9#r0EQ9z@@(coP(zh@+|b@9E9sOefM_R0nFG~6&xVy&DS#j7$xd? zQ_1v+Kz(+m3&m;qT-A-z%bXjSsCI8VF(S!CA5ffSY_fMPJjKjCj+cHwIJKa=nbRq` zAwG4_GfeJeO9)9+|HAJ-Wb`uftx^pW!#4|?OL8L^UEJE+W4OaXnH#aD_ zxEC586j0Cib>8lD<&;+&pCGx^aU5YVV z|27Ob+Yv7JD`?1doNeCHdHs;vqCw{0u|OI=$W@e|0||* z-s?G<15Yq+d)!?-akO{$c0m5e4!s0A4xEMTNfYWRTbzofleH!4McvJBnZ~5|PCseROV<(N{ZK@QrlN~yN+pK< zRG)HpiG~;s+D4&TnSv6qy__h&!g-6}U#A*?s`5C4=s+;KYUU~Pgo&e@vWe{~5QBoiq zRrI(G)`-!(Cr$9CW(xD@nbS$xMX4Ut{O|~^%{d2A26#>O2$mNhE`J=W z2lz;h7IOSXe~dq@7~d|Qy0yUnUhjNEMy3FA8s#Dj8qVZiGsq)mS-S;zEaAr$Av59-DxoGOV>eBmSlM=y|ByH-hL-64H>j zj$3b=)yV~%hR^w!M8y^E3#;hj?7F+>?A>K;knF)~7P=7wMiPQK2FjTA=bY~uSAMMx z+>vQFAJ*!%m!@px&(*fOTW@jELmvY^8;Qa~x+RnkhOVfdgt_P-0nRs^UB2K_CQkN{ zN%;^GrYRp?4|Hd5;DV zrqU5z&r*bGt|_g1AO$vPTl56PT?F$DF z6Kw;?500=jlLwz9K3rq@fk=^tt%zaftCL2jNx5{XsQfOQT;e4L#&^=-Md|=6@04R)77g^<|LYJ8mFd{Gh9 zr$a1f(y28zL{#x7Cd^I?te4C{FV&cE3E-$-0AwE@Pz65&^s!(6qfsCRgnU;-_TU3x zyaxcdbV%Uk_cs1~dB7<6?|!)mn}587*8xUQv8D=Jx5x(>E9P=v{iw8RAGrZ?@s<5u z3-khfekRzW?K0LFh1GvOnZpNN0is=8aIa<^>)yGxHZ96J%Bv3ohAcM!2nNTE8n+$S0~5Za`476 zENXDc*n4S^9QxER=P^o<_OW7<*~%Krcm+Um`s-(<#+}JzI$A4s5cOfyJp2|lJ1fOI zzB~W+(nWy39Q^vaEV;KhPrRD9qiuV6&y~DDAF#~dY}>eT2bua_Gx>Avu8RhOnp9># z@yEY3mcvY%DoVkm--SL2`|ao@b$;b8tx@_g@#u7yD4x0dG<&#S%478PtumMDuaRUj ziWTl@g9~dG^*i{3X?E3SfP4022YJ%703hfhoA^~EU=6X`QUfJ%b2(MNYxB%T%&Jmc zf=uFE-I&kh1$ulI6_%r0+w@$WZ(F9og0X{iY+YNzLdU4kD5InTbZ0L6xC_{gZk>RZ z-XfPkz>nx>J8B)=Qe#@U>v1h8T=u@8-4^seorqKvix{Lf&_Ssi9as8G{cV zI5Wn6FJj|)_cm$_)Ps~*P7%S^ki>gc>L1Tu)wS+=i^h*-_SKBWRGd`tPRzg^usL|? zBIM3CXagd;aA$L#tJFx!sKTJm*@0Do|MBA8J8uJql2)bF(~XK!OC&3L&DS4vvmjM~ z3H@JA2>Bv~onwX5$wyG7aI5?EoqlJ7l+HFIp;k34w7@F<@wXz|qURj};%~x;cLw#b z&$sz?r_=`rn7tcyxN)dWo%E$!F3v)x~G2~T0=HTiBK(4pWB zIIk}X7&NQ@l0idx*QYseJpRC`G@NZLMPw4NP~>HbruOzP&N%;Aab9*pqSpx8Q72xn z**|VXl6ZOn>7p}8^Mmv2HlPy>2=8Zz4EAB#^bqu|mtQuusIH5C`}XL06fI8W`Ws(j z4l0TYCm6x8-hni`rziF0L^>i(CTfwelxINw>O>t=8RKA&F3_*`e(oN}ZP?Lf3*O?9 zSivA_(*tt&vsJA3f$$Q@tIJ5*MD3h5&94vjRI=_Jl__?O)rileQHbL-HK0G>4Ky6p z?yP}xzB6}rS%A6Ypy)R4G^zBG5#3R;{%(4HuIU~VrIfIR?7R*bWPK>U5^Y7p6I)k9 zlAp;D(?oD53sr#)QW_IE@W)9lBrZDUJ$`za_pIjL-5Dh!;2SJTc#%Y%tp6FWrh7iS zkKla88(MZ6iU~uUr?Eq3)M_Ht)O6%$Zj{n}#~cZTRUzp?sbI}fsM^J&Gm6t(Du;;X zqvo^aTE;qa7Q)f}zOc9R1AgEyULndXC%rGoeSGXffTnr#1E#*QyVXir5V%gGIyBUI zTAn-r@lg?tnScip4K1=A=Yeh^`Ds1I4#D8-vPCOFlxgP{1HzuTae8ZANMeZsWUv>#S;nR8ke8|r2jWrT`40)Bkmr@x7YTPzqU~-lNV0rm z^!6>@L9&un=A3n{4(~4Nh&y2I_K`R!_o2mW3(+5DCrd%&jg|=Qob^a^?0No_k&&gO z;l1ET{I);(=?3@vGK6aOqHj5%=(uwca6X8^xaq`+M1_PO#p74|ww?h(6dbL(4go~nC~s5lzpGtCuS)Y!0nd#Ng>lQ1*%<9bxF?P^0@ zj*Ym_Kyjx;IP;XlCc(!K4fWcVEwOnGYjofmkf3i%&o4IW_@TJBe-P|#fh>In4%?FM zQW{;fxw|5tJ!N0ZcrrTBC0wj8ei8yrvTcqM%oroqn{kT z++y^}w+V0Hnh+#0EDr*i9Z85!(%=@~dE5~`5!HHIb5w8Q>V9M1H4(-U$|EtCs7wy` zsn}CN3W$Ql!L{7It+)I2#_)UcXM^}_1KsSs?5gZEuAD#8veJ}`50>mNW0^_ujR423 z38lU`6Ip_d?>Xvw;+;}o7rPiKKwl01`stmSH97{!DVprht9j4cA!8!goz{st=x|L) zDr~%MnL$%2)cT2o{)qHPBvFdVSz21$V&#!SJgvi%R^_#?EY%!IJ|%_LlKAYWMfrUe zwQ##L*6S|08l#if^x`VbBAqe!3Wn$BQ}4WrB)h@a1r2$cuK1)30BFy20S&q$hKS^GQiw!s zelX0>sejg>nQmdg;VW~nD9-kqZ#Bx(&d!Uj?3sxoY9ZqQ#)(n~FwPg}B%G*AtO z$gg;$gHp6VlYjI2{^wkuPYp*^+!597o(sVQI#t?q1dj=Ps6#z11sFKV>-xf>TWRB+ ziSMZ`2VyI&7yPzY^}>FSKm72Z))mU-V}o@W*+vaPUn z%o+cm)tw8E)%emQSI}B8`jxO{cZr!FEa4?Q%fAm$7+C)Vv&wG!Bg;c7*fHo*E z7~ZlprN>ogV7IZsKM_ft-LQZ*v?sh;mBxR8Oflp%H_$@Sfyox>u_I>oyvxc@=>gqh z<@e96%#NLF1J*DATp$zJtrw++hacyfH~6pXvE9*mle+}2RoZUgn;oYT*AyS)Eq+Gu zj)0(v`Y4f>Ja;sBfU8|@Pe18+n#6}=Ot<41dQj~U2nNn*hUc2Sn2GMr!ZG2!Y zXlkBF{ra40SiwS%V+7IuGTjwl!cG5hruDZ=rO+hkx(EF)8)U?l|Q@bLyI1#r@t5v2?Qk0^t%ZJ4}#3#A6H<D*PuEo^`x@K^lYhk*yr zW#TJwB5I>Ff^hCvSx_E*qwMH3!>;+<-zUe267S!MdJ0=R-E#n}b3;mFTF->A`s(-6p^wbwK+J{>Q;#d9IXDKJ%+!x~QObI@G^XWmS@0r5ThNj+k_||)_$gvN8Zs~yU zEFr7PgYE4097UQ5qSXdo)7O5Ne3t^LYoQZ^&?Ui0?9o-*UXj5PDWA?wGsb79RS%Gx zKJGdJRuIe9-f2mT8^J0z>nDV*AcGWN2(0CIY$nIuFL+89 zJ-I3Q4pr)6aB|)dsxWmLfO?Q_t%vg0`sSI%#At@SA1*s&Wo5v)Lx1jS{UMXFq9&6}2()-|yp6SzrDeHyv`;pm7tY8RM+59jlPix^1zW6kvr(7ic2 zQrvR0w|CAgF&l4NjgzV2UE;OsRTQA`L3Q?yk2!~vzLBf2E1z5=lxT7_?{EXc$_vGF zn0Y4?Xv>|TiKr#K7Sq$4Z`I~K>W-~mK5Ka&fb-+_PWLE#PA~TqJY8En|q< zrPL`|Z<^PYn3E3CwOKb=6q9_X@Tmv}ELIi!VR)oDPy~cKdc5jbcw3f2_G3M~Bd0S{%1GdSH0l!f@X`^D<8Ye%_cpl28O=IzKx+~J_&vx(dvA;?Ze`K^K zB^T&?Y0UKoOeJrD60D5Q-W?sTOA8wRiD8*SxbtzZ&LgAw`Eo0wg?!ANSUm}*@E;mb z0ALm4fgZg2vW9o1Uew-|oGAO~LxP1pcja2$zSkGPM7%Q3$c=LHuPf~dRsf);C`3Dp zg13TZUf>$^5Cgu9SB!*=z)SmjL2Xj2>|92VoIx7=0#|(NwWmnmw1-3W>??zNpu2OV zQhG&ZzDhmqFKRS@?hz42s!2N7R@ZSBSZu2Q&BkI? z`X-y|l7Ou44@t=y9U@l$J%8fjMr8UZIAppq4B)!n_!FZ29}!*sy=u996(FjdlunoBgJW%XLd8z@t8aDvJ zrTmv00UO}KgJS^P2$l-gheXyH)|S_av1IjLnbXM#tlc^PP^#S?(oLLShN8mc8R_JXNTOYVC+S~jC72dmbTmb&-F%^W= ze>9U2eTs4v)F`u|XZM?YGL;hPB6?`w{7oMk!t&Tm{BU9|a&%7xT;wf?z`C3W-eor5 zL%K-t%cxg4$g*Ie*_H{ow36)k*=ASwd+x;`&a+<9q0n-LwZ2 zu{N|;l{^5Jxp=Z%Yed*8=GZP-S>{~D2jt44#vnCxjy!@|s~cL;+MMeV1|2w0>okf9 z{`+x&S`AKJFd_T|T1jY{gaNG%UON{<~1t zyNI}6H_xTwj|?O0laWNzfLQ(NfSV7Dm-g!5qD`B6E3EXPpA(O;a;3pOM=X9s?SD?x!ee!X{7jf8L!5K5lE(> zzP9T*HPjZL>Of`f4`q?9ef#invoE1u{XNDyt5pk5-u5e9(S%GIVZ3v#k!|ahe8HHr zAi+hqzNDg=TeU$qzis4HSq}l|I*|DhKnwi=3;-@X@<%K$)4m{xk?c7H-<9^d5eYEM zJ<0ubwOh**^jsuWa;Fd9{d70I7m?U~S+Gfb33jhk$EFy01F@YNdwSDBL3U{^YapFk zGfTjCwTsugps)WoA7xM+%^qxc_6+Oz)Q#Wt39nei@rllrx}U8d@b(Z5*@Q|M?$LV1 zrTk)%>+_-GA{QKMC1hA}O?D~T?A9BS4fcSskW=u~g{t@SLBofxoWCiq|BeJbet=sA zd~oV87Ya$OE%CHeP_@rDS1qYt*)pa(Mn(SOJL=J6GE1m8%q}X2o*scM5$!R7pG!Ae z6**mZs7tr7$+e8n{8%nE+VGTN=bfAD%9 z8J!yo7_28^vb}B5R3Z$ogl=9-^7MTzA+Fu-)m=2%`C)=fHi71QpHHF0mAv-@wA725 zF_9m?RJZo-%e3AM3u8`Ku<@bDnvhE9PVwRTaIfd>#DK=-0PGZquVWtp=;!5pk!Pj= z#5xdn4A~>h8Z>dwoK&+f-FIZa?(@N$ zZy0R0ck_gA?lX>Yt1v4>Jo9Rp@Y4r%2^{1eX1~P~dT;BuXkUUoXKwTN)gyK?k~8b2 z=&q%w6?K<;I1p!0r!vM&-yx)AK>S0AEeY2F^>*$@#960ueFX}Z%Z zdP9t6U9<_$wmpj{KmivJqclLoAU^m~&Zi{ESzTF?&?WypGL{(2PI#4Ov#x>>#@WU5 z#qDQO1A_Z3ZwS#u1pBpt%FoFc>ysJ~e%to3eoFC~`v^c)v0{ObdtJ|G<+$S>x9bf@ zt2dp|u-rw$p{2ymdKwu=*Q+hR=qWP64^k|Wux(UZ=G%P&v+%@RD;gzv4 zqvo%@PIXnpq3jQAbC#|pjM$6QBu9b$n|eK3XaGai*LE#h-Qy4$swuq90@50hq3Tu~ zC4&`tY;<@(@>KsS%TybUv-ujm&dUCo1hG%I3gbQ3-^ZO|+w3je`zQ;a9DJBzm-{N0 z79vHmv8*4Jqw_X;{_!Q{R8zT`0R`Dhc&lX9{(euz0e~X4e``RZXO}LknN&#kaJk(!&*BTPFQoI_K*$u@Xio&ryJNa z^F20pYA}JJ&KX;c3?XN1R)u7BpNZc_yBj@;)PCplD@0`G%i6HuI4QM0ZxN5o4D5?u zultdHMfiB#kF~EaGVkQ~vfot%Iz}kW+PW$!ZIExzQTb`#{>=K5j~0aDgNE8%F7NU^ zSetOB@SRtMhI2I@%uljvYKn@HKLAAbR{Eb9m-j!$@#IWuY)U&LoU=a3{($CC=Rj1d zZpTx~AYqO<%$%p;dz~!!spz9r(r$}<$rR$Exa3*-W$3!R^F&gfRBcxEV)Ieq< zQcxhHHohPJI+@}&L5ebE^PnAA8Zw!78$Eu6{m>p*<@PgZQD_C=p8Cu*PgWURyEQTG z-FcW3T;ud$75w9mNdv&={?ka{ZRBN8h5E&N0EW|^Cir9G z@2mktIM^CQ5k7sO6+R!ue?e@ZL36B$tnj_3K3De0GCc0K}s43_3 zlMB18=?KIM*bk=@mD?EQ?Y=BWMqBbTYA-ul)mb@?@mdHK9bt0^58*LQqwrPSrQUKR zmD|CJZNfKa&xY3vN*L|&ida`+={(!({o<5uOy9p<_oe`96P(V$lh`7$ zYi`J0_@kP|8R#^dyC9?9goup!RruI9%Hn<7uiZkWgubTEvu~NmmVIT*F$a5sXyK)J zfJ8;y8SbTL|31HX?sSI!!5D|X@&0C&)wiOF(S8yFC0hPXuQ8ACZ=utqw)>pl$p&YR}m#Ff10t=Re$ zeqWYzu63n%@3W_1X}5Qlp?lzrf_C_~LpdX4JwePG&TC3s^l0$%-b#}XIq1eB#w$s* zuy%Ys)3NbTDPFnB&%c9uk}BQdJ1wE~NpZsB^#c4DyqrXwtVlMRSlXXVbBHJ%aAdK- z_rZ^)n}PHM6ztNxA*Km_&+9fuuyxd}teK(SBQ5Gi=?l&#NHF#Cn*oZXT!ry>${jL0 zYv5_IHY&HAHXUf0SPLni?xohU_}TVZHwk5yl?97QiN;|!TgtWXF5hp_523&pnr`~N zi=Dg346jF6kHKG7`tb!OSw;gn9}zd3qLtU;I)u!e7%%R&J5+@boE3nlH%>nxF$91_ zC*ouHjuIPLw~#?RHzBP6op0C~MWn~c0VDRJ1$0MRLNL_JCXXvI>-liztWbRjgC_MY zb{&&wPm<+&b+pM+_C|}rLZEua^vvcH-EJ}!wm0v1;r!c^U2#ej$OD?yb)ydp*% z{q*g}rbu^shtWt65Q7ixrR*@$OpwYl$=J`r-YRQT)6n4NSFui!$3c28qU!o)f0o76 z)O1UT0wc@&X10p>+x-@~#(tt*Lr$DWbY|!ToqP%DW2PY@)E@*yvVvH}<&6MV3{g;p zWd(iVa-fWSW`t5SVhX)V5R+G?~TNw01JCD zHemw#eJJx$NP~|T`TGmetVWrr!b=f6b0tIQGkpmL+A4W7Ef%`+#7TnH4d$KI1IitS zfw~K+?P(%E{B_YB-&^6*@d?v`*%mgAOi!Bh1RMKwl=YtxByd(q;Mr4#h^q0kNAeCA z0&c%NUG>x>GjlhZ`@V9K{x<$7^RaI?o(a079`|C4?$({>LM%l?i4_zZVs7j5Ls6Xp^x3N0k0o*|&!7M&x zP&cAxa&)x8CeTg?LMkB)n*Hykc_t>7Orp7>VtI756+WI!MI2wl`K=5Fvyfm-vYU&r z2UNTmh(XJ7SVJeGpI+o-9ri8ukJV)ytF^tV1a;KXZ#RyT$@*;uNi2I4H;_jT? z^73KAPgw6YNniIM)U8cRT?F3$c4L3;&pPx$X!+(kQFvDjYqS{A_fmz)cd)Ms(@0UI z`MD`}6yvCXg<(kB;-^o8(JqVu887|(cPPl4Lm12If^;?BTti*+2zbB>@^)*vXsrU% zH8MAJMmUpG;no?u1w`oLfmHgp?KNSW5P~*|wjZn!*V7sxZHRE7Rr0wjPy@K&@rVRJ zYPL-0;~1@t5F>sMiLp}xmI=LYB)c>=7~}2n$K8SZG@L13nbFry#9EOd3wQARR?&b$j6!tA z+j@S^zjw_}k4+#oLn(`t&91p)jlQ{4<(L;cy7;IBt<}CjQ?I@ZC$XH44Dj;#7)@^S z0vWVxDqkzaeIdC!d@Bm2&4mSVWLDcUUVG)arE}(Eqf>oKK912<6%Z*%Ulvj}R9x!u@xT&IYP~#@i47`vd0tXOfw>=s%oncnG)h z*9ba;e}f12cM6Ijcyrhh0O0JvDNnsMX8bQ;OBUMl_aSm^)PEm}T1}V(S0O9F(Zihq z8rNqc*PyM?LqRL#$pHxO&u3R6E&4akTjl1TI&{J`RAGi23pW0UUHgrAhnTm_4*N6B zGhx>4;m4O<#gYFt_S7K?0lHq$>`E(YaxH9&+S>ct>f!er%z*T zOC@M{(eS7HowRv*whD_B2lviN(j#NV^4JfHI=`tGtKOzxsE4ftmKE<7^T$ywxx{mX zG2tM|niP8OX=@Dm!$AgYe92a}Mw%MoJ#^j_D(70uXP0sI#8pt| z?C+DSMm8P8QUjT3D_h3`jOFzS0`SJZ;1x-&DLNY0i1mHIIPji>OoDXNK_g|-e@8V?nU{9fY$!t~+EbVB@^dB#F+BuRkqP3`vd01P87@+qd<-q{ z`N6xrM-E8XLA+*@>x2<+GlcixVCdw@crXa@&$7JL8R^kG~ary$58cOz5xR z;e_bgTG{}Abyi&;+{9sf-p**blzu)Jn2K4tVI{4X(GkFps%)ZgOd-jFcbTsYVrPuog--$b=(KfSS%b!H8$y7vK zd*Il((?sJtPNkQcxrL}rA1K`XZ2@2!o!hf@asoqBe6I(3g~yQ>L^A+Ca=gIYsThn(E-Q7uB?6Xr>=7|^X zD!J5d~#I^fgr93 zW}12ojoL#Of5u?d3q@_A8+gS{?%l0SCHS2QC3KOO#rF;N7_KH5K=F<9VmGs8wE3-< za&Eq6a@K+7JN8_BOu9(U!_$cq0`(`=+sI*WW?13Ek|IicZ~^Y2{3^k#?jIRlxjpv> zSyK7-%A}*Sy&jQk7zn@!4+}#>a^D@1_2?#wtZ|yg%}hWAMCRKCr@453qT}Pph)X&W zd~d!XAnf(=g>+V&itRB2YfL5JECbC1{9hWpy!nM+R1Jh@&Znb3vGaORru{q>Jz~;F z&X{c?*$<>85-vGguGY?@H1rvazK;ps%(G37Rv?6iJ5?t@86%^p zT_K02cv*jAU$Spg@uT}kdAHgZ{K0@_eLO2E=$R6z4>bY0f9&oA_f5N3`k{EpNWm!f zofo>x5w~3L+I;r_tr(I0Fp4?5m=WZZ?1zhjPt^B~5?s|klo7ar-&{DBnFsZw#(9_7 zic|FjFG(^YcN__uB(vk=Rs=vAW7&80r~622ZhSR%XEi-TIA5iUh6}$LE~hT5)_b)& z$I7fve{T(zj)>dtLpiFGz3JOsX8Ovn+>-KM!Q&dBb^kO~n73s8z9_5qYGKu5a-QYu zM{g!q(UeC;^Jjk$2vrzVSKfVL!y|F=(p>!)hf^nBN^29RJy4TBZFy?;@+_KP$#BW_ zNmDlg#u$zXZPj6+K@N-b(tVK^Ii|AZ2>dR!Z|AMdRb%yTjaRs6SkG%}^Mf$&r?%0S zh55LJZuXx2tf<8@$cGm3kX-Huia$08hfHHmMD^*(gtB^sGy=#jn8sp+%3YRTtE3Z6 z9++H)LyE0r>|9$U4aRsfQhK`D>E6dK226$o$dv=bCfDitvRn#6H_z7Q&PjuT+K1&9 z{r6OS_jk_JU&l3}UgXz34bLcG`V#U^CMad&JZYmKhgzdYPvephy~`NhARM&RjLc%Up;{PCMljTPvT6UzmcLki=8pkRmY_lxQV_uJ3 ze(?lgSr%cDvo}Ph59-s|@-$!EQD3I;ua`l`v^sjJV_hvp=r4sd&)x)%hZrnu4gVJF zcA#-l7L00Gi3|Es%A2~TT0r2=upN3X0KNm1G?{IrHLwgwyek6W%zG-4O+lfshJ-v6iVJ=y zgy#lIyyDUt5#ZSg7%=Xp0{bp5`CTPElY;8%0+tC=$t=Q%R)db3!HEU)oELzcO93C7 zRBwrEDCCz_AkOlnVXLv>tMQ6_LCU89&*YnzyH8um#z{f~%=+ghfK)bLQeA@0M;XAV zyLSa*lJmzhwEwkKi?!l;?$|$r9EQwk@;t2aT`GkBX zX%GXtgmGUv0*j?N+`-5HH^6!GoRV+j7s{)$+Ww@Bb1j}O_wIdph@Q;ad8n?&?+|;0 zFY?L3>4f3r$14g92fs`|n@LIYNzn<92qj;O)hZ9m4!sv;}h zve5^skjcgOjo1NIn3i~$$0f$%4+68|i^YiOX-M;h*yn6Sj%Q?o4BaAO8dkpoGun&- ze1yzLoz^v6Z8=<@@O@&AqIa<=mf=vIyES@gb zW5lian}ria=H(8&K~&7PQfG>j}9JY&<8@#ORxO zR`Z&vm(DvKf0}M5 zw~_5Z6zo*s35%R~qu;!iRTZCbLy$3LFKPz!F$O#FqtYI01?Krqk&HxN0BplGsFIPJ zdTUkEm@HdOiQ-sM*F#yQ_vP=L#Kq4L5j}?%nqhhJm(*h3Q8@d%4$_v(dCOd>`ERD2 zKHC=0(Eqtm-I*6dXa9<^{Z#o`1aG6%V@SyYp}PU|sE=6=0`wkIODw4w)1b;!D$BLR zG25$=tPtHt%8S*5m1^EnjsO9W26j3MSQ{3uYdiBS3jJ;-`Ywk+bgqdMI5uSSOmD0w zicu~#{F*I6pt9k~0J==-i+$n0TxJC>ES=N0KFt zz>6O|6aIt1n*pb|3~IHZxVx3cjE+x!JW^CN;r^!3=N&XPVy-SBp!%%XbvZDA#t-e_ z^lsTC=PFLRsZOl4w0OBjMzIEm>R$QG5bM4T4iwq*IY@wlo30r7mT}^7Dq%hkA$EZE z{N5`5Xk|`N9S#}W6DiQlu$jvAm+2&}RUi@a7IH_7s(;0U>PMuRiGmP`H`^O(71nN> z4Ja>SlCXnI$X4iGV<24&qoN-;O7${z?dUG^a1XX+!<7+I;eEA)Ksp>lcZ}*u=oj-ye7jf3xW}!xE07D0>tzv~sfL{wFDUjK1UzM;71zwHEo;>#>I(ENPgQF@ zYiLC22@yV@#F-|Lq0;n4Sl&ZMW(P{nH1LzBPcc;-&L>?(>R)}U+60Js264~8-$lAf6sRI}wI^;8~x`K85j_X5_|8SHSkz73{#=0;TkfyU-w8ew$=C?1{y)sUcUV*1`Yj5A zq97tjuPVKZ^cn@}0wPF$ zea?B#Irl#IKKDO`mBn0huFN^!cZ_$u2p?tH;-f}MZ3mtG{(TkxefCe$;l;E>BBeUj$ecH;&?7y)yQ^m|xNo zA=&NOA1HI_x!v9J9A4fyE(la0tqGYJOGZ9pqbSZJw6M^>VHfzJ<4vADBN+F6Y4;;W zHY(z|yHwnnj**wMqF-PM*q~vm^g(?SF*EIhS_YT?Hrfc)KACUza@asiZLorm2udLX zHP;R+-W3Jxt)aGCje@?6gJscuO&Rs2^Z z$$(yz+IZ>A2aXe6mKPB?j*7&Y&j3cFW$s`DSTC~;q<0%N?$(%$Y~HNqv`|pv&&Uj8 zw(eG^x!6!dgJC4KE>DBb0{i8|3X}FU!8_%Y3#;?M8Td&=PTV@%P$x+KCOVX9R@sJe zBPB+wK3BgE^9Z37ylTp3W+zs{B^7GS8*I`4eSY9LdYttnOR#mKb3)hVXp?VCJnK_B zQZ%-rV%*nnt;%pg=kXqS=ae&Bs0E9iJP&ZjDTlZUz_G zoKq*I9u#%h`P}-f#KQOKB=2-ynQ@= z=8Z@6mtu`hDB?Q(0sibN-*ZUf#s*zC%Og!)9FMbyX<8A|D3Gdr1^eLaz+pf!^65RZ z{&FJL2kI%;lykh5lCLB@fucX{D+VpA+pO%lkA+FI3W#Z(0_r<|rx7Sd59 z|FQvaV@w00?aWZ#37LQ7oBl+`|MT`pp}a%YOQAS?pwpdPNP(^8(fjt~{1W@ECb7_6 zI-~Wup8KEFzCHb(-$o?y&|j81ba<@VUeI)$-z}BmkoSt@*OY<4eYY5HbEh~9H?~Kt z88_0WtFiamXJx0q%BVu_B<33ACCUzVd|i!j39l5x>$^|(@(JU}jb^GXda^DkJ}OMF zcBI%bw`M}{)$A0*6>o{`VZ2cNU|D0RpbSq$PyfQFzO|$=B&ZW~6D5C4Ds0#kUCogH zv)s_gSj*QF$vrzpVrQ`NBc1d5xjj+%roXfi?+*kl93!-5Ckbm3%4gp->VIALT*Srr zC!RSaqL*4F8!wkI}!E%2w?O0~Ep^HjH?G_hDrD3e**J zf%ejIO4)y_xtOS{`Mg8ul6yuZAy=Jh;G&pqADo{d$~j^gjPV5WD$^o3`4Sy%Y7v8x zH@bL4WqGf<3=M`H4l2X`g$t+bq3A{!8Yt+fsW>M6qaLTS!FPeU#Vi(%|FB z)u7UC&<;`3F*CIm)rKkr=BHj?*G53(Sf`%=_TB%$%cby+5e?$!91&T4CJuyFuo7o8 z5&#k>19{88{9bye^QQHfX~HdMw^AdDC${dK%Nzk2=x@Lld!*-U{m*iXVVGE6=dgQl z8;e5vzs5H`0HMI}uQFM!f`Z&M1iICejddJ1d;RS-D&r-$$~IfF|8?mU8_V7rV88>B z+8its>NHD9jT^s;#4!%$eXh2(5$~UBv2FKx0jQ1n#c?gMm!Xl_e>p!g8#!WqIZvpV zJl?{;4zmj2ZbM45%dvfs1waNJsMj0;pg6XxVLF0cy^>C0#-33iw&{t_iajDr$G!Rz zA8f*IV}fJ}kh?yOxo2P4Rhn%*?MNOq@BwrG{nmJXKHR!WAEDSNuG{1v1^d2sNq_P1 zMyvNaP{H4ZOb8T-qBY=%E-tFB6H@gUj? zw}*>Ai?xF!rFHQh+igKvfVR)~0Zm^6WJlMn^o{b&GCR>2+3};kaS#y$vy=p%0Yh3_mQWW<6Z~owl2rNK?NWXzW5jQC>6q6BdWHs_ z8cz>H$=|40A+9)?9YXTznQ>+yK;;2$NZ(;E?}gH#0XdgZ%e8u?*tH+JbgU){s9SBI zbl9k8e}bsgqr))_StsUEr?Xz?WX-4}vi@j^#r_$~EfJL&v9E_&Tx4%HzJb&(v)a;G zh&_9wdn$ zn2?!ooR}~^4RRDYcfu*76w`Wcn6Fl_*b82w;q=h}3Tw~L48@?Z-+AZ#M2#7oR!nMC z#xkOCs3X*7lTgJ%Eywhzx*qBCuhurj8w32>DtGERHji)Me z4UjY5M=KNqoXSg_)-5uct}??e?%HqE(mxYvTTeU4)YkWQ9j{sI$vs3F=WW-{#gi;( z#4FGdV;I*d$YkzJvUDvGZp#s)+{`Vro_R3IKRW599OOFQdei)!OWNbLQ!6aL^TTNi zjPW^BR)6xD9MSde59ob_=&P**Mtujf1N0I?&NMT04iLVww>Yx&Zo%xDkA3v~?XYzLw{Y2KWRM!~L>KsPg{s+cWxBtt=KP@BQl z?Cu>9-!!l3yF?gA+={QUdeGfz`VPf=Dyo+cz8<|MdR}S2HNs-sBq%8C{Lv7nxfBv9 ztI@_|dx!+gnBo9K-!%BuBNkLW#nsM#pGd@{j8*&UC^Ty*D>cUV?}8fH1hj)|lwWv9 z1~>af`3*0MIMRBOQg7Qgk8OB+eO;D=c$VVreEY>Z!3-Fh*Aa>5^&+%ZCt<<}Fm0VJ zyxK6^=tD)=VQuB9Ro$o}6 zlbm)CL4y!HH@E3V3_hv z`dnvtFp&#?RX#bbd z!la8Pk8*GP%9?j-YGWN@`&?RczL;k5WBPhvccI4C+2GSSp?qS);CZVx*}a_x zE`k&23Zxb2nnkek>KtCOqu+94vKn@8KDh|1{L)v`xIpVga#pCQ!98v%4!L!<6b#rDi6rrPOrPAStV1B`gyz1sSG|mO3wq4=sp1B{Laj^^+6b5xTcAjH z`XI3c^{RzxcaaA%X6oBr}KbbV*s^*%YXzYWpJtn~mC;FyHyK_5x$8-Cz zWx@6t{$lFE5}y#kbvLEs+>;zHW{GwCW1)8v&mClgwMw?r$sVCnmrL}!p*MXLmPrK+ z-cE?x^Vrd5CnmlzvqT z@%*itAjij|Z1-R2vuUoruV_94%#v}mVdm_9M z4k@}?oS**H!cvy3HsZPxiwO)7+}j5>_wz4ToIXcpL33@=}G~EW7gi50(R>(%qU_fby{jNV@Ilt$jCX zFZ;akQ_o!aNO))ruO_f)>dGjxD|t`va~J5cQWR_(sl~p3-dB)#=+Tv_Ik!Ni*y3}} zw-4=t1(G(#SZCYWy*Y)>nLmaiK>-^PR!6F zf*853On*6BWL`Un*0Xjkp@S^6-HyZ#bklgTG{bRAZ7k7Yh3k6#@e?;aS(9$3A>ppK zo^vbu-EcCD#srSrhZK1uZD|hCnG_#UdU^0}*dgPkx_dBRcq?PncL%l@mq2uTDBh*= z3QZg2YdnZb6cP!gbCVU@>O@C;-uY&@jqHurJ8zrp6Xc?hPW8XNMIci=gVR71%)-@ep6*iO2jIJdp*J>U5eWB#2-?qzHiRyUer+k-o$1fthq zVGUCuuCwvg4d`*@aK~n9Dc^i`-+O4@J@y4`}#6~atr+552 z_QXiTA2F(INjn&q0yEEK!>|?o;5<)`ra%YL zjP1PJL1|0Q8Y!#fru&K!pr`R%sSc@5rdzeHduFwtH(-0diek;qIHt}9b44hJjF3rqM5~&ho^7{Z!JCj&H#=A zb_X(|#By~Qg;Hw*;+UtjkfVkXTpwV_wKc?Ybq)vm?`i@pfZPXw4X6fGW_ZH@uqx_+ zo{|USuCo69`6^I_EBx{#7>9+Sw6U={fC422ba4O_!vUaog$J>PC42)!8+%)1A+uAl zF38|knh&NuY>!kH&4YTuDQ>Vlpjpqq?w~&irYtVbL=7O$X%E((#VE~ZJRv=a%xcE- z1)}L&`oW-WRx<2`Ne_l?5eVkK#QwbUk@od~j7uH-?Kb?_L|`u3)ob^^t&IjINOh&WFd+~-HYv$;arU#$$IZ*~9^I~Dx*Xa3{S0GSiO=KOK#yUJDQ zom6HNsS?F=ir;uNR~!wX{Zv7m&kE2B!wv^%qHK%|Mvrkdp8D1M9^HxmtkaAHSh)5tl(eB_#%zx z_4R?Vz8>dzICR$!eEAN($UGUcGAOKJm#AR$%uf>u%Reurhj)kU0Smd1hjr)LS>!|@ zknUw$FM@4Cmxw&dOB?_F7-g;tg3ke+6@FFlHD1W&yO!YDhNC}5L^bcld5b1QSEoYYQJRHl(V;_Z8Q~aYBIhbc&=2_lJebq?jF9I1HXg(tF>P2Znr<7Uh z`Sv9L!@lqjtY-CclW~qY20828>?yV$up5u0vAo1fomTJYa2p>P6C%!hnYl6^o>Rcw z13xSy=$na5k%GZV{p3s-$$8-*=Ad#pj-3W)Lz_naw&$Q>g_0#7)y_K2*DU4NG zX+eP2{IoPq-LP9unXsJsfG6fef?v=OvC-A}ZGK6f#?f@4;|?^)CJFwv353ny3Pm&d zbJ4yDzfLu7W5g70RMKikiz#E|w(chTDKx~(X*BM-QtyENS@*Zu&iTE=I7xPrwht$=j6ZUdxVu;9DY)+uk}W2?bo z!-di4U`P%QbTP)uOSe%$w5JP3Z5*|)LfQ+gBR8HCG(66Omv|YbQNKb=ax1$>W%O3w zeB(6XbWbC{vC$Dt3nD!$p&wY03$u~kSLf;cQt+OdV z@K-}TS%%r#_qTkz?%^a36xJE{{Mw<|5;Wl&d2G@Z$(1y45JL^%M}ffpbt$}$fOI^& zWH!14TVMTPfUM1p=lCdzbL4#qq*C*9g()~!l> z*1oz^qwXMv(v3|#gU|_V9@*Y7!ShmcS`Q@D@3cn}p7k}RAspryD`x>FMj3ZONUFBH z9!8ngmg|02a=v-IkfNmI6vP%D-R&~l012!`>m{y$NuIhnM{An5c#(g-uCCc?g}ByS zNXnpG9cX`0{5pt@R#zoXm+T2%YoR_`DegvA7Ms|I%pau&pXfiuJFD)@pO~O-XHpjO z6uI5e-zuC&NV`O`6`6%~^kYi%9>X=*^Ky&OR*@LbNs&DO_7&@Sa0}EeU!I_E=9v}!|w{J1{W-`BeBCD0-Q;V1- ztVeDJ-}a{xS5aZTf3QV+(d8pf@l>Wi>1f~*TJk|unK80Bb`E+BQuJXc^3+nUv6t)L zCGoa?r6;1-`r_8I-rp9$YNc6N!wTw_%^A2ye|4oVj7vndQYYZgb5-)EiXnrM(kk7wn z_@tH5mFg03B(*$zl=krIxagw^HnyX9Loy7Vw!StB<%7SapxWAEUC+*`Xlo{dVt_B+~b}eG>hX-8c9RP|a^x`7fy={`) z>a^BF^8WbPJ`8%M!UhJye)iI$?1 zlg{{;yJ?bBC8wPy&k`;a#g4_iN`}X&(BKjjM75gxv9WT)OOI<|Q5~U|K>5)fw*$~> zP1dUVoCLhW$~!2cmeq-q_t=dk&KKa=H=Ij&boUPiOBw{W)=~(dfW;=2k_q?ml zPXv72Cr~JYVR59CFkc^5$hvX_BhJBw|c-mR2x7A(W&(oN)k!%GM?&tAA>U8eF|nSy1GNM zA>M((GR#a4dp}u09mNK`@;!;C6Rt%d-rinRW@s>cT`Q3q8F6* z2S=_;g8G{k&I23ETBcnK_Hzu&-}iTEew4iLms{r;m?W3su_jA@_C6=6PXN!0@croR zI|^b#OU)nk4;XCh?DMT19Vt4^xI64=ho($i(MjkOesOG^VfA^!wjU&_7f(iv(NSz4OM-%RkeCJ)ucEWN(}xPs>qCnu`*E@(kC;p5Y$l*p&i@I+XF z+}&Ts;sGK%?>-skOx+sblEkOT(PMgy{kE#-@gSqLVo@_G1lr1rON1^scd8 zgAk*iJ1DILwm`6d3r`>u5 zL^j4t6uVkbh)P5coopGAmJtGhGKw=TD1sVB{N!3$%`wtzRP|cO$y@Xa8{~`bxmR!T z^{#&!<-ORbo+L${!ns)Yl|adD=Q|Jq5Pf7LKSSFr?5mGtzAgHh(#mPQ&S1z5lb(mr zL}i^S*nM}&10ALByV8oLzwr_<*N_sSuI7W|InGn*eK&H!+p0+yd>jCWal9^87ZszmUz57g`NI%OyJqX2YhAw(- z8C$d3aJW=!KjkwX!5^|a)jI!(AVzt&!rxE5R&JbpXVWhgB$SB>&67p3VV!1M*8DlY zqQ}Dp`|Q2F082@QV+if`GA@H|qPDtWjN{N1^pYSZvO#uOUVAfEtC+89mvTR z**~E5uc%E0pjYkxBM_Z0(yXI3aW;28en_#bY;Vi(X6W(Bnkei~C|9`;cz(L=t=JPk zqo19YrWEk=La+LZ|5uR$PYyQ4C?aAADvoaavB&|K*-sQ;NsR$Pn&Yxusem7ib<*{d zn`cmBV}&f+GM2dI=CX*lvCb)XE3Q*?t8{d4uko}_iGLD3;Mg-kMIqf+Hy$7sQT+Bt z#&WtfA4MBQE_fDJ?+r-Zpu5e=mt90)7Q6}q4pY)5u%9TH3T3Q5B`E(Dv)3r`?pwDM zt8f9I|4G%l6%pOFh7ec#hifVjvSD+=B$ICpdR5ln=c?;XgILKKZd>}jtm~ElY81Ph zY>Eu5m&oNQ@zQ806swb%y%Bjqd4V+VF($r(oiRgs91hs-XaXERk19BXR!#H;f59m)oA)Rquq%S?{wPR?9VOViP?vZ17)t%M3Xb4lYfa=1sf z`%4UydnR-(3L%v;5GN~L{WjcLm#W2kkzv9{sxo$ujb_>}l^Y&4s0?!nJuCf^of9gU zjsmqc(0r`YOg$EY$Zynh#XH?Eh6D13y-!^8^nKyT@kasq?^)D3~|t}lo}Jz3k&FT~Bb_pIqMJYfVAuvsX5$daiGHME~LGp3I9JOi;`q0g%bESD+JIy5dFiq z`H$-x&|C!H{(g(7KQHHn0VyG%gI{`w9nzy!Pu;`i;k?R>$M;IukldbmD4G>_ch3=h za|uX}JO%I>QY#dGY~;rTP_{}y?xf3C^a*1y=E2m`DLEYI!Wr3nM zt(f{!;`jdjN6LR*PBjg*tM)BM$<4Titzol#oCx8^2uOo(Av{>-!TZ@8DND2B9gmbe z>m-0EtI_^{uBHV4dDs8-AON@IFZQvTi9s1YPPN5bES}!_8L#-2W)An)Ys3AA!>#fN z=|c;VHuNJgI$^$}oz^YnHNJ9UQh|t#22H2w9YcDLFZNMkZmzhImdm%bTf>suzD47% zd-|1?>Rvbb>id`^B}@M@Wcfp@nWJjNEwsv(swz4UTDb#+jOh`8MIj*zqSRD|(-DZ9 z$r7*`-r;FJC(!!6(6f(!RrQqPAd^#`QI)%q;WWUJN+Jomi_%;a512;`#-8s-d zlDP?VgHPk#n~D%Sg!>St6v4f9Zwy1Of8?i&XMF?x05QREU4b}GAmh;cAY5E zwWZ0E_7YCP9_2Wti6KPd%UV}=j}>DMW$V7zGSm$-(qE7qh^aUBIsmqC<=?!QH_I_V zTz&rvAOr&d8D2I5GdSJkN7(fs?n)_~?MJD%D4I1S9U zRx~i%US^dkZfKn<$jv#qR1zk6l$D3-kR0;7boxcL-5k(|oY){zt9+0@u)MM=av|C2 z`LS%DXfW@2$^xHTL)b6pGx)ZbT=j~v>jaY`c4y_|kjwav4)oSs1flBr)gm}A4Okpg zXb5~`4o8ILnfT~pVm&g)>Tf{Qcd><@=B)HNQ8{g{I2V~ax@1;d0pYXbJ`22^aOeoJ ze353;){%y7m@L3B$DdR;CDbNQ6I^@y5l6M;d4u)N=eL~f0vvR}cTDFjKVJ4FH<6)O z>=_ZuSn9`Hit71go;}=ku2VfBK>DEwafD$5BfeuZPp*`4AHA0Lz5Eidn%qir$!>Nz zJyu{kTilYMOIf?rf92(PuAWf4(D-H#^{z%FRmzL4w%N4BC^pS&Iq3v}Qt8%w6j6GJ z{aE>{ds8@tRVubQE)LCJaVmjTZws+>v}dfn2yF3-jgg8>dBp%sgVN|K?qpNZ@vNr~ zt)GQT#+kyGA!B7Qzf;gzA1KasI`pND7o&t!8sIrWuO-bKS9;y4c>^9Fp1n_(Z%x_C^>#Zgt$jWYsf<@V zkr`ZP{?QG(XY|HKY}`(PvEv>U4dvMgis(FGWc+!;v&Kl9ypu26-FizAfj8SO=Iw^% z?(W>%W0olk<*-m8jOp#2Gkn2q5^(mKci)i(1y(?E1hd+SEm_;Kdwe-^#@fiGAy%cz#CN^`)EJd2^ ztTT;ltdUezemAo45djwgf~p`F((Asp$B7FByD13f5m%;o#OhzN#JoL8ntXIxFl;cc zxg?e>qdC6f5#+jHCunJfca-jGOvl^k8GYBD{-$Cu08q~+V9CF<&tyVVtmG9*HwGa zR;ig!QB_i+ZuY2BgI_BZWOYC+tG+Un=YAo6cYh+$z3agT>g4Nbhr;&gmtX1PuhkCZL% zX4gnLmb3o4*)*@IeCu^aP+BRk@>XeHv;S)Uw?Q8B0$y(QSH4q8W<720My@WNcDj?s zM@GYTcY&VH?JpsxroQzEXd87UhI>rZs--bhl|r8@#{OyIQhE=gpWDh3K2R2of@&f; z5=xD+xsyqK{f-uw2}5oDRDf|+N3M+?Jxjp-G1t1LWMk{k-N}hH5i{KeF&{5P*FTu) zjiJY@Q;bgF;uwPMrkAUP>&It)mM;Z!N_wk~8uz{C(eU##kt{vZU{%>?RC8|4c>6pY(a~86JI!0FLK(u-q0FPcxxqFsYerCsOuoPeo5sEdO!=zgz0t=Ui;Y z`{H5@Z*EC8c`NCm$}^HFu~W2u{x&tY)e1CtD$vt%eBpcKaZmFnqJ>Di&fdqcH&H2q zeg?4ao=IfzU92Nl&MRIQjYh{z_+?gaiwsJweYm5hVOK+DTHAz8GG2adz|HvybN#iZ zt<@VzHDEFsDv2aj(oTUc!-}FfrS`}LyWOY|x!bWlQY`0e;_^RB+6S6)kzZ~L`74)r z(sW}VRFU2CR95tQHuyHC-GnrwhRs}Ji;D5-r|)JDYXuwDn0zjy^KQH=1GZ)}yME!= zng#(?=MQwN%A-@Ifd^kHM}%isi=%OV36_G}{|b&UUfkl`t9ruJrm5u9LMaPEP(EglYZJ zlf@8h;b#UouLK(dT*GgmEaDm&3MhHJ=~jy+1Yil`J^)kQ0H*W7fBfw4pU?*@;yFuk zL^A^|6?vIDFPJT>3*>x~dpm~5BY8{zCMh_B{mYMqvYK}ifK4S&Cn{z;kH1NM3pHlR z@$zUj<#z>ugc(g!vfZxK!gr~ftVfLV;9V{O2TSvb!|0*Mnv~ZW@y%-aoCX;>cn+1X z8!&%iz6nchy1u%EdSymchO%3W5Px;074*H33@|5cL z4&0(m{6`Pm@wWXK+L%8fNmf|`_H~u>*eB!59f`x)NDJlLSMJ`y^Ey~hi8|`vYIbiC zL|FIjnd zqXNkr)*Z zN}>iS-&fk#E}cl+(Y!Gt7Xc$dBrl+bSF5KxjhP;wJRjLVh3x;xX{n9xx0ko7yO%(H zQuw^wIPjr(G>uc-`R8<@3pKng$O=DTrAe3$q};kb_+3Xb2VS2JeKtM%!1tY%N$o4^ zH}|97@{U6)M4*?)0e=x>`cM>!#&~uE5}asv&weUT`@4?+gJ3+Tkrzakf6!aXf3Q5( zf219iXcqY9;Qft|0+MR~qQ{I_W~9KFOMueyo5pbcN$NcTVCEmR(MfRU zft_kSs$-d?PGBvDmd#P9!$NRKY8nu*{c#K4O<1U&p0or<@c)1n{NGBpqyw%be~^?vr+jo)#D@c*0xm29@^VrKRgvow{b&|CrrYppMxiHd!yn#T{h`;$-ZqOP>ZXsoI& z_qk%qeav4xVUui8%BaYo$e^74Lr}frGFbb-1W=5IMp@m$9SMSuC5MTC&qsA84X8pW zf+Cj)YfT8u0HXF^i*G^`;5sPTq+a4s{q+-5%j2-W;z!6I6sr;q%6b*x7wMG~cKujK zl5ozoqj!NyQYv6;{t~7DykOh#qrx4*(^T1N!!>o4sM{aNA@xwTKHSgo^Pnp}lNh?+dG;XWiEkjEn{OY5__2+=s~rELKPo2hB1OEe zK>2HxU?t1xc4ZUf?y0ENszY)fdwg=m6#HPmt0GfRwX0D}K$elA^15MXRRvXNqe&9qr* zYu)b>>_qRFCrO6_ZKzhNsLQ4EneMtAZ;jQ@GsQ}qzmypSgHh&;s4#`&9<-~e7Sfl; z8-92D=}6dieI)3wm+)12j!o`F-?>3WDbc*y@|Cyl^sGY)TNrgYq%aCa10)&Zs(5h6 zm!UGMz<%-OF$1dBtG;^|QS#xJXTr)9EW=x8ri<1Mh6JFx8eH}0>Rj=JJmBuX=gCXh^087=S&UJ zn)+RC<@dU@=w43e$-VLwIIxI5qy`pLv1D0%)FQoS5I0gP*fZrvi|h%qKb3BN#Xk*6 zY$mR}R_x94!fxb!)q=$AB1=VAJJ~Y$h|D(BS6uUozu=0`lm(c+B{lT!srk#~FO%co zg3FeB_oPH8jWowJcQ8@BV@=YC3$5C*z@PCKwI5|{7-=ZTH}M$~blQ4ut5R({c?NVG z+euky5X*Abx#&i@l@C^!3NXlkT^hU>Ue$F(%dtFkJCRQFRcj)=Jydv;)O6<#}2Kqs{5-m z`@f$41*-B(Hm3=2wETQ?)q1`5wKm@Finp)b`qCFYCoo2Fv}6-6_mc7RW3q&Mb+i`` z;bN<{5JZV4?-d-aeU)%c4}M2Kr4dMZpol>VGSUFS&H1Al!@9v_ya=T2j4KgDUylwB z&!Q^lhXE;GgS#`a&kSP$qgB!=U`Sw-WTz#zr=ry_)!hPD(=b;adjqdED9(8Di<3OR zZ{ZX+R=BeEGE_(3%xcBc$6|ZcVy7mxR&#-j4;8(#bjP#jN$u!Sb$+_aVf=@YZ#_Jb zbq~sW2ElATuRjkBgU3boZsEqSM<&`iR6xA@=j$boy8OaDqtwT#r}%5WMM10VfE{NU zGD}bU8krOGl)gJpD}3b%f!zC+I3&jYa~|>I0;(cwtZ;krRI`aL*O1vRkU7{k_aUpV zu0qbgTvRFVI`{g-KOQ~SAQ)kjuCOY5cUya|i@F`$tN0Z9$ZeL$wjXsJScJ$#2Uf#t zvudr6dWY8OdPb&=k+>lS2)U=%N(8Gac@QOjOof0;6yT7{OmfAnb+uF|cSQHlxKXsOY1CkE;^7nvN=SPDt;;Au`V%e)$!Q3Ncpe{yzQf-_z~O%CEH=9 zL%JrJg*cwuS%%PZocpOtZuH4E+YRy_r@RA7-}=8_D4)t1{!pNjrwz5qVRr+#K%L^r zCXpnM;40W%MGwX-?c4N>(8uC0U82KZ`wIAz9JdOz1Yd36!}2ck`L)lS z!8+{qHL4IS!4K-{RI*K<{tA+iIAKYuN1gy41uWN}GL`hJvw4(DypjYtE8?oHRyG35 zQC=(5X$y3V2Yu{M+30duId`;BSQ{InPg$N1ROSw%Zx{MIRbzS2l8}F$j}Rw66`WhU z8b)aeD+@`M6M$KxKHh<@2-26JtKqZ7l!}|&rt3sk^?MWg$pbmYJuAkhAA~PY7Uzv0 zufLz6bdvfel(MOi<|_efgo%*I>bE{NU=!-$A4)9;zrHlgTt*1D!M?4`mrtZ@E|y28 z`Sd|^oJxK2$8t3IIeEqB;#Hw(qLVf_res}2cUFn9TJmWD>*HrCw@1}(^0hRog-0o* zP4-VuvfyDrK*bOtYUz1HJ2u-`o^6)yn;b%ItJ%v-;QCv;E1F65Bh_jYP!vbGS<&`- zkwcc91T6_L2RDY}vm#YZ--f6+GG5n?(JG@t6J1HqTI0wO!deBkcXKw2gf)>zERBkn zR9bLCXsNUw0Hub5r33@+Lp7fW&ML6AydTX)*1vL(lp?LiXnZ94blNP`#Q)Z4i~=}& z0g$tm<*mQd#NIw|dy1o|>AcqQ1at$*&`A>+*7|16)kUA|42e4EKeL0RF!k?|<|WJykn_<3W0c!$AAu zPQURo-=vd0`>$KT68(69+9w;7$8S7SAgiQ6bvpU?HlXQbxkYh+`79Izh(M%vixOs+ z5HW{sJuqzLKh=I@2SA*Zsz|AAN_1E6@2+H`d4@^Tm`^akbYl^W#T&W)xBl}b4V`~0 zGo1T}dieIgZ8Qx0fFVT{K<06owW9FL8yf$V{Juds;^cCED7Z5@rDyLw+l9I{s;iBPpZ^x(lDGs9w6frFsh9FcA(t=T2*utPP2Ad{2Ior?0tJq6+sxFBG9bLxF`iT4zESE9%uy*!|cjo)~q5o+-5-~g;) zBYkl6i_^D6TVQTJZPciUZou`FaBqxFog00 zlFjdSV_G}*Ht{hLUU}_~f$;YMPmo7*nAm&L$q-KIbM=P-aL;Zf(3j~(pQniC;cy>* zmvp1&@yA1>Wgdk2zgFsdWL8T&S)SA*;DN5#8^_P$?G!EFMM-OJy<5A_e0J=1oTLBf zyqeWBzX&99=oj{7l_t0^jX6+VV{pNEzi#HR1pf03MCyAL1q|M3Gt}?o^;wr|3;opv z3vg%-<0pS3^1xM}x!@T6thRKaM}A8~7J~iECRE7VvG^Oal9nds@i&!aTsIqm+%_3bDT)l3Wk7J8 zb=?_UiILZ0apYt+&Hf1z1@8e>{&p)Z^=nM}N-d+a*Qsq6ZQ}Z4=;9J+iAJV>Zv0Z6 z=(nP`Js7}5H4jzN>q>;C?5T~uuss@7LBn*%leG;dEoRDM$7SsTqTQuAFH~9LK1O)O z^;ejv6hHD{h8tnQ41~qQTkms=Y`XTNq)|(bTRw~vw709AJ3A1Gk0E^1w-9$acOgVR z8o5Q6?+%X*6v$IsEXYfpmgrzcsE+p>D8>A48c^NIBZ;eEcgXYUWZYl?jz6*tvdaNE z)fGK9DCE5|qS@2HxGSrQg1LUzl{Ggx7=nn{fAu=I3? zKkoHs;RkNepYtj&v=8?Z7(9YfM5 zxX9Q`ryBBB50888=-v?T3D?T{xLw9oC-PfZxY^-S1aN4fjaPKd!2Uimm&1t#(!O%< zz&rXN>$rCa#YF9I4PTyJk@09_2GtRptTY|)n>X;N8{OBi1v#vz1t7rYezS4~hNM)N#CywgPIRR)@(1J|WiH_c>373ivMaHs;C(C|Rv^QtbD- zeR&TC$zPrF5<11%S$`s`{@$|=vk)_nWir$y(h3`N2s(W(y8ZoM*Zy}dTmK_~ z@vj!}f9sR;3M2YA$I@au@bL{ml>+mEoD%%!Fa5>-OP$60Z5*=xVd}rNkgl)90`$%3FQ+R((imNGS%=R+<=ZsGSQ zvX^vnt95tUd2`z9$$(sLuiP*We*dqd2e^GVbuRXgXvIcTx=e#dws7bC+;Z)uzo)Q zBQPouAm#fOSUUpQ)*U7U{^CP0`VxT4@&#_`ck-;seBo#+CEjJRf!}x`$Ck!nmJQQs zv+_?|Cfu$POG}3)`Ie16`(D?tpKT>UX3|euMmJ!|*nO9ACcijQDP1pPi?>3WPT_ie ztOT7ZxqT*4i6(fy!c&#Y<9T&yV=i|+#W+?~dI-i#Tzx5wJ>ZS-^!;b6E@+GDxXq$J(0TStpyZPghKTr za*?LrRWH(pR+J}@`Yg{IgH^rlGu9sm;d`y&>FQ@+CYwf-c?R}oF6ks|oxSvsUa zq(@;+dk{V~==x@Lg2IYf%=Q&S$f_JGFqu>rSZ7-pAV5o_M;zu*7i1X%h5wiMiI_eX zKczLG@IUq!@yIlPZrz9*QHXPSr#j11V?WA>`Ck3{V^|N*0w=nyljl`Clg4hlEAqCnh}q2(zEcT&fz!f4M%&f276A%D+)8w6>uzU- zM!6~H7h#X@-A^32OSG^Bnrbk8IeEmlNDCyz2)(0cb5l8wN^`5rc`r0%@7T^0tn5K; zh;RdDklbHQ7xbqY{`xTz*yGVOyRq4RZ&~71?(+H{;zTl*koI~<;#vOv+5P8TPtrJg z&lR+JWTwLQZEmXtZl;sR@o=uw5~Lkz;)KVOx>pBK1k7%G`l8Z<1J-7S`(~w*Xmy9^!en5 z*MfLX4cNU~^fN)aMx6!Kx2RS~gtK+ykAQ8hbz!q8rWEoyhG&FIt>IRcZp92GcSt-p z>O*>OhRdgUo0quhb@UJPo*?ReajRj?f*-woNDg!q(mw`X6{Cz*N|GX!jgWt>UI0N2 zb%Q#;l&O4g@$}loIJy zKsuD}Qd%UWTN*}6YREy6E&&0hyJ4goq(P*lq+w_fm>~xkcz(Bgf6u%1+54^U+sEFd#!Vw*IMWKB>o9U)vLm(Mh>m#&-nBda}`>O2UPA>Z;e;;+V>Y1n)U|W zE(#=)A%HMMNs*gC68a3iA{}8&^%2uQHuPtdZ65aPzf+{G01f%`6Gok$Cf}mXuzZ$< z5#4a5_Udp>K+k4e{DaW^#ka3~Ta$&9I)=Poy@kA<6B&BA90HTi2`ZPd5knibN!$82 zL#gP7!Ox%6G#mv?SDzW_JN54#?;85h<;8Uf=((En6H>Z>-3DXWBA#TEVt%G1+tKK9 zWPwn&_?{KRwrTr*_kKeY0Kv9A^qcw><>Zu}_v)Vx8&p8AML*yW)mq+uEV^r21)Lhc zy8?Uyji6g@Q>}lC+&?qQuP?HHmLvo)w%-4re?#`qlLRTo@#LSDAHO9DLyzBOutC3@ z1J9MI6F~c03xKfN0=>X3|4-TTFH-Wu0M97dke%h(G|u&h`XjO_2i-%H8D(d3 zVWCGEX?MWsqZV`Cu2^X+Dw!kL4W;sz`WbcP_AfTtD?*5Zq}Qyc+IQ324+Xa;yY1>e z8A9p}Vir7$Bj#6`Po-+-rqw7#dRSa69EyhqXkWG;(kBDwZWmlrKw+rc{``tFKsG%< zMiR(E@+qC}?(&5B%B)2fdWm|lQmIhF(_^#ZW;m9;AZ4zpD&y-DE-;6PRs950h-Rps z@!3A7n;z9vIZjb5P@oGwLtC1zUIH}PeM6ntmh}zG(4uoD1}2(cN={$@7gyPiw5GPn zH9={%^%Ht{NFh1iNX{mjRRZ7>kem+ zHpMZct|*^ZK-XR{s9(i5#nC$SB+6J$ydIKgzH@h+a-tvO{1gB9RO}18{D69*3RYp3 zdR++=v@a4^c6lI8s;L%(SY@O)(|+E95rV+6k+=-Ki%IaVd(*Uhb3EB+9J+u#*Gpbk zsm-FwTRQK}RNMP&0p9Z-u2i_G>&n6TO~q?;J7A|he)1uS;K2(ym23_E>4+mJq2chw z?S+yJ=6IQBia1{-o6btRU7phGrZpLCJF-%0nl(Sw8I(GqgTTr)QR#K$; z+2FM#FLd7Q3RhbX^Xm3k>*Lsg5+h#JaqAx-SAbx73RG3Gc+!_Z9iOe+Zo7VwHuzKf ze6PUY!3ZA`33OG!4BFGJr^?&WQ9i;~K5FP)Kceno9LXOA#H5inX`r_-5N$*8CGy1I}Bfd%uC85z3ij&K~A;p zA^~ur8r}Mwp}4}c%~!C0d#b1J&YV`HSy1C?nq*HRqR}0p8R3!=zw5qM4*k5{A|UNA z?!(4VF8*Jz0Rldl_ouiR`^gl^&20hbi75$kWWBnHn=+zxl3^g+??snvpotx zYqU1`DhMbc$Hw)x7{dB4NNvB#qNDvDUn*^l=3OYX344DlH4b`=TdkmXY1tcD} zxGE3gu{0TEW}(zgANGxF?c}xNr4{}_(*FUXov(9U|NdpkY*7JdwaR9`^HerT@KviY z$n|`>PrAB{k2hKn`$a5Ir@ZuYF0*Cu9#x!1!B|uYpcUNH34|OA)3OskMA@mfflanHA!& z48z6<6k#S#xuEC{yviWi>O0rfDU$x497^k?zQOsmdzrMdl*C`#5KRQdNK>IQ>UfY> zkbXslpy=IF>+3eFh8b9dD_{p45r1pZzwr!*d3qMmJ#lu1w^Ny*Vtl89F1b4WvR=P^ z4S-BvCaUex>0m#0lm}6)O?!&bM)mbqJIb3;Qh;Du<9-+A%*FJ2et$&5`lT5Dp#aYPk`n!H&s{Kgt-w$RRM6lxB1t`eaK0L?{Y<)r8n|k zp9lGhd4k&A_y7YY^Ahz6eqfwu`mcK6&&-m0tS3ubPtlqe8E$DqoBGDgW=*vL<;npH z2d^3Jt{G4g!kT>b%fo8{Ax-KL7dlC$%iQ{=K>V(fJA3p>VTS2>4MM>ugNQ6)h0T!Y z21=BsSZEI=r4JQ5PPbL5n-(Z-rAJhCc#}L?$Hgb>CA@`Yi-5J~hi-N#q}NB7y+-Xb zQT?bkm~JiCY?bZ~@B?_H0f})y37*p%U_!Y|kD1h~2(llP2uU<+Yp9RuWe&F~qXw6+ zm^!2w#@y2>`(MOwY}oL9vf&xHn;}E{Lf*eTYnCKO*?r0D$4>4Ji1i$Bz_z`~0g0Hf zZHj`j_^Z=zyekjtvat8#OH_beGy4&&E`dXA&+e1yI?4k77!tp_@0~9DLSxLZz#7 z+dD6jo0O@(#3t==w;*`!j>k!&@1YhEzE5u_3`KJ-Z$y-r2?SB;C%{fv zLSaHVMt|uV+xpHMh*lK1=I!5Q1LSlhMt?HGe`)hF_!SFmVVeF$Sm6fndqOrK-29Dd zFHyXHJX?762?#SBdm!i^W54t6u)(F-RXA_h2~n6=0A|FFsy~BL49Gk>BJbPK-LIxH zL_>iqG7|W%1EN=QFras0HsM;eKD>PosCtx^vxqYPEct*7{%;m0oNXu%Bt_x@?<8kzKe!V#DO z^#S@#ArQC#idrsE{Lgyw?_H$G$X+`TzDI#Rg@ib0@GE+fRwUi~rXNW@yt8f7#X|x? zdyfhe?>|w9P22CC->0RWuGD+5A`Xsp-pj0cg7j`m><({L*;kkgq1qv`lmYIrBdYy6 z7=})Z1e%!UkkBhUvvW}jl%B0T$Tf2=fEP%KV1U(N1VXcP9oh^?s?*ocdYZ;jtB@6B zwd@I#D_gJ;PF8A*c))Xjm932*uaX9mw8IP^t&JsB%+=4Y4@IMNNpG>vu9^(S5Psw1 z-v%rUep`td#JpBVmi_2(i0ED1F3uac_Z4Juvdc?AtG})hDfGYAma{r1P=s?qvF9ua z<_J9V8$+ngy*Qxj@k0sq%QoKeb$jU*h9hBo%RN`ZKS0qw)?|cfMjO7V3M~`HBx4zx zmZgwrKh@$v^3b-G)wjwNvz9a<9Q%bR+DrF}giGfhKnsl%AsznwN^p5Cp%*oW`hm4o3NXP`BJ_KE1I5`iPrts)9yRM~9yXKuR1bAIsV4Sd8 z`tXc+7|L;+!YOS@Bj`9Z#p0l%H)-eHI@uSw@0{HN9M3|X6YJM6jTOb8`HXqH!Y`YP zoPZCUS+|U>?RE>wsNE*j*p{CsXK|SFrRQC>M=KT^8eG*YUfxx7Nj2Id9%u}wG{(0S z*)_?UYB2>dAfnxY@ls#|6(XZ+dnSW(jGi1;Eqv}~OuGO0)iZIzji%XrtW=v<#jIo$ zRe#z4_kmB41wO&!KR=;ag(FPjuyeP^(5I-J#`w3sYvR3TA@mYBGR;?yX$9r?nB`;Lpt3L? zZ9EUKZVEx3hf1^Ft0KeT1#mpVy0)hxfy+x%g2#!)oK7W;e0lj)AYXykUBse9(dHG{ z8NTN3n*FMQ^@@YX#X-6X(taHXJ?AI$*&}1n3AI$;yS+4YPVZ7r&toF6z1p<^jofqf z*(~6QibJ(2-4D9Nq1W9Hn4}Po0Z-bMx-*_Bxb{pu+dgSr`_=(Jza5ut=@1OW_M&DT z<^2`~!BkGKvAX$4gIuNSF{=6U$$EFn;{($Hz#MCjGR|q$5s|=QVdBSjSjn z)Q=4FNP4v6Rn#qGFIj&+$3T6Q@JRDi{bzgusD2g;m`P6VPR%7X&`~#a?c9Il$Oom^ z9s)|jx6rRkvFE;FmJI2(W%Rsv!=ay=gj`RHTESCxCGPjF$}3ZPwbi~{RF4R9E>Y~< z;gUwgdi&}N1xyof;U3f1u4Iz6QKOQ3>vX-mJeOwTw&nC}muP1Ugf;_r$L&sF<)xTC z4~JY=7Ub~U?X4jC97SB+7a&o=5MiF&u%JRiVfE=_alV!VvNy%l-|Y_vOos>d%P0BQ z)kE5`{dAVoOTYQb_;Beue?+nfQ|ZmY9o$sGGrR*#QNKh< zU7(6aKRB!BcD)L!cEMzD2sj%u9Xc~`y)=bnf z)I5u?MaS$I^s-QT1xH==j)T>vMjDS!Vlv^Pu#*qv$0K|Ox4B9;6(n+=W5&D8&?1dP zy8xqBUX<5W{&nP4_Eq*8fYLQz9@YcL;^)JEfL>a(1&x+h+@G%9&JX`4n*TPJ?;R~x z2Zy?`ubz>?#FtM;#ne@?tu%9b(*Xs?kVvX$(N85XRl>@Hqwtu zR!3pN2Z<;Z;rVKv;v=)mqQ?yHZq;W;j7`dgk}G&yLJ|PDd8W zp6k7jS|6PZ`K!u%kFI>Bwf_=dxIgG`cAl|dDm=t$n8B#iisF<%>{IU?+3@Qsgs68* zk&m`WuUtmbrm~M&LWQK4lvITtQq^s;#w(Fo!YKkJ%vM`(x~u;Hj?wl$vK&c#&%ewL z5!yF1ZG<@>T>ba=&ZKP&jSv9v&f5N092KTyrpfgyQBIx_QTQhB^93MLCy;&fwVHxC3^p2TWvV)S=d693WTF6L8yZVIY z84ue&GvYQpNPFMKJoQMzDG|$p3^y>yNEczQ2U{ z|MnAEg#5oHye%)E#0|jRC_nc51AXp5bwTSd`>7jWkkd@`lT*4+AN32)4MY&r7#Xp> zbWv&P4)j-02*?Q(Eczc%fr!M zMBaY*iWgBa9Py9ggVZZzQY2i`tM8f=NqXTm!#Ym=n5;gXHX~qOoao^l63hwe6WDKy zt5$AgtO_`dkB~k?%1ujI20wW2ZcfmbHL#uX|8yl1XWCPK+85Td>MI8vLLo9Erwwxs ziJ~HyteCZrv%&Oiu?m%{&P3;kGPb+wwY|;Bv~+n^w%?i^|2ymg{8r#e1eni4+#mJ7 zQ=rnn1E|#EN)mu9f5A@wkGQn4na-#i3^#eEUn6-&M9ZkuHct5+B>wD&`Q}`{PwOi~ zkJ?U1vtA=4{Y4U|ON@4~Mb;YT=xQtI@vpZh_t!mVl1_@z`)^#~23oTe5@RH8-8y>x z<;}+3^3xAXZkV0Rg#5d+t)`8p2H$js9N018tGca%`d7~1p-~<>n@7UrY8Z1t37948bG<+DP8@g zj{zLE378=`cO*@dH$AOfc!Fqb_!#KH;bIgN^nKdsIU74C-}Z#7SNK0z6ipnef4H9O z6o&O-BN1HI{zTSaI@pu~#;FzyOPzYLB{}ncNx9n7b<;$vbYj~UZnurTmaRM+RbD0; zB{6XXdDp(Z=B;XohoThE&Ofg1x#C-GF4ngHUbH;NWh-S;Nqive{g>(`_P3?>bl0KE zu!K2eFo``lc0)l7JCSz|@rvf`PG+&y)bo}gF;m0XHHcC3v3pCW{k{mpOaM#)Zg~I!`CG{lAE8zws_n?u8ut$u7RS37)wVyg2WY|3&a_xczYJT=%y*-w zZbs4tU&BbLd1tNdnfMB4E!Da9maX2mDRy8>xTZKTe`}A{| z1kW8|V^g5!lm5#FW6yUCf91H4_In+elc|gJLnsrWnSA|vI5X>>xW>h*?6Vb&vf8?P zFQ$6KucTNauZa)%-E5`7VV-Dk3X*n_DO2Ygtl}72{k}P@H`6W}-KXKw$lxVa_}X;b zI5UE~f`>;y_YLbic_>NSB&$QR_6yZ+_<%VfQF11Z- z^NWp9?Hc+M5Z}SWBcPANmA*aEN0N-=?OlhSK?=SqERLC*oT%;)*?`aM@Dq4*`dRVt!!Tx;y(csZL8eE7G_qvr42Gxy9-IEDz_2CUBP zxY>JiNdkko^?#Hqn|OcQXa)%R)}F%PACCJucGfbNaTBHZ{oPjdK6*NTEwBH&K{fu# z^EB{(l^1J^e)=~2Pp{vCiQfGc@g?FJ)rtYzoU773pR{o$sc~v=Ez&1;=fgi7hgUj= ze92y!lGHWSUC&%TQKv{PHdacVs91AscmcQq3SbcZww3n1r~xnMuTaX97}nzHnEkKD@9+DE96$qXLcmi`(s%QFdjxTvwy@wY7F?nUu^ zy7z|E>W&Zf>|gNEKh6TWI_Z6pdprfeiN{+gG&MtN`cCgNpXoeJt>sj;Wa?;^~* z!q{?MY?{p>5+XshI6!BC!tVmz4vuVEg#l-_$(jvFe{`HQn(5js@jqi;qK`aTW>s{G z@Wt2te1F~0)CFTm+r|aW2uktv+$g{dpcHC0%OATe@@4=hGE3iS0<6UYCW$W_4p4f9 zn#yC53xp!Y<(YBo^lXzYirnc{i(61@S+`rdmrtD=U$3~}P@U|Vu>Q8&3jQw7FtPa; z@W})=<+WzICa!2dJ_krKdwG!e@JBcJm8g{&O0i{VjRAG}vj`U4moS};-(f(+t~w*G zDtMm)e8jF^o~FSi+YC?d2(gqc=gx@q1TWgm-^mK(V^L!{g0*F(GXQ(0ITR`JoNR$tspx~;#+w=O~ zxX!Wq@rUv??Z~}H=-mt>&xen;P22w-$#DBoOr>0TM6ssf%Z`bBWI!(C&wFp$5euvv z@VwewCH!QrFaZMbL+P|#K9>aZ{+Kqh)}qnkSOkv~7?(5XtSk(XHGimY(v9a2DXBj~ zwVlF6fdvEjrq%!8n=&1aQ#ckxY6nNZ2yWyLXB1{Q=(X`q+|!4y&>xYL#7N4 zOBL*Gp^l)Iu;MckKA3E=W@@C7VEdxq!E$D>R~vDArs-HsbGYd>7x9C6dVM|M&T<6f zi@C-_c87f2{{^O@k0@z0pI*haO3d}+J}QU|abP&dybzHpIU4xm`+wYLgT*BKm=U9^LFU zo3mQz+}SD=@$9`(Gl$5>qN2j=qpQTwu+6rq`}O0(U)ITWj&kV@*jxuzszl)pKmnO~ z+s_PfaH)DlaryZ2@xZl6&djxI<=mpr z-8RZA@ds$$Ib6tlnBl-XI?a;7-KJIRKmkdHHkJPa6w1^XD>}!FXxiMGA^UNx$3w%` zH5L>sG9&2)EnbR&8#}S&lp5x*#p_iWa@Q##=nT8e>aPdisNFMUG*7(rHx8N)B}6_&hn?wADQNt zcuF%zjLBXfNTL3ZqI{=e`pWQ2Zq`fQOH80`3<}iOfdC4bhN99 zWsirCxat%=@0v(o!D}ex<}KZZww#amN)w$xY4y7bcMZ8pN$+}>uA6igdfq;I2kEHj zQmsZlOUU5~b$J#I?0mA$?e-UzKns ztZaAKFb)6hBdc$JIkNieF%<0kZ8Y?%SbD7)l6pPgejIZfbpD@;;~dxT>xfDW3W|ET zQtY~av_6p!Ig~uz2;|a7K}{$2@><5-=Te&?Zf#00yop_RzTh%MYt>e^wl$J_>K+;! z3%SzTnGB0cvtc1{1karN4#@m$5%kTtC~vL{^{zRvug$moZ>!Jk*S%Hr2N>B0c$-5- zn6NM>Q0g&K{5=&888pa)V8+zP$f5F@IES;*JXb|Gc8y4mo2(nwJhX{)GPJirheW6+ zeg?pZjL{MrI~nLDEb$VV9LxM2%~qf4nDW8-tJ*tUeY_OO52k^ix8P#I|vxZb(z^^6!*dmm7R|(oGkl zyt)vDF>Qw~!t85bAqILyF`S_YYO`yWqI4odims<({(W<}3v1`<8+*IyABLrGUtGeM zx?^Wao}vJ6AL*{-ixwY~*5iSfpFR&!Mlrl+$9*MD4Z9SC#P*o7UFDW}@KP2}}&tI6!cqcaq> zKtAZ%pLulZ9rqwcGwmgt9<%VXhyJ*N(9?XxV)#mX1Q2-y#zXQg9erKxDQDjqKc)F9 zllwY7JB-$96lYlA|1>f|W&Uh|(rTTk*%KsJQ`q5r;udUgq5)`%5^EZOsTFd11TMFD zK6ntTR485@s`*V`s_7f-Buf*z56Rt^EZ+>0&isX&Zz4K3AnaKs-{aHjMFb{5VHTh)Crie+}6+=C3tcx4n zZU*g$M3m8W5IGb@%mAcMGbTTaCT9{9=koeyl%?@bO45gmH-m5GSfZcK_0h;;t+~E7 zt)%G8oau#YzaijNj8~@QC>hTG82%Y_i%(hdCdjwiOuziAuA?huhBY^>b?=$xM9C(z zZ$vF%Y1kwXdYXyyC4KC`lI&N4W6tye!%Ra1>$w_1jk>EVV-Kl$wfLQqEDz=kP(l7s zO;RWFfkI#f@}9Xy0e`dCQ_HmJ5#W|^p9~l0+d=W%8*iIIX0OHdIM3q0++}V^K+wpZ zU`;S!Ja~li0-V?UWUu2~SMVCksLaahC}>i!rM~6|9&4~+&H-qE6H$MK%>~FVFiKx@ zl+A)?$2f<1~&7DT0)PTnCfSXklEuvXQ-;D=K4CTHl(69(E zGC}xd|Ycuo)wRd`D8{ee zGCEnU)lEruj)gG>+`}_ul>AIS4{?G{N(GOEL#D5Z)n~!o7t$Ba&J^~#7w;`AQ@)=+ z#c8}+r3PKdEE?ivTiD9J8fcMkaAr?)nwzFt70H304v)cOb|5SE-QntWYPd(XY=#fmBNG}d zfcam~q@BBx>{mo6P6LV>@ePf2)xrBt>q_@ylf2AJa_2K{D#gU-43gj|f=V&p$!^I4 zR!h=UewNS`=OA)^0X4X9yVQBYf#M6>%RUJN-a{JYRnCaHs|=61xpvHf6$1O}Qk2k+ z3F8`lzT@!y7uN$%*(6U>KxbI3)$u#3l$+5N78n`#J~su8CJxI;o!GQlBAn6b4BbSR z*!s||2JrS5P`SOYEtDeDQo`%eTE*6Pu{94Wtf%eIhJ1pd;Sk!SYL`2$?|c9`}RRMb+hoyp3FIu z$>AW=Cnb2eSWA2}heA-2&23S^%{B~Fx)co2bO~pTEzcgMQEZrInqZ1}ETB8KRX;6M zCHyha?3Q!GJM3pJk6z2rcgRD6E9#LWy5IDAbg8A{EJ-UrW{n=0y80nMEl>|Z-u7Vi zxHdh)Dk)Ps!ljA2Q14MkDSI4XvIJ8k7eI2xwNv2YZ|6Lmr7~VT#&fhNLH2)R&fPdq z1|j1y_DHq)jE%aOH3m6yhgEvZK2zKyzc0HZgN6KVQ}7qx6wXcwVYpc&1F-;&87)fJ zkwRfiP^o7)9qbQw^>%Z!`vW9!cThs_4l}k@*AHT~PsB-c;16?2lKv+!j~Wuq*o77=qwo3)py5UEH>0p z+P9b|rMkzpU^ez78Kc0~sAz>jqMZw$317$HJrl?`pKNm$C!&@P@XLvm;1%oUI#zO@ zxo3?fPE(VZ={UHoSorRzE{Db!}RE&RirCH@WPtwtUO;{GDh2%xdOa(jKHY%HVq>G zoCe;o48CTBQ#6<*&RozDtk)c>KK^p{@ohB^DfNQ~x7QhC1iY=x%jB3Gbzic>1=UH> zS_shiNjoi|KL8dZn_MVQU$y_FZY!1H%%Td(zhS>=^kGU*ll3A=`@v_&W+j{^O;g1Z zNP92!S)>OqHd^e>7v3ND@9r38{^wBqqFEw7!p1>Mz+JnQ_Ycde-yJKRHx(Se9(U4#kbfxWE2U`wl0xCxdS~p-0z@sD;S> zW^b+A%js!7Vc(&GbQ^|2W#%UwhXVqylh_gyJO?P7-gbd}g2~pJ7{dMh^H9BW9c_|z z)or2L^(kOFqo)9|MXl=JX5HFTQ}yh{tMWD>60rt7toxentjb8QZdllz5jX?J<9z!? zdV~C!UK3^33`e#579rMmS}bN+T^65bihcyW3VaB{oM(e!LL*aZE`-tm>7=pkrZxHp zRfy|KYfh@EMdcZ)uUb4qyUcujTs+!$WAfvS*d%9tBYpMJbH%!(Un(}ubKUNg@cG8e z&=&-<{2i9_e_}rW@@L>Cg=pp|GVQy{A|S%N6p#nmiXla000%Gzs~p;!fnWDcVByC3 zlZ%&PDf7x6PzQ96L@XZ1>@7XQpuzMyvPH51`D2EEG4_Ex|LmIo6Kaxb{%f&M00X%1 zKwDV=t9mJ#rwtXpx@+f-J+?2{e}W5tH$+zB!+kPT7QJCk(P^~LjFFIrZ?1!9;=Gzs zo(GF46Yyv3Ww+#)&0cJQpRc}X5==qfintZA>TYpMM?Y@kj_@5K=bRQZWmGYcXn9(f ztwwVWk4H!i9TJRzCLgEr&GD~nFC5EsOhu`my!Tt;Tc+ppk5G+-SVd4-YkLpqPky+u zcB7k3QVEnHjz5Yu@t1J|QoXkj=k9iLDMmY?Hc>>`EKhWfLtJFK*wn=o{BUDq?Aq=@ z5~Sju68=Ewcg9q6Caylla#mozlPCwtzg4z)KT|XHaimg=%c2eJ=GUayyTjTY2GSagFEX(n6>~W262%=00U$5d87UmBJ zhu^yvVK>@QBgF$q)=pF6Q9ohNS@POqlkNJH+GD;M#`(uDyu4{WUQL}rNjggB#&NiTd z4*u|=L#KoU749S(X{n~#6(rEI@w|@VA9dT6p+45u7$W$}?&{1qNLoTghih%kbnlJG zy_k^vUs7dm@s7}}Lb_Xt>A|IBo+DR=wzEL0hSgq-psia>M2UU|h^58%vou-l2#w#9 zX|$lnN^>-k!l&^v=bCd^VpeOj(MKmnZ{}JT`UT9#Zm{SAdkqgE zC4(zxulMdXOjS6hT%!-S-)|1v$k1ltgMhAR3PJm0 zM}YmT&K>vYmr%7{xw818%U|A-P{sct&m?t9$vQ^gxaa|$Dr21n$akF#|BU>!^1NQF zj6Kn0xIRhm7G$fdQ^1h62&zU=|Pe497RHKV+HB_zwn^MpIzh5X>5FC2fKI86a z1zNVW+CO80C|+=ht4bUzhAp9VfU_#kPSnOcH^J0*nIU`}AjR=uZH$nMSvrgwU%<9f zP9%?C-jP1i-wn*<{oy7e=`tK-wTVedit&R~FU)WkN_*n?z)$Bz?l$=&RjL#M!jnyk z5gLNAvgxp=8c+3Ta_^{=bjLn?-7u1e)&?loWK#s6ou;>sZ3MTVas^X^TD|Vk_6F(0 zx+xAAZ8W2bl=DXyBC31zWAcw5^tABF<5!hPtd%)@#dw8bK}w$oQl|^~Yh1`_Vd@Ha zfT~22$krl9f^Ej!9Z9sKjJK!ir|Kz^g@s~|WVry4c0mni%+Q2qt&Q4+4$2}LN+@w2A*?NSLnYR95KKn5<Ne#Jg#h3HaJUrr=a2ZnR2t-%L2YoK|QKc;I6Sp zztX~a7n5xFjV$Pr8kF)k*z(gRkn=O_nt$h&pE4r(tvdt3LS1mVgPxD$M!7$Jol^as z?I01J4I89hPPe8V-ddiT*%TQBzJ;GGY`|8DgAf^nd&s}Krh$${DYWf^r$dolH%|FPR`W26469*z)2IVuCC{(CXFOxant?EcoaRJ7w0T^Cnqyp!B~ zJz`Iz%9=1Jo8nDq4hHv$v+GJ)%!D z1rKW|>cMW@-RjIq6_|5bb}r$75sX2i?2US~!78`6gpM`h^7H9CW3ZmkkiUnUSIeyZ z)C+V?{1Z7CSAr6iMYw_R%ecs*rw?(8G)+~rQOSL?>QRm@G^=x9@huNk?}Ez7n*RLy z*Ln4G8%Fa+s>~Y{9CNEL_)CLbHI*dbzB3Kta`uGlq4HE>C?2JGNqZ&)qbL~g2(&2Ya2 zj6W8u9&^3wT$g3rF(9i7S}1I&uZ83;k=$Rb zaf^RZ7^=beWBmV?D6U0I=d(0cY+rCCYkn`k>eisz=#StPd~@(hjAN7eLRTIbezYEm zXcG(=YGZFiSXaOVGz_J*s7MDbU1{9f*w(Rb`W-&U5ye{5?c1Ztv1e6ANO$8NhqXwS zL#n>mpU~iH_ff_(zP8aHnZFytN_V-o&l-Z*+#^PYcFr183njxTz^XcUH;f+>41^J; zQd@aa(qg1x^l}cNpTiICqUQ7xOp&I!JcUrU+g9&4Z}Z6&QR-+W1Z#jsR6rLbE7IuH z`d*@fL_h~5r6pmx9i4b^6Z+kz>kp7(zFU)ebc=d6%1a-Sk?gx>4@?Fmo__ahm>W8z zafgbtdvr*5D)>$*M3y<2(og=V-J^G|Ua}uobj3>5IaMXWR+#-<+(FfWhfU$qjgQgl`hj>_3su31S(z) z0}bX<8-N%Mm?#!FS`+|AKcfMvo3esOJ$i{`9a&>WoM|``_9E!VfX68~fbm{*n_2Su zF>89`Sr*}5yL~L>_xYL&;ZU@A{z7k`^-x_h_0!Nv{Uj4b88rpMg^hoOoQxiTmN=t_I|!#)q#mZAJkMX`O+)3b<5wL|6|Sg)n! z=o&faoZ!h)7@B6Kom?${9`a^nPs?sv+1Ji*fa!3bNmTEg%FA-BN>bioms%kewyo#G zQh8%Rf9c8S5Q!~e%tHkQ+|`DAF}W!qfLDr%R6*WG^96T%H8~Z+X&x{gC~}dQd+wAK z9t!RzF<`ZkY$s1$BxSDnitv;(E9b;%geWuJPYjim1w{muBfu-Y%`4IT9Zhxbr^TdV zR1M{(^+dy#V^7|wt$~yEQ;H?aMBZYF@%hGlYkyel2wKvDPorp@EL}v$#FyOQZmKlOC zSYDrkFR8FT4V*R?e%~G4^H$01rUoJHodTF3N2@W7tGuC}vyKa}poB|m?1@F4u%|Qp zEf&P_EA{NmtM+uhTols8emKaHC&{DaCc>&g`-nOAZeM!&hZ%J$30wNw@ll+w(5fYLWCk!nSedHjZesn4w+yd;8n?|OF5Sx)7t_~Dhro(Mk z@y-qJ*}T3+bz8A5o8!B15D}w@I!(DJj-xHR4PPyq+VOjzYVkLnYgrqK=wdX=i&*O z*}b>Nns}#k->GoN)oAiFie@GBbsG)$4un(fdYJlk)X3#_!J@@N^%c17QNwYcDY)&I zVgLXYZMs+4CnU?#Bq$pASNr$hoOsGA9-&6UI#wfj#&3& zU*^4qVtA6;)kao%Ju<$}>}@KLr0`%3!yu@d0tbW<1o$O^{ED-U(Hh8+9w@1w7Rt*p z-{kxEYCGN7;(>>nWR1FU*~Em8+lS(Pge1X=o%QtM2%^6CO{ie(Ok0X>C46Zlk2m$w zo~KHik#Gth-TS@1SLY|N;uh8w214M!4(-YyuO=7syE(%|cu4o{k4 z@^EKh`}Q)-^mPkC9UNKKdAaqKnbk>=SJMPE?5m?`^n4 zAn+G8m4~+jQ){>VoGaAf>LgoG$BBwJw;o*TeD!_nKpddIu0yGTbE6C*z6@AZgTM@P z*LTO-I-u({rna+A56&5ALT0S2H!;&}0wq>|XSfrW)h~I`_;Is}h$(|`=Bv6RfM%$h z+@{Axc`YA=3h7aM1p$Hm$E9zv#0=$z#N zdz35Sn#Gy3+X1`tt&QIAF~T|!(GKKPQv1^kn))yL9Yx#mi}Usf4}HcI@qRLRl;5Y) z0!!l?!=$%C)re*=WCY$_hq@Gpq=53=`FO$$xbo;f;0pnYh_-3d)p`$>V{yJiXu3wX z$oYpL9qSDZO%3Mb3ns>*N#5)CFq)Jp>qu{E25Xeuh^hNq^7=E0N$Ei&p?EEdNS6K0 zo{UFh$LahX6mEtF`inc{r$_1~@*gUA)*}@4ZebhJxO}#5nQ{5tx%#yX?Y0ah4j7h_ z9|4B5yAN zwf9Dwoh|cVK3ma(l-c!)%+#X@7|HGC7Kif$L9dd1KBaPN@p68ReO`r;Li*0hB))^w zDZ=bSNlln`=$8xb-^4a-cP{b^!t+gch1be!-`+J=kr^QF&j0>;>9yHI2+`+uiW3HG z`4T%Rf!Gcqu?fb*E~QBISS=F{ao7@<0(|}Ki^8?`&1k7L8OCyd9ak)24Q24FH%bLs zF=c#I`sQyWabaX*Ed*8_ZfFp(mVQedQUTU4yP8dnZnmY4sSLPSCAPb z(iYLo%yG#`By@R!v1<{#UDS#31|}%Io{P40-zo#v2uBrv`oj~aB<1jTsYOnPx`e7+ z@qA7%ot*RAm{kkr(BjCa7XzAweSq6{ho9r&6GAi5RT{UnX+|l6ZRh3GSp4~&S--}EY}<_7t#=Azo+w!zL4ig_r8E@A3}h`ta%VTt|i zNMB8Yyc}Ke8HwEx(K~aT>~W#P$O6;0Vw7S&M2PpcEIbz1OsYXbCEVAwbe`_a=WSF2 z5x=8JKF*iQn>ufaxHWI=(P77~F#O1*LPhp0fsdf|tCg_>ho|Re76ugJXLNICfa<*^ z-H|~CDS`-nzm0YJg!}Fvpuy4F^P8*oHoWmmny=J)y6-xl-(3xf=l*cyr9?^LA|pxu zYx&#)2IM|sfRRn;5BtVHKoh?DFzX8b9(rPcl1MN@gjAolR|5yboF%HoFAvYouWi!3 zU#xGs%*587r+!*;bCPoWj0rXZemG!i2IzSvzDi*e;D4~QVDj>ooi1zi#$9RRIFkOd za|Wh|tXR`gQ&ikCoU~c#8;JYCgpp2sdn1x z?Dh3(NavNaumDE;S?-X~vYG9;?UeV1!QIp$nJl^XKKB$4vr{Pd%Wr9a$XnfH^0R-Z zAPXZ5M_$=PS4h~dgoZujDPBpPuxMPv7MlYpZMnE^+1=m`n$taiV`*9L#i0t$_UIcS z+>5lpO2&Fd(tq!xgo$5e*VWa(d}q1Gz6_xK=GCIwLvGA#?B^74&Pfm04OGl@k8l7# z%C~H8TezvrwICPl?{Jpngd^rOUby&f@&s|qTp&=|~T&YybNbmC5mPvF+ z0MEo?@X4oYBmvr|pb(2KS=d$f)G9+qgIbpGT%xrvy}6$*8U+3&DPpV|LeXg~0P?nI zi~9qVKO5{_BSu~iKbYn1yO8Y7b#(I8B|u|Iq|%C8hRa(fWZDMG;};z^u}v&&=7I2l^IWXNKedcA{`7F9HU=bSAPE5jI+I`J$R7`mEh$DAEDg3c=mBQvay`2%-I9a z*Sm0%@v9!is~)D(w?W(l4d#4{FB{BXxSjxP4MHa~#*YZjJJZ*9;TaB;1lsY_27n>g z5AtbEb;fNNT(!|Eaf~vKWUkVD&5`5tA%Xv{ss9Qz^pu+6t%o{}LlsZCb?MTsms@FH zB|W#p^~TAv_52KOJvN`)k)KiAs>M}?g?XkYCY*i|(rRpvD7W#&@} zPd653~5u9ZzJ` z6dNdvmK$$l+ydy)8L!e?Rg;W|E&V6r_o%X!vDas~?$&MPpGTA`1uOX#;pF@hq(Hg> zuo6SN6rKI*0F^m?*$3Tu_;Iu@w8)vJ|M{x5Pn&~QGRhk&mycghniuQbTE@isJ`9qa) zGhnqj;1AG#lx3S8X6FCL-h0P2)o1JCQB)KZ1f(ND0a0nvQA$LGfG7w^FOd!+AR@0{~{?>XnroX_vhAAHzpJ3D*t zwZ6}K)>_Z{KY&7lSvEd6{huC*A0LD;?31R}F~eTiC??tMN^WzDe7Dj#14x=mDZ8E? zMJhdf^2Cp^&*j&2J3~wTSw!19A)MZlilRZwBDcCc>1-5BhY>22@sa~5N0zJ)`@X1r;=h-ghe1av;AD|!Z8r;y1A z3%6F{nK`*s)bcu}U^{Qpe4>=$p3%Wrf5CG@5MKwAJ8dUq6`O<@+A%Kf1qI_GJe2GH zkQOQOk|aFpGBE!1WivT)wiM021ZFFP9QKFc52g6wdO&Wk;cg!^A!{g`nu=(-aV(4A zBv%BlUD-W_cJbXBV~ay1G})LhDbE(4isp99$?=bB+HisAzDqEbrzy zhQJgNNjcS+$7bi>p^fp7@JEm+)dTem!!s|M1&)iD-a?(nPvXbjG<+DmXFEud-^z^y8Q%hmJsyJ=&R=iqmGp^{WN+wcQX($BPCNe2YNDBO zG(N!ivg-krnT;Jm-*q8hg})A@JLo@*0@P}b0A$-YbS7fa_pMEbl{+Q`j*Kv+t5R&% zD_ngad5%(R-6;|~90932-Zecf6Yu%G6VgDe4B=OXsn!%tppXMvIp`su&NrgdE0Uh? z7Wx6*5aoLeu)~p&s>t&K0E#g7h&2Aqj zbb}N!-XBY=cmZUwS9T65#ZKOi6N!85De+eS6(kUkVz@g}f2N$+`b^0?a+*Y>h0abtf{KKQqA$*e0F4G~7M} ztWM3J_nZ5EmDtc#Ilr>W`}wx00k?M&$n$e}G2$tNHTL!XF~ImBS%*K{-L7wrlwW z(*d&6O;`ObSCu>d*cG~eV+X>IIZu7|XdYjpaqNXEv+3Kn4qJe5)c->~1axbfru)x- zh7b~AF2TZm>i{)ctBD&&D)y5#=NS%UX%6N0t?SD&N3`(z9V^TcQh9H&W%2Lgn12u7 z{N?#gbQBTRLr8Avb-kRrD1TGu4$~3)bLB72fPn6QYjB9@fQPEsyPo%)m*7bzT^>gK z@C5dkz~uGs_MCk_QzH?sFjnK3Q68sz!alJeQh{H@#+eP2Z#W zSzKuzLz}`*LBLKZ9sJEDm~BY)WNuq^SrQsiFwu<`Z1OM5qoj^JN%nR){hjA)Pgl%0 zt5uNTGuVN}N#R<;e9XgRS2ci2rA0IlnhI)UTUdhL=~&QZPQ@E-Ts@rP=ql;~>SnlT z?D=*TeuU_z({I}FS}fMxNIOjKVBOGg($1w)t|@m-ibjo@V~t2 z3|(%qfT!Z2_^NGI0so0;ppB<=@xy=uzqyuhpV{%2{;#sWU(TO+slT&!>wOUuyD~iz zeQ6NB4}<4QCvt0G=pA)tH6X&2gqo6J(Z6&F%cM>Z!r8|u+uARR+8^rY+j!kYm z$Z+6VKiVgC!(^=Ch?J<1Lj>M%q)|14=A(3i1YhfP!ELIOzLwhpYnecDXY>DF**k|;XRsvK|?UaG;(3`5IJXR z^$aT$8Pa^|UbSfCVF{MJ5Bfd5(_J2P-!j#`0d;8eRYLbwk4z8x`nFZMg z7jylqL?KliPlM^zjY=dmD`u%910GfAiZr#|98{7NXB17+;Z;&pQtHRDiw?NcPm(_b zIeZ+L?nl3#s#n}S%_yBljH&8&ivu6psvz8K4fF8Ptx~Yeo6f8M#3cG{=<0L!#S^TO z@Z#JkuKonT{n$z>rG(Zap~~-Q{+_r8p3pSYKOJb&DZU2fh6o_z7 zeA(7G?>`_7yG<4DujS_s6(T}Iq)}FaqSxvwht^mqbyu{k7*om%`=8C4Jy=hlAa@G z7c8hxiKQjnUp+roox$UeC{dq#HR^D62GnF8Wv&zwjQTbHuu&jhr9%-R5R&|NV5Emy zzG?9YFTgYbu3M*oeAS!+8wHviIEO?B%L4+bfcmsv?Wc{+%6qM2)UI!yRTW^xseNUD z>`Md0?H-WkkDzy*;Ny+o_9!erP%6&RvUvm?(e`}C1dfb*#qrdw#bnczJRX=`mi}_g z^s$80D!Qv{&`G?&$?c*^pzQ$`hGr+xYjK$`#Jft*i@h8l|3I!FP?{iY@~d2pSBLLy zwQY4>m%M^Mc2z&YptSC|@b0ayqkbkKxi<{GHsk8$+k}Ni!mp9(C0Zkba)ibv9`4(O zt|b|(4Th8Mk3}9TD{sl~fP&a%khL@@~?7o#cVWK=>P*z_S*r8JX z*gf)2Q4d4&rR=rN$WGqkAr2FtS2V?YpieTHra>`i1urKS1j;!vv(mD zN30fDxN7r%|9Z+bW#-=gFUFjgRZXuGrQQn|-7L)5kXlje(*CphEO%r9K-~rH^Vp zB}x#i5D7aw_2hJQ$gU~Z4%fH^^GAM^lv$>69INsjxwy1#zJrBDj_=IO$zgOgBcrD$ zbW^4>4$rWe?w*)oX?d)$@Jn1yEP&AUIAFjc2-E%n>R-K5J9d)UlIXX} zJxvpwHmK}>=w>Z;Io`J@r2s80?EOe6#0vV{HtgcT#-fiPW9n!Q^mZ77C)s&*A)aFgno z!HNO%r3<4S0Y`hkclX@4_g|gWVb|d~w9$TI_ZbPcU+^paFjf7M3$q)*anfH3Gmvjv zw19+*P99!y9RojsYkJ}HJo(`o!*_n@k>Mq$_NitFF#nFW=A%xOYReL$+yZMdC_j zjLC(4Q&^7eh=-7q{;5`x!+6qUM^im3yM{AcMu4X=v%E{xPNqx7_d9WqHSD-`%SvE`AMawdLIA6t@M2}=I)aBvaPOu)LPz{uu8(7h~KL4alwORl7fLj#lr-b zN4_Xpfy7>j#cV}{0`u-Pt0e8aR6gaUZ2nzi6Xk|k`xt}Er`+BWdmbZj=j(f^(D&FM zAbbC)#tK3AEJ9TGp(EF$^32~)W;k%OMa!=7TN_^X(<huPn1E*cr)*ZF&GsL)~y<*_o8(;hM*0^`@$Cn^?tPGW1I7(^N zlYezXX?1+IRZwnDU|y}-fC!zM4^x*!I3XIK-U05k#Q6Z2ga3dZs>F9)_M#k8!k%Y+ zxJoSywmqruDs4|4-YH*Q&VD=w=?|5ugr!vC>a5FDlz$I5PUJ5?4}?>%qRblL^pL)7UG2G##m9vC$Pk%kIkFZva$TNw&VMFn%D_C~W^-PU32N#zj4RC};0PIK)&o>a&Rg|(Kq92fLxJGoB{r1du z5afG^Fv!-?VaR$j<_AawC`&qqB(&0>yl$nDD}Z9iNCJ{~Z;^_MUTBo%F-w$jY z^&+a!pFV=net-b+kC{h%=oNLq?Vy19&2qpy<{@R|WI7ypC6C(U_xs-!#60E6>xw$A zD(Q${n!_htDXhE{&ST%Ao=5;wWm=oo_l_+50PWDH0pfQVkSq%k zJT^!NQ!n+yh4VX(jwA^N?^HOsAX8Q2W3q-=p}g&}-rMLzJhevneOPqySlYmx=};Ze z>TL;zHu9`Y7M?cfJmq3}Q?g)W(YycSyGZH`<^2pg3RcP*1(%r~AZ+5+LSC0=?H9f- z5@|j$G0ou~d)C}6Gf>p0F^phB;ka^Kgh)jAd81!yzi*5?*r#$lD{VtQCaRyRO02}YM4xO_h&DK^GVE0CH5TiWU{6N&bxbs36DK~-ILA4*CcD=Vil;u`M)}>q~S2WR8ffoaaj8|@o8K$3n$JQ)a z^FA}!cl2~hO4vfn)B~0K8f|;S^=v;tso#HqAX|s?I=qwxT9VG7hJ+cJ9=vv~nN+Np z(^3zU-e<%L$#1`FEX?)Vq@=E==I51Yqy48@x4*D&|CRqUWe~tiXx1m(7=}YXe6zV4 zL-=lU^+ZnPA^y1UY%er;1)d&*{Q(CE+xN$OZUfnx;N1v5f+z)ob%CndD*Dg{99vfF zSJVA6C{P<~>5 zWum4t_QrGvPm9Vn8VNr(HSyE2tvWj5PV$+5DEYLhsMaWc9X=tGr4o6brO+!>zJ7`_iidzyN&93Mf#1vq-$yX?;^$V$b6cF;CP8Vv@1hLtA8o{ee^8!;@(J>>2y`v*;Wn{Wk3IO>M-3!GWg@{ zW}bxFMH7?M>-6#wBGaxoeq@yak3H(YMO7sSzCBZrb2*~_`ROD-KmVq2`$#U6Q}!y) zzpj5~5L>nYyxm`XB=RZY;dx`U5gUt-UfPF*(K2`XW&^&~`0m4nhLByWB`wKYYz?@*Q-snn)<#O6E_2>Wt6EoQyjEWWwDa`F zG$N(iLwQUtR$I9D3sLY;f!RP3Q<_m=`ATiKgo$Q!ww{q$b-dfV>+5a1o`1|8bGp#g z2|J;0dxM3nEnpL7cx5T~%y_N%h(4#WT8SfZ8>S7gF~;V4G2%Q`yeHxkyX&Qw3#-Rt&$&;W6y7ZVdZM#a zNlS?#Zu}_=bhr%2Ois@40(YNg(1Z0aBka+Lv<27{XL3T9FZINs#`aUMl z*zU|Uk49H%A#FT3dY68>D;IZe;GtdCxh_qy55Q_HygtFvtjP(N!(Dw1;im zg&a7)FFi-$qVxd(Z##@SY6ZKYJ)9SX}-s4%|6TvM0_V^K!GgZV- zDG57jPQ zEn*NigY-&n5O<#VZCd|<>9mX>jd;E@6p zqLYe?39&H}1LofXQU!5|S?V6I`T>W>XsX;0E3=L6t6s*ZCB-roh=!-WbUr9QR77Fcv!)ny0N!DNe_=CMwro6feg>$*NBFL#S2 zuDbYY47EIHjHZdtqyN?=mI8Wj9@#_&+;9Ql`^{`?F&lvImu*tbnSQx;w(!xZmJ^%E z0=UNS!6J=6M0O>%oX@Z`_d!p}xJhX-%>ypYFg^hO*6upS@;hHR*NQ((^#^F8$I zn9kpUU(??DKjKDV`v78KwI!Moz*Ih1GP{eo9sPZ=#l3tz0H~9W+okUhlLi~H0Vs?~LG(nt zNpTr61o1aJ`R&8K1biWKdi^-r!krOE8_EhDBH3rm&FKic;DqPc0=zJt6G{T^ysaX~ z;ph9PLAZ71t8b{6t4lG}MNRju#&twuM1JWE!e?+!H&RrVRNtZLzbb8YnGeJcS<#I?mLma|7{`te~}{lk7=Ue-_k_y z{=OFPuO%}7^zHl@Ke~p|571{|lAgtOTp=(K>O6wXN)I{9fcf?K^TA}zS{YgepctV< z(nOJ3sdQQfq?FbpP4fV56z|#P9sfnkA>yIz=~aDATK=DikS<$XrB^+FR7+&DlIs(I z31N$F1Au!GV?p33i;*SQvQ~8d^bwu)T$7!L>)tKh-sjC6b%tUFKfEjNt64_dukML6 zUi*onb>|Tc7xPA!Fizuhw0QhQYTfR5O@~v9Z<=!tEPRc_IX^b^LsrEGQv)>;4H&jKXL^(|=m z?+4^P*}4$8s*+;EvU}t_0)%IHV+j$*t&@s%?(@j~vC_=K{LwtcKP7A0V!P_kv z=Ww@|M!I_dEzu2-X`sMs={^v6hJuL+h;3!W?3Fs~9|32jdvqoGlaGig?OpvGTgW`k zlX@@%2t;|6et?X^>9@tV#r(%>WiAgA0V!qYxWKczdh@j2abdH?|NnS6FJ@hZOO+}LH zRhsZ(A(464(@bTD*>A2;9{!vN3}duixXK}>SD>SAnwg;MLr$MKVE{`&WugvQsTR@4 z;&l!LKa9-GblgEELcNoa8!3uYo?TZ1nU!(%0>-Y#QSukFoR4?f|sbVR78DG`%F z9Aa8ih1?*2?7!i|{!^h&ZlN*P3Ui~Ujy2zkf0L29zm=g}zaTiVxpM5`g6-n<*B<_l zOp?t;E3$J<*5}=_3951-Y0IWLyru6gvPPgj?mrUX#d$&bGmt zJwNb3?KM;^-ZE5{G44XX{{aVQ(6q2vBRawY(NxD)tC{IKQm%OBs(`0|TbrhRP>mGUX&@|#C8PX!iSV2&D ze>eWFdyoJL?jvv z9K1;(`O^n>R59B%ezcxs^u%ud^Pj|QayNi!y-iDog1aiZD56c?yMdb(eo0-?tI9oA;Olz=%OXv-Z?&?t zW_hp3>RXDCVyGCS#&++!oUM#H7@jK?qU)!1_&R_gsmBtRi=56jg*P zB?GoMQA3~a7+E9s$9MvGoG=3TnOK`(kC1mA}iyrU;f1Nq@}I z@k;4LG&Pi62P(lsGAGMjJ^6R(Pg0!*&t89d=4Dzo&$#Xf5<++t7>HNzzq7wKbvpTf z8@msNd^Px|vB%5=J^PcfH)EU=0b?I+CCsZj)j3t_KYtBaw!&15t|h?(Ma#>#3{cku zESv7M8wcL;;`9!2@?hopI>onD9H3*jP??F9e;BA{LbvWd+J6q0o1>U4KIB<4-cz)j z_g8mw_E-1026&x1h%07K46~~<2qCY6fBP_adz

V*mo(&vE3D!B417dIf!d4}fm} z%>|!8v;dUcAOJ8EnE>GW8$^Wsn}_IW!~$P9{lK#&GkpN~0)*Z^=Y9+W0te6&p_cCu z4UVLqRpb^5vPyRzLrl+Er&6(~HTM=oIL|cavO)t~ZUnB6haGUO#zq6&)JlQfeD~&R zZHmwa>Ba)8=4DCrySxVh-OvN3ez2$W>nv@((S!_-Aio!f2Jo$|4V^clpWM<#uK32Y$@m>ujM(P&FO5MX5HIOv$G z+iD3@=oqTsG%Jrk9^0xIllAq3LiBwzw$O6Z&999J-uwGxA9^qWn(ckY$HO>N3b_>6 zt8?WZD{)@A{`AocW&xH_^|A)4CQzbImG9$jvH)7lg6u@t+DNVf89E*(+61pP6iuo;#&D5AS$Qq`LTShSzMMqPIIAa>;xEh?2KV1~@cX>= z{s`nF9y{!M?!l%Q_r!;?3a53fUbKwN-W+d`&HI>@!DNSlfgy3%nQ**6 zHHBBhq0K%2VXPLEk@0U%^91Z_D=V#oUTZ|tXAObONz*aQ$O>6gc@-vxZTVhK^XE`g{Uh>SsS{Q%wFRreg5R2rZ09IxnzRdw#X z*d6%rJs$Zp^w;YuXIfm?9%e}GH7Tgt6VYQdk75Y!vN=db;8Uf!T}H-U0>vUuO_fPp5xYgD z(EHlAw!EV+%mX1WU!=H1heDJ6nOwZntVn_1_@2|Nqt#(tEE$4K=W0g%0PkG8S^9;j z^fRLTmJA_;(U27TnGJR9Tu|zIA&vf%HnRseM-yLfc#^+j_JD|YA?e6 zOTQ`tP!OL?iXFP=A>XBotNuPVq1C6$zsUE*6w2c$+_j*|hiHL~sa!9r7?8(?ExWt9 z@im`t&AhwM;JEE=R;4$bJ42j^WnYfuIGXKouqQGyy^m zRBs@}MNeBZ-8^75xGi-bF+(gZG4R*b12K$Gy zn(+jsm7ytEIQ?8P(KjY2=VpNI+&$grwm>Bo|24k@p{!8lMi#Fv^xKc`kZjXAIC>6K zfl-OcqH`c^$JX?8gC1ts`R2t5mF?>BwND0`ocN4SCaljfI5B!t6>xJatpIeFSR4tc zV_3lk`K@o%BdTh6=n~G+!`qx$=Qh4xAE^kcceoucZ{_Ls7n;`)LBIVM-sPulaRMr* zMKoY{dZz#e_6O*TDrN<_7YLy;)A!;2L@xt`$m|v5uOY?fPjz3K zP|lY5#GN|zrL~jg^2~$#nnBZ=fxL$TBo3P2zGY|_6{YV4lt0AHI-9F{l}{=0a3<3FW1jTq>vTgOR=UdVZPsW^sc)XW*N3~XW5aX9ONxQcg0k6Vnj6lwCg@;$zNjmMaw zYvCZ^p?_BgwzhOOnd{DY&Cgv4H=1SU7yaB5q5*c4s!v3=LdX_$h2G5Gex8Y}O`-Tv zUU~=VwpqWO6>*azURgEz2@J3(e{*_(M3C9weFKBYNL_~{Y|4s5)wTP#XAg>rotYV5 zn#5_IFbxa*&8(1Lo1gw~Z^Kl{^Q5~3=_b|_ljIun`oXtAYI`TE_EY}V_82?GuZ=}(jPI*v*hR_vmjD9aecPyX~q*Z9=GuQWt?9E#VwEAu04q*%ZjG{+8Q2Y2 zaipR4e@^QF9TI;|>-@>?>Lw$X=s=lZ_VJ&|z^OFsD0qv6Aahe0etrL=o-ly8q&JlV zUZLZsSAd{^d==mn%I|FkBdRJc(j}&&O}05ZqC1011A@}*7xKH)q!>qaQ#163ex|)P z^`@uRqilG|Z>QE{!||A>TVf;Z4q7451IcF#NCtP8PhYwE4EajnN!!d+B;vQm1@kk2 zd9lc3pTE6`zyIkYj`PICx?|A>JI=KwO9w}keVUf`+GJu4&ioPXNPAyq$sZ$UNcN~u ztS~4wY}wO#W8BxmtG)S%P3!-|)`GS*0p2`hc6KjXa%@A4X5NHNT~5HfQwI`Qe`|!U zm?LmGBBTW(0HpBT$Rk%0u4>!+CU-v7b#SIgj~dwPJoj$x)YG2YkUbcL)4WD*NproO z9oEp0nr3BbT9+{k4b49Paz3J|{q8XxG56swfVnlIZ2mehK=U{1VgKmo|ARiG-v9q5 zyMIl5|7-j24ix^U@zej{>%kED*XFvraX&yWNC4e)53BpdP#v+{g;=)O--K>wR7uBOAAQ7MLrD{s&*jKYqW~&I5H`1d0_op@4d-e;cOKTrW@_ zwnp!0ONYOb;^b4@d@aN5!7hDxMJs^eP#xbdY5q(Z-u z4Z*j_xo|OF<=uqYcc?dmsO`~4?c}nn|d|f?_ zG*gO+H0u(Ft7&Un2J0EpcLou&4CM~@qw0?B8BM*6@BhJ}*}rMP{wKb7U-K`z@MYRS z@l)V!K&%$LchLF`>yN)3C)(39?s;{9YA~(;RFG?E2&)HlCXiHKLR~r z6WT6v?lHnpkn-)YoLutAlyAzryJFBm4=Qxl(k9kw&h9+$Fa(<7-m99a z8pSh~-_9zb5-T1?M6WUuD7c+4xG)|lDXMEOPoBQ$ZdaCUXUFEllV~%QypQenY@1}X z;IRjga!+cibSz?d68i&m{Vh6`4ly6Q3t@&hb*a-&RF+o{)k01bX}^uvo*YwfbW}aA z@rCi@8`cNUNW;KrRE=zxy_ndlEZ4Excl>I06g{^f#RsiSUxgKE03c?I23@=d(%xZ> zL~yIJ0TBNyUYg&!R_6P4q4K8(X{!;B6-=Zaz3_B`ev+W{rEhR#J9_1g~~_V>tS=5H(3lAUo< zaBt7oI^22@#2b;N|DC)$h_F+qHB`p#NrT-#BtwP#}c z>r>Is%D~SpnlVT!E6#Hajq6I!TG@*~#$Yp0c(XZvNhQJQI7(5Ve9{_eX9|kjiTS6Z z!@q|{Ie)1k6u0?J2PMdF>(A}%oEyn4;4xnq19@!#$Jl^nw5Hx5KT)fyL+KV*#rBKy z+wAWM@Hy)c9*q0M)+Q|KcV<1$jhOdYkcYgOPaj2<3MK=3^CZC|BQX<+P0v5N#x1c= zx;9(3xxfcrY8y-+t9Vui%s&m9CqPKAqVUO4Wth-<1wz(3P(K$T2oCiqln-1qoYHZb zy__+LzmRwS(48YpjKT*tk?g$~4K5s#m znjAMs)LLkH)8k^t0_AVI!xGT&X|)AOEMe@d}_H8d7hN^~7OBVpKp{ZA!0~{IEGY0pKqkz^F7#QJ%ib}TdP4~mm0vlhSDY4C87<(F0 z^QqeMT)innuFSErPjY2XKu;OlF~vYTl?sug0aYRvB99D394bRc_y+TqCQ0hi#Y2iJ z%xiZ={8yi~N+e4}F4uJEu|V&Qd_x^yuV!Jz z9wK9$6Q=$kcGDO4U9q38=hE9y$DT&B7U}tKD|vHV7OWzGr%gjKF5oZ=SkwC9%yQD} z@Dzda&FSun%Fsw1j#Uu>wii%iAz~VG=FgajPsBN)UJ+nOAm!x9!cYHwn0uq#yXk|sT#w<{M3D9w*aCEP!QW}WKms# zd8SuXoKlkN7&U&?d@9Ju}XM_SWmk&F$llq!mzbG1zu&wVPf~VihMk? zt^M6>2z3Ky%a7FG?YCOoVa^aBFewg?k$J~HKi1&fT%?szy`~TQNV8KJ$>RBN4T4jm z{EPBZX}f{p`LD`x1(GwdI?f=1V)WkQ2M4iFfiB@%djhl-#YtD9-X;!(N3WNqEAsld z$Wu5=AG)ibTT~Ci{K-*;yJnxAJT+% zdsltLo=#4>o_4CrSbdrw`XXtpQh2Vx*#+aZJ>{rIPk%r9{rk=A z&W?i(`mXfiiEl?ZE#Dqh+Mg*%E>xO=ta{>+$Vs7IWLW))rC~wl!Bz7$IMhd3*UTfH zef{WJyF)k=#|r^(UwZJ4$BN_7qdQ2l7AtrBy$ezJ?z-@RSLJijqH0@!oegh<%|6Z^83Dx(t9yS*Mbi;Q#5rS zktUA^K`0s-*Jv9@S1!CX)Oe*!DR}b&_E>wa-l~LFRqLN|A_0$zs6Vj02Ww1zQg~-r zb5QQO?fDztH=WoVV!2hkg{0o<;VM~z&kwJbsbI(=VqqesETrM;%y{e21E`XiyNd{r zh{FsKR`>(NrJ&EELs9a_nDiA2-DL4(!s@u4FU|)$#Wz#qvmc+xvBrdgT;T`EWsI%* zA;E&=0{0J!NI{H1KQr{DUSELRSC*qsZYjcRX?-J{%iF{XD#s@M1_jqc3$#Hj`ZNb6 zTm~Pa-k?>SAt<$A*x&~VBaJ~uEjD8P1*>O@qx4>Z%lRYh)2Ux(mEr^doGl+xuPtzN z#@FT&Ezg5`#Lx_2qBSOL}oUvEyGuDh%Wy5$Xi@EHJ*noxR zX49FBh$#{L;bHP`Sm8YRCDw>m=_nYRC&mTCOz#FCn(ePu!>bu-Cz#X{S+eg+&6Jm# z#+mSI>llcHcR1Z?DjXO;difBxQF?6LbVOa!0?tbABP>yH07?i9ka}1-R~rQ=V`*QN z%eA`lv8LhuanO0+7iNmRSP@|88kTZ-EO>nKAQ$SGf_zfP&7o7Jggv*UWv-I_QHk8C z;cNC|7-jQ)p2-YpM)WgeQ@nEso65iSHGB?r>ai3=_KjcvkOzE5C8XY`p+|_)=I7@qFXde9L{(|?Gw)a0JEy3Gn>?=MYu0mI7^4x-FX{b{76XN7*n83H5?A*=s<4qm&6(29p zv9}6&_2_|ApG}9ippi-7^6*B>K-Ta}FYWdCzF!JP3RZZ;`9~gzb+pWY_#MA3=wklF zHk!q4JG*gbEGg&BVs%nUc6r=K>+!1!%?aAYW%>O=&LM3m@mJpxC29p6ykIEI$p{$v ztn|u~KEui-0yX%n+E2g%OM?;g7T;A6=3Ano!p+`A;SUveuUyH^Z9O8A@EPQzu%`lh z0WK*>D#;1%gp~P&hgYOupSB_1wamYh!rK*(GZI+yh>D%L#n_7>j0q|cGv4OGsQbwZ zxl|}2__=D*rdu4W(VnxTu?b}tW;VKZSxr#@rO_P3Nw6n9CNKi0*a!|K%oKmKSX`CA zlwP*9Ys%--Tpj-U+=$5jNbeWg!WmTzT^L(tEUOLuD~bh!<9F#iZYFOOLEM>)k= z{Lo$s@BOtqaIW5m_R5ngDb@lR!#4S8iIV-tUS3(d$5r7Rrh-}JItT;{B`;Vs4CRI& zUur{mpajY9wd={Dgo-Jr3OpuUk^M%tC3o-mO?ls${NgWm_5mWs=O_51ZOhZ1Y)N3s zoErcq==_9oBPbVFduEvK60N2F%~^IA?6|SxzFE6IL0e7Pt+=*w{MDVon4h6y2W}oF1~3i4$#0Q1<WA zo;R`$BK!53md0!_xEQc?pYTTM>h?*$Tb8bI#+fiU672T0BIun@tfuf`r<2e#8|ni z{_NT&M@LU`%(vh=<@t$^Z&>8;;3uW}&x4N)!;j!kU~TiqYgsOLXtlnK^QIb9veZwR zM%<(*NJl`NI}nfQYD9_GeFH;WJ;ssaD+3xcitax^zWwd|9k2=1D-S(oh(eI2CI~Ku zM~p)7*d_#%{?z(&ET5-B%q4lvvPelO{)d~Vg&mSDMOeL_pc2#I-|U)VgPdC-B_q@` z!Te-LpzC235G({xPgUIjS`}Noi@Hbei7|4*D!>aU-f8vNt1z9u%Ygl^< zvJtfkTY;?}k({BJ#*xx9uH{LoVq{g<#*wu$BN>TNM9Zl&D5BhT!oawjsk$Ih=~-Kl zxLGOaD?_633;Ao07lkwuB`*22Tn*elv2Wv03Ppju+$4CRSJi;nh`&KsDkl`Q+qe$1 zxe^wK@pIDK-Zy7uHV+F4*rWL)w7BL+-af`7F%ZLHG-1k#Y5l^niV#*Yd zJY>wR$CkPU{*hZ)tORIH=GyFSD6faMLB``W)A{6;bjqa#6c-kVJ|ZRI#(3!i{*eQ% ze047d;VL<5(Ipnt*iYS+4an1un38W{u*(z&>8)XBUD_vg7<)Nq8 zfjY5A&~Jz4m1mJ>9;-f^g03tR)=2?Inn0(mCC*hpCdDUT)d`?TPGSFDX5+Y3b8dEom#Lw$AsB&w@I( zXspuNs26(Gk5Vu8sq>I+@N>($y>>2ll8Ak4yf#`dl3n`c#UvgH;p_N-yamwV;n)FV=aAA1#i5F4vmsPzk^FRDmXwsCC2*!IpRc9mHQXPrRfIIJ-e$DTT z`uZBKdmpbn;&4Qm+IhMAJIf5y`aeB8$k1F1A&57q^OGm=D5S5FUk{#@Yt+)r`in>$ zzkKA7-8C7Z2yLE-WYJ99ub&0%JAzLKx734(D5SKy7}=;P_%tvnCdT;6UG5B?LG>>r zT&vt$=)q5y3csULDg@n@tIskts7BAI>huc>VdBU~mP6au4H-oBwfDD8< z_v4)z9_7aG4`u$U{8fay%t+-x6R(+!aE z5x#?)2f|}@=IVBE^H1ebVHpQ{0VM*%9`$rw$5{!scKX-=x`@y68BF`qfj9pNiAAB8!YyGUx z2n;fCZUXZxa3}9@W7s{DfM2Pba@zl0BD|CNew^@m`gWzS_drG!X$d%Z6bro#bBroS zQN4>O(G1JV)4|1q zrzo;@kyrUcBb`il+jxx%Xh3=&ix0K)AENBwXNSizdV z*`kb;gF|zYgDg}|x?QW{O76SXgp#N-J-E+F_;I^=`KRxrdv_y2WJDwO6QseKp(U7$ z96u>Ul*$2?lcO+N?b66zr1iwzV;gw!5($Z;*H479MU1nI+Gz{-2df2l)U%K`^ZmV> z{l+GXaQYnnw>wD3EU#2?Rk#X#!}2~rw!Y;N7E5Di?cch^_>5{mOIJS$cO)aS$P9SC zi6#4MKgHTfPhV$pAOBz6y=Pby+qN#;BtfEN$w*WI5fqS&fh-b4a!|>%&;kMtG&DiU z839EJ5+#e|)J=|(lVrNdAkYm64K(eob?(_~?Y-_j=iGhI-ut`H_oE)D>Z)0@W{sL- z%rV~ajx61((iWa=b2fLwVs5RE)IOOZ%FhAb=?!kNQ3a=E2rI4ttqM5U@(!x%@VZ4D zXJA?5U2V=Q^bZTiZ#jIt2?|vIo*ac1XMn=n*aa6OU{4CoH`#T9eNtJy4N7iy)RWxU zy%(Jb>RpMK8@@2XY&E(&Ntoo=W=llRpjb5WVnT9qqKidnDR@`tTbZB#RbV*#U;qCJ zcp3iV)6y?cOu+3CB>^7#Z3H#e>W%B9zHuAEiuMC(2#^cr96O* z)DN&WZsX-~D3&O1v}QGzNC!8k?3m}g(8i2&4Wdz)vL#d@zVD^GF=>*;Vw>)?cqbs54{>ZY*4*|i({<*hRm{L~anfhOD{9W%x|wWv}paa#P9 z2F#;Ul;Vxi!O6jcPJ^vVu(uG08Kon+Y*BA6PqFKuk+HQJYIOB}yycLcMG`J@<1@+9JLRQJYuYliel4s+95Ahv zQ#y?b#6$fICm%LEz3N^xlIz$>MB+$Y_fuhNf#vAvCWbKep;J`@sNqrTF!$WJD~w-QgssdoY3T=;kR zgE2XZE&@T0-7QD!P__A{&Xzav`@Y&SP!@7T486JlYaY!m#^NcCnx7Q7$IhB&g zR*SmsfxrtF`T2RdTbFsl0dkWI&C3JuwoFp|#k}cFC3bp|ErLkm?hMT%C#Fil*It_` zhFxPK(sO4vqBy5B<4q+&961^p9iTYM(_V|D+x4W8w;CJiSg<5~WQFNT%r<^o2qLke zp4H@aRc42N=tffPDkq1P}%3!R4*$` zlI+s7*a9cEfIU73m<_4JzCF3X-yAoP z-d&hQm+c9k(LQRz?6pzt&5)SkAgBq<+Xg>pmIj#bo9?6LT!yl3pWWp6`+U!mb2X*~ zC^cK28X z?QOEnU_(N}M!svKc|(4HixeuR*<`3s28f%2yw-OMWxJj63+Hcuh-{2e1~cwcF*NA5 zb{y>{fe*`6DXHsq+-~&x-TWeQk;1)xh_h-yxT<~Kr*WApgbdd6oH#U4TzxgtCw^D?*WPF zLcjab{*YVIqV5Yp$`@im{Ld0d{cHt41{(*HJ%<(?MVyzxt}Kbf#hxCwi!naJNJSj% zzS55B&&7<RG}VKkO142~Pjr+^D7QZl5QT_-JQ~(q+gO(zh;PhpV9XDbfEN-O%J2F>@+k` z<&pFY)VXl}^%(#MH_w96kK(5Y%iz+ii>&;<|YVQkXJEFXb~tK@qWNJs9LIwSbgwm6ItZ&}W`ip+}T z<;YG}AaA)P0Bs7a(sqi=oPL2Sd13;@RZ;Im6~LHdX7C+)sEKeBji_v68Ikbx7W8? zx@qcMWxHOo$|yO7xt$*Bg;g1ClP8WBUV>$nkHn{OVE&4JT&n#W#iy1GX}t{B-- zS2(QQnIoE0^jlW$&SU^=W^QJjF2+W!1`D`vZAeom+m4LFbiR%~=TY1lPY7PUc%sjk zBHKDR0ZBnRWzIlQi(Ny(sf)x)OjH97XqlLpva~d&6q!p_CK^I9yEI$9mrY(%MqgeD z40!5A++K)_Vhk8iu7U&x$l%&UqB4mxFNlXoi9NcUTW@&-eX#>;@#u3&KU4R`I@-p{ zSzhi=%-M2zCk;RPXIM+-43Em_VLH4++GAnEH$O~abTciynxl-D>Sea$T7=|^^j&M+ zw>8qAu(QkM9p&uM_I7xsb?9yR0g)E((!9kj1Ydn!s7UD}q7PBp8l~|y&-)4!)u@s- zgJC~m;4H#~aQ@%X?)3qw<>7miKOIPI-y}>rxomw zGR6aa9E7xmT*B+&OtD)>*3`|wdzBKWcRW5Yx6pwUIh}^9J*G}jJ6ns+pJpvvuCfhL z;q=~7$7br`kY`NbwB2S@e;f@^z=I%EF#Tp|K{~{f0spYcpM$TPZ(!Ejsr9WC5#axl7Z(7iUeJU-?5_bPYL*-$${4p)gD5J4$N;N`QN zrr3ER4>!v3>bFyAm}-d8qq5m76(8RdB}_kZS0TC6`)mtHAYu~^h0LN3y%=!eyl3oN zaiU&oI_*jxD&@SE^=Czo+-92S{w|}Td)$!d#O;g{{0qchsvv^BZ-gawOy%Yp18Q;Yj!nH?{i2Q!>(x|k1?>IBPAyIL)n z?n;?&?!5k_^X7}d3rXI%7q2VhwJ<8>GUv$(b|@r`(;`;@>9%*KEw;HfLP5U>?$G3Z z!SA9?v|yXk{2PwifsdEmf-mkO|4-FB$N;Om!w4Y4@3$E9aQ)m1H4{o-0yD19ZKSfM zYNAbEG!19Gcpq(^Lnc2>AwZkmEWfPV*-X3m1V>e5ii7B2nKNVYcc0gIv%e+QEMeKA z7w|SJ$h`5gw4FBj!~-42Oc>x6J%8`1DJ25?kG(L-+@9xslGS0jVfqjX7~kqk}RXHHP^`|;|9;4$+V za~79Khm^Kw?3G?dFGP3%+)`!gQTlnlFb?8&rm-0S8sR{_&^R91Jg~AT2wxqwai>}D z3!U(|eBi~PVYTu!GfM%EE$^D*X2vMxVCdQypj`)OYPY%R`MHyLCo}(?j--owYpOQf zt1sy}E@Q8Z#qiPa27jWI%;0_!IPv45JQXQh)^n`JAXdHMGjd+XKWa=bRx z6P&y9UmmrnkPl~4E6L%_YH@RmmB-yXB${8Q%Bu<{UzKX#P4fYY#-Zc81ypaSBnVir ztQ>M(Fs!A+G#*8s=MW@d53s^vO;?smS?H*t;$5u1LWRW(7T+@L-jr-a*9;d5mQ9Kc zcE+V1s7XvBY;np(Fp8w3Eevyai#ZieK?0t&H96wUYJa{oUe2eU_}E1#Ylpu?qtP^! zsm3j%U@P;xEiNa^Sv1BaZ3^3e(#^mCrs^mKC#-je74?_82?e=>%``+%q4-P* z`~D>Q>WZ3CtL>GE7JvX*m?L)T2aw|kC~GeGkI%TI6Z5H_n#}#@)>Z50v8U^VWK%uz zX(S8WagPHg@nI=53{SSSZ5gnRHe~Fmq3>~dAbB558lNkkp3qS$zQ>ka)0G>PpdNxy zguTM7cbC(_7yz?e>eDp5&CSWjBiossth3|GNsYwGnEBuh6~SOy(|gf7El;Pp!B`oB z*6|N8VK?s=D8JEpSJ_egL}70le*O!@3f}}`HpAltCHyN~6xO&i19biYf9uo2rWGjt z#ZQ0p4jGwC?TROsLr%^$eFF#uX8Wgx^n*kO^Y9--*n=)4yJvZLgK(_EExI@5O>~9r z`5#R3WoS9O9yYF?m5~&F?_*--MI$_L4p@(tU!bq*o#oL1M7VWN@z9+nQ#2UKywUKs zW^BOav$^wKQ*;bErZmor5cO?Jf#|YCc(XBfdCN!iaShJVr?huY$XeiEA|zo8u-&Rz|uQKdw*XcvX`j zw#vmXxRO4&^WkO@(=It5f>0jv3uMo8B8!~yCbY8)6B4ymxZ)xM1PBvK0@#A^L-Ak& z@o3dKYrCW<&y132;_mF3nFRx|dWj`nlNi?8sc4A1yJmLW4T(_Zg4--IL)TG+nEVaj zi5Nt6J`TV=)McL2yx6g4wYEu|zvy$K7OlaoreVJy#%du-0b=wI=#nMCNM1?#PT@ZF`D7T6^h_&7{n6fmgFjDL(tI=VGRxQ2g07zFZX9qx^Y z7L0K`OhkjX!#RB%Q}1~iGWPPTMO8%FMc9+1NyW_=A0^o-EXWKWo-+tyaXEIYi zcBLVLFYQ{chT*-%Tk>@eukj`Ysqj=e_ib@e*gcC^zegiiG;MbOnH|K&l_GNCb_@wa zw=*?$Qturu=ZSw=UIO#UV@$i=dA`7T1O6Gfq!R_HE{n_MdVbZl`a*TwL&gAI! zpjH7}TaxbzHdvmctf?lSa@^yk%%Iziio)&p1pTc$-7RPD77p;G*nWwfm+X5tV`o2` z`NA%fM%npXxC05o8v{xB4fkJokr(8>eLf8HzIgV*Nr0w% z83=kDqqw@eGb7n$b_$EZUX@NW;kuMg7jG1@q)W%ET}hb~lW|P(c_fo6EU@HWTKfS< zx1=2x2?sD3JAxQwEBlklVNd``5jakSyT$o)y>9L9tYSp zT=5c;3`jF$T;jQ8sZCqC45u?u&FhNS=F>PzgE43$9vD=hOl;iPhO+K}j)&M}qLbQX zb}*;%7hw-B`Nw#wr^xdQp)zH&jN@<2olgkg*|Gkx_$42SEgEY>GD$%$2}QV}=9N3F z_)_dOcNcKZ!{Mem5tgKIn-+P^XF~AT##@R3PXqcgBmrb!6@Z!`_G}#b3SF4($hbP8 zT)7}iqWHR^ivh*s+7>grHb`-&l{C;-{qMF>OLUn4o3?uBwE~$NcF_uFm+iE0p+d=j zFv88}Rfb?_|G>qy-gli1Hy=002divd-}~{&4}A`xk(aZ#l zkb7)TMhSgxOZ7&9v9)}toC#Wfsp#* zzXNkA;&ODbG>Af!0>txqkv%|c!|SxX#nND?-4_yGg7-#KF;~`&XEEqJt<{U*Z>Ux{ z@p$8U`BfRj`utOX{5DS&Yb=2XCV&Y@_&f)-pM6*#WjJxezP7S9c6Z(`e%|NqN;ar{_s=@KllBky`L^DTeuHUXy8hA{wopB>P{?RTdB}3em&jF)}gjs9G zQcqly#L?^OVdYwM@YZ=Ly6PsKSk*9=c#LVa;}fg*J$0$wM7$U77;D@WaD5bRkEGdP z<>tib9Hu?lEkoU{4kxq#M4FuqIEGtoUIj zlatv%$C+pc%jo-sNpo8pe&xPl24bFfEK7`l-9mcafXnnblf&6VIBCSH?%P}($*W31 zn{VnWG-R;Vu z)#5Nc+C8!GnMV?XbWB9YssX+M-xjU_1;~~&XHq+T@qigIcH5fSVqsqEF60K4pI4$< z%L>ZyBQ5yw22u>ciQ_Q9y>~dXg_5Ynt9}Nb?wsdU-g;rJWNuF%>e@FDxB%p^ewq## zR<6zr$J2>-$6dtB4(H)u&;8DHuCE=T6x6-*g-k4jI>2^LeB95~+&eB^SdN*W>VVIz z>y8rW@XEL_Orw@1Tpw2ecgWj#X{CK+4lUc=r<)aIo4Mk7b*n%yl&H%DIu!(Wp$k5shrkNwXYYtmUj-ejJ{Aomt#on9(qWI(j*dR_*|A1Zev$((fz)R zifdN}A`o-}*tj>pal>B}sOZ`6MIAL|XeZZCyT)C+^gc}*YDsx$FpE)9ipP0l?;pjF zA+CMLO9BD{CIOeu40t%8>8Q@XE^7yCFg2PlaJ>E|+DD;guQ%{U&)W`ao;gEu6oyO2 zl6NC6pQq!+QPWjJVXWN`5vd(A-t_aMn(f1FNx-27HtDlWUtS51a(r-i2_V2qb-t19+lX0{W+{ znH20iunjaeOX?l|LR%U;o~7Kor)DmQR)uDm|10rBdxQJ5x zW1MF`w7p}28HuqfcUx2U&U2Q~WE$3CWO6|~zXF}*NH~fv{5jy;468U}!xf`5qX|Tg zX~&(Ym!}@O&{n3AoB+Z4d zC>FtI>CTDEhRJ{ZsV3?FU&4x?bA#DTxlJZ!_}>Z;8y@xu=V_~g)G(d=xv`N$+**l5 z*EL#OcRC%mQ)C)4qTu1e)jdIcXKTub1q3gbngE|e#j2702m!B1!xP~-9;OxkwtT6-!8^JF`^{_ zsj01!bDd!=N!_}cz)V{7m&l&z2ZEMB-keMud%o(x(f$hKS1!})=vA{FEa5D zN&IW6{8#@Usc|JaVKv4wTlSucpzgz9UxOwIjZOaS%I?J6K&`}^Fa6vM^i|k=S7UYy zvKKcST<~d6<|mpe?)xz0^g{RxpCq%W^n80AN)$vIt0scd{rlyC5DEG=^2%Sn{ZuhD zBu{l^+9_Vu1;Xv|^{O7<4PAw+#8n&Ffqa+l3CR#kF?^x4CUMiBRdh2L0YvvpN20># z$Z{c>n;@3%hZ(rStAI?bE108rXrKq}wFm}mXeM`Y&eru}H}Fj%*(WAjhoAl!<FLTxw4jWC0Fb~HD)@K)Bh~5*`FBcOX1QO!ctWbJFn#WHw*bnl zLqP7TBorb44yUIp5Kk+gY zO#fu1YYjyFeKTp1S&->Bt_-3}r6jYi0Id8Z<0MlT@&H8{dG$)I zLA+W;p+I~VQBtph_ib&`TP_DVrqY-C>qtVmD2DDEjQr~a=KnxA|I<{FKUI(_C&AXn z0PAk+n1G8T+&{!aVcImEm2Eb1imAxszLU74g364sRyp<9l2jQh z`iWBUATBQmg*tjw`8p`+=shk64GB}?8GVW|Y~yk;$S}RNnp;<072Yd5vlJ_t`|;&1 zt5h^kILi^7VoyoH=?Am9rp)9?E4OWeYs4ey*$<)tAI?ia1f4lRdF!!eR4HTI48u&c zd#O^%>k9G(ayw-$%z(P-crK2Ft~E7P?puHmH~TZZDg#`WgJ*tlg1aqHsT~LS74Izr znYh);;p|nyrc>FbPx2(p-$(zj5{O~q73H_Pte=DfB5Ewq*hOkPs`XGfkXp+x{NP`t_6C+OL_$q1_VF(+ooU*(c@|RKt^d zU5i9O%*Gm|VEzn6BPk|%F>xf9epj$>u7k+_x#pc+9gru$uPNS%DQ+>EDYl#jr!B=! z-8|{1uQLrWO6NB)f5w{-sNU=+6jQp%&0EQODt5?UI{j{r0 zGintt{DL9`h}0Laez@<+dC8p&p=Gn@f$j#z| za+A2FM!W)$4j~-q`Mvs`OWbrdW1igHhfa&!@|!L=N$M{fM7%_u{_im%Mw35>LIb$R zn>0Bqwjj<7OAn(HQ^Vi&%hk}DyDo17PD!|rk3+$%rVm`fYZ6P{l=$U={+8(%@?}#w z;V1<$^yBL|57WwthsxVzr5zq`cMe|I#C0`UFApHwdj&8%-5oCXwzTiu;3_>df8z8g zPnLuz`)+-WTdWEjP)v$$$afV;{5eEMS<`xE660bF5^vKKz3ejyb#p3_qtRS17QgDG>~y&Yf*5AX3Xhv6X|8~kYC${ooMEbyG(i9SM0f)=w$xVu=W?}SRcP8f+ zw@j_&qiz6OYF|-N`O!6q`qOcJ^hbt=w);1|{iyXIq)-Up;jgxST?!KiA5m$~OYfpe zhcg8QB)&!4Jc_8h+QaEk4f6Lc!G+w#tbT6B(&hnWgp#T>JtgnGK6Q1!;B3z2pgw*G zco$uyPpKL?Sq{+2woz2$dj`fNNGXixPz)d+;?am3;|L)~_PB^0nBvPB*HkALTZov; zq>x?m%ELD|R_}5IZTYiuL0Wh24UR3leu6Anl0P0%;4gC`*iKG)#X+czhiez2Ig5hR zE5{R}i2~^=4(E-vCreoXI73tD>2w6KiGZn%z)KZ70k?&mG0##CGgE7IWG%vkv-?W0 z;amRLqFK$>BZaB90OfitPYaT13jxA#4tM6`)n{hho3fRE^w;N_*gsW%Ls{E5NA#w9 zPWFryg$%693`NjYU^3!m#J=vPSv)I9uTNAvwi)&U(FOaBcX4oAE_vclw;x|tIAn7yc!Viy&>+Sos17__DizaxJa0u=PpGb zDolkDCIYyIAxiXM1YhIW(g(we*lDrd45?cW$~`YVja4(e)WONURsr+#hhS2VJY1|> zyk2e`&%~5QR@ID@I1JPo8o%8k5wGQySm9&`oeJXJmg;_5|F|Y~uIR}-8_?&P&A7W6 zC+_v*?0G}v$@atZ4+vIx8x>t>c@@;YDP{gaLxjcrhfs^EZ0jZ^*L&-b7td3V?66AZ zXx&il3-vEucMs)z565mXn5l+sK7c(%B2*!YL!h{ys-N-TOL1lBl!m11j;RO9kqOoMU=fSra@wkyVb^KJe%Y~3@ZQTm*JI)s;WXK?u#JKA)x zJcRdy3H$i~kYeCH<> zRiof|sq6E!fw6C<2@^Tjre3~$sg7_nWmk>{$gMac#BS5Lb{k}frpC^!T}xNlPF52z z?nhRoR5+nF@Qj$g!5KdY&DK~wqI!(){6|VkK0ttzAkpqBOTtC&yq_6vh71A#YE=W) zD7qleI67SB{rXU_)ADG>ynyPRSMNRvwsLtBKa%*!7*98B-t2dOY!}fE|2%7MaiXJ~ zEZP&dD&(zM9WJj%Yo66wcst!eW2z+8V0bTU+=cUN$>e@XMPdOp$ravssKOL5g)+j# za)1zOE}kQM8{0GEIT<>pO<^D1L%$xcw(@jGfTo>4&))^LNV8$ZIJ!KdAY;d!X+v+X z@%*x$$x3F&z9aF#LUoSQvEKK^#o1YlIgu)T?d$wK718xio?oPt6?nx3mKJBjDftl< zwU&L?w6#CGneLh`!G_Hgfx0we2j|vj5Z1UoRH=Ly52q446k@V?#PNME#V)4rR#$Oz z{FfDxpz4ajQ6HwonOkjE92o0%aXLo=%h>YRi?|5UT!{y?RPR>FVl@TgbA)FCm70*_ zqPL4R#ed4bE7^GlYO0j$W)FTTbL$ehFz1BR&jE%>dD${gDrn9)8^d&lp}xFkpQvr- z$J(d$L(cnjGsXlBcqDQ=TKQC;3+mnFM_&Y!#YyGg9yO0@oWIT}D)#lk0q>K!nCVb= zIvW+9r+1gdDJ-_SRcfY0?Z(IBT@2QQKGJ`T6P5lE&51QPuAMZqD@mP+>F}xPQub=FMr^t>1SSvKkEf^=+vYfBe!%mIx?uRLWzS8V>A_&iCcUi#|-HX z+Z}TGi>{iAL_D_tc-A0O_C0)X!M`znZ``Zi=C=eAr2&pT49_`*9y$Df%NxX;B=iBg zf2*w6hf0jby0?nf<)q*agf zEMIDWkMD5%gFgUCR$x81H6>*J!O_m0VUA`nD3&~V3v7NSeM|_RQa&C) z0+6UI8eb9qzd4fgH+z5o?`2+{o<4Zh!p(g!s}Pm#>#S)=tEa%h%E~ag$3J5*4FA^w zbqf%TdNTG(zHwW}YSljCu7Sba@*2sf(CVeJy$p)^YH*_IU1c$(f%0}T$aTb&XSx=7 z$O4dNB7kgl+{B-DLEwML85UC6Tl=`gUmzX-=PRwG`f@ZWCu9u=D{BqnKV}zCsVshh z7Nr4Lv~!@;MoHY?9{SVn|J4~4u=})y1skcUH4V*u{gf=gm@J)B1I$1p7U$_P_{n%} zp;C9#5fgmPxkm3VP>Phpe^U*9jBAZyjDo(ze=^UCy*J?6*nXh|ra>Ibd!M{pQ z+|70n1vWWWR_4yfBBBzvu3d9|sd9L8qExA0ukK4thNam`Qv$!ltgGzxi@=cFCK~yN zh{As@a2B`0#kG-kqk|P5N1j!t_kPq4?&Ahw9}ml-{>g8pWWavPpmJ2)b!lSP$mChV z_>bB;kBQrDgUDge^L4U$TyqLPjX@G`cuEe($L${Y3Mj+iX#T#zla( z;^m{u^kdP-olqZLR+L+E^PAga{-fs{N|oVwsc(|sZgoz+wPPBEmCjq26*~FTWncaC z$N1?B>&vYvucQS{Pxe-l&j~ghrZalq_ebR3aCwYHdt??WE6F(iRTSkv_E*;pg`L)X zD#aPi8cYZtl-Z_injRIcsJrorQ-V;2w3($KuirD{jir_q@9k+~0vd~)_I-?vyjUMU zz|qubpzR-rl5$?U!yCB@RQ?Dv!5X)x3tLU?DoyyiyDIcbUw55W2uEHg(|@D)8~kzH zI1%rHWeryl#+)8lW9ApX@9yNjhI~(imGgb#ur;zGB<`B}|x*y_+UO zn%zv&MJb|e1O$Mu@2_2$70~P1a{DM*$h(A{e4lNhoUwPiQ3HQ_QU%-_T(rqDx$hFYWo&wZXpVTrT}@pndE~FG9gnn_^Eqz5 zQBl667p(V|dg{o%I>>zZ`N>Bc76p(d$<9p;W#;`d1J4hd=&*K5!SI9|3$nom1yn$S+VAI1!UO zGg+a8NKI>QwN!qbGRa%$D#7XYKn?EMD;OcXI7S?($R6247y!~3vZ{y94}QDq^PhG} ztp5Hpiq{rnx#h9QAgP!e{p~a36`h`750R7jS=mK~#JX=!lAVh_axU~*Y_y?P%}Siq zeQqNTw$en_0jj*~SY056$GaGpYM#8CjYyoAR&}}jz2ZT85UWMST770rhCH@A7_W~y zr)xKxODM~&f2&DWxj#y3!N|(B-CGC%&ggB-U}ZB-{+5liSf$uFisugRwM7i#6d>N- z{n?F^Q+T@Kt0f-!i%*x0>W9vNq$9jD0*n)19^!;{cP{W^jw3t^EniL8T+$HBdG^&H z=A-J6^LD5G%^b6IHesKI`KqbM11+}6Zg-qjZOuNUU5LPe^T(Dx4e8~ zYCd?o(U!UkL|)^1)b9~or16_l-r$xP9kIxda4EFfI5PHTZ|>K=ef5Wzd32&AbRYcq zAFdoizgk*B=(o&XrFHfOFWfN&_&up&efB*pS&v-y&dJK)dmrEgDHZlZBlW*N@HadE z<*&vx=qD40v)Vo`Qc=lHH=akYiVgLxlX}p+vi@7k>>p%`{aOEdATC>48;oV34@~O4 zP4iUlna3rYljPnen)|1ydyiV};!d==n3Pmh0_kwi4{7t2uZQcs10e6)|D93%+iOkP z@$BiLF+M&n#d9TJy1wP+)wgnYv4gfhmwC1I>t%E;_;`X7_tm1iU62Mu%KKfVfJg72 zv2=MkqqTMHizs(es&F4M(`4_GHjlnwkpau_`6t*ip;M{G6{*_VF3(q02j<)~-`2{i zk@LTU{u6G+b~rnn8)LQ}bUq7w0VYZ;I_8GME|5K`sMIvQvv6&02u0YVdMK{pQaBCZeTGHU(xhgc+BKTFA*VJuh0kH;HO&hDhGt&oYiqY5$ti z^0+P7zc>rGjf(N;vZnOhqk+24?xZDesa%QjRj@e(fWKI=VCre#kXtnKT-L7es=|;+Xw1U>zy6>!&IA>hO%5;^#_9n1n*8oU1MSd~= z^_Kis_1v~ZY{K8*o~Fmw9%)#7+^6v;0$T01L1)ewjcMxA=q^QRr<+~ICzMF%T&nWV zaZJzZey6!*%$w1uf<{D+WKTRDeSw-1j@{PNh_}$Dyu9qAa`N6VQAK!hf78tHH@d9L z3ara(*r_*!>oDcYDAOj)z7A;vY*f_B*5iV~NMn2!DqJC?rOSu81y_C-cYtN>w4%gX zm7@Iog&$%FrI44e4x2*33G6*_0{a?TSBQ)$qYfNot&)%YnLx(D&|9rmLZbn4(1R|s zH<`&f+j^#gxddC6wc8SWb>B-YU*6@Bj6N#&19&UE0a_|+j?K(VB?01iiC`{ugV^!! z&g?Jt!&T>NcG_)Y_KQPebLoxL>16q9?~gWrjRc*9bO4#6ahL;iXfc)*Rg=II@oJHw^U`?L*vpjB+~Jr)QCIxa zgXwxviQzcqF^GasNx*mpGZNdC4#}PS1+uf$WR7N}3azQVCP^pvgvfMyKfnsqvYmGP z3&e&OPVL~nq9l$PitwbP2J2GJ_KQ_se3o`w_O1Yf-J`jWHe9+MY2NL7ujRahJ)`!R zx6JaU!kg4TXGMa~MrFc7=~CHa>7ksqhoGl!Dx4PzP~s6h)C4}Tl`PIJPZjOL(@7?X z95X|?6hxwk&VLwa?#H`pawa88wD1qZSc~u$6P{G(;4WXWny|b)#H{6wYl&O?nn81x z1=JPPac{5fjh9_ist05yArD3~LEyiO-eZD2uVG08xZs0oDqd>kB$UxhRHw4U^94P;7ijfk<>CI)v}3T29i4HLQ;`*0el<6 z?X3Fdtzgnj$JixJ!G=Q3w(es?x#&+e3GsBkr_#VHGn=(K&wGWWEeK9+ot*N`mGDmy zd+AJ1^}I;>;tXSij14(Co5Bu({U?w~fTS%&utJ5#+J*;R*NvQZ3;4NRI+_}ZHf}O| z41$KFHUTM*2pC?tgk_BlVBP|psnrFuo*WS_@A(cG6vEexAZcn(E`u_3^l+<7<%|GC zp@YDPSe@&emR!E<@o{)S81mvkG7B61SqDv1D!!dSR35TI(9KDpC zJpInTkM)IKlBpf+Z@j!wFZ?BEx%Rb!L(vw33X2TB{c}=WLC{ZN&(_Yw%yD%(y=;Xj zPMm8<7PQQK_=Jh*0N^4d!L_4->?4U~F}@i^lCBKS4?8tSm3>0tw`{)E`doUtJA=Hr z9L-4*-w$UZ3?uk5hP0zvYXX>OMd9G_8|605anC>2MN&jo%w4!(9=*(qZ{P5}O42H| zNCr(iI10n8$Cj2HZ+NjkoUEMlNr*|3{Xq4l<{-U2KGy@FrkZEcBM*jd17GW8fD{19 z^!nhi?j*-&71i6+Gu3xX?TM1w)19RAi935{ncPa;TZ;?|Bti*U8WBa`=i2=p1?2Vf z!uXwBr;Y9iM=#%D4=vHY8eWD(4A8DwFB_k!B zYvF8sdX-n=PFYU(A}!vg6la67O3!y9cy@>GRrOYr9Xz)dl7~*%$~q9$#6$0O91Zy2 z1$gOFcW6*NmvHMSs{#~~%FMr3zbg~e4~6dqoVXK8v(~oMPxp~RuPyGSaLjH01IaA z1Eg>}H5>I}C3$Z(MMV2agMbf((vZlvp8hM}v`RjGkcN@D_M$~BjZw%bB|QAK4_99UJbdA6xivTiQ1Qy`Y;dvE3X9MSF}K)&a*81@Ual8tL5)H!9Il1^+S zuLl75_{1pm-YF^{zQJgm{l|d+27LgsFGZVvfy$%c=d{gRO{_=2G?r-U^9y7R#7|GT zK)#Ewfd-96klTGgITh7X$Rzon4*aqFwFpi91>z3?{N)xq$U{a;e|L6ke9Y>- z{snS_CxdkV%F9oF9Q1ognZg;A27oj$#Yok?%6Q7yfDg53;QlG5CrXu(B@KKptKP zZdMK4Z03y6%m+Nx-`1caA@5Y)zpL`R%VeJ-p|q@;OwT5M73-r$DawW)hl!UaJ{f{{cVp$PjG` z9&4Y8OZ#qL`<84s;IS@pa#DIj<2wbRnjRA)0itoL?4i|pCcNkKUt2F7q14yFHI-j#D|Eh;fAg)_a;^y`-Y|3j z+`=;a3$#I8R~`;$hKXU&L*FLy5B74|B1ieC%Xi@o$Q_<`;mE#HEkoXIqvXW-${ad1 zgkRP@MQ6F>&;&ccDQF>BTH273sCoJNXYSmo7Hc+`E0WD~mW!ih#-!lVdH#h^LePr||GAu0ZDf?{Sb1s$;<`S*O({&8({H z0GV7sP{?T}t>KfR{u%V-i0=HGiOgGDKaQ+dhI;j#UbzNzT=8&gbV#k*n;&yss6Q$n z(VdpOj8d;b^c%GuEA(QjYK zdZ}+u7~%eL=Fc5KEI$odI|i75e|M$*#ieEu76tP`DU;-V!b;(4KInPHPA=S8scWcz z#qs`WZIG3dR=FrWEx9lg)VMo?3Hzxjs(os~W$>fEYpH#txp|`vub{#<_)V-d}M-q945S<0H6RsZNq?M~n3xaEe)%k@=w zqyKGnajaotpgy@Zl8bp7ih|F@%(y9pd!P+b4<|6qlNLn=gSEHe&hHT=z*dePvs&$ zQH*LESjyx{PC>AY1YnAbK>lF$`vuD8u8W)ABWSX``PuLbH2U6=M0?pk#F>W@f1k2! znUna1jl#XH6#~c}$T&$khbXbMx)yxkC=S;hPY~btJzjbg{eC>wS}gMTmainiQPY1WGYfz+HKUqz zluwQgfQ-|Y)N8;(N#EKa0A>eNZYLdA-d?ai2aQ{C0J95R=eMa0c}Q=t{U83M!eRLD zQ>h0*1h~cnAdfB+Gm|sM2mkm5+Kz^wQk3j50~22TXJFu1O1VS`cfH_0yOhsK^~g?t zgAW6ntTk}Eo>Dkr3%&uiysf;IbUuPFevd`sBd30Wb_)oh)noSwh060heSrf1zN{pP z$9Cfbc?b>dfRIK4^Wop@Zv2G?3^*+n@ADl#^%Hp@H8pi*SsZr&Ow;At@HKGx?_Nw= z^-gOnHXA;L2LKk97IBft<>d#!SvX`uS;OyBuq@prpqm~lz!UDnH<*Fe&^#x!jyyjX zbkB(!PJef;u1G;nI9~A_yCM&Xep5fS_?QN|vs&v@P#NNn#;Ar)xLyI;?Pu!Al(og1 zb$o0C{M@%6zx_wIsKena%77NMnp`~P{8?`L+wB7m2+gI)L)mSs-<$ev2>)g~K!V`i zLs7~Rn3y9&wO9A@buODzP zvyYYh5BANPyrWFbFRk&HDt9Rr)pbrr-u}a<&b*TEIOoEFqA5qBN+28syYk zO-_Opq_rs^b!vfcG~tKpF{I6JAF>W5%2NHakm{%zQ$4_Y`y$_5G$T}vp^+9ZkM69_5I7+bb1PuUR)dOWI2?7$p9|QBWm!?tpXVBEG zT@dLbig0eZMc}JL%fD~}zt|oFb4!6=a7v@8yoL}d4}tpcL;g#X|H{dKb?3jP;J@ba zze~n{m(l-v1^?$gr7_18v_|`lovn_VpPg9t&bv8uRlh4=)K9W;+;wK_W;@p%{td}A zakZul@%0%gTRsXXd}Qr0;vpsL?EWs%rGwB36M4a=zB}XhKq}&|w5SH#UPc^A+g}>9 z&Vn~`7in!J%l2JNA7j)eK9zv-#s;1cb%9V}Wu}+a7uG%!|2P96CfBk)vsv7Ih26!{ zHtX8-Cz3r62foX%;biu;Id|csy2uU2*n(3P+i%rQ?>`s$pdYiw)w)j*E9lerG<8>o z?{^lEljG)KDrRJedd3wX^Ey}~mVez=+bF@(oHVkIAixcyaQ+Z6jl{l$kl4d`ze!Fs z-pISfM53iYjptjqO5@~=>+X%FdaO=1E^=1hGwC-T$o3zQC||R-mR+l_O8})zZvCo- zRy>v6Ca2ix6PcKfh#S<*ejjp0y4VX;4|E4Os>;Pjeq6t0vlmM_#7>p$%W)i@oqb>vK^{i$2@90IHwnYA3BC3Z^ZRqGWi+bW)`A_ z0=Szk9$phqZlwFW9Ss0hWQWNWkCN*;Ti;KY$x93hu}zTRZKJNtA5a z>3V4EZPO3ECDrmt!+At>r#7r^Hc0a{1UoiIhIX>4CP~*7qi&Izaq93RZuEPh;;DAV zql4VJbVF$K`x!l7^~(RvXhq|Oy~PzI+7#Ow_CM?u9s5xK(c^ldZ7!=6tU|-Uxn0oE z?py3*X}nN_6N<=IU$ok!Wj_z3<;a-lKnE+H6nsTTBKtkVEe^s8GxQcYvYDV{Q( zXhRj_yTW7DsQVsmvsr)*@4*g?vt<#W6Z_r_&nERqgDwBRg8pZjrRa$Czo5d;x4P|m z7~Y2`YtHVM$=I~RZ<#7+_fgcl0#HD3q6MIEU}MZDff{gV1@IgDsIk}I5s;r>wpz9v zdL){t`f^9T(MBq0i2<)?%=VKeuZ#u*@B;<_>&eXYg-{iqh9Y;g$uhA54#3~J0_2p-E`875E<<@89 z=bEI;L^#UONJUU{nLL;jlctO`o|pS-gX|Ggru}2fO^TLU z5R3=F9n%t`F|@tm7fP((>6^|PX^NeH?Cy*-cvh!u9C0poZ@%(YJsxk-%)lCjB7f9q z>p|}q?Z<2iFv&;n^y**m%UwLiUe7#}n$>$czvVTr!Q;*G(K8}m$2Si5`5!r}y2VGQ zVPk32o1*(;1$)Mm@oVkEym?8x!s`UHeCjVAPS7UQT_g!U?@OSf-l*@(d(P8&t|Lov z|K`CLZB>C(Hv*2q(odHp@YzS3qZ}{YlpQUqJK@Th%T3D4?ZZ953r0?{ZdunmB z`FIiQ(O!Nz&5fVQN`%=(E6TlK$7{m;IouoX75iCpI+G&iIe@j zv-w_@xusylXi>KEn`#X=K^@;b9GVLA37}kyh0J>@E4O*av=d8?&iA}0jN-g!HQvD7 zbI6CL{PYyDqv>GTJa_Ca6poYfQg5%osB|!e3#}x$;O=- zB?8x;RKLm9xpqljso47%QIeEFbh(>hW2lz({_c&|F{;Z;WD$RjpZ{*c+y+l?jDi`X-UiW5=!2J-?ZQV(n#J2~L!Wyi36ITIsmnRVAH% zY;dWE_R=ObWs-15^A%UkwyD?t`a(~V5`|P0U!$IGTmWs@3XXoIL6s5g6`&Qf1Ub7* zuT$5uouhjw#29Tk@!qYinRdrjt3jn#Ok<1G0qw8~XWMI#Pk}sfY6;P|@2(p$iI5c< zXPKv;o7*DuCWC#f!_<_+f4u~szEJRM=bKGxw}VYpPP(Y=XPA|A(~6D55uq}JxK|Q< zzJ&_$i6$emlBVXny1z8_)J{b|C9qXo3Anaid(Dvs`{ER1VS;FB3S+rwMJM}RX=7RA z>!F`s*~A+vIh$NQ6*MSv^>mJD&cwnyhZpTk1n7Qt&&Kp{(ED&aQ#WzA`BShJhh@7< zTc^*P(7KE_GOrINJgUev;qBUC_xYl6YB6*1^C7AlEKQotP}06z9#Q?oGWopmrNWfh z!lZ!D?FtuKd?hCZ)z2Ro2&?m`D<>J0?5Z7qx93Uzfw9Dc+TToMiIwf=TIDW*BU>I1 zhw~kN*|$60|5NA0;XB{2r9TKx{eGKB%!n-&JHr00+^1E&Q~uSH{bFT1h$o`h3}H@a z<8q)KwHuTj6TCsEuCh6AwyezDLbY!X`(_S4>R=u}x2Tuu^{^Rs6{Ss~N(-h5W|{fq zosAJ9SIrAQhyi37d9HyOEVCPMzNn!cxW|wWlK2c*l{*_2WXWJW*C%XBltSu^^7Qw{IsR%R6~+#5Cpig$4R?fnO%hv8mXK2uIDx z4C6PC+k_Vj?lybCz-KnRcFz(bQp!ksY!QPIE!AS}*>!#QXh;-}yZ}O@{Gh%4+RMH1 zU^M9MRDm|>-gQ(<5?e^~{h6+c9JNmh=&%zas&TGbWu&xR_=EKlLQWT6;3i#|QX?(S-lgS7XKhD#07Q~uxDV}Lt7DM-E=4(kk{z7B94Y{czePF9|&eIQ*rWJr= za<9G&IkQ>&1HKeXK21FUJgc?T=Tvd%$Fx;{kgX_zI@2#xzm;s3VSGZFvc=|`!1hJI z&S>Ifpy$CnwihYhw*p*m6YzQRNipp|9hGi1_YRM?PAP(cP}ya?O>EF=-`cBEr4XRsS)Cu z=)I_$I7evxOw=ar|HmUl@sYs@E%v?fzPFI>jG4q*jY2CkOB6hknaRi42h_-rors3KxQyZ?j|n)fXswz^u5OPoZ%su*LAaa_Bo2R3mlN`ICVl#=lNZ8;z` z@-4J}R;K-b!KY$+*8hou+>&dq5Zq5}u8$7ID`mkuU2UH#w6@WS`FS~AYpZ?Ir%&QY zAc>NQRWf$@p||y(8+VnHmcIsJsBy-%W9<4J2c;j z%yU{pOD5J^e zX0Yt83<%dR&iw0fQ6YPs7nkAEg&K2$#@lw7S zT_^f5oS7+wMfXJVdBZ)pZ4f~r>Zh=9+Wd`=wd)tuneSLIew=0JqMjjgSB56IPA-aG zYM2!t)l%rQIlN&k{%r(omX&q$v9?_{$F_k}&a9j}Y!>y`tP7p=r&{-6c&rt;$gUC+ zmx*0U>t4OF($RFp--zVOc0^F|o?|%O24-_|S-f+2m)@#scQp z?vjL4?ZHnmP65x=d8{w$^E2ORI(oruo#MNXLpi7JUq_#ww#iI6q>o;M?)*Hm?z`Ic zT%JBB&1&b@kvFcy>?53)k~qTnfmTbzp8x zn5dN{-!vlxi#!NZt5HMjsS&mn@=__fNwzdiB^&mI<-ekyC^Ux3JKJ7zk7vDmFz5Qc zIr*DKU36z)3ce#Yhjl7{p~(6BjX?Go3&c;!FO6_7NtVVE0)0K@R7MuoZ{L|mQuI6J zmTkS|{ncyRFGbx89ZYSm^k`ttY&7zlo4k)7LO+HR&P zGX#pLY~>{PqGq@;G+x&0aAWV3uAJ};z_C>QQx@gXBNrpLK9iXFUZDHtz!#D8gEawx ze6b|T;@Ld%1A-++g0JZm$9P~~I@HV7A0O-Oq6?pTG>iT2-L492Bt||4>1W#cNqy{{ zMU6Kywb{^JK_MFZ&JbF-4{%UItVh|$dpV#6yeMr6p+QOu?Ag{@;PFbO{g?=`DC}|V z!^bLaQ=b^tUJJaX#{a^!r>M+-|GDzZR=U0vnJG)$GvYc+ZRUu^luMn*;-|H|&2z|l z700v#9rc~Md2}Q*IuP?~opgs(?G5ecXhjM~td)T<3IhWai;_uvV zA9YG;&UJfl8ZMAoLhPJ-4pL>nLN&BAYqYN5XsYWy3qCvJH7aY?FHAS%o?L9F zEvFehAzFT)6oq$S34tG0rMn7phyVB*->O`o*?|b>>^_FRZe!oOLc7I1vGPa0wqK%q zE=b(ru7+&M0kF>kYLcS!7cVh`9@7Izkd=sj1PxtiZx3-TGbAzU8?0-%hMvhmeanN~ z^XlpNh8P`6Ra$f@N6$zB(n3wxzkF@<5mB3ZXBU8%o`Gy9F;uoE08vj;aq~3s$3)xg zKBlH?EK${EC0}z z_0kNZj(^O8U$Mt&ww{~@E7z&QIhkiy^f4L9WgRJmVYgz)*ub^b5GQ{mifwUqdQ^_2 zlrd?2vkjs_MEC<1HjJNFHd%SVa@y(u9dY*jwM&d3i|`RB+ChybqbN*HKwNl$xTKDj zTO8bC79d=IMnPkg{;7;<(Yf^c_*btntWMP$h!ky?Q*0i}&)CMx+sLsvGw_vUZ2nh8 z{Fn6oIos>Jb}Bi*6skapg6v|+rd(+4Xygv`CjCTH`XPJ|KWkr8=t?d+Q z3b|Y{V2z=%xmT^wJho9Plcn25M|O_&>bcH#3v<2ha2(1dZVzKBeLdojBvOCWlwd1r z+z-C3XB7-MGWTW)aRW88F<%k6)VR8Y@qf7eN2c5Vna{r%NTeI@K&m+@A?Jxf*ClD8 z=U2?_DjX{9dj%{Qg+W^Fl#V9|P#8(q)7DtcXppqnFuY5MQm;M-?!s43yM}u#_AQ<$ zUvviX3wxaVBuG5whw4+pX>&6S?K(zSTt9YU_8kBxda$S}8$qvt#N((k`U~>cuDuIJruVi!LEV-c7?rlEB~`o2g&94b%rQNj>NSo&Q>QWNAVQ zo1Q)kFo*aB_b~R#ftz+_1pyNeHf0n^Ywv>gGEUc85CYrWb?GVmVO^8)G*%J}@XIvS zehHzFg`V$fU^Z61sQ^7h4Yt+5i|%rNZRD-hme7>!C^Y-23xEsKMZe!akkL_K33hJR z-)D8!xP)VHFG=%f_AFaO)!!$mxNE)c3a}*~$HREiIPYQrHL1416cq#!_==__lkbdeD=YFQrYu@gF>@B_)@Jj( z{lMVCGaOt;;V=hOxeCkuu$fTOHW3N8Jt z-4DGryxLhKXaH?rLO5Qalb>IR9e4G@3Yk}42}H!mCyn~By?1t)7xvNook^)0oq!jm z<{kmoKupcBVBXt*TmAm_Z2ySYj7oo4zGD};8&}y&#ExtqPtID9dU1R(*0?F2gr75x z?6Q_{JJ3G}P52`w{|(ar(O?*W%NslbQHW^THsfm{B5K!y?<9<9xwclO^mQS(eVXd( zc*g2~>s4?3GlL!EMMKO%O%FAKz`A)x;w+_2M*PxL%B5D-jRUy@I&*{`I^mIXPe_M zT`HS@9iqx88pZfV^eunHLuKx9vBIJO>v6Ng3534c@q9R6FaUj+szZAvjonxf!G7v4f#}~+j{$B{h{>DILfu0V$`^x*MX*@>#5fFtJIrnX8bM{;>TNC zf%P%(ZXk>aUSI+zOOIgno3zoFiUMk;TlrP*6fuK-#p~katVzxzHt(r5R<#!`nTftt zl-Dylcj0j_!-VBv1YvuftVXt`rlH5Bi%0s;CPoPjF`f8b>XM<&_5GMJ0adTeu-^AI zFOJW2T6cb&vO1zxq;+lnaV!3!Cw{ZNg7|@mxzq-18Tad}1>qyk$9So?r2(Pm!G=MR zJITZ&S{ppXkTt7Ke^E|u?&*2gmJVj{qJ}nZs$iWt9!!+97uzct~Zc3X9yvxh^y|m4<<=*ws zlv;*VN;~|dMp3X4&FS4@R}|ETXCc3<`j&U$GO{t{^(A4NetVb9)a3Yu=;!rJ`qq2r zGe)Df(>{a)GhcP5y^iF>D_0c1Z=`<05?1YY#rU6pAF6k7Z*R z$;Rtb<;DCZ#11VbxbPb=_pHhU`~vR`s7C-?Ji)#17uV8jZ9r4>-x{N@IJ*|Z=qZ-> zAaH+Rc?D%BjEAU!3HUP9#ID17Z&>=aPE;RuN#LVgrNDODk-5D;aQ9rhJZOI`JCrPfBN z5=|Nem6PbUB}6SKeeny#_kSvb@hMA)XY>U!NN%XA@&WiA2@*DTx~Ty8j}794@qMsL z-pER^jT~?RLG5n9jT@z_18T(&*fc=h026Eg{&uxh@6_QT z)I>Klw5oJR3N_%=3^t=}=s{4p!Hb(ZHVleIMy00dMs`pLc#5<^nuj4Y4H; za^Q-1MKX*ePnSx%#Cskk3HXHhIN3|qUUZLqDMNba;d@XcYf9Pe#_3%(2g|>XNu10H zx_DrGxdPu3VoPWd!$bwQ&&pup~8x0m_?@rOmI zw!}{zu|~}%DF0r8Ow2mh2=peh zcV&;IB+nt~aG}q(!Ad?CRL=OTiopDOPoc6iyYVv- zVpDun)X~1vj`mJhHm2;_9Wr&MYSwR)-9BBpeD-s30!KACY>HA|_kQ;R-ZSg*{@w** zGitJ>m+FosTlNZqNV6{hg1mv$^tVojx-*xh`56qBd)Ygt?DfodCAP~-9DMX@!b=~- zz2_C~ir5pTsh^Ttk2ziI>JW^Sz59+|ey?8I$pe8-Hn=n3jKuLI*ZT5CSMze7!^98H z@3vmcVvBLKGD;GEKYP}VnfSfvMOxqt`&vsu23pjN7mG?Oug-nRGH%r9Wm6le4?>82 z9)S)Iqg)mb;x7+AV=aE$q)Jhy880fa}y zLUXAJQ%jMQ5j&Eq9@?e-^s#At4A%)B_ZOFa~B-6Oexam??S z!~_o$rV8S>OI^A9=E!-79VH%!Q+DkNEeg+5h>K2_lTG{V8gud;Xb?Y@PRadM}jmR#D?*(Y2=Pu`PS9YWIg+ z+ICI`-hKIa`%^0vSAjjjSRCHcD@YcxRM>uh!@loXEg^H$+P7%K6XR`MfvTIo>#m*V zKC8bD5rLT35U=wz?H}#tGzQsD1e{*cFbOUGB;Fs=ZL} z4%`GyX-c6+`@nLb6vb{x#8Le_Wod4Pz?5gEP>O&0zVfI?O~Adz0N-u&Wj1`r8n}62 z1qjB2`%;}jweT_?y3!Jo4PA+w5(lryLCpvmv!+o!lwlC}78_g!Om+g`ckwO&mAM;0 zOspycJs~v}wc&>sy4>h=fIcbqCJ%Dz=m>6hl6jWS83V^pp~0itUIC95JI&qLeiS_a zN4JeARs18RmtPb~5!@T(X9i5NUt>CYhq)>7fjd73*wxIuj`{l=fJz7|;T ztqOa=TwQd)T-m9KbHU(o&!s2?Bz#8<9AWzh8KMpUKn{~!+%RUyXU9z&Tu>*}GeSFb z0&J__shw3nvL-iw{Kk8)=V!Ke=X3|cC*Ne``Q{3#`GwlDE8reXcqK)rH`|QqaRhw= zEis_d+=XxQ!qF2hVrOGhw9LAw^{8gYJ3X=yTB`F8Z-iz&j`aKlXAT!5_8(c3^UQYq zhMcL3o5w-J61K?XDYcRd=ha&t7jsXiJ%kmk6+~Jf0xKDX+0|NG5uC<-u=W^omv%QJ z>@IsIM!?z+`HVn3ThNnsEfdoy7^fXbf{m}SAoBHt|Vdp$Mdsa zdmf@i-f*;RFONh$vgRv$`{Y#lhfQbPYQmNf1`8A3Xw$Z$7YK&i#oPl*lI5nIZywHH zX<}JI*h29GUXn)hOiz>>Kgv_`kEPgrV`^@nZWyEkGBem9z$Iu-I&YX=uU zRfU<3Sig$)r{p>IuTLnsBs`rrJ9yDHR$W#!-#v>vs1LQkJ1a$*3kRAc^IpG47%&++ zw;YfF-q*!z>aU>cWw2?v7XpJpX6UM~8NZ9s#Z{Pedoq3T{tW2L{wGa_&KhM5W5he7 zLd-7idHwYeL$PRwnFP)HpBoOxoI)|NXlBJCB$t;s(3AYz`!-W9LYMHzdhyP zo|uP1cE;B?m5q(D{ITnigl$)O8bTM!kl@YPn z*c)mph~(DXTvYpb@nwAk${gB4_MInuF|d7U%3bSzq5Q3_rSoC_g)*>NC{m>8flc75 z*)BkHhAq3BWO_Ntd^G1T<)=Q+Si9{1h&l|~L3GUa6H|p+iq91d8%{J((3{_C_}Uw* zS}K?g)}}3FxqZ)0a4jCnCS*KST$IaIaD@qT2_R|9wbB13iUkfktmQ$_O2Zo z*Z{iz6yPZg2?fAC+X+8k4opHrJB#j---|D*R-$Is`z;~XFq3sw(dW0ihjAp>r8n$E zQr9C1s&m*tXo!OyUdv8QEdohcQe0MoQVN1CjKlkVrVK^vkpu%ilUM!hV_l~$_XiIE z&snl+UNzJ9_XyZ(P~g^4a6Ha0o12}oAH@PIlZ~M# zk{&3eaT5)4V&@guP#$q!Ozv~;bP0T59K^He?x05`qS%z^AMiJ+Jf+P$BDn8x*zuW! z^Iln-ys3gWy^@mY$IInf1Ivw}Mk&1u5u-z(l!b0C|EgbqG8nnXn~^P3%U*tZ2O6IAfOYRePRjWJPK?W_iyzLbY|o4O3E$(IfEu<2j7Zn__=h% zf(f3ciLYj#2~<5A5B9mq_)nmn%b!U7pK(|7R8H$)yAxxQ!LBzvZTV}|@piom4#Phu zSUjgU6D2kBSxFo~^jz4|CCw1I}jk zv}hbT{QFh!x=mGTt($RAj4KuA#p|cupe*nUdwW2NWfq!p)CxD&;ec5ep19$g+K0S= z?N8YW^KT#vkLQ*}yT}G=X6oFY$)T4OX~dD=`tct;@(iRP*7_JB z1=FC<7jNE;EKy_ptXE~m;d-MoYg>xwWZDN96-;s2^6|*nV-{>+ni!_zgW<9`FY5x8AKhn}Bvz7Wwt$ zO^(aU!rp}6WJ;RB-9Lt2Ua)`tV*RR~H+J3!TQd5uSOz3uDc6vd_#R5TG5v#oEr@^2 z-&;~Yp-uptULjAaYKcME3+V?Z^6o>g>!{YW$rOaklM6`tSwJg4k78;0#kJOI6R@Ei z2FlZ((OavhL_hr8?cJJZmQ!lAP*wqJIx+od^#A5fD`>qX-SCpxxvktSiQl7{zMNW+ zq*}w$427$(Fnt4)W!=7h3B*V$;5gKR1-@Hw7Wz9f@@EpX8bwQGKhcswNE;pKR=z8N zs>FZ!PjMAEAgj=RJMXk@j+hfg-*Sl?iebq3?wn}QZ3fZa;vVICh~7w z?Wa{z_`Z3v0v9aBuSuJ>wKazl`8*xpL>Pj;MFYLMS;P#cNwD5*8_mqV9C()B-gE1C zq8N6yF+{Bx*#@SOUte9-Qtu+2d2aIxt;_sW=xOs!-MPzP(0W}}Jt`J~`6p`o`{n|C z|8>8k=6}Tur^D&tzd23hy`nYbaN@_Nc|zF&-$caIr>A=h!pT8O0u^e8r-Ko$__H^wnnK*@$M*Xa!ThBx>}V%Iodz#wFgRn>+7% z_6!`WHl;4Fg?Ugxm>N2!zis-AVf{NVUYrYU-|2Pepid*J91+cPj!`r5)F#7^ChB8( z6i<^v;z7$)g?$$7ylhQ1;>Pdux760QO@*k-g=p)dqgxkhh)7tY2zcul6h%wYJ4o*{ z*&Q7I8xPAaoT64$hTdojK~IlE!@3IIv{=y`Kf)LLvX_s0PljAxLLl_6;Ly*MJtD(x zY|P>eE_IlW5taDu`){Y8XR0C?WtU?%V7I#-KTD)+e*f{+SJOtbzzGu%#-P^}P9x@0 zx)F_K$cItK~dQ8g*)~jaVcq@4-uVx2I++ z7=Etd4mojF&(T~|Bs=4Yjg<0roN0IB!#6fI9&yGUJ^f?TWg<7tgVUe*))VxSO1-5# zRqkr1KQ0m6hRG%e=&kFx>i2zrGESVQN_7W8RcHPoVz8ywK4VBHRO7kxTyN}cc zI4WG@v2setisxSDU%^(Ixodnf%vjc>HT=2cvCa#79oDF==QJq*XI%#3L|#sFm%e_P zR>{g)YKuXqx_D{M_0KsTe%}hSCNnVkCqH83I-d6*57x!Z@bG^a%11Mpl27#_HAXnRrYx{HOn7C(qcDm z)vf9gedc(n`&)K%zz(yIn|C9S3ENVq;VqKp@GV~CeCI_yU<|_Upy8lQkm`lCNt`^I$;j+pCQRl!*D|fdY(wi_Du`T5cgqX-qpTXnw?UhoJ@<5 zB0u`#RRK%S{G#8(e`NAlvRWpNf-r)5Eti`g41_7{XsrsEJq(w=c)V5)V{)t=u$PYE zZ^6EHrA_pIN>oxkDk68ZZ7&NS>EtfGyKNwkRGoG<)dB3 zfhz^5guM-^TwBpEN2|C!G)Qc3rStboZ@+urbP%l9j(R8ha0|RnND8o@AJZ(6QyW0W zYAVNarp$6B+ltz*ditAq|DHnMi~^Ji6o({=9br9^(R$x$Hc`#&wHKm-)gCC z9PYq3UB&m}sBf>*v)JMJmRk^8rV7hT_wVS85rbVdgJhQk6c&_R1p?h^`h9Grb<;p+Z=n> z7$!%+FGX}eZt!P6uqni}QFRBD;hKgbKVMWTiTi@`RS?+GDsvqp@kMMQUD&bQ>d<$5 z^Eb*&G78~c-x-+=0lxZGd>F{jdwDOq@x}WCaxg8QeU3N&q=#O8-*iW70Ew&AP z9i@;~Gp$g+k07S+E-vs=t2u7&4ta1WPfC_->*mv4Rj%F1Sk8LPQ2yZ~F7FA6g{w!p z1kUuzOt}*oX42kJHxTmK{UYqr8{`fgdDgNa)wrSVfIvr~JlAD*jd2N^lw_*eu9VlB zYY$isCER}n!W6{_J0|DWj<9QIUrE_qm~Qp>0KW|WZvMqskC(}(tm0f$);;JbDuV{e zpZSHH7B&z3>N4mkvm(a&un-WCPPS;ayLJrO;Xq_QS~revFNdb2M$oh4pku5BV`!tC zQZg$~G3&aow=ulO+d&Q;y`0!spr_V0FITfj(GV9?R$IxR48>nxBldRge&q(Gj;x!g z)Z7x*ma}EEi`n2XR%6uqPr9TU@jmpF??-ZWmA!U#DAy$9Z)k0zxY`tSp9p@7$awae zqpL{Pn-u-z;o(~<)}geQp7mrM6DCeAk%>Izgo7I6B0={UCC1dRdW77xVC)zI$M~Es z+H;rC>%W_LM6_II*d-;-=j`}g3Cq*NNo%%FRx*5zWas{Ra0vH~NNhGycDXKt-jg`8 z!LF(TY1!4Fg&XH|%g644r=e^H~W>RQek-cTUXgi1Fj;0*0fh49>RmFPL{53 z`|H%qwO(s(y3N?gF*V@}(Kb z>prQ%*^wvfBn$L*D+dMaXLhu3R*CY+L~sv68}S3GU}aC8$5Spyy6TsV!L75&4?r^f5~93!js=8C8`+Ii$L7KR1e~OFL&K*t7{-@l?VhUF z>%O8pef0+shNB?|-wapuO`xg%h_r0+K1`=Q@)RlCr{K~WcuM9IVbo?%uqrHPkG3k4 z*4rFgzDe-t`GaqVF8H26>Y&BT@cC)}WspQ_p-8H?Uu9AkVgFY6UHnJd<#Ech&PnQR z8FtEiIB;2HczuD>fnH8AlA^6WpZABC^jeXT|7{Nbf37+BQ9M~`xjTa5XJ|5lAjd)| zT-P^RvDl<>fY70&O*`15xJPIICI3zU3C;}J4*XO{aMm-D_<)RoZv*hHSEd1U zbejK**F3XH*~(2e-aE+(Z9T2RzK+k`Z{PYdpO$VtF!>NSrR&KJ5A=R|mqyfVe7RM- zudLPE{Dt?YV@|%WFi8PJJXiSJ7F{;=+uZ;u1b9-JMsrHheplTy+1YC2afkC8RTwcj zogI{iVyF%rET36V8MUECW_lloKG+bUwtA4hr{!&^37^UUdWD~fgB07*q$h3AmvHzI z7C1r#PGT>O@fGlAEr3t1T?Wmt`N0dRzWy{L*jcWss%JBCCUSjyxx43;ilDN+7Y4Wk zAE_28yE2cB#|^vV${>yavT&rD=+jg}*x3fD;b7|Mm}Zy$DT(a3GXGqAcpA=2EY%mWjd15F3iA3JGaQc%*q^0lkMGZ9?uTb2m9jx%g4er3p37*?pda2c#D06q6p@+j}8V?ZMG@6(8Z zIZy!cYB+{m{GP|Rhw)d}Y8@w8Tph{z`bzN?8OoU3*)4paWy0^&v3dwX0;?LKBOH2H z1VS1r6(!&)%pkh!*{-Dg%WH1fKu<|R%i9d``0@?;x4e4ik^m!qA9SJuNN%C;oc@;u z?NV#X7J{e8=qa6OaW%^lQ`pMd59^2tv4Kt`RqwG6!l7x0*VSoaT zLj-jk8jjJ6MSa1%O4Tdwof7-(1s3YYAVr>En2=viiGI3P0^*v@8FkJy$3@3m>NZw} z7L|6>eR-?TWC3TO{tpP2y(+BK8BYV-%LcLWb1RjN7K}?R|B4P1>L23RsA!k|n)=O_ z2F8^!<_XEW3_1@#iBa8Bc_EZRcw$h*G=hX9>-7M6e`T!;(nP@G;ZIYKF5(m@b`)=W ziLq!ugo)Rk{Ss3_HQ{B^W*sx#cPjd2XQQ}OF*OV1Eqe9$Qg|{^R7PMh?t@0}^H^xe z0EGXEY#4Ya{+vasHrY=a8XR2=9)`+oE1(8B?>Wz-L@7dVz_^^7c^|8uW=?bR%(tiS zjVc>?Eq7X!Cg!ZAjSXzFN82|ldpWBk7(3+r=#wy zg|B;?CQHJX;y1X?q*Nxclz84^%F7noRXoIYVv|Bt@Mr%SfAI!TJKlCT{^H)$T9Vy1 zA5X~(hU~K2_9ZNo!Cb2M57QpQoa35fsOhr7D)?bnGdAA@`xWcsy@kkO9VEsZxlqFD zWT#Yfw`CK8%L^fzD>KZcw{H8+Qk{r5>#eeEvN4+3nBXx}GX6k$lcQDZyvgvc?Ubol z=&6n}_bBf0vJ>}z;on~dCRYXvV6z};55osGv6^>??}e}O2#eX-*;hwx!m@+oyp%x` zPrGK+%!Z^!@d{}aucZ!ov3y?Q@mNlF^D}!+P9M&_)m8ldlPJRBMtrV4NGwKE_xx7m zOxDW4OaLqeIGY>)*43ZzFca{A4Ov078tmc}P)rli#S^a2&$PgN-3EpdC7rJ9_1pIl zXj%x5!F#}=y|{9=SbK!4X{Es_h3!pO9dS1e6yrCiZ{tXqNvWqF*>83AFsIG-#c8T zn;7V3KTO4wbyiL7AFTxb&8wW`W?0!{2{F$M(q>iCI5#+wc}ws|D_T z2hTeB;0xrF+`B3}_h;R|k+a-LRRY~StSFHA|7uuFe$0py=5+Sa7LdOMrUR+dq8UfylSn33E~Gr@-dE zzqAX^6$OjsH@hrBnZ^pAw|ni&K|-qpdNHPR8BzP4QD>Ra$^a;RZqfS%;G}pdkY@MS zFVg3sf}U=jvhx5Et^sBhKx&7Ym;)!TbtteVo6-ld}!le&;_lypkxb;om&$IdTz{@he>S#CToyy%a#Rb)DnL4r5GYyNMhOYE8)frBVI6dk>TW?``Ub&a|N;CEf z%{nnV+qdli*G6vD(b-hcN^eEG3SvMGTZXC+Rdj8Z`)N$VNdLFZk92k_yo}9N%G<4Jz@=V$k-3buD;`vlKuR)s6)7k z=!ciEQE3ll@cf?2eLvQAhaDlK`6tt!L%#eBUEHpW%69OBB%BxzR! z)Fj;S-E@&u-t}z5rxv5DA_@iRH0#kq&PzlW-lUMH-i5wV+ zoSpviG`Xg)@R)p2OQDDG7%deY46(J6G*gLpr(jE|vYJ%cGSpq~tz>)K@Xb+0tsVBe z1_G+>NQK>wzhF?Gel z$}~;5Fcx?qtk>xE;3De2r)5mh-Gl_qz~^q@?eOqi6kE+n;N9=W{UZx;I0;CsXwRzbEWW>sZJ(kC-(+uG-}2 zn1fY&cc;R3Fy#A{wVIp5+#U|^vWv2HvAu+qaNC_DLzX1^I58wj-I=k!xoc4o^S11^ z!g!BucfM!cm&wYp=?HM1a`7BUw(wN{%BoauBGpQj-(5&}a3iDZZJ3R{b6dO3L1jB9 z7Q@GL&yqX0Y1~Tvest@^XZumU$lDk*l5VlVZjh(6nfzSuPT%y!j11R|CwDrJNPN!> z`99R8Um{knTptpJsuKw)*;wxJy^zq$PEI-@-I2|c%oBTRyjn~Z_GzWD**`bf7Xw|ca_~~ULdM$@n zbYGk7)AHP%hlGqPj-vJ+_ex-_FWM~W1I|EnA@r`V!I$!BwZ3+?kKxVED6)7^l$mu@ zEM0XmZmMDYp=vO~XtR9VcVtFa4d%O?iPgKwz;UN92SVlVk6Fb9#|$4^_Z$;rarmrd zo>%le-=5B0zFQwTo)1D?5G=kJgs{gE!7jAL;XLrZad2f0*0*TeBKn@>1eA#(Oz)FA zQyt$K0EbF#JW!lwm3Ewb0VK7^<7h5m|JwA%f>96VKvIINATlSA*ha9w#L`rNKfYlJ zG4_vk>f}?C^+JnT;(Vo{5PrcplrF#^M8y=|+h!j8!tQe>GdHDM=|(M?40$l(!7(Y zV;}!y#vtVOOODCaR0wi*5XMj7vz&zQ8{W0Ag6|seq3^wq>qL&fRJ*=)3Gr=^77W+@ zD%bl@m=3~s97~9Vws??JfT9Hw@D>A)r+HW zonx@8>byE=`kA7=obipr%v;q5Eqm7ZC4>`dMkX}^Ijas2!RHx1uIBUVxBW9dPA90# z-tmf%EH{uBRPRS`ywUIypCz-KHfFRv5L<-mNPWMGa-jPR0!4EIfqA9hQioT`{=Ytw z-*1)d1GH>KZT!(*Mdr)JEd%ZvcM}X zkIM36A7|GXo3AEKIckc(0X15H1AVJzmZR77O8f6HGm)$-DAv%tXnO8=*yv;X>QXkG zB=Ah=r7p-2R_SJ@b*0e|;QYitV;1nZ>2JoaBH3>t`fzWqs=#DU@XGIY=YL{r^*D{i)7`G)kw96}p^7XBl-bgtf)XZgKOmo2K=wJJgfcxhgx*U}h!jPxXwJ zE3eO_f)MpW6K9n>`T<1P+AzE2dSgn8&5j(mu{>un-=Os>h;U$Q#3eSbN%Z@jsUQVtL~_j;ftdNO;d zJH%v!=d+%Jf^qwH)BV-eh7vVFI&=O>EVte~H(s+<ZbTQa;E^9~-~(`SMIgxaZpGzb|*X03$xEFT=5-umExT#Qh>d3TYL5m6G0jF4mJ^ zaA4HygP~(&Z)48pj^~INwW|~jVmiEJk@h0+$|g>mS3Mo&%Ew-vKU-9C()T2nz${x~ z8Znm|g=43IfW*vm>?bv$*fIL{rilZz))vls{m9~hqSurr)_QPwVLK07U7LJ41f$xy zlA1;~LG$eFpMKHPe7DFZzI8L5>xs(96X&#pXI`R|#W14dR+7&!H((DAW7Hgl_4P@J z@il4sFF#3-iyRC$+`or4{vy7eAaC*heuiCy5&((305boZ!{n0 zxZ#|0kECT+0e|ekJ5^w;mHhgyt6|~6oz{Ji7NLFuF)*Vv?=nPG))x7pKEPK)QJPUp zB4zo@`p)E;U-&rJxL$j!SK=zeSr^$4Um)v^o&E4p{$l&2s?AqtJv!7c!hN4@eiv_% zZ=qF@AG^Q7%fqjn!u2QBq~tzSbNavZivvt1H6WW|_`qS156Y#X2S6{S00)y{2ncD` z09kQPMY=Yp1wGFgIN8M+)pJ!Wma{Zri&Sp_4FNZDZw4@IwR$>cD>zo9{mA5Bk;e^o z4Ij5778k$~;_ID^*9?iMc^tR}VR_u*TMYJ$*lMne;Jn^tF5+{n(}88HG<2 zA-bk9_w)G0y=keacQ_3b^c_r}!Rtk^)FujgDt-DU=fFIN%rdE|dMXV4lv}c#XF>?2 z`eo}cT?8>bgO2B@KGjrICpT?HMg#koR%U;3hOTkTglOLW*9U{PZWi^@**t1$b5jsU z^%(#PFJHkjyJ-1KbPq8QJFBNO&W%nkP$K9{FrtBE0tf0w5lC6^hfQYE<=ubW)p|^F z1+Ei~86@a0fdF|h^v7)qrvo6lKOUmZaM#&qo-+S?)z@bJUjzOyp8uJ}_tE!no5k)I zi6|j9ey*dlm$dt=3*O`{nLfY0>cYFhkpl_?t+=81sdavIRuBpjgtRh}Vg7y6$-#7$ z^Z2wckPRauTyI+v;*FawcnwT6ZjpbjdF$NaCYJfzW0w-9)a&*f+A7S+TjiWwX~svg zXg^cN?<*BJMd;`dNXmczZf`Fp!sbS}p7;kX!ab)xi_N^yyHt0)zCf*`#nYaHg>P14 z`J(seS-7mNdJ?=N`;+EAsd44=_t}bz56H882Ani3gC0f1lG?f99o}PyJbRZPWvtuEUJAK!!M=(&5X-oD=!!iyu$BC5WH~ zFXCwSz_SA+!)KWxvOmahmh(NT|JXHeG9L`sBwhZR(aoRz3@Pk@rD?({T#rc%2&$fD8&&)tWAd zaia)+r$I|t4S;V98yxmb-NUXfC30S$(CDO8_MyO5e7IbmH--;LBIReCuH#R?y`+n;!w=Ulq^){nPQ z8&Qt#^=9qpSb19x(rydGt{k$Jcs~`;jVZ;NKO_59V0fuot=wK7>y2teWBc7+ z@BdWY<>l-T>kw0HPIA>Rne=&e7y?T9qQ_`tRc1?pX8{Ho+%C_JX2t$VhKlJaJ zJL`(cg^^*69!7wl1#t{1xol4#H34h`cz0%dF^l{17yCdbeW!MG|Gqi-3pxb#4a^9~ zE6kTa9>cT&`5SWX-?u%WsNi=^I~Bv%I`+!f19t4LaoZIXv2b!#_D}da8&ruc9``Cb)|RXIa?6`8GvCufTWP5XE6Ch2sfWOa@Q3;^k}H#+8@G3|iz;s5X>b*UEb;{m{FzJ!Wcc2?wXZ*_2vRKY zuP~%EHI)&M4Fz!DN<)j>+?5j-yHqowX8RtEom;|wW>6%7gCjTRv!HCigGZAqYZD2>}fBpI4Z2v(j!a`^@Sg5jzl-`*sfPm1S|vYM=rQcI2TD=y%LE4%9s*Mlbn!Lw+E-^5PqM>*Q5T09_+G}EGqlJ_R@4^MM-~7P1x(-q z8AS!?3Lq!a62p;-c&c{YGd)kTOQBgb&wF2&z@9ILo?vh`XVJ#gT zx9$(z0X%PgV95btE-5VaBK}2NmbmfEwyP1qAk`w4El66AZVN+YiBL1?>zpU(2ZmQoTY_tOevNYndM!OByn2H&T2i;=cwGfQ}sQ+U6b$1IO_D}=ViXt-E33T zxVKxnQKtLVfJ9i4#p=Uye79mhXG07~${h89@f#dDZrG&Pg(c?+tF^LbnjBSoCY55p z;K)|EvD`zd0m=K+#5uGqcY9k!=QS~@qmyDm8x9bHSE_^`?c1*9nfYdILO84a9EMW` zcClAj7e5y(Q=62L&e{c646BYE7B4c~*eyRM5bJ9pTj1?m!3NP*UqlqUPx0 ztywx$sAOCnp?U0dmDscPagCbXoZP+dSN9>;d-`6|KJbC3pfnyHhA(-H&%w;+9nu!rcj|Zwtx^EFK*+KJoNj zECJ)6aazLg+Vh--6-PJqC2ndmAa!%Ku?w$O!6uCw85F%w#0^DXuVQ<;Drain`YmE9 zrvk!P*Ku!VA6DN27~nz|B;xo1RfIt_{ob*$QMn^cHZ#J4r5c$Q&hn$q>4}!Bcc0$0 zI_Vz70tm3oq$@#)YE?w^0ubB)$OJtTn3oLI2x}&}B-8))&Z}ckhwHZpt8ep+=^vZR z|H;Faz=?QE>7=95VUx{3zW53`mkJ1*P`DR>Ap%5gxHQI!_Jbq;sQ%Inj;P&^1UpAq zTm1Ym=4%@h5sanWs&i-``$E} zoIb|`s@8q^i=6~{oF!$DG>5d=pqaQwBsN0_T|9Lj)VL*d{w4i7y^s)q8qWg$rW2zc zr_|mv6s&O$A%v=*O{Oi{l0hJhHbhrD$jwmc3zTFn^3`N~?XgdD6Lm*-ITS^zJRfF%{e2N=o}yS46d+TQ=HMLS~byj zfu6N&#dLWl77CDl!(^O{cPPLZ6_@i^1*6zH|sS2jR@ z63ALJc_0&3{%yPexn6xtC9CQ&qSVtj3Lg;aG{srOvR7X>hZ)W0_sOFA96;hlhwXqQ z_LP z{{o5917gVVB8K)>Xs~3Ymqr5-Vun^rWNeVH!qAMRxyDe3uQ4|9yVgY5c(%pV`uuxuX4_al>s-xT$Z!jDjfMUnHj{@MdpM2Wu&D?|zRZH*6Z_cZlM33@!9a zllgToD<26wA8OhH%MlN_WRvdXatyu_xlm`F{x>M<*f;&~3%7zm-?s5H^v|dvUMEl3 zm$sjr#4CX-|Ndhf2})Z9N13qYg1Zl5dSQ#J6V|1E^HTqvpHlb8es{Cc8x#^PMMU2> zu)cFAw4W;I{#-GR>n^TrE8r+Kf`vK)O3vhx4V1}K0Z?vQ?B&!qQ4#fel{YKcR8v-t z9+N7iULio&BNa#?(Fw!zXf!4IeEyQ~W32iLGeZftyF?x@2vx4adiNy$3#2?7ztlaR)*3n_XMWz`Q=aRD9Wc8ZA$B~}~iX=uIM6Id@<&N*f>9=UHtoS#zqHIvw zF}QI)*O%Uw!FRS6D@%61f4S=M41JCdw%XaxsEATZ3}%}e?+4U5+I&y*N!UJ7PEMKlD$nH zz75aKWkuhT^Ulu@o8dV4s8p&e^<@xCaH6ew+j3L~9jUUHtXvk2+XO#d_)eym<~lKW ztht4Fsn+V{+d~y7*Y=fhI%<)lC_#lMXG%e~es-a`R-UzPr-M)BF86+!eLE}M4{i5U zJseO@>Mp~22sce+m{iHPy;goRQq$`cV*`lVr^nW|0zJ|MC*174 zTW8L@Cm{XS&fq#rWzzoXGNl$2R9a(aQ*O=x-L+{_po!#4iQ z6-J8Oc_-YqtcA?QNR`sXDY0eggnBOI0aD@OQ1d9katG$}26?-Z=VQUfUbWtn%>tw6 zU5g^f%J>o^>-_!MZbT`AM;oT~<9Nfb1}|1|Z0%Ybj~^cWkih*^@;SSx{0~0M-~2Gk zzL-D$2S+_aecT77@Y8*)JwX;T@RZ!Y3?s?XRyVMvGx#D|FI!`#*2dFU0@v|cETgwr zU6{sj-zG=p)w6m^gL*MK7ehw&`^T*@Tl3&R@lLTTrUFh^wt9cuimj)wAR&BJILOsy zg7CXzn$P8&ZFOTTEL@V`J({#F)m53wtd{qIaC{sqgUESG4UJ4$%P5}G4~ePEXnUI; zRJ5getFBjYUshv}S!M-PNH;YgYa|q84bTvpx1G5s%jE1`MSrU5D)g!jpPgdY4IaE(wdcTAnv_!} zPN(U)Q@HMd^AE#yf_21Z#R#4%)rZM^KJn{S&M4y!zp1r-@!%41-HIA6HJ?lGfp+jQ z6=mS@6@X2?F){OgVVO?mF1O2mzT-5fyyV^tHp5Fd_f{$lA=(Au)P+M}tB0jZt_Y>> z5D^AW^dRW`XTa4E$_38lTHH%x29smKQgmbj*rr(n81!XJ#B@J9Z3#h2SPLUPaK26V z+JnMf3T1dD9P<{Cy-#`5X=uQR8J2{qe*m2yN0oR2A6>~QJZ^DyF<{gPAWe3Nnlw%p zZ8q&}@6nWhLFl?`xT+VXFbR~he-uor1a+-jtpta6@5-5LJVGGg2F@D5P%UdAGeG+z zT0OK>>Buds^=UFs-feyNal69DbR0ub&iC*(W#y78LTWA6h}U(6o0~-RqqA*+66$}g z8Zf*hX_B_;0*LLEo6$cuVFo{OFU?^9*{4<{(5*U&!U5;{T1X<;YHRzUhD6zEb5h~W zDtdS_q1!U#Dc3ru&&rkadbd%Gl@4o`PnV{=H)5;XvU>d86m}`O!&8Zyp<>z$y_JSz z_)|%SBB*adF%%B^qqoou-I%x}3J81RC1H~t0HO{7jy|8nQp^mGH+>0!-hXsgAwy4w zGMWTSIigl0*e)@TIa+4?6ApX-=Ycm+?vWxOt3-6j|YJJ!bgFML!{m0F? zWk;xgftrf>%e{ zm}Teow@%@8W9l#{W8X6k;*8tnMdG9-XWNqoE!)H@oO@I?ofz+ydQ5hTNaE%kS#`dI zI||hwvj`lp*~ygn(q|=~g|)b6b*SdQ5smyX0;l-r!-xwY26otwb{bW!Kp&b*i~rLS z{6CpS`qGhLED?3F^ael6tRM?#`s`0*DZ42PL(#K^kKH7Wh1vC^=ZhY&G%%WWfGt1e zYWKQutCJU84~VrW-Dw!h+~RanM=?3-9n?%p30umCopMh*nd5(Q=9HipYTZq#HEq}^ zrOdZc{3l#Ps;A`!$xiWUA1`_>3S5|shQ6dn=3Y#t7?$6P3F4Hvy<3jrYDl=2uVS@9 zS%6pDRt8n_4ktuO{169es!JjW^akZ1Q%=iUzOX%F;&GXz98-5=)3cGbsBVOY)v=82 zcBS})?m1n%m-BX?_H5o|ASY#j zZi?|3LsbAKlv*ofJ09b+D6Ql^A54$rB&_U!$R=IIF-!;t)Ya{>IB^WA?w8~huS z>di3cSg=PX#nu5h_WJRqwLz1@$WI_Q`^VB@Pqs8dP$f2^KjSMu!$W0IgUL<#OiM4z z|C`BNyGs#`fozrrqgcVwgfvpJfn*#CT~ z<3;GU%+PdAG1eSD?|3?mCO#jkFhgF=X4uLYa0uxdo4QhrO78pSWY?8*JXWFEsSC~g zlxy-M)AC;_qCrpZ7;a|w;*{CMf*5m(pDuCv;B5Qx>BN zcNa|J;naI`sM(%*5H?Rr;F=zb!c3QgztC<6CJ}(bZGUx2otOSyz1=+=OLtDDP3UJR<{fWB>N5(6>f_ff%F^6ltD&<|#${*Sn-@)b=_b#+8AL zbOkp_naGKBg5n;VlR`Cq10U3$U-nQKneky;#dF^ct|rSV*phV2SVv+0NuHCE^L0%X zO;tiOLL8a61ha0-uP;eAKJrAPYpBF=!BgVXqw~$U$s15V5u6s~Ruh)aLsOjJ3G{Ht zsMYi#bRbnAowP*M4p8A=-+3javqxKS(DRw7{hCD|R@6qsb!S_U#RSsx0Z_F;)A_N~ zy$7!Y9@8Zi`N2W+j`;}naB)f8m0dGMYhRu<3&)Z3lIC;~3xxRjkBQ&f*iaIksI~?D z2j3@fA<7M-5=&$J*C&JgCLSXtgCz9amH3G>wBx7}SpKfu0D#FbX+ez!8is&v-?y}U z$rY#{Z3w*#c7k!BLdbwOnh^NX4m6dlztnf?+xyZson#jxTI-&|TTgA%&yO#G6=gST zT3Yu@Oq`hr#fwCG6HyDi=r53NeHy3ri{%wK9~KGG{`U6_^+;F?wz%KlruNNTcrO9; zAkxm-N}j!JCp`y>hvIm#-%LpAAy9SogwpfDO!YKibZ7M`i_m*mGy;0Z_t{AcaoA5#L%e!<-o`kN9RGe`6GuS+zxBNKr$gRr>E8siI3> z`ClMEY`|VHml~YUy#$6D1gN2WA5)BN`DfKQ5L<1GpyGhLYF26*i0pS<55G7TmyKrGoJSalZ^k?cL)!SDGndBv$u}#Pz$2 z4j|QhkVOVXq4g%?4*+h;s0f~~88ULM;o_;!EMuJR8X3njsW&JnrJyx34o;BS;sMOc zs-iU@P^SXhkGf>1mTg&!s?WzP9@j4fI9eKoeqe1)8>WvF1u_@Gq%g#mH7W;Yxcdq~ z+`z3GoP7l$3)4iWCv%f`ZhlHV|B7BNS!s4ObV2nrEmm#oi7 zfIa}RHt+FJ*G2A9QaKFdk-ktLx)A8K?*Jr#ip5smQxl>dd;d_i;t8}f?#kXwE%Drq z;Rnw6o@UpWLHiJPHGd#l5{a*j`on>SA~*J(6sP|FPotNoo9vKe4>twd2n;%GF-%j# z#*VW;3+_}~w7Jn#sP)`s5qC{Jgu|0r?LIyEv62AE7ozu>;xrk*!)$-YeI5J>lbxYa zftun5aYXb-7{Gyj7iEKFXcWVyR!7r|K!gZjx0x=#;llt%jFA!r0FxNv9RI?QE#HGG zM@{Xvg3XqvCH(k3Wu-Pab>26Pr`d%rW79lReu_tYgS)YBW1=7FljmOGrg(pyq=0JV zcYxlO&;SgP2wYic3MMQgKaW|2@_Q&;flhSiyj8>o5LAYx!xHo$tC9E##RjZYuk8Zw zb+vHRP(=dQ+i&*DqzCu2U}OTP>mNA94}f$2cVGKU=mNJ@{2Z81)zh+AzF_|^`P~uBhBX~!M+%pNhf~l>|0oqJ2|18>$sFCGQ%<$ zFL3@k)tyvtn7?ox-*WYC`#XbhJAqJ%+-?aHI^jzA0dtoZ7w&mx?mYnE(vfI@%MoSS zhkBG6EJ$9d)ybYfz)*8(Ps?fVOHtQ`4Z*SxDHzEHRTi$#^KwHTtjqmDBS!gM%g&_0D7HJ$kE`wt;7cNuD)q?c26pth}tk> zrbr46Fsen%%lk^k)e7^*;dh8p$8?s`LfBnjF2TA?^hstt)!NsZ&%9N9O`qA|f9uL^ z{XJYa6*II=;7^^o@w4$>Y)_r%Gp+J-(A%rr09kU<<+=}>Zl=Bu0pg=UEY()9wF;&% zvqCa|X*^FAJ}-O!^zQMRHTKrhe&GZ~QFD<#i_T3fVM)&K2~Lq@Wi1jc=t}eMt*tVH z@4BySN&Rr-m3mF^NYRs>MXUgRKqm9k^DIFg#t!Q=Qhlk^6IfX_QJ>#B(>c`GS9IM5 zd25?>t3UXz*>pMR{s|pwX6>aB#M%NQ6`OSD`Vx=5^wJ2U_5PRW^#Ow`1X>~8VsY(w zZ4-NSkxo1MGyQq}DRihB8->42aI&;oRa>9z`!um+s*bhroxaSmyj8&;;#^}E*-PjG zWO_78x<`^6Jx6Ou;!sw{IM>=u1t;7#=qepH<@b?C@C`L|B?nNo!MS2n_}ho^V|GR^>MO(U zRdvUoY05;N@0BCqHaOd#D=3;8L&B6z-6KAA57yu5f8$n}q^EGgrhqdHk^$89cO zP?B08`ynVCD`0CEw|ag^TjyQORsbNOob08wxA8SfZPJdqQ8BH-MS9P#8bo}md^GI| z&wz)YP(I22(pjIbX(~i7NC(MBM%WOf3Fz8TQYpQJANpxmLl(t_e50aCNDATawk=mP zn14y|Gydd-bpDy<)@s@#+(z!^6PzT!rgPn_F|q$`3OMz8bTPpx8CZ)j`FcjZ|LK$+FPve`+w-&*x&>?$$p=rTas@P ztfEV{FLliy;Hdu0e-osZuSzm?Z|M!KPIAEnBG05MdS9mtQS#mppzB7RPDM%3)ogAO zc0P#HU$w%$B~AyOMiuSVr)VN|2%+oU-;X0m;U?O!=>F?d$U@xi@mh>>KHe36kSy0b z@QBQRrYSr7J-xO?zS6YmUI9*NhfT?Tp$LH+8kdO@-f<9x0(*S@3d-CZuEez;8<}pr z?S1who&KJ~dBx#!^_$IZttP%u7p<91oAigWKJzwebu0#2*lWT!xAm_u`|x%P`+2t8 zy|Y~|!*?Es8j#VYY5II-?bdCI&DZkhV%BydjPZ|74(`oY{4{jSSuOd^)+v7IngV@- zTi`jmsB_+I?d~Rnhvzjcg`XW8oLP`B*aB~s?cx-Om^WG*&(=AgK3aJ<>!t~+@(BjDcs0A>@4fLSCB}@*+Q22^ zH25|OM$w(cJVvGEOwy|oxIX<<(r1sp7OnB?XT!>&7r(vkw@;Q@=RYI}GQkD!2&MCY z6oSX@#W>niQ_fd!++$=PyO&RntMjEWJ>xH$YRePGWQF}DM*&^azE(9maAGIjUc@5) zuw1`;W?nr%>SHz#SQDj|JN(us_8z;#VHB-w?#HN9Ra!a3X7~$)jQ^3xkYYJiImclO z{`UkRLV6fSo3a+p`jh@Q*sxybU# zVA+e!EDPi2Lit;6&{!Q-=el-z8T+rg$@Q2_0nsoczvI45%@vAo)YTDhwd-793idq) z`x}J|UMG^RWs8LtJ%V{N#N+2P9S0K@QNcECyYmP#+@-G9^NaG_c%M9nxpcpe^W=IS z_u+C3WwH#%Ti{6cI~!2~S4#bOEDGW!FuG1yEL2jRWl@Wkrl@}n|6dL&$YKQ?D}C zpqg@DhkD}I&jj|y<6oaQ^xG#3b1rE1Z|zo2NtQK<6yq*{gjiDl1&E3D{r{;I^L9KyrV9Tu8H}t?*QhDljI* zm%D|ov=S5veEM1Zr90~7H4=jT*2^mwUyhu#EOFYp<>Qpzy^}$j9yPa24O^S~;L6_v&Qjz`@mb`W#PB9Nwz@@b0S6umWT9O1xS$X`%z5lJ1#_ zv=->6QPN8CL(&wAlCO`{^W|Nbc*8As3&`@uo)PU2=D^gL-ed1p-ecwX_e^qw*f z<-E12?5A~l*Cst5HQHZzO?F0k$zAJnc#%!kOL-Sztv)deGWubGCE3JfO-;0niPbOo2>eqe(Q7B)d4BS-G;SZQBt+k(>fz)4 zj6!{)2M04|+JK8J)YeJe9EuyYvL~Y#YawSag)QhkHWIk%`r8G3II?ojJO?`4uF>~S z1iJL9S|7{g+yw-)Olp%9p7MBS=Mcl>DE;kC77wpO&3fpzCtHt;zqQ z4nB_kidMVWt#xUVmv0V7trad5Ea2#U%%9JdOyp*#JT2-ZR2zSkos~JGVine-@wTH+ zymdXRu08wxymQoy&AB(!#&0zz6`-xWbRy3^i}0*rhs3q34y|zgN!&CG^*3cvA#k^DanH2f%U~xV(G=VF^k(q@6qCbgEs6M!Wd248HIKR?w@koUR$0W z718VL_CIkF_UVm0!*Tge$W5sb3a<;{ToK<;?hUiYatnwqV1tmBBy2!|O>bv^#?_<6 z6XwwIExsY*O7&a|J8)rC*+x~hFA%v=!C$udB5~Rr{Vb20u~v+trU@K$Dd1YM2jRC2 z86zb))eB^Zj|PFMJcGSl&@ZswFB`IWd>b;VYPynQ=o_8lTna#{g2%;N_I`m>+tC`w z)rH3al^K)BE0c&R#Bm0`?N5bAH?he6VIEL_oP3Ll8Bz`@*7Ve%1BkuCDvDa0G*g@_ zKj~8*Xf&XzKP+bgSpY)4Kkq%^Ws`Hz4@U@u5C`74RjWeS5|XjOVD zE&j4EASzeDM#~90OS7fcI?ndR$+oDN*ht?VKAEy~U4TYW>05%IXKJahG@XCib$?MB zex}4A9fG??d4)T_N95}AnR7HsU+;hQW!&q*5z$^9SkV-~tYWd@*7Nk;N^L&I$9#C? zA1-T-F8(*`#>vQ1W2Tdeqr9V}X-FLN*<<{QTg85A-A@9pe>Kkk%2JqI^#$P*U@pRu z9-Z5mr15^Nsn>=0!XvVc+0y=Evio(y?6JC?gQRN$IZu9@V4^xcxVN1)!r!?D4pTaS1@RN*6Bz8(@OWt|bEc&g6x#n*IIr z;X$jbxjcj7ONt8757lPL^M{Wl>uxZKRUu&qUc zU@CUg1J7{=jwYV-JxYxgm|mK>PKDV()8*rHPKQ~ja%N`scS1xtm2SZsX*BKxd{Bq8 zisPsbt5}iGO%1WMnk5tL=HYr*Gcev(jlu`?S=IlupaILi+CJbFba!uy4c~b5lP) z^_xb^-jy>lAiaK(bSfih1NXBPKyCai;M=P75g&}K46~ui~d_Culh?n z4tFlC=z)fpG)TX-yUp?ia+VhRm7t38<=R1DBLW3+6~ukikioJ#@=r#=T-M(k{~uI* zS+0hu{2BViMf*uq*W&?AAGtfkVI%SHQSZ%iP5iu2eLS)>WIeF(lc0c1*7#p}@rh~B z0&07kpKqwTW4i~Ik3!5_xBbN3vyypq0RBLRvHexy{--(O=i2?IRmIq<=hP(07{B<^ z4ln7qmrHO{V49-D%ybM=!SLKZ0D>|mMxJ}JFk75m96HsnMr;=Ve0zWt!2y*C$1vH3 zDH{k!ehc3NR{ft__V*+E?^(t*v;CyS(fsMnmrHf|V8HL%1q8AID;%>CH!=mgmIb*O zE2+(Ff&XHH{OJilb?P!3qZk<$T%a~RlyoJczvy0kH|#ic&id6)Z7jAhP zlidQCAZ4FYG)lt(xBOhwd2ivuqQ`H7>!*A-?xo9~iL)*QrU@;lylb6ac~0djj+l4n#Q;f4a`d2>AISWM~I>sUTl8 zscmGUq(yUIJ!IYHP}YT!qv#jb#VAYI!j9PoaT~!Z%)l|Dz9pHumCg$g*TDSU!I2PX6 zlE6Q!`@dC0ee>n^60ouZ?RCNDPT|rZEfoMV8I3B0k$#7LC-CsW?xF@ipm4wh%E1wr z@e&RP%9`;cZwj1_1PFhyu{D4)T{$|O{P_`ZWkUTbiW!Uxk$=wfhy_%6@e8ETES8Di z$8+D93J>1=;&r`yRlzWDxXp)kq%pt8b{IFs4ItSWVDp&}koSArIEv)KYqG|atBYhT zm&-x%cFx?X$tQ1y*w4)H*2;|tj*jbh8+%hc)Ly)Kd0vL^)o`NkEt7i9u0gbe>^fyy ztX^~Jd%O4c@gHx{tXlm_5nL+=aUJaTb={MpWrHyj*RuczfhFdF-6gmB$^v ziAu)C>Pda!T5AIm^Ca$^8lwaE?$LHEV3>8_`vP%0I7vBF0DJ{%E9pxp`mQ1uia#OV z55a}h^bmW^KR<3QSZ=c=)N`SJQpoP}UdkfZY$P!>)M!3XLA1Btd#Bw>9NXD)k zz9wj!0Xpw+H=Vr$f54d2GkJ@d)40324)WF)2>%qivT?{|-p6AjrIIk05GPJ{v4I}G z+P&%6+RWtVl4bXkm&U{0t`@y(^5CcNYrG99)oW^#ypM9P-aBjlUM${dG6k(GRkiaX zT6tP;QsVs=$c{c|ki5=&Y47mqw|6C}dqTnx15{k^U)T%$EwcEp<^d5~g6%4bsR}Gb z>9gmN-P(2Z_(_9qiDD5xo87g+_o{s5=5b4jwAAh&S^)5@(1lg)DD3LxLxH*gvn~Cl z&Hg)PA5jj58lb0FL0uHm40vqoQDm{y+QvN04Yn|`-v3+k_g^mpuTbkL5+%b4Bhoor zd2Dw=!TAB~#Inu&&h?0bX&QAJ9F{Cm%ZIKy|K3oJ{ej7u=e+a^j!-9Q#q?Gyrn%D~ z{VD3>M*ZdE6(ANd%KGPX_TQMdVs;0wE`G>q86l(;Xw3XqtY{ z6Dbu$xs8WX)C)MDQe?+bH|P~v0ay@PN6S?UbZ6Sz)W0CkX4HWt5V`Z2wd)q8$F2ie zi=91Iinnjdm2F>A$eQ8{Xo1w`T|_)k6w2>4uoZFE1{D2VomxwX&qejA*A}B6^)Ec4 ze`>Lrq!!#Taf#E8eh=F{{`tC%>Jz}!CbEyT|7p6ll0_g@DmBBOpt}8FIx)@|m2IqV zed|*I3u( zM#QHmg8qKXILm_c#>6YUJYB3ia^V?`>a5O5eqeJTLfNNVE$e5t?X}7F2Za#)_C zO13>jws=4uR`S_fSI*L5c#WF(VQ^sA30l!}2SU*d7S*=~1XBwzi!5kf%F|wFq%|p= zTL#rU^Zsmp&h=Zzw5OA)oX(5AbLcq2@l|OTX-Tpn+%d4-@p9t_gcN;+1cnH(y9Nhz zn~V^BUmzTD3!B1f1@$(vYqLb&`M7mzwc5<$!r+`W*?}%b;-52}iMvwgw>4+g?9^Is zK99#KbEgm*@R!l%iT+R5k~@x)hIs<{bn*}nyB$o>#7Ev5V& zEXafp+{EJXFFXMk9Hg5}Ng=>W7{4fp)#F|S4gd$#t^p)imA7XDkO^>*JWY21A^KGP zdF`o$HQyXreJXa~WXBE`%GUtvS_muiUoy&E2nyzgMqZeMGJ7TfhT)<5k@c_*F_6tObp#N%iVwzPF8<&iRkdp;uKt-X*t!u=(^vjSmv-7@@k)sj*8 z>6ERU$~tE5r$6tcX*y4q2_Qp-oxS$i)8lWUCSQW1xzVic+K=JXbEK8qKJvNVqD%3x z&+=0E%<}BGV7O5F1JtEK+4xndsOFdSJ=#O9fGV#jzH%EAgG<{JW;YryRRP$?9Ps>@ zkwN!w&+ZS>-B{*&Ky;cn43A#m&%!Ovhb|*gbJd%3k_VDQiq?#pZ3ZL)6kGE`FMiow zclW;CRwYxlwYS%-u-p?{eg;PL?Vh?-*O+3rMj|XRT~M?p-^aJGy9)Pg^5D z2jhES)V;+QRE*P^sl*J625$?#6m;?Ig|+U?oF0Z+ z;CQ(e?05iO&-!%Nk6`!53!{QP&f!fgxW*p@AS~KvOm${PA2jv_2n_`La4X5*fAw%d z_t`!Uekkkq84o-7Pw-s zn#ryjL;3noo|mJlg@xVJDK9+4tBIV-x{cX|N)@d(Sq~=_bYuCm_62o+YOK6?D_#2e zDsD`Fuv3cZ)Qx`n1{_rrNXtm)CQNv*dN+bticFRdx``S1=-pRV)7 z&ZrMIi8tAsvrYkBOdoo)olNCiT@rsxP3H!p)o$Wr$<7bLcOjRZf1;T-+xe7YUPGSu z@1?i3@PEv0_MMIATJoulKF}A`>jnpNH=_{F*YGTpm|NIm*`SCWvh(~4@JG&-C9x_BNL;62#6@;CqUa>|oV9Cb31X-b zRW^EN32gXe)o>FH>^95Fi}x7Rm6$8`FJaIBpbC(4sE>oG;1~x5WK2`K(6tuji!rF{ zRwjqIVb=yROA2lO{f?^VU&iJC2E;wX)S?Qt&ZVV1oz7PD^9t$OE7XGRi~F29HOzH# zFS}}i^M?o@Zng>rLd^NTxx#sSItb_WH;=H;Ulb6qM+#fgdvo5jI&O_KvQQL$twjEi z5nhj{@+Bx+Ee7h+@M?Q!7OWq&JecwAuIoR%DO$mF3H9iW#dH1J5*@Sq#p3a*`q&0H z_=$UADMwfgoFNVwhp3Wai8|uBFL$MgMc8{nMszw6GE@ypws1O-v;P9&RFKg<*D7WQ zjg#1SL@2*3RZi3m(BGLn?DBxJ>c^$wCT%~?ynPkFImU4AoO4fSVIpsLQGC4l5WpM4 zHX)@*1J!kH$zdK(b7RuvRhGbZ|b$f!j?Q z^pw&~>btg|;62Ob33<1oK;cgLRoPhniPOWZU`pRe<4%jR_N>~N9Bq^#nyyv&89?S8 zXQ|C^`0o@6Q_S%bG1BrJ7vAn!VeMBOb4fYV)s@9lbyq8z8iK{m4)fFRGS4s>Cpt?W zk)Mn#n~+h3{alF0d`8D&{-$lm);>|VoG&_EW415dH6+{epw2|kVEQk|Ph0ZdzZhV@ ziKqTzd>?iMm`iz+6@BI?sFSup6^+2(+hx~G>nKce z)>SfnL~n2Ly%?HQza&bFflQ`k{6aPB@61d9eD$j-cQvNZHJ)Ao289+JVWz`>7e>f= z0?Sd5rI44Er|6kjo1et4z9P1m1`Y^_GEM@2J)OR?nvYZw*!;d8AgGzhQI8euK*^d zqBA!Y#E}`B{01QZLBYRE&nX{i!n~dBKIGAEe|x~#=y8|M0B>piyt6>PD`fN<1F-0u zZ-)8j(>%ZN{ZpL_q5Ltg-V>EgZ?C(ee@R&o67v&1LOPV9ff_lOiu$U8Xrf(;SVu#b z@Xfyy=w8!Sb zH-U#2yHUf(v)wA=_?KXJ3N&GJMTrv8$)-8jcb{^$nfnU4u)oVt?uGk#;7GYz`Sj+Z z+@2e#+v_?-1>W7ixySDD<)sGpMhEz||71;LYg^t3>S@pm0f4H3^K9exjTEy_Y@0V2 zZNKl{uy#-Zv&N>;wsXni^nu)HU^!i|u@(;zQrtNEiDh%T+9P@gm=koY>6&GG*FtnJ z=aMr&!2?5DI#Y(Q*&!-z_edsC;6wd9(MhdPYC>>9l;~IN7p)&r=^^+oKOX+NQ*5bA zK3rR5&&Y?JIDyT=q&){$WYgDD-qsVi?nN2JVcfY2EhD3#+4vWlD>o!ag>&63l~`q8 z%u8Mnw(v@w>gbR$fjac`q|oF7Z&u5fYO9o5jU^g528HJ;CK}SwwkeXo8t>L)`Za;S zz@Gj+?M2A#tpBVRRuw6Nc@3nLFGpV++VZcVS|yee7vNFO&0u38_uIw-?rj$~kEO*7 zZwH(M5p)To_(R|J70UqB|7c(FdkiXaXXari`sbF|`8PLj>{FwBfn5Dbi}1g&TVeE% zj#jB4@On{E5iX{Ng7_d*f!?z6;^3I#bV^wZWvW)7)S8jcTNp z4mE8M=7-`g0L(VU9TsKr{!WNa*UEr|(6hleo;-2*e4|Cj*7lm`xMwMkfu(5a`-xj+ z+3uVu3%*2i`{Y`B#(<(frrU8BO2bT#1cmd%II{FsEcm$A-<5gVn0=1>6Upqe)*GF? z6RXd!@HZE`NOZ7^yl{;L$FX}>M)oWjUK)IQ%0?3@>&Wu@U7o#6>2U6A>DSUtinXHU zTVvLbM_mry-2V8DS|~wMMbrMdKKor+83&r)+`Z(2T(=bS1z$o#rS!VN8#h(T2R9SF ziAAfhS+VK%1Xy`z-~+!>x8CgB!s`pC+lSi5n!D0&-_w^K61co4=xz8bPPe4U{)|Km zwx?IN#`Fv*L`j&1geXdJUjj0WAJo`~^5gvq;ACZ;*4gU%_1)r`meSGiAjQoR%`#$k zTdY=SuD`lX^weFGbrDfAO-R{i)N^D^u&u||&W!Af_*xx#H%{@<`KXgxQBGJ{`J9^L zr|aeFj@~wuTCW~^t3My%hnhO=u?zLt7=%Kb^`2bPNdoM@lz77bA@9w@p?v%Q;Sted zYmux&M1?Hb!;pl8lwGLElAY|skSy7kP|6ZQStk2#$et`A%5KIk%vc9wdQSKKt>}Ay zKhOR7-QVMQp5uA`=or^AUDuq~dA-l`I?wm}wbm>>Ev;{!0AVyi`$ReJRDTBWW3rlkzXD# zU{@H(*-VL^5;|iX#VC4#e8pH!r!9#DTm1~hTC9Fd|6x4qnM&+PglPQRhyAKd!lf0V z>nl4?ubyYt2vN#C_DLz_eA&>qsNsm-M-bo`MM{)`>-~V-tBlw=L1h8T&`e;f*$#F* zyXM8?gk$5$MTTxEKOhh1j%|80}w@HU(dQANET zwKX%kbsA~4Z^bCGf@-Kf`hAohe$ag9wLAX+r`?G1#n}_nP>1|R2)qz?pM$2s%j?Jj z%+e%c!$*SPl(1Mu2MRChwZ>~p<^Taw>3@oe4&tqdYXvC~gBEEt^-)^qX#*X%Ef1Jpvoz`m{#hSbP(Uldt7IIqu_ zeL>aJiuZbFWXk&{H4+i5VuO0X61D_NgHO9s9_^yM-pb>1V5q+2>Rk%ip)u3s_*}1k z#Q2Hz^Z7XstzB4UKEJ{d7u*q{HT11$y8|P-7vDIqM_rIIwv4R0F3+WWv%K|=j{-dT zgz4vo54VFidLMa*1622x7zs`aE5d%=ai{mq(#6?q!!29*3rifmMn93?GKs$?HWwYE zQDSo1n%Dm6NxC_5(P_@LCF9KEO&v@`AY-8KqcFw|4fyawR=Zcha=I@LT_B|{r%1N9 zTH+6kn@9hE9JFL*Cl2<8#%s?`p)nB>GJ24Yt0PoygcCv8j9+H%%ZOQeZ4)iF?C9RR zxsSgut7556s1hDjMRB=!>?`)HO(eMbUqTNXf(;9MQ10foBDY`vAFrqWnLxC;@qfJH zxCB@)UdcP&U=j0rccX-%Ib!|32nJFpW{V{4O^A^HX3_f}zoICN_W_wT_AQ*PVz74p zVzWX}U~-M<=3>5Mf=3Of&=~x66#qLZxV`!|`yx}Mb;Rv9t6)#et}pS3$v=IGA!`*p z07$tvG4*$MgTLi+-tEt4rmshbBDlY=dBIJ+%XGjxjz;ySR0Z8eAOSpi!+T)83tW6SyQ+aER$sb;te@37k-m>(6rnZ%|_kjPe|0Z?ue^U;6O5j))Bd8V2 z1QLzFn`4r2VqBGYIv)Yz&JpVC;JYyh(iCe6|7j5H6Z>BSCF(tyWy=As*q!*ixuffC zFHu4Ry;Ep)fjm}3Wo?sZ7w`@_dGn`=($CMovIXegUzkh{-n_i-kNXmq91PU>R}8*z zcMDS~oUQuuph%Aqtz^;?OeaWww?Hxy&vF^V=Hq3e06#IFOh=HyA{(Jc1_|^v)x?i| zW>$Q+)nC|ALaS;69*A)SQmZM!BDtkq8>F4|caq?SZ{0%V{gJCM#`GkuY*?FBcGgD$ zO2x(W4qY0DTCEm}Zzpv=GWPo-YoOj8H5V9j2Z!;2w~C`zG%7CYB}`p8tYx>DVJ}!G zVNKchAfncbz5#ml(b6f=w^^kjrrjx3rM(HuH*qQxyirYAO!u3rEMGJJLM#U&f^DxQx_dERQlX@+Y;=%%=qGdU5Y5L+D=cN=&B<0H=kSVS znmuq0tj_J;7^D1(pket7qe2fwC(sXGb{M`4n&z;BQ;zx2e4d*WF<7FPLwhlHD)+FL zcP)~Q4$!NMw%)x^9&#UXYHADv3%tF;A0~CKMrJ3yGr-Q7r&5}aARco^(dv9qM~Fki zgFppdt#ERsJFe|E_&~)XY;qIw3G3<4H=->2Tn5sWnJ$a9#7$e=S8tE-fNZ@B#-pJJ zzY%12UK7t)jIdQ#5?f0AZnzdcc4rHB5&gWdLTQ~_>}^tXEZWgSM|ylIAOjZau~e6} zR(%CVHmKM5&db-9`BMI&#~kPCpA1bzrkfvq6_o(l1`zWeoyL9OdvudOo=lqOD9ba_H4p5}p=AP@bK zZ>s60zO%Ilii=@>k_!EGWcWbOcXRZt=L$ccxE-*Wmwx+L@X(EXMaINnUs(;`%FtuC#F8o3nKVD&9&QQ-QnE;z*g4oqqA*^v zjd&cv;>P;b;@Ehecb&`GisH~D!`?|_JNR9XNQrd3 &LSiGgV=Zb;6RaX1)Ql4pd zHKz!TeD&TVTTBhw>!se%-UT}!(H{Lz(E$uh42P^EZrmhST?&T_-Ww~$`--3wkJ%`E zec?XL1kohvcRCI+stlLx5BdgI!_T7mB;b1END>3Yy-zjPw8KW$HD!TW&de9$n0@5t z1k3IwCoL{6PVIzON!UI5&X@5;`>+9QdGTU~t(PB!qDyDyyF)Xj73K9~v~^G z>4U(-l#0kc3q4K{tsurT#}zbk7!2#THPLC}Q?mO#%{Bq%dxb;CDHt&DqrwBOp+qT#H;8vW%?_uHAZ2dc z=Cy0Q5?eYZ+tP;euZf+EqHj$s`L^W_PD~wrKqb7)*N~n{Ei@69#@%(3LJPZs1|_Z_ zg8W-InkIdTE-8&SO@&nlynR1NmJ)? zJ!xZo?d1uJvN@SY&7lRo9QF7xFz+%UOQOeTk9VGcUoYw^h8`T9G2%@6s2Du0Q5+yq z^YKaa>w4jrM{R4F%|+sK;4=95F&KlDkvLwFFX@dH`Y_qXmI6+2@s&A0S&GlmC^)#3mupbjc0z-3*Urv~@9;TLv25 z*E=M?fq0_)YIx`RM+AWsyfcopeqp@YTbB{tZ`$T$<9d9#N#Z1>03`kZWx zS-a0POw?&XmUVheJ-DKz7(L$^m?Yk%9tc;kvgPMccwx&{R?1fjcRY8xPFiBALymMH zd%`a8(!3%jcCxeSP-7-=c3I+92^iM=%5?Dl(xl9^6WiTyRFBC$ChnI8NO73pFikZ= zIQj_g@`~;vJ~cxg)%;+z5_`z^QIFe`v1bP_412`93cXAE{Q}QKr{93qbfC8|K?&Op zkgCr-oJ(GnEV;8sePQIA2_06Ft#j)2?G(XCa>bAF%q&?GFF({z4kU7@N==F!Z8Z-# z*MXT7seHRERg6J0!_z5mQhZHjs9`sP3t-Rohve9BR3{w17plNPck8Tu`h-)Hl|Ksa zj~lJ;J3gN<>hm2bylpDBb-QZ~GGz8;@8vg+a@Ma~p4S>b8mR=zTnb4J9ld@NehWmXMO%0Uz4%@cw)<%c$8-s2GdloL?*IM(2f{V{b) z5-FSM#}JCX1j}BvWZQwVr*=0Un$b=)aB@r?GKeM_I}y58JFAo3R&$0Rj-?2Ra$aIS ze2VV^1(#Mu+%QKF{JTM$GqZ_=$QTSF!#cZ#47HQtwWPf>o5c8})h=~VRjNSx!&vj# zPe~?%b!7fVFMdE=c}dp29EoE6Q`k(uV$=zPhQhWtdg_nz^V#ow@c$4k0YTT8W1){% z5fR8dxGetZ*{0WaJFlfPURcB*uREAqTu!2=EqcrJzGFj@QU^~~uM{`nqYzHx{lTN6 zedwptbFTm;sfpxhb&LxjA3XW}W6HCI$8M_-2nl^4@hsdOM-iAq(8<7as8+wP@i%{Z z;^y+yqkB2qE>mNm`kUdGoCk5^lsT91uuEt3k|Qd%%niyzBX`S1fmLS(lZnCzw3 zPY%fyVKG7!B=*g&KWB*LutPJ)L^X>ukaCjtcvHajjYDeal6A|9Rj96fbfUy^oPyA@ ziZM)6X0DKD;0JM9GRN_9&R=cfwup0&9uEwE&2n_?Xl5SQP{%~*Z!V2r1yXiZ8e=S$ z2-kG@)vg6=^Y7VN- zxA)9Gx0XWljzEf5CfpmVrf(Dz$gQe)lcPh$cwCFpPn>MU*U%Y$7GqaZ`h5wTKqa%N z!Z6)i6x3%C&Tv#tfKuAt-rfzJOR&RRnPXpG$;XFbS8k);P0^F-@3yTB|3H^ z^-67GgQ_Sdm!AkIDAzvL zdQw%KrYPHXsY9Ygm}GbmZitV?oaV2R$2Na(*@_S31&XAZ6%_)Vhzab}hydzbhwPA= zcR6W4k#$#FC~8~>`PcN-!e+TTSH#W+M4V11E7V~h7gTN0SY60I7wyK+BO7#4Qqm$TnVnNgZ52V_LOEu(mBYBM_G zY5(_iKoEv29lMaaqP_{y+e)2!&WK%% z_v3BVC)YpQKql)R)JmXh5Pc$08iF>S4#N~dGE;!J_2Qbdym98-!kcZIgl4I*k`s?V zq1d^j5l0|fbHO0JyYzHzG+2ffYxZ<(J|N%+8qX^2nGk)rmV|4t=iwl%9YJ2ywgi&vR+h|E~xb=Zq&A^ebr-dBHaNIu8g?4p|_ znz=>8k{wfv%YeRhk$>2_Ct0j6 zH5&0_Hy*E+ejzBjmAt=r#jmZU*8JhQ$%r0YEs_oI3sy=aBt7WvqT!0XG7g1hI+W|- z0a)MrC1T5TKa17JDjsY|6Y7-nRa|mo$lH`cxUlG`qv91uy4&d*M5YL$v=P^3DfkbF zR-Ix}CQ${dvjx5V3>*C@4knUZkk+#~JxRA!E&1Jiu7(dEvwh{^EelrKQZJr*i#l!> z@mra%IxtJjkHQelYfu%q7+!H45wPy_RdKyqwdrhCtm1R0FEnL6*2#5ytL8MgGsusW-0Wn_M0_cO6fQjuh^=p zi7ixpUq8U~*ZAOT^RDz06|I)GRslrqN?uL*QR7V(Bx=5!f9D9Od+wo(3&2649Rc=Iaov51dzjxz1^e0sPN|aW86qz<3AJ*KXn#0qhx6h}Ca) z@d*&+{h*cDeTPNAuW}W(46Ooab}PFbYLEVXxv*6dV#JE;uZQhVxd z&)wvz_OjypE^5%6wB99@9{qie_EJ{sSQ5FtJxSL;0kZBJK{Ds?JoWlH33O{i8tAd;^ zap$2>tuD-Bfxv~2_8)5>laOgo{y#+-LoJi_!sbUnxAz)>+Ex@o^`9igW4Gr2$FF3W zILQwJZTZxillZ|gQ@kEp6^iEeV?X86H0gRK|4{sq8@^Yj*-dZO`ZkfUY;G4+#beZc zr*RvP2NjE=?w37&R2KI9cx9BHy+aF`AOi%GN9D$>9Pedv^rF^Mv+jnfi%grZ6Y8xQ z?Rs?a`RzhOMvsS85%O4KP6E*{V?Mv>zUuV>rr0Zvnm2rFxd%m%#+=@r2wM0x=a*H0Dk{{=&d^=Pgjt(a#H6NUr*Ha2!#Cz-f+=>n zxHLpRlfG~}{I~l2K0+B?A{w;Q3c}8%{F+ z65?@_?Ad?$z&QSa!T9s|mlnMvs}QiXxJ;F|< zr}%?iHXjLwVN^dT@*?#RiN6Or$vz@#@%%pS*{`dS+~E%*(i22+UidF?opBp}-RdJ8 zD|z3?6XZ|Ngm$DCkaqM$y<#>!vY3@Z`R$R-8H^HwrqhHh;Z7uzc9F%i!Qyhm0FGl! zln=abl{JYsU6?E`u1>7dj*H@@m3RMWG(1g| zS;bYmgY7DT4ps#(w`bGV7=zCt`_!k|_SbLK3Pq5cq{d*9-{p~}93e7-Xa8GD%E0;Q zAC$R+>?8q)THd7k$awCg!G%R4Ack{-gYMn4_qeW+9q{{fa~jEU&$>U7K#DJxc{*OL zSU(Wqo^p)+ce986av;3eVjZQzDctL~C=&6}x^S7B`D zh`4yB9)9Wm%_sa?kZ>AiGn54St@^=`0LhSH$Kp{>gpuL`k+tVd!ep^y==E;?MCnK0 zV4XPg5mpNPi*f!_ZlhX$N|nc@Yc+lD`9|HnQRyT5VZ}G8c;U;YGB&wU)SD70+oSmu zf(Hy3u2#yKlqe-iQfG3SjG%4_tKQh3P)3Gfdl0Q-zSNkwhjh# z7c5KS<&|$+o|nmF$G|ozzJczZfNs$_1CnSbQ4W|uZBh{Z4ZSp;Yy20HLx+7Vd6P^9 zEGj|K6=+BG?>h#0-=B;R9d=K!cFTAP)^0rrphf6Nk6M;7Hz6lYf3A{WH?zRo^GFOl3{PGAPjqHL=z|b^X^EA77D91lafxnio)QWqR1Mlqj z>G;Ke1p($z6HsnUifGBcw+a@|*F>DvH<)IuVzFX`;1Z}ZY}Mo@h2GEMqX))Ai}wn# zIv2GnF>kcbsH2O#`VURqf9T?ZDskjGSg2UTR*zI_Mi2sv)wjV4H$`gz@}p;e=?i%j zpE(wB%Fr#Q9_*84puA#?OMy@FrC%*gHoIF1)L&!9<_WpP_;xteQVgFwXzCr#iapGYsS(hU16Yj*yXILL(&_A#SG%v7$s8f7a7jq`zHSh%xpNcJ!M%+nMth zKZiz>k41)mZ*XX4z=#x^pivQw!;NvzM~)O@XZ1|BAFO>@b-fwXJ0(Xc5IIpc?l@#l zw!yll0C9tR)8J0;yiUip76G6IlN;ju8O)~$qKr%z)U17wxFj(c0so9qr@%*O<2%P_ zPFb!e*~X?SX5~(dm|9ITjGu?Ld}1jWf~R0=c!t0up}|J@5fn4vcDU(~rCh084t+m$ z6=g!p;3o^VtC7}}hW51)FRpo-y-at5HMpRP$SX1K3!*cn2~wi@cif{Adw80zf2XRt zdwwEMdVVMiTrjzg$blGO9Bk&XOr(afHAu+|z8h}!dYC_=VCi7V`E{6Svr5R666}D~IfJ)?gCSV;$#<>zT(`6+PxbC!WTmu=Bb&4~r2l ztk)+O?{nDdy7GMY!pZBo5kd@yCV2drcN8JfQaXTWg$!32Gni*lug$D4DMWr-S`*KU z#>fRNdp5k0s>!9*VQ+=d&GR@~>JsGPUMScp+kpjHym>m-V1PDzaf8YGo zr1o0_Yvk&QOm-{5xPTS8i?5(nnRK5lYtHVZRY{~bk6R(h2CX*Mv)DtlRj-LJOdb)< zDQLh)zPkJVW1K`S-Rhgzu(+3&bODKWEV)HTPKl?=Mx}=PYS}~9=^pw>A;#(JJyE1H z9-%c{<*F*iIJ!^q6|X*CVp8TN6yEWDc2iAAf)h)A&>a~FhvAIJp7htSV#6ONUBbpT zEIF8lxWtue8(Fqp8?lwo^RQWp<1`Hh1z5;t=Le9;^hMhRe}?PeI$qk1H}rgG{Lrr+ z?%;YUJMRAS_2+Vrr%$tf#c-W_gbGPMTDw)#!TOzMY;T5$vf&iMRMp(5qPj@N)X_a|SyYR7&8B?S8meJ5540?nuZ7&h z1Ì$RrdZlIAh&i)?9x2#Ty8i*@-u2kydQtq-@x^d(h&%DPu(rNyqZenQq;F@## z@Z0#12R5&{!mQ5bC1=Wsmq!ZM21e^?+Q4{dMYrt*S02=-!d3CNF?PpC-MH)X@ul^| za}{D!bXCFKRw}LJrb7?L(qg{a3%1mC_KILeIPub!+4KG0TwQR5Y&Y|py;qK+zHf)+ z-#N*ol_W+f=0A2U)Wgh2DsZd|cP$qeh>mRzdZI0v-eoz<|Bb0##oSFPj8vdr>LJ7= z^wjNe9<_D5@V4=gP6$4&PP9fi`avAo7pjFSfg2H#1uCSFk1dDB?!K~jYTclLz$LL8 z^@*bB`l~CM3G26|hlhKUY}!9HXMe;q$kdYbVmM<)Zc?-m`S8*haR8gnhWc;a3NyJ= zY@zqKB?`H zYwDV}4#LFsHmKO4f}rF@VH^#S880yWR3mcEr0N$`=|07|>kbca!QF$JC8ph4A*LwowfZ z@nJrKdVncW>_e%L*E@O6hxeawUVjpJw>>d%@v)T(hQHoWrSYaNzPauqpOpD{(~31$ zi-u6R{>T;kBQKi;v2QC8x`kB+u)>Qv+Q_< zbT_8P^}De3hu=mF;&ezIZJdXmyq)ibzDDw6n;&~~S{fg!GRGpd3&igC4XpRbSFF7L z0Z}PU5GP#r z_amHCg}h?tE|)xZR6D3V-WkBk*)2or7E<3I>6O+v!c~GQuQJ3W$|3iZioPg29_ zg~On73koH*M^oMOdlpJlw+8I^KZV1Fn;P1>@$q zTLRP!zP6e>wN|4RDKv>visQymH*U1iYaIXh`U_}~jg%ajz?_GbT3pXzzJ6<`D@)7v zX$RFgW9nLXM)4e4_{7{forJ*9A}c%UStI>7#m}D`8osy`7p-h^Mj+kc{dtK+!OjfO zR1u7!JVu(B={nv?YSHK674qb2U*4i$B#WG@iYb#HJLg5T9EW&a&amS6LE;On6OnGz z*ZFDDu=eoP{L!2U(spOmr>Jy`joq^ zJbVM1#kraLk>DwbTh<8fyngiZWY?N8L@sAefICj)w6plVl4Oh8s5FNv)+gbIC2Ea^ z)$8m6JxYD0&S!nYC+S`kp>iQ&Zj}eq-VhY|nEL!#DKGhmdW}SrdLg(bKb#b=jD|tx zN7$=zl2c69vlW*Zj?3@7;*NIo*1rGxTiw7{HoxZU>VCd){nJ5 zNFmJB>fafv{`6dn^Q|@;y!J#)IM3XLe=mQ&n3P?uRdA;9$6M`k(={#$ zUKMQ_)J!5*o+b_LHvsFljKBqcgB0 z(Vwj>f^66>pv>fbC-caiI(VG+G$Z%Vl>EM5{0|$Pr>@#f(1z>f_~xeEW2sU97L)vJ z9jxIT8I1{HTcbNpC7`9M5a>6fi(r>BhZwtn2-|n?58>>cSoG2sbtdsZCjRrDzkwJs zP_F&UQ)BVXDmZ=oLwrq1XeGJgw@3u-M((@~cpzpWMDDj6^hCqkJ@YzMmjAk-0331m z^wfXY7k{38EG7GRlbhz84Qt3sD2k9Myw~>C-Dn8c-jcG+;=Fz|g zlZVYz{(d&!GRo_H4R_k*o$BJNll{hi99wclA{9;l$RT^m6n|X7Z&GCdx7?!5JI zF+yOQ2YE$}qvnmi;CQ|-9fB<(E*YZu7$$x|s8Qy1BG5Y&uhGX6+1Msd&IhF)|7?&j=!;`*;++|gyg7{G_dFI8*gW-Jvn$E=&q zR`!U@JGbbAbislbNQXorW;7xD0%`97&}j#1&U;$0zcnXGTK4{L;k^gDr?Q%(ZTK7g z>bHDt98SCQN3EN()D(MO@2=8+R4F9VK-wA+QINTzSpA6gJjkd0EBf;}Qqn!0_~)?r zGsypl`9G(`KTls)DnbFO$OjwA9}t_2rT0jfnOc(h4ZsK7GffBWr9J;Aow2VU0d={% zN>`ZWYDw9XqwbC}pI#9%N2*!#s}#ZPLIfotqi&GVBKMInJaHXo`Pur&d=SMqbnUhf z!rrqlbI$`Y(5EN(3WJjB2B>t&(q;dzwNULdpR<9D>rKWt*lIo-tD>MMp5E>&J$Y8S z|3=Y?RL}MNov5hLa4L8wMg8azHz+3AkE_eZR4gGp9$7p**}Gz7G<4F6D%Id+V6?Vo z>(@6W8A>-v8ofk}RMFC68p#;$$0AC6528DIZuD~`IA0UIx*`%0mp5i}Yn#b}o7=Kk zUiO4rb3@J9`cAS|JIkyQ{kAM4<2wyK&F1DO(jT6pZJkoRoUgk)uE|kpJCL=h zz11HEr@=^5)O%m-O~sK%pkHtm6;^Lc4jLMa)|16wIUjdvQu*Y2`SUuOS3-T25CU`) z-aKwh^?W#v>v(@)_bQKJAy*Yytt^(jKUerwDaO5TrP-!HPhV{Gg32U^P%>ZeGs;%^ z49E>=YF*ZAv{Lb*qWQ{sn=8E^rv@|=x*0T;?hhc3E<}o3^IYt|OFE{y>dlE@blVzW z&#XhMnqC*~yKpgQq_2DAOXGK|b0#?k;ct(kBxsFq!zc{&hepyJn;kqDl&t;M-+E>D%QS-* zvD69TCUo4$9a}IkOgLuIeq8X<$V4myC^EdYAWa8klg@hRDy@1e>O%zn1_6LL<;Vyx z2|S4ZW?$|j02{xAZQda}v22G0lJ+o4M3)>47;_#O{b}m{%a%TcMi{pg0>A7qW_)b0 z9v;D2VS~;hDL3=?VUe~9!_%N`j;iBzQooIi@ZP>=>w2DoL+tpmE)T}l(IA(SwlSI$ zLJdA)LaUEYB0A~wxYxZ(OJl~l0+x;!o^WoxsZM~m$+AX(PDl=b&9Lou0u=iKZ$f;0 zcvB1yic|l>gB}8mAk4=jwhnPmjE*Z3xV%ATI_0sqDv=KL)4)CL=N-#3o0VgrlqDfq z=TKKvtJ2aC_S>39EVzgO9l4%ZWY;k}S}Cy&#o|XE3&0GX>-DY~e1tpV3BtbDh`Fd3 zbSXl0=Od!I%J>7U;@koLJ>7Z^eA!kHT_Nr^+HNPhpXipkv;51*{U${Nu(3UBX_?oI z;~8%x4&q;_I#8MhD9SbDKNe7gO0Fc;KDO#7o`!p2H$r(&bvE4MiqE|&Fxwd`wynXc z{zaXmM3G}u=9k$$wzO%OLIJDTh@zU`0;AAiSmof*_BGlz%CSm_x5n~Vqu zv$OGhr04PcP`yoY1^G8Si3fxlS&l^7J^|MP!JESZzG*rstArmUAobMb0 zh+Eq_%JiWG!AWkyjT4LE?X#;{QZ|+krTDzBp5M5QSKI*H`&>SF`p!qM&jgNcAHly7 z?timwp9ogFZA+P&a3AM{v8UlEtV zD!W+>m>FbZ`@o9n*KDIlWGgBr4tBvDljkRh$96u>{|?u``*{A&EXHd74qsod*ywR0 zWOX3hIVNDyZ=hmy&luu?Y5!xkDwSis?lp)UhYr#8|C*9W|Mv6wTRkk*vDC!MC?oVf z(m3O@&AASy{@ML;c4fJ*pLjx5${hqWE6}Pvw3b)-Yl1$#XTAqDy5H}Ce=(*0?Ge*4 zv(@x6d%nWYvgc`ocHis#B%u{dR^GdOzWc#^9vdHJUzC(|K%mXAH#Wbq3O|dCKR^Gd z>!sB0E(@*`Zh6k zmn5#D^%A}0lbNhy!4-V%+p#F+m}!H*F2uf0 z?bi8gVyrTE%tz@qK@nnHr{CVm&@>($@vZ7|b7}09X8L-@0O$%+RLhRw|9HDRtE9M? zKS|=-t_{Yk{Gc(2?n388*9LB-8VcD+QnISe*_|F)7O?=-1Snt6j{tO<_ON?tphzce z7^&LP6s1}@sEmixa_|6g7)QTzbQv+a`cMAup8nnYRe7wg&D}%NVUED@@rX>mSP)b# ztoZOZ!#8)8dGk-}_%$66JU!dtv7-9>q_n+_dwcCfsbqE7%>8$_P=zr%+Ab~?-w6>< zYeWbFnyVCYd{O;vQ_eQvRx>W_TE}41G zLEdJ%B}zIc>V`{nv`wUGp1s@=#C5`wa7IDZBL~gZOzE>b@I3GSNFBJ`GtIVn&o6Jc zyJnp;mTYt3pSK)DGcCyk)!8qfWw=B6WsUurz>GVx!)f~(zQEEttTE@nQ9a!vjYX|! zZqf@Xr`UZb1ZHQh$!#0R<6QEyXZjd(UGa^-kqEdGNxO19n2iEo*%{SSm6Grs=mlJ`5R_J21YiB%^&J_`U$F94F^!k)p%6N9xaMIY=oJ(|)|zi`57I&&k7 zv~U0KWIgoQ2=ae2ierxwXWpD`77>b#dC(If{LC|Pjc&R^rW(z0H*oui;$-^DB4pX?Ji(-8^qnr6{F*7}}g|{O_!BkENh%FD=M@*;#KOmWN>#7dG}> z;2+feok}$S7Cp!ap=S5q!e4&(2RZ~0#fwq2RIWxGa*;8-daND}%fQg__jUiNNo0!; zA^LB^mbNth&99E`m0(A5;$+Ivw+!Vp{|9&6-Mc?QCH@m7-FV_YG zCgS-`N$93+Y#W~Y4(zk~WO=dPN4*p|M%gXrQ@ zuPmeail?aeR>$M|D&-|>zE^E-R}0FrIR9M$Oz(aG%+KWi|C>Mm%|gb3@Jx5EB~(a^=s=-;yTaXab2M_!oI_EgZE!#!zVVN_Sef@b=DS zb?TUrD_lu)42+8wFH(ghp3p{JO1&uK!kMr!fXk-Y6jya6HrfDOzY22ZfU2hv&irYW zvjcX!&J{P)w0-jZ?;pOqs2(1R{*SP+^H zc^Q((b1RfJqucY)f(gf9N6{-UxiqwiS7tNfNYNl7VbXa%yPHC@#7_wotW2|gc{=>k z5?cz5S!(b^o;fb7?3wqpgbFU7^G#*3rvPLbHrgh?)30vQQi4q9YKS>d*oxe|Wi(fd%|8#IX zx9baAi3&o7^7)Gpa|AQ%zZA6XuRr1T$<0ogc2Ms*v6Z~ra3z8dG;nBz-~c>^N{lxK zyC@759AyyUu<5=w;zt{3FEtZI7v3bN=#~?_i2-o~Nf1>EyZe`@KSz*aQn`^vX zaw2X+?cDB5L-E%QD*o;cAIUkMsNaqG{Hr@ss4Ide_H?5=5kL)+4fwYgW)9dfB3h6L zA}IZK`9Bgx;TGt(LkQWQs!U*Eh;hc66~3Iu-x=y_ys^34mIS+Xi+rmG8}#_mZlBBg zfBVNs^HR&9zdPS_WS)4jkW0|UblHQ#DI)jARBB%P`R-?e*s5DkQ*|66SSiW~wr3_Odj;`k7!lj?%2<~Xq zd_*!*zHs5?OnBRe@PRh0qWrxKgRjUIF(r@udAx^KWyniUan6d$*t-F2dc$8YWU(l7 z`|vbXIpbtD#up;XaTH&*ElC6onihLhA{LSNh@WdWkgqsjqDL3#9dDgvv+`H+ zs3x7zjhX>gsXT(6=f@|qG$N!Gi*)9{K^3Cy-@vU|FVVrwc%|!s7Q(=Rsjtby*0qnNb!)~z`>#lY} z03xa_u>yAOEWp_+*M)ch@9Kq}YZ#c3w23ND*@>F(O<+1!kJnHf7Ho$}&Yr1?AAcv=pb&P2?Z`U;z`nnH zWV#AKbIiX>3aJ+n%7zlR59xK8%qv0 z4^ZFpr>Eg2&i{~i-uxZqzX!$t(FnD0%OolcD-z3K%h>~pgvZvX`6s(}<0l$=@PlB3 zF7H(wH+ps@5WbpSp-B`YwiOQAJc!_u0-kiAWm(qfTK@N%!nx1E3dtD^#_tCk*p^7# z56F8Yu10W5$KXHhX{H9mGXI?M+QP2N7SFE-56UaaO`~5W!?&&YQ0DRCuCqm56ejyP zHO12voJ3J#=72Obe6r{X{7Lgf#!FPn|>p*!)JZb`mQ`0AgLU7$E5pM^F=rg+MQO1GS|EG{VpR z*9wQ(wLyK8n+P)3?+urTK^|zBIck$rpS};!8LmIb8%1*yVgQC0@%7>cK;TO|Knq6+ z(R@eeoxvtVsEYy^T>e8)pCmR9R0vPp1e=3h(`5}eJ3v+vBUS^7T$h5%Uwj9#*}a0! z3xXQnIF+E@tJ_YLEgeHpV_Iw8X>&ndmE z`7F=^%agBMEC_fn1X`zd{%I8y?e8ToTcNG=^lXgTAu8MHMOu6POPiF7$f7C=43yff zUaq79yEA8a=ziD(SqDzl(0g6WReqsZhe)NjRqh##@1zm$njCN=+)8A%+W`;HM%wRIPT7j(Xma?deB;a zSu1Y}^LP4Et~+%QC0n*UOgYTU)Z?>pYZhuJ4K8i|W1xPL|Tuc1~%_K;Whc`@bfL!is z4FtaE5-Es&cYp{7JhpxEp1)Z-`y@Ea_t`Z4zLftoA-n>ynSpl#8+hPb z8}Xn}D@a6NAan% zc_J!@^7PSmqSx56$-ka(nPf|jm2G~X=cfv0ih7Vb&I7fgt;2x&h!fX$PHVj-W@K(g zA&?u8T=bfoe%Qj>5q<}MnfpmGm6Jht?m$odda~KJ!2QnZsezo=&tv>0IAexLMxnRi zbhz;F9}vn%rqxx!Tm?!CuZsKvEg3j&car%CCQlmq;wV0>8wp}_M|{1*ah*kBFNSh5 zFVwTxyEji8$5-rtjJLmtw8;k*L5tM#!|Ve$QIQ7sI9_>u$E+bcF466Tg}4O^RBxJb{h@^shaAX|~A@i}g3H?X~Sl?#m5l6W6js>@@AQ8ljvA{2{T8@R#P`O+wyD5jXzs8k9ol z<tM@=;)ylpAn3A@`wp{6kvYT0M7hEpiWRyydKD;wpZONe)P)NK-a#6I(pGBC zW96P_9+wc3138AIL^;gPRPMV3=m8#jf}V%}9wH0fAK4WBSJTEnGq!j3_{-_Y?LE3l z9v%w+y!Z~U9JA}eb*6SJ_ojS5X?sw&}Ar&V#hfU zg!5kGHYgpUx#=n3sLu`TR^K`O zW;2O7)ytZHiGQgdG z+n)N%|7UmikUTT!k=ZW%$EBuu&Wx?)52I?2&cdUpiVAOpk`E8I-KdRjxJyaWDlKx7 zq}B~NzLxw{!~k~Oaw~Q;uz34W(xJ7ZG~69l?uX1CYnBCDya{?wN&5Oi{W;C(8%Fea zSDV*XG5A8=FGwk6WV_wEoToRN_Txd<=UoQPb|(d|hK^ktcP-akVOi&rKK}m)oq)7B z-En!&96H5Y55tI=c6ZB6D%4sei88T%EXBl9wuH&5G763!EVM8Gt$t^mWMN+Tsre+^ z&%~0mc9T2DXZjb~l}d(SbAlDv1OUERH-o`;xnvv$4=glO_(*9job{$*oo7xn8(3W+_63 ztExwej&wFby7p?Bna#U1MggwI#+y^+ohbb+&+*JU)lBF3ga^ziF3dH_G zJKe&hyEG-IyCt8|!s3qJGDut>M;9^%>pu0&w#en6n?Plp1h3z~PJgE)e()l+0t?tY z#m(TE`B%U76mF8&fWvx3hbsR&;6PlDgS4! z06N7+*z(0W*jn8WNUTveY~lpS0o~!)om?Rndlvg7Eg-+o(F+8F$52~BfUpdJ1&KFw2=fx>&?GG=+jAW+zh<7gVGaWt*_SXP`&QN>rYs>b*&~eXOLmbZjD2UUgPG|)>wc=y z{oFnG@Ao{P_x*gH_n*Qwb6w|hoX2&n-!I1r2|=7C??IUsA=L!fTlZHmYIBYgEL6`J zFFgcar}sqIe3^-`w+jKkDCDnGYrEY0f2}n~gfT zg0cQCQVQYYf(|ZYXvA_~&ME4&x6kk7Dh%N+_8r>YSI8;x*q0gaz6a%ih2w)_!w5F! zZQG0iyWuyH8SgtXSyTJx8*`82lwOswScw>@FFJRC%-?S3g}DLcIZD+9?ww7Rd2I1h z%S@MuOC!^+TJ1<`k7E`vnmip*cW%WHsesEONi7X;)%w}?d+nK^pVHx4uIwQX1cHyTleL$_9!;@oV-XE|2jwLf-U{@p0g|ByUS#3 zqOa)+i9}Ss`)u8?>GFvn6s6>*==;ogk0LXx?YpHd`H!Nx7MnRmYe*8QO4?GafvxNT zrppcroEF(o3n;jvcGJy%FFn+@fb`yl*#nBla)tZP&V^I1psXmQ0yq14B|^ERR_-L1 zgf%S-q|$n)&t_lmhsNUH;lqw%U%nV##>8QGqxNQftY3al6kW>-7nG}!xr(St#T%Zu^#jWjpor)%S(uATEoux8 zE`DsvNBXZevL7E_ob!@=_vE>Y_L}_T2SzJRmII$FZ<;jZ zkvDbRk7QF&hj7=d${4VG!9?bsG5h$v!|lt;emp^X>4fNQ}g^d znPJMXAaH4P+xb9Uui?b`1k2}!x&qWvH%-mV5Lb8`brw8|S3`r4QI{BCCWUlZEY5v9$sAjjVMiczo}Fx)mitexUBd zo;uFpP~VcEA(MypVLNf}hC}_5n3Kb{%$r^A6g{0taG7YI?U8$g%7b|;9#hRr%trbc1(D`^^w(``A-3`mcpIp3gdS=-g7y<_ zWY^Tbekvv9!?m1a;=3N+y~sRok%Z#I84uO3CLB*%Nx`LJRFOXS!8%|BgT+_(C*NDS1Pwq5@#h9 zwV|)4^!dOfLMnL6we-W*v51_fCR%sgl}KVk$f zw_xD3eY}uT_+AS%1zu$s7Rsn4qT`*U>J}sNVfGyhL{7q|p6Ait2Vr*A?(awkn8;Wx z9??Ww4&CaDfW^Bom#L~UyTm9%B?*`etrZ z3m9E=EAF&43G|k`#|eionFItSrQe__WV^k}Nt4jiOs#}#8HT;zC=pAUT;CL|(kiF| zB|9P;zGfzoNjah3H|6=d&}W-PSU#~ScAenAn6dtGa%IJ2p6ygMPQP#0sz8@kI&qS78 zE|pg&c61;n=Am0SH0U1MCad#O;j6YBlpvCG0lm5m186GL+dnYsMxSn0c`=eT&2RJa zU=4a;vz5Z-XHZI&&nj&byA3ukNNs# z-r#nskk1#W#)%-p$56-5*-8>Hnjjv<`0F_x@-D)*!G*k{`OT0X|SS*D(Gq#^edXx z&sKEQxk&Q4Gf` zUlC*B6-pFG2jRL~x4I2THN)>!=WkS((4WG6dSU5P?^p>!9DYPyVO%%=kh}5~4u3*!IbHv~k;U(|PoM$1N}@{aP$h$pjfldBSD2Xn zWaT$!UH{SkoHd<3vOxQxSYUpsZ>su{$k-lQ2zBTFZm{l@47tFI-xszu$IlPm*4s;K zfJ=hm>OxGC0awL61-kgy0=5Eq3rc6&zxw!(z)LL%WHp}eswJ#NA}Zto;&AJZ5#pB8 zwZo zIZ2Fsw{IY6rlx59hRe$Odw5w2gnq}7K+PdZt5gL}B z72c>?n5~k%KYO}1JFlS? zcHr9{w__Ot>|vBswlf@@CCZZ8#kmREYZ~f}CubqNWztxRszHA^Oa56mzQu-iJ$8N` z<+GCPBT{#CC}`4&)8Vy)DK166mU6oz1)Ie3G=WEI&|-ZA^TF`;XU)!Iio?(Hf)Py{ z!bm$bc5n`Af?11GXw6&5pMl;Z=X6f8t6?=6xJhVLETSLl6{}d_z@O`4tVYvWE?9BK zjO#gYi5+*@=S@tP_H1cghM&Wq#4%!!evKSS_Pf(F_e1k8q6HM${I%lI3(zN9_NaR( ziD*?g;}NZsuJ&F%uIq>gbHn1uebMhRIr0nAM>J62ocOp$;qa?c%9T!sOl3GrJM69Z z*PB@!C^`P9M*V!K%Q|1t^!;vkQsFxK21E;6t^)_y8J6j=x~d)+n4=3#xSoEjzTwK; zGe10!ZVLe-1Y5{UA~`B>d36d9uKFpXUpu&doU*8Iy}v)dnRUN9fwPOXGG-b`jW_dX zF}4n+mSqP_;3Yv2ysG;GIeCtm1>x`T_~u>M8!PJh4qV4pB3Nyb!#cULutCv>=mzPX zR(XBkTK*i(pQQ9G_0PkiNHUCXPzWTd!q#zs^W(QBXY1KxOO{4V1Ly&i-u7rM`G4R! zzpA7Uz=CGKSprKt)|y4f1dpGorN6I)rFwJv-HFqz-u^}tu3sSbTC&5)4Mqp(CMa5K zK$$_8C_yz~77ei0xxTg>{s$D^pVmz(T=<|lz+;aCj=-mD6SK!pmI89-Ou#t$Q>Fge z)ckXKx620zHIRHO?B|C35tky_*;*{Mzk_e9!$Esi&jTeT3x+b+48wknPP1hN471-b zEr6Jc?cC>FG=T1>PbZZ>ev9+q-O7suuV^e{a!0=De@NF}X%pSoNN`dUa}?2zEXXf^ z#8<$BH;Xi z_5A8tSRKT1OP%UW=;uztKH=e#Gkm*q@DmB@SFo{8k@ordpAiNjEwtvvcTht!eNDXc zte=?G0uDdEeJ#dt`eIhsKkAJ(Fjy!iA!G3hvDR>rm?SuWv>nio^NY5nS`bo6-T#{6 z`KWy}{|uB8Z-=uS)^k2+-Zr&zNtJhDOT}vw#OIks^ zXneM)Kh@@YN+=>zLF=RZfS`SN*Fb8vx_Y(1wVK>>GdxxzuaghT&hMGIs@2ZxCaKv9 zwshyO0KmWg{v%8E{}UkaPDaf(Lp^0P@y~RVb(J5rZ-h0JIpB=3{O^fmi!x6EnUf~k z2s3!lS*A-+WL#QloL}K>Xp<%r>hLlX<>9yI!grrPpUc!o^*U*aCG_IW$Rwq}@l!>Y zAai$(xLb!>DLip7!@`^bz?#!%;Co-13afcv1xx>Q#ht603vm=*+5D)5$Nb*qJ0jypgMr4$nkfM>d7Kr_`eC@$I!Z(~s@BwW-gws7;iZ!j-uCNK%15nVbz8g3IZrzka z@B@=#jE2WfX;+O)-?^=ilGnZoafhJbD9J?%yc;f;Sm^}a(RH^LI@UC!`GH&IS1WSo;ttXZP_x*Nr=f?p3)OU8rT6~rbs}y#=~$I@IdXHn?x+cPMSWa0<4jL>sqb@a<-Qt9HK|G=TCKDjX3LyGB!){D|b4EdERq)z}Mt05fywG^d^NISfb7W7NDU>poZ%05X~lXn!2 z?cv`zc))?~`rFiE!Rz7q3ANSGywpky3#Se#c}xZ+i&=ePtoXs4A=}Ic5#d+(C#d&3 zcW5r2H`hFN@wD~ZG(lnM)((_s{1G8p}^*#QTh+)y!g}RphLCkt}uFo7qF@p?c z6j>QmVPbxnYXD+H$a-8oJl$T_Os|u7;ORoGE`G-Utz(`=BdfQia?WyF-zv;%DCN4Fa8P?NX9PTc3~I3hxqEW3+s$ z$fPEsW2%xja~bmuwGPZ*q<)7iMF>2HcsOwteFrNA8QYz>n>CQn$8u$b)|6 zz-Rm!m4nE`w|PuDIPy1T44%l*oEMF`jvVk{kH7u#gmdu|9buomlJ=Q(b5ODHQy{yK z47KcYc3643ZwgDh7U1_8#QRHy&T>bQSH16C;T*NmM|vSfX@m)GEnf+6pF1}(lAKNp zNln0rqow*ZM%Diy#o|`rY(|5J*k}D7H-YZ?yW`Hh7l(VX*_zBxDR$ln9IMq_vo}#+ zQ6bc=h&mr$* zGi3@wRXg_7^CYa!hd0tumfhJ?mOP~CVs%8bvf`Cy@#R73a^(Xd8GcstA{JdADo9m} z(t5=JJBLWeBQvb!ndL)gwz^Ylxs<#veiA7>&mU0=4oe3(F(@HtfWh@5_SKS^0j>AT6yB7cXavb&B(IjCUz~P!QZw!t-g*Ah zhD9qs@<$T?)96t%54uH6hahFgaJ?@KBw222kRI#yZCQT!-&b$Hs730Nt!{#tuI~IX z3AYOZZ$`$Zmb5@#;^&%Nx2N8q#(5Kqf$sl`!Oyl%BCez>(F#zfeshfjFzf*lG--33 z!XD_xk}s{j&&nl{fpq5#qReU4O%Pk&V3|oHPaxUTB$Ce+vmE{HwPy%QoV5|6s+hwM zXA!-q2xN}Pm~8^|wBLj^;k5Zm>Q=~#XakOwD6dDIjNCk#wqx*rH=}K{l}j;mb$psN zn1!@@L@Ai5Y{}Hj;ijefH+&T;;K0}22VsB$=;khiSBgZr-c$kz4YVGggCM?i|5wK* z$%C*?p)1xDjj2=XC-u!)3bLA`CETBRSbnVcC+9vMo)*BwSrRUyP^6~Kpl0WG3d?`F zdv(g~NV-|f02_AaJ=T;rok=0LV|#SeQ+%~~Vh4;BjJ1}=KnHq>&{#|GrPsKm>i}^X zo`gH!+mmk~cA;I6*yj;;bjj0JFCEf{;+nJ5td@nuTZ)F6!% z>;+ymr(dN7;$nzwhTJmj`Bf~X`Mgr4zz3QmuTVSwSW7r1vyVYLH%MYoa0)n(5b z{hvHL`t*u|_^1{LB2gBE5E{d_mVn*#NAqxnWfH=6gosDuzJ?6`2Uq{pl!(Tuerh==;#rD83P}$C(cHgV--xm|68vZ?B4|vQMhLaJev@)Wz@>F;S=lVUbo;m>r7Yo_Bj$&q4IcN!L*9oP6OMx!U6- z>gN7h4(ex}nfV3ca#)R$o#F%Rl&*wfm}tL^!Sv-T8uyAuv4gGmwJ3+#ozCSiHeq|D zBIuweecs62>$ zLO%~iK)u+MUJD4*)@^&to-t)&vj=pvgh5Dz#WXuF4f3dFHm!Ukdyb&bE??56dR1(gN9cf@y4 z7!s?=2Q|0NJEn+vXdk%Y;BogWQo>(u=ogtK+XB>UUL73RRHMlJcOLiK?R^ z(T$*tu-{UnF#e_6F_tgH-l6R-T}q97i5&RMz5AM}q~iMxtzz4YVVnN^mzP(!l@w=h zJE+Hcd2>x(TNHciQbGcIJE14{I=F1!@st1&z8{1qpjqbI&r-^3VT(0MBh+3p0&S-s z#_A2`9_du?V_-Za8kuf;=upW<$e8721Py(!HV{;51jbMf^=Uwx&#S#$L?u$0F8e|a z)x98==e^fQnBKEz6rrReFnksD`6%4Qltth0O(b$H>E7KJBCH{OCBu%(N3AaSjdSFY zEm>DCP|p<$DQ|G=*C+NLut(y^$pFtqQe-%AL%LXde)Vr|C!Z->4>&}(b9E_jpWD9 zhqVDKMu{Xa!RgJyClmNO1oPcxZl%m+VJ=~A`Pqht?@0C{W4q__x@EsJ<9m<2dJiz{ zR4o_9j4#hw?6|N@kGUND{L~|H)`0SCG=aP0_1#RP_SZgfy@rhrJPm5|r+A`U z_xNwzqaSTn#tmkbFswN1~q9KVB@pi`HG}wvxms0n)P}-HrgGN+*@sIZ0;p~ z;{k(slKjwdF3(pKH+BQN)CRc-;uoEEz-lc=5^LUu`Ia$n!5<3FBl4{=%R<6e>QZ}y zh*83OPnM(1QIa?fU{pvW2~NPXT0|P?W{5cgM+>qYZNS?IBPis<@SwzH*5Z$h z>jlKLD9ADlseXYNDdM?ct1O)$I|6Kz?60M#OK=Wn_$+ZZmhKCrcQZJ_Fy67xaOC(W zpHL>tQI6ra3l9w0w-6PdS!TAbY!12<5&BQv4ZbxjTmHVrCSt)b|6jzgi+~bPBqr%1 zi95eQ)=ncTxW7QuGr>@HAhuZkNL~MS{8Gz$i_xDM21#PG&dBE8Mr?A_UBEvOrT(M6P?Qm@=_vdwSd=hfr`m3?R-CLC~SqLwH<#X{tSp4-JO;@QTqiV z44VyGvv8_LED4#zR;x3%(aV9r6tJ&IClABWe1W8OrGnad7$M|aHT|~TGu^gw-_V3* zM+SHHes;b>dFCb!13C;{FZdya95+}_OKQ@tU7 zUkezZe{R363~oGu^D$q%x(ZL9F@8FKYUhGa#!**k84HtTm`sTxhT4CS3VCPX=UL|MyM+-9htPQVtF%@4BoOPy?q+EyXMM~3;cN; zJ_Qr4_Na6#OH0&^`xbRqop0~b;ZbWSU){;E2X(lcNaiNgy)J3Sts0fP+>`CcydwML z3k2J{fgAl)!iHRP>+oA@SThW>o1LiIurl67$W>B$^#u|WIScY0hg@MX4Yik2rX;Qu zS*(3#ty+&+8y)%JwR-#Vh7GF+C6O;WEhT)HHk-Ud`-e7uP(2zLKqZa*;| zUv{O+Na}1=Q*n`gl==OavEv&V^Y`H)ZpU)ZI}$&2#;s@Hvl5ZPfhbyW%Qi96CxBE> z9-7H{_h6Olsgrr7&QHZW+2~`;_LGTodQpo`t*)M)R#!*pz=R5@Np^x?IQw?${nS#o&NdM@a|6l@%ECuLRO3S&I{y<-MS)7@{1?J&;Grsf zyd#g795*b!8X_b*=ugG#zhB?4bvdLb6}nATeCxjvh3_@0)@4E5HCMF#Y-@FvNw#C` zYe43wl4;98)b&080(s{EeAm4*f7H;BjA{FF{~f&;zEJT58|4L_so$q$w-ySL?e7ds zwyC+qWul@1j96++bm(sVTIfLXZuqlyVe{Me>V)o2U!F<~dfBzS~6G`48bLS|8B# zh|xT}BuG{Qo%SLSvT6Q3N4<@nFjL|~7kcyn0yy5l`v)N6{5T7uJg5{d9r;>zVcEsE z4M^~3x{bveAIj1~umuius|x7wHA63nTfkNuK(Ul8DVV@^cD|||CtP+HL7745Nv4H} z<^*5PKs~3@)5x&^`Kd=E_HM{?l&HsObHqOVtdy6|;m!T|C5~SpCh?s%Oqzs&h4ee% zyEMo`eEzajzuTuiq*Z>iv)_3GOZk%Q}PDf#KklkGFvaQ|KP0~?=M-;$A4lQc}A7L3adhe#pu@h zmPirlN*H<{299r;hhXboJ=GVl>Ed~Gj>ti+yQx(J-EVIoS7FH2VeijfP||h%5>K_P zpn&sn?O-+=eDkSi_#;CopUs+dEUn5S!e3zElX96^)KZbxq4!T{n!9cvXL6|H9ErvE z(MFEgPOO@Xp!0KzCl_9pA>&{+kjsl9DpRC~#{n zhKu>prD;-{bAsXh^~(^EprW?@NE>M z89BPltA?rYpAq_NaBh$G1fK2{>HC!jCwm`uT@sJ^Kyy~@lRdfkOjk+&0z2Co7LDYr zC^W<4`Fk&~NjXX|nVybQ)f$Z>P-j_0eMF3knE=ZO;3}Rn(6Eio?#tALstSXmL($Ix zI8}Z!ELy?Ap#_H`$~OV5o_u?HlAqal|37TWZ~gBMPFJr2L#>Pk7iNcrR%}Y*yjl$3 z}9&ZIrrj0bS99_uN9WEVezz>tsD=J@3g zKkv4kujlJ)gcLFEE8D8*XYTV$2Va$hyQ)!Az0Q==EX>|Tv2cVqIk)=yU8dpiylRe_ z`4vPE8Z#+9J&sJ2x4O^`FO6KDCL`v5m@YD-E}0MIIHCWsY($2Y^|P+5om(~6tc1<1 zq`O6`*8r<&mD*7aj;Q!7i%l`#yeUyj*Ck2Ti)Aa}Z2J>Dsge`QZ7oLOhun;?=xRYF z)jfEwoY;MLHjLkGaFTY)TKU~u9VJ1gmEbMB1J3m6YKgkd45@-7|@|dEF zG&fZ1!nj#gF8r9L&fkm5*eYmcYoOfD@YD;uBrXD5wn=>u zf37&GrHC7kccO1#WQ32nw+W*9f0fivfekT4|wXUbZjh+ z@DdL}r3)fPdWF{+H?-o|GH61~o!3q6y}m&DkZ`$vM^J>f#lm8! ztZ_ykJU3=xw^=~4vu%?gr@j?;%ws>SB!~7EvZ7e&DC$*;(uu%_^x-gRl%d0&>FW`; z_X{&_%F}5){e}VL1)K&A!7bdEFm)ei6G!o=>kgi$HZI3 z5))Ez;FA9Xe2tC_MH;HS)`1>{T7r0d1TZ1!ZGYNDqk@Q0M*uB{x`Z*+8$3RmBGthY{`}#3MuaQ1AYbNPWPWaJzM0IrmYr76ICIdHd24EV#N@J{qhFJO&t-5I-`WJTv!(v=_g z`Cs6akdBp3U0q+5{`DdK=02_fT_8p&>%kY}a9TvK zV}KDtYu{o^wIMWj(IG#xXuvy4dpK=Q%o zd(kwoQ6w&o_}&SS zNHNv3Ob^c>=LZ;cv2+@F>(RfSFjc~DRtovgBO3HoxTG2x76i5V-onXCKG~HBA5$|A zk;S?WK{E7Zf1}yK&m5g(y>0rnBn$%}oAm=a{~WtC{RfiYLAdn?aR(iwg%#M-nF!)E zV$#CkG;BJ8(1y=w`q#7%&#AJ0Gnc8EXaKOfqRKK^C?S-)qgem2Ebm7V1y^4M&LQt6 z$-fYY*VFAG7vHKFoVqK+mBRcgFFrY61q#JQtljOx3!86-`AMl1i9P=`xp?A;C$ni! zu6Of2nXl>>1}&pjF}gp@uCuT*@#L5w*Jqv=_jH((uW`+@Q&Y5#mnCCS!BV2UUehs? zs&R~Wh`6<^1Oz=!VVc)#eEwLH$KGOwYpXkNl^GW9X$!QbJbNuhS`B$p_BMPNBM4!bTl>YtX{i;Ib@Y$em_}cq_>{x|TdzZ@FLqal`3(q{ z2B|C(-^7U7S@lp};P4ZaaU|Xw}p!0vEf8WOAcj zj4T|{z&Q->$#YUk;f!q7&~`9>V(owr>x#NGL3b!ZJ*~KR8a^s<`mF^+KW+MX71Pft z`&Q7s_)^uvG?Zsl8O-zm2-_^sg~(Z-w~D?2JU_US#1>IB6QkOzSM9-G6r+j@yS3>_ zWahr$Wz3Cg*pcKteG>Qa3nbHisJ;4XyyU8dYAG{}YUX2bif7feOsmH+rR%U|nYW~1 zZ`}>9sJX+7%LZQUbjp2l^qp6;6X?S5b9*OQ zv~v=mMQ~ZUS~%$l$YBq3PyoyT63T|7=F&jQ4LMe8KUg#0rC9<^QWdPIfK6oft$0ip#r)cO}R6#v$o zN=e%AQT*?)e!gMBuXz}!^7NEW`L1q$Ew1@F$ls3ox6?J1+m5X@X76sK@^SUEY~EDW zFu#S`rxs=xY^RpSPZN)%8}6IYVA3N60z0rZ#QIc*-=aJ^NOz_hYFtvFCK@0&9}=oB zOEoMVFlE@iojBq&S;iWi716SqO4ziB@@Sw6C%I}_Vj3rwii`t{HQ9Hms zGvm5Fj&hEp`#<&2rY=i23&Laf;>`|@3q^zukap{S{-^rRDx7V14e}F9(>#AI~Tk0LA~IG2BPTC1N%T*4Cz`Z_P7BH3f6^OpF-Ds>xCm(i#C$Ud`oAj zNOadA7Ne(!414F7GMC&6YC#9M3F$3DwzrH!Mr9gDtT%{xpxi=*c$#$4|45+Z0Nhwa zCTyVt&~0uTHVD&!(5g>iaJ5hB)BC21y+c#14xd-Ea<`2>7}yR_;0wbaAtp7x#^Fdt z{HHyyjZx-iDKzg#NH3mCx%qX zx(#}b1>jG@nL{vcr{7q$(^SXS#S%8LdzV6t5Cu3%pZ<97Dy_#~S(1ds_|o zTuW!FxhHX3tD*xF`|(YlMMs2{WrpJ*T$y_{w%^64@($B$c60ONu_`h-u9CWa1G*NU z6Amv@SGPNx)WC1*`OhijWtitE`8>L)O-I4zsmeTgs)a4=IZ<-PP^7IyjGs4?O+*2^ z!_rCdD3$6xTe51qGJXOQh7tIfStWyry%oD`N0dJ9BXo( z5otg3brh4E<>DD6YO**pbFs|ZB0#43La;X{zML)W#rw4x__fM@SwC&2`{n^k$D6kn zZhN@bJYx=a$yd!#x1x0lQkwc~FgIb>M|yljml)@SIr16aVZ!EIFi{0fZ)#G`ZFHP% za;HdgvsmnU0ouTD%EGsm9SZjfB^^KtewN#5IQxmyi0xS*E_HaK`Y={zzIkER_k(s` zBa4aCe7sSODzoIrmMagngc&Eq&4%<3OBU&*jM=+#u-q$?NVuDE{blFP56Sp`-3^c+ zO5Y?sUZ?~aZs+w{4rUw@WSsZqS0^c^3BE8{3Trs>*15&wy&YUns}hm7zmcM5VQ+K4 zf99wj8%A$1ib6MsKdz)!FzFhgLJ+x@C*94$cwrm2fmOVEpHx8@?2aPE zp9CFe+my_<)&16yNUF=iNXGItg&1zCS$2x}xw~MX*L%Dg`MGdiCcleRrJ1(EMq$I$ zn={nJ2j7r%{+)+8Cv?}`a%Xg(1Fp%HAf}Mwp{;Hm67`m{&HhpQz3QvmHcS!~|Mhie z(*hvmLx}NaAl=%?3to6*cv_2*Q%D6E?3zap#5WRD7rkzifVqFP#sh*>xHH^LTI%hMx{>(uZ2O zpH5biX-wc9oDQMr*juzsV z&Nz-nu#6%F-7^sX(V2jSpj7ljOmb5Y)&S*?0?<-2YKX!7uQ-4?r%@YA-b4Zlr?v}< zQ<9Pd(Jdy}LbiVN(_w4sFov1S4twqy6}%sS6)R3k*bu`zAt|Xr4o~jheG%0o zST3Xch^ph+emixEz~i%BASSKQC%_TgKn>Ho=u@ zRO^UN3ggw>u|Q(Jw16n@cp4^gLaYZquTY&;gwYenzZD+hNOi8?j9sGkN9QT}xbcMD z#B>>0Ye?-1CJu`r%0@u>jIqlqSBc_PtaTVs+#H}xQs+Hvw|9`d9eMmk?dy&QHHwRT zumvr&n;4D)10%1A=?xdk&A4~DgY#sKOsu0g0we*xbj? zEz!zP0@?TPPV29VT~AyEG*M%1E`sikQWO^!o&1SH=8ti4!v!S;XCU6}Y4>vIq~z%!xp2Tay-UNS>ByXVyE#PhYdJN<=zihaX|L4XyG?xf%DJ}kQqs2 zuRteUy<$LlAZ|cavn96WZ3b;v{#jkRq{RvedADNp85<7~teD?mrPy@wL!L-BYv=$y z+HyFIiV!P|d-5LeiPWatac;TP-d)K!s#8G)j-x>LR%LH7s{yyt)@k^pnmY(fMO#d* zpl)m=Qj-XWBK`;mJ23Erq3fSA2LEDs|Cf$zg*Gm(a->EgF8)0G%RMKY!dz!aRfXLe zizXb-skq@5>2L+==BzOGf`7jbyYDgkb zirZ8JB#A_Ae2sB_*QNCLw`nIR>Hcq9%q*qussSj+=~Bep9>=(gHBMm*T23UxDq=O@ zR|YrkTJ@-X)DY?qx00uR$!-$PLWk{uW$x{;Q?XLrJ{0M+`tqB1lxMZMe z@5dYLe<47BbG-kX4J3aU1ojMU%77EG)zA9Vbra_qKyE@J3jD|*B@mv(h{u23wTcoK zP`fHE6V?46JRtvz`GLw|L%zjjQ;EQ=Br@Is_z3i~Xa~uVUR%?EL)9wECGcXW9;mPb znl821li$<>>|dMxq5~5d0~WzV9LiKBInQdQx>q9YRixsMB1Q&;H}0VOIX-5P2=M=- zbMl?@V_^7u`cqfB|(>A2Tr%2pJ?v zMt!sBe*Hb!c>mrvby5e-OrP%sHKrhtk^st2-NF`8QaBw;E8?EKI{{2j7i^V4D5SPC zf6C!k@)jrLPcGjd?8fxB0wJjH{KG%!yETtKr-BZ%XXfC@t3G6Na$hdx^nor{&jb*V zmUE-HOE2D4y{`IfdmUtFp(5S)|3!7`ryo!I!7Y8fnvM+qor_lSJK6Ya^>a+Re}fSG z`20KI%&ExzgY`ll13bZ%!OeA$FqEYuOZjKh0ZaB5vme8200!%`X^;0J(ycM{Nh!KO z)!oq{j=?%>@!VU98P4C!Crqc{OZ##`T@(o00$>%s?~KwvDBCmW)OI2MDf4nRaIVVE z-B;a0$+4c--7NgHZA5q~l2 zf4*mYwG0JG=iJUCXKn%#Cai(e_mwRQV?)#d)%KNDfe8TR^8X+y{yqtfz@t0R-#?xP z$01}taxXbZpXr^4%{T462;<$IKbanRJt|g+{s?0e;~i(~P$He1C@&=cL!%75Igja0YjWtyja(U$kc$@{jkhLiBQ! zzHwLTSo6W1wL?l^Gi`t(wd6Sqd=b;t8Hok&68>Q~M1@>(pZWeI^Sm~GiLi53R?r7o zhZcr}EA(1c6QCi~ukBPgV7u-e4W$)JaoGDNmbe6>%%d-)YJIT`QTn1vswkca_{Pco zWv3URK`Cv0m9ajVm@qYe5wh$W+38Snqd_K=!T0^qs*@2RJarR2Fqbqb5xGV)*w3>= zOkJY(Er+bYYZvDCun?0=TSBxTR^>>elDh;qs*mu}g!Xkq3anhAl4AJO5SBm*-1UF;_ ztjGZIU7E{0I?K-D1v?#7IK~I{$OKcIV%%u6${Qba$+)aExfX~H#DV{msI3dA-n?imLK(FWlsK5%=+_m zhW%7A+KK=CD3b@2*t+aKsE6)Ix9%UcWQu=)~n0Xc(1*JIjZg7sNMAFTE=4snS zp9$@Dg(Y~hKly`FkE^op#hVm|VF3)qzz@aH=d_6mI;ksPT)3kmbM%?q&h|@eD%{Ef z=d4YW_wH70FePtG%Tlc%9xcZ5^?GzXH`j&stQl>|9#$3eu*+?T)}O0ATdN+t9Gc_# z-&l1E_`q6-vTKOd4GMtbNquS6K4y9Cxq^IA&MW3>WR0? zQafeO&^{p45H*%#X3yG3nNcq@YB8>s#+`{P zZN$5qRL7qEI0yw;9rk#nYk*QnMo>_fQ$gU@&36kNqs+l^8 zrJt@T9p|gl$Ql!q&T>18H5!VlURkynp!2=7FmI*v(i`>U_Kl(|o|b0?6RMtJwH4Pa zOF}a5W{Dca;=WcA;8JCg5?r%^mdacyHYOKF7~r*e95b1iNoaKmNs z&R&Njnl^(k8#@U*Vgi{{Bo4l~;;24zyU!qhGOU3$B26bCk8(*W&bwH~IZDJYtUWhS zTy5;))r%Btw~e?L>ric!WqUrss>iTiy2qzZJ}g>VMvAN^n>gOMB+e%G@*??z!2IC_ zZ;k^c`3z1c>&NOH?mI!F$E4DQ@QJuuie0-~6(A$*GecwLxK!qb8AW4{jH1m|7=y^F z#Jw?(C6&G;v(K)X!l+RfYfyz`em&x92jAjYa-#)h%$!XFolD^Kj=P5g8lF9SB;iU6 z7n&i8V4(r4dhXWxB^Ggw^4A09JT=sNI&a(F``{wA)!prg%oXG=_99|{e!>Q%(tjV1y{A`V_J$S z-p%yqBO;Jb*7T8SIIST3^q}%> z*SS8Po|?}RNGuA;5@a`uK2d3VD$QZ1!iK4SjhkpViK^0 zF(#=VD0_#hm-O#!vGU9=@1Leu?97JTkDZN*su?>A#mH~&ZSH25XP%Qk`7O;KwvA<> z2)d2v3vzp4TE)!iz?-mpF>Dh|_0<+yg0Ddh)71_l`BxyAYuX(Vw&5lku4d;CIDmeL zg0l+fUkSequKg*4LDu!1d1gDe=61rs6*AO!sxy-RDd-TR(ytJG)Px;bAGtx1!`wCE zC&VPZqbKZ3`+s%jfm~RpYXwdCn9duG%`Wd$Wg0=(e$Q}LQl_ne47*81HFfQH#^r;_Gw02K1b#+xt?yFyZ z;y^hTG3Ug4#^6Co)B~ zdG4~2k-=>Yp>|YlS_Nxqbw+T?ky9TBwJ|Tl_lKqFBlR6n$kS*_yaH|k)o*Ud_l`y8 znOhO_-bjP^gYuMTjXdu)yDzzFH4o?3olniFL|JI%^thJBtY)u$qN2^X8dzm?`wh-x z370}-njc(EpYk|MNUrLJEr#L+b1knBjm9!N6m*;r9ZcCpEidj;-=U8iTIg9e#9ny; z3}Tb&ttpd)TdcEXcc!!*b|_Mel5Yicyi`=$q%9vjfM;IZ?38=nC)}%JbEGp_YRO&i z&6+;7`3m%m@IqK~6r1l$E~Nq^A(M7~418i4uNW}0mr26a`MkEC38kLlkZpH#cL=>e<2`T9sWMC8tX#oKNDJjX3ZctjfyStg89AFrKTkkmH zdCq-5=Y7ue`{RB8V8;Cs_FntCuC>;+uK0e9AcZv$LiN7AJc*kHL8^WHCPgJakJ$pQ ze*j;b9@rO4xwC7GMPpWee6FvVP)9|#Ecd@ln@HOlIc*%mcdeU~c6foTSvh3yq5&BK zjihW`4pZm9Ab=*K%h2mkdw z0o$v?D&*b3R)CjylYe`$c$9Mi5c?MU0GmI#zrDqcXLlDUfb}dpaa03#v9^><74U&) zSr=rG3x1%W|C>mD9h@h1_mAg?{XPnPUj!y}czS=84saKLn>YHtx#7TLY&?M>?*-Kz zKg1cm|CcZP%RPl;9{`2lTd6>KH~3IvlMQ=K+-ScZ*l0EF0?op(l#@!!;5BeviD4f{ z*fCE#aKW~XZA~V-{nyj@ee92U+xz3Qhlq;~69znI7^r$*XogVVRcTR+7Vt@lI0xki-jU<}()Oc!K+%n;}pX9N|TMokDy0YuH2^j)Wb zfN)D-KlBg!MG(KL_%yf8z;8blXscFzfi*gvHAylRG`LQS-a~4TYY`2Agd5 zzWahAUvyBP!Dbc>n5s6+{w8B+eC9eW9iwKRHO*$!nCOFR(tN;CEm*kn%EdRK#B7As z=6a~5lKHBtX#&`#cKU=(7ck|O&gge-uD7Eby#)~GUsF0!--W z6Cx3uJ9vDiPxVIETjpYXF{%KJjc%Av^qkJdetW=DA6r>P|JaauP)BIrk$JR5e)ReC z?q=bP_lUKuKu6Ci)EJ=I6vP&JS)ybf?y> z382BbaY~|>??S(>H9uAJx_W7#wh4aP*TTjfLgjqoar(!VlAz-}zMQM9%Sc313|0nh zH;Z6cD@ExVT8p1EP$lK-q;|ahBKBS>{Rg=-%QVwk-QdFQAPm=>kvEq5BXX<6EZe|R zS0lnu!PbY7U-1#Fm5x=Kr}qe7{zi00i5xm^ zqC{VxV3t(WODiqpX6UMhCOHc%Q&;u*DyhhCIGQO zrEh89a^79k1pjRCI~}1e2UAeNr zyb`_u)PeH1d{rs|xfN?&{EIjI4)pUAmD(v*yLO>|G2wqAeb8yT^8JNGu>o0qXD%wsdmna&%AffrH^?3ig)e@T+hFxBI5>~s)dh$)&ZP&#RbywM(PfAevAK}@_w@j1{X$(F7c<$FsNswT7D z&-Yh+r1M`|@*u8X(jM_-5HDMQaGH$o6bqh-7R6nw{y?t}YYp?(>Q)ak=7ygPFF?P( zlE;GyglY8gZ#w&xsZ^NiT-R6K1yz9E$LZzY=Gtf-+(U*8YzjnXPmZLLQ`X7omFQ~NId_L$0Kye>!m6G@@xx34IR`ta~C#! zq|yC0<0Skd1hqXEZS*zw#TN`GefAX?-AS^jq(Vl>ToY~k>nxW2`s~nY-V1p0gZoB5 zY3GzBYn`9LkhrD2GkTMBSTqzCP*J2`vd9=zoI-cV4$Ff0)>+x`aK@2os?_m5&vIf8 z{^W*xRRq0z*ME*^Q!2ADdKc}o^%iHhBvOW%)Y0UT@U!7T$j^p%ia_rHHdnH4(%4Gy z_}L;H(fc9q)#v*$CslS|k}ev-o}z2gt&VK7=IYM8qV4b0&@8Da72Yz)^LKq5{&~i@ zlDN(~t!q4qBg@PXIpoXG5>W&=1%IhLdM%y6XvTVvHxiAiUeQRmC#c27}fCzi9 z3Zd{r#+COcALQv<&%nzRLOVgr*h41++ss*dmsrxVOz+~H*t^18mjvvu2#-4*4N*^@ z6+P1`IHM|Ra$L9pWU?HT2A!ZQIoAyftapgs+7IkUGp51bL>h%=G=;TH`+V))HgI_) zyJ!~pkPZCsGlmSNg(8{{3z0R-RpX5RFkz{yTX0TqytS9*InvbhLJ?~w7v4uiDguaO zY?B+(S-`{&6up$9ywHX+M0|1p-wG{M$pEQYBp!em7pv}#EO~Xvezz;6_sy~MihCz4 zjCQW*mU{13T!KGJZb`ztH)P7!Ad8cQj(SdJuJf2KE=&fs6fWd=CWNDaTvDY2elOmGMwb?q#2X=yT=IqAy8Kce?L+@W!TZ zi!}|X$nHwysPGmQY-z{sJL2Ky3RB=S9v;_Ip~q+CeTf=|kfAAaHX;h(Cw+p>U!Feq zz@Hmq=-XN&3e`3_>6hqa@}a2E!IXf;k(bD1=IR;~E)56i1&HR|`*57&inw{>tb+;sVkye9qIU zMWT+XS|@|$jAoN_jmX_-?Tm_0Ji!D9vG%KdHxVMak4-|ANW#6fi=xy3#9k5X`E62w zOwk>(-I0V*wZ{^p{QRw5=3a>A%T5~W&{0Yq8WP`Hk(1FX8BRdsVdY^u-$0M6!cNYU zfvkWtpwsO1;U*UVqlo1UL*In~^^9^rD5VYnENJ|#+Xnd#?S}PG{tmkcK-vJUq)EW( zx!dvwwkGO3I1>P$!X*!sw5kL`R$_Qcv~hp-N4KX!-j;az1`^Ov=oUFZ>AurzJ; z7QzDMj+KznJj5@O7|J}2n*wm7{JT>5uSOhUIVFxnqihiuqwk1nqEGq{h)&|xA+u7* zZOlO=fYnvhY~bFVV|b9f%`*6k(Rn)BxV$2hQfh zdPz%G=0&3EC^o?ctXR*b($}(urgfMd_%&>&9sFkt&Ij=9<;(y!&?AzS99hVx5Y%*2 zIDqK%*CTrB_)jJrS&b_5?+$ie#V>rM*FmwBG<-4!P;?2x7n}iPwK2OTZqP3l=wx^a zmGMGw^g!ieGl1`Dik%u%=1BmM7Fa*qp zk!Sm!SN_y;6ax?yT198h*%7VNs;`a(xaTbq7%(=-u`HU)6A{8Xj3oPrD<$+}_ zcB@xmVu_AZ2lpsy_8J#n*n=+XRppt$B~HJ`QC1BoX&+?g5yk8a{{`MJsS%H71EM*U zrmVMCQ}&>Vyv-V)8gMSx-YPwDsYT>xE-y~k2YF1>Z}}~>%pO0tWsjD81-$|ECh#df zbg_%Ng^%-PH4qh$C$J%=vzZ$^WqnznC|LO8PQc+O=6s4tSnO4FgkZd3YBolI9OZ$E!Uk9ZJ_GnKVo znHjL{Bdzs{E6#2siaS7%lfx#JVg04=nkVnR`88x|K=PBNd5zOnQ5uTb*u`^y zsOdh+jWE0PV-`2UWz{Rl!{Uv5)v76M6qz}n7zxBNG%KG}v57S&IUSfWsemtCXJGP{ z)c@F`FX*V?BCWMaGh^)@do6+0TqyDq=tHv5u5fyxFbFTJbv_T8k$Ef*ZEm?2+I%+0 za=1;=UE&Jvbtw&f(-&(FNqq$zCyBM39M!@t=G6VSMHcvyZ`C$_e294)xR4oZPHjSB z75SB>w6-N(@6zl$)!R{vuT;ilFv@laBbg$TrX*$o*&Eh#P`>com7)g23g_w+!-8FTkNP(0afmGRZHR@ke5Zp8G4e=zdqmt$Mx?}m?T^9Xb`PNR-mW*)d z=vropmzJSrTxsA?(PRThR-)wABQM2;*S1RPavoaQP=Gf#Ks4$66gqn^zN)H5WCOJI ziTu@jDRQTz1UJf-Daj_+MP?fw$2Sft?nb?OM*S$+T6Sc10`$7O2l~;_aq-^HXw?zh z_AokR7MgjbT{P%9IV^L{!*i^-Rt7oF?GksLxK;ka9?+Iyy(>RIXA~q*k|3fNx0x1m zJ9J92H@3R;97x_i&2~yruVQ~jYbp_pM$7^QnD{vf*T+0s-f;vJStVhuQ_UulB`#JuG(BgMZo`shdNvK%F{!7evY**1n1` z_9CPxaREwdPiC;X$vUIOI9h0Yvc*-0U7*8wQR}s2H`fa9=JZ+EL*&<> zYAw&%noIZQx(3b3y)*?1s~z(x%<3tz`z~hh?)x=8be=iZ>#V|RN{Vb#pwJ|gd72WF zRo12z>Qx5YUoY_+HCN)TY36gl1SYVn-#|5c`SHKF6sY0Ni}|CUQEE3UOJ6H_KNi;f@19;ZYsblkq{j^SdD8KH-wAjl91H)AOyZMGiJZqrhU_pIKl*xlyMf4|taQMNNQj(p{0zA9cW5E*tE6}F1JkK^;NF@&-%Rj|L@MWE>Q#}F%**tHX|_{F)g}U&-F)I zK3M4{oGg~qwCU+9v)^lY6vIcE`MPm&!OYATHzBz1qG8>vc`)4k*d^>Um;_^xViglv z9792tkwnZyG2!h(c#$s*XB9v+?fOvO20AS&1AT4(=qKi`Q}cW+ z?8vE-f!?dEf+^@54H4+$#)Cecmn=24?+QY?cl^!^VFe!~S_6p^A8>u$?hIrb$av1> zF#S`~6|@MMB|;)FXo}-az-`VxeYLidO#T3S4i41nc0UIet?6%|z|>!^^Ya6Ly$K3D zM$y1X8v{@SY}F4a2zP@>8Nvt-*KQ#)tUtpoA%`3>N&6@}pgN$3$ZR%@0R6;{%u+7)zJnYm-(24 zWJ|hc$YhfVyNw=nhR?G0(_Qa*J5#%??TSk16W4vPk;A+Xwc0!tu`(l6Hr1~fzNnU?`@*5e#U?15SDx;8RyF&i9C7pe>&b>BAbdx}qm3#jU40^PlBQ{mV@>y3j?R<>|s2Vz#0$P%^Z1VN*}C_r`uf@rubrv(-1( zkHk|(Qo(@(x-+~)>-3s1{R_00l-j8`2|IOFS;>v}=JX=;;q+xr4nQb&`W6+ub_CEz z=&l11WAUI{T@rN$)7#$-5Gc(3lb_%o{@)qKe?W8b=y*V4Y7}ear57mAhVRy89uK-H z1AIX7Dn%y7!jFSbzG7DZWtPDn{MQNbdxU>K4}R6t|Dm!MP{1AmP*JFXM-+RFGnu9l zv4B%ncgnn#PNmdW{(a5-t#We!N`f=UN(<>A{#W>VZAx9=K$F1gk@5?!8j;9*C|iB! z`_+x13z2q|5Mb8Y!%K8etJK=UXa~O^=#T)Q%t4xYf4>y`+ZyK#dcPVcUiCpB7W7=j z#!8PM)P5J>Fkd$?J6%7WO!n*#aNeK3CBK1Q$+Hdt2qYnP^;p&+3jci;%|1N~7lKzE zL(LjvrVn*48Gkxo96?1T!{gKf|q$-46KEG;858i|)1Ro0COOHr*NX`}u7^Gy6Z~ zmwSKZg9ZHwU@W=)DrrXzqxuc>sA~G;+K3t$scFaY&u{wHv~Zy=GMO1d58 zfmmr+$a!qt99FS%Z2_k7Givys$%59?z}^)U-_s3FSAlxZ9iZm%ANY+#7f<&utiK+t zTd?VYPXZ-NJp?X$KW?Z&cH?%xn1v5x-aCo#m#>Qy3%MbRKn3mP+m~~MQzX9k8QFgW z85i|q3^K+>uCWwfwic)2AWjYM5g>4*X*Gnmt^?KQ!?|@`l}A*{Iwa~glfKmV?_2|_ zGYM}(cEA855W?E-g#6j*I?jE>pThZGaeC^fTUMP z@+W|`qqMLpY7KQqE7M2xKvnvZQDEJUt``e1CPm;in?Rmoe-w)9#0_I!Ail3J0Mz49 z!s*a34Mq%D(huT+mJHPQ-m=p5^9@Ejf|hmr+hD9WukD|RmQrM=VU-p{rw^IMl8zPu zhbB30edUkNLA=+fB+N@71Gm0IaY9mgv>`$oh*=ix+(SzG3 zE*01(RdKTzJ_)EJ%G0a@oRyW>Q3D&Y@%@eZ;X(j4*dKCs!UBHC?_9_R1V3M0W&D{MIhRQuc1_&Vu&nrJ zCguA#(sL!YJ(sAAAOdpMj2~^9C%baRG-yPRQq=>w^Fq}#N9H7wY-G#trH&)SYy&T| zM>pM8-Ez<+tLC4BoNH@6D;j=TzKORfY-2v9JKs=6zV|WTaA7m(@QuNH;sKBA8VAR( zt;hFHF$n6@PvrZt#f&Q+$iZ@62|)$ya#UMlfz*bPIJB+_;ui-GfSsj%kx9I1R&Y+q zhF8E#^9x=)^T4RM8wPs-c^ z@5w{WELoz4SJc?bXZ8^91e;HWB_nux!~`iSNoE^wZxfQ&C52D6LS1EORnF_!iz^@+T|qya zeS?9*l8du%wQcOH*V5OG@WaV8DBB!G6TjB6&3g&3ZMk+vRe0*vw8zSqqqrf$~t^?Ku8 zausv*r+M}rH2|P@PH^Cx%RhjF{wfB?(YvbhI!Lgay6W2$jZV2_yHVyuD#O2xyF(MrF@|Qr8Efj!BC3Ja ziuj|-Z7nB{vm|W0*3}}0TTk%#=4BDgv&tNnU((|Aa^yVq`a5>Q$YFO-(>`;4!`Fw* z%Sy`F(<^L{A&o^8{DtN05zsZKp}MPB0(3J4J(^_7{=G~N1H1Pu(bdB#l4F`jUKK1k zyCiiAkSBI$jVxI4w_!+XC1OfQgs#lZQ%dvFw^f_2tW|Ybr2t53n_Ks zt{!?a62o-gjd@ma3betE$M#Qw9h9{luBC2S!2Sopd|NY*Yftez*80kw`TdPEFT(x% z${@23fy}O#K;1Ubw?4;>*KdIk25^8L0106<;1JXTj`yonER7nFQk`4J9;;z5J(m*? z+!aJ?e9!hK1tE?%A?Uz=x{LK^$j|YeVZc!%$plihC?MqQ33vvS`rX0u3{czlhBCL1 zB?EUo;!rsMJbO8Fe%I*S8-xvsZ+=CoWX*#EKV*z=wM541GriJGI9VF;I>m&d_NNOK zGmkXXxqZjJ^)QF$%B}SR_qA(y@m1jYE^nsj18rNqOE+k1lh>%-TIT%$B#S zHo9l!KtSjtFbz9Z!@TjhbX4s)mIDa-C=nf(>;<2@W=*0uuK2NuzWzNZ zJJ}UvPvBmM;rWpzZyG*1NWv(9k=fr9>k#14V{jDB=8SL~z4p#MZ=?1_@DnqOX@Ll z{d!X4vy69rce1Qr)3|W6lx1-_jCRR5TxFeAaF6+`F;e6FH%4)lHU4jJ(}(WIDgbWB zS0bzrZ%PT^&05q=7%dt~&Dr@glEeeMZa-y-VW*7{v^U1vQFq#&e>LS~3e~^6%hahD%>L{_5_$1um?}Jj zf%6Q$kfi+scm9!tcQ3{Rfg}IvhyBMU&pQ+x6Qt{Z;s}sM15Vj#wQJY1p>2{a$2b=E z)VtZP$9uQl%UUiU;l>-<+dd>3hYzy@4z7Ye;&KsMZ+10zCN{G|9hGyA0mQ&CzmNkt zWrhDS4X>880wKrNO;5nB3v9b|qT!vy8MC zHO?0G-`e_r$IAXY#>xT@gfAEjz}JBR^3#PgTY>(#hR;SNmAP$-1`l7=;W&!lRrXbpT_VRb%;H-=O|2A%>-k(71tK}j7 zy62(!2D}~d?RpCFUJp9Hc$RaSt3A~MUgfmh|5Jivjn|%H1I`CQS?3E#uhH#)s^ z0Ouys;RS8oPp-a;CVhwmthGV?&s41K?&f&f%ZKnG3_fz@!hC4>|PaZ4Ya zlpQASzrhrxc{RE4<#Tc8y*!Hll{vk48mCZS?ey5K&_)fVtc-*EOe%zKI*&>m`)K$R z`bkZ3KDF#7lh6IZvh**%d6?dM_1p2+y+yfax7Ui44BIz9JYXwGK9kdhrGdAG0Y^N8 zZTMccZZ8n@l)$FOW&WT86#I-;o1OB_S)m+9XCJxTRILBWA=iwgFhyTbx2)~|+Y#D;*5dONfQE6fe z>Z-hF*bEA(%?{s^zaW}50RQ_hnbG_Ava*A=V`VoW>qpavaqH_17!~C9^?sW?_CV!y z;mI46RMd~~5^o(6u@&Tpo$?j_TVCtm`@RyGpM!6MbBsU(YaU6x4L@6iCBfi?ed&(B z#Upp4*N|JG!l_FBt+i;h8 zT{||hh|BeSmws{HxAN&C$-IIs8k!IP>-VtLwQzUeKeKDEMaundj`(N9Qn1+*f zq*s9e>$>1>er>hZ6lany)gG!lk#k|A_qk)Aujbe|xug@`xI;vL@DJ zR6ZE$oX?u@)3bz?7>i>%t~ic->Q2Th0X@s3X@=JW>*W1TAR_&pLl!bi2GFQ30f*r4 zG^!w1XDtRGdun&0B&PO*1lQu!jw6aQI{Yh*-~k+c1HLAe4gbNAtN)~GCPb?GezIrA z*4mnz<$;=L;O>pZ504k6^z+s9?-3Q<*Eh#ptlMWEg!~}U1?m8%k3y}Ieh~2X5Sc6d zLg*Xz^WA}_qEqozeIS|p)4j485kA!+0z2@{yxVF53{2J&W_l!diN{;zR`4=UgXAN` zdtkf=HMGqO>x{vMPr1k0%EyAjdle<|N~}*y2(+=1FY@ysflc^l0pS(BFYo?&kV-!f zDf!IlK<@gNTv<^wlzf)SNuF2YO?$AAd1F{PD?>pmOAT+vufp#ApU8^p6vB(UL?BG` z>C?P)#uCi&*5f-W@-q_PYC@t88OGJRf7b~9W?cH_FZ(-tw%EEPy6r{w0oLT7bd+yi zpU5b2^fZ!17dw5vt?F@-aYA13JW=FCPjfUD?jgck?W@8O7E>_3%$V zkLYRTylvtzeYmu6m|ijZS7RDH=jzf5Q~|fH0wNwRp6+-EKgAkd5no3}kI^uB3@Byl zDOK59^NyUd!cs}1@?;8~zf>=ISK$2VIRrj+=jSrg%NQ*n5WfVhr~0}@r+?QN9aKMU z(z9*rSMZ^8@n&bCN=WO-MgwyoX!}%jTxtBlP$e0q$bxW~BbzvKQHzm^0ZWFQX?;Mq^ z{%yONn=Wy~MApb)>wz=!O2GOM_I&-&HxQYua4?9oo@l+DCm%JJ?9HXVK&uVJR((dfb&>k!C^`HD|IVG*r9>Exzgz8nQK_E z>;HEPP`JipKL%!2Po+z-i;6dJll>Q8Fmx$GqJUC3zn^%>Pc2&t{;u*65R`FfV=a0L z{4sFD%qPd@+@t9!d_kfKCv_+jRzZ&0di^cb)EiHV%QnM8&>G-5?vS5$*QVE#9uC-y7Yfh( z82@d^`q(&-vP=Z@bZTu-OIfWO)$Kef8`~0;&>NROj+Jd;aQDOcG#{Fz3tImklKu5( zE)3&eqIUckI3#O~rTqg@rHAO-;QtHpR&G4TmRrZH1zA?Lwg6%Tm}_A?B=L6CdZN>C za?{h^e-O>lgdtWT<8cbjQA>xcUeigo8#a$7h2`X|0VGcariT=_b-9yq|2RVy!C8OQ^D2lo`-LMUl-FcxN(s*GleZarW(VMDRWi%6%d2{#4^+4>Duc?dtzRYoBiNqXX?_wH+09Wq!N1AYkml;Ls?>f~^TwB~+ z^_#46h2PQaqy}Um;kY##Ud7M?d)EKqhyRr=oe*;nn&6VTy@TyP-%`aeUTz{eFL^Y$ z#*LeiWPQMN)|zD;+fAiZMc>tgC%Mefr1$>yYeF^9J>1?~#E%&6UR$74`7JN49@_mR zRSt_e>{XzhaNdbi?DG7IjdK0x3H0lP`jPiwY8X>GAP=)kcdvzq*Da5i$-FJQ17N%S z&_7UV0cZo4=~Zp79UsH^60g9sc{Umqg?@OUXqAmb$F5gbmD{{6ob;WSOXfce@ZkWz z*W_a-lPdGEOVF<-l2d8s%;pRkmeEF|V=nngbsfi|%B&kEuD!wY@66Y=KKyBoq?<-+ z*uj1Q=k4z*r=AQM>4uq_@8*cV_KY54vKg=bLkj?^A<5{crR?TQz}<>n#(D5YrJaf z%2N}zd^zhU-`_bLFx?c=Tw&q`08ID02=JKn`XHr`%ltccU;*<!0&zqd-0I6?uJ}dIz*(;5$??*cJNxwO!*QeESF?;Fo_Pm6EvrysM&?peZ19OX6v@4Zz4s1^4l1GlRX~e?yY0Pf=sI^+wPBN{5mV9m-447=D|09 zziITn>k&A%FNYF-*aCgOC97+dO`d!Cv8)OOmBB5 zP~W#{QM;V>ba@)Ss$~HP^ru>)Q~z%6Ee_iVEA#r=d2BOK263~U76A)BcxQqD-yZ$x z&zkG{T`XBoiBC*(rY}JSk zqcjIU2VDYh+MK^VkkKcaC+w*EMZJ;FO{AUBRYf%Bx(3;WN8!!!xDkuBz6eV60~cCf z1iE{6@R-2zgEu7LYCj@2a8kid65$MV9bC5{f zYxi{9O#hAEiLnQO^d+Gt%x!bFwW2MzH-d0B$@Vt`tClzN6UWM+%3is22kKf`jcL5= z8MAoiInH-*Y0zda!aAdLqPKWQzdEZ@VXB}p{<k~!66yktwm@NXf! zgNYvzi{k7~`^etNMZatTQ_wX=o=q@OV&&_z<08pA2ykAVDLucb&vUpEbom~L2?=WE zoy?Q#tiPPrub!g(DquI2=u5=E^sZEKlL`gdRwsbqG~I6p$5{x%7zlJ zn^E2@;Y~VQ7L<=M$hm`RG^sTZ-)(bmy-zMTN4zVcY9UK0qh8Ds-F zSD&G)+BPO}+YK>31k#~NbQU8LtJhN3Zxg&G1_YikEfxo*|$D<^lG_X>QyU_CQm~%_Tn0t z{u0Wr{#6i+ z>-;c#6h>%PhDpk-P?=G3yX5Sy(529ARgs2tm~TWGI>pl z|G=8p+H%e}jm=kh;tOLlz4xpsZxa7UoB(7@QzeCWoq0FOdhMl(W4S_|7DfWi2%U`Z zwiZ&M3xa%w5^dqhd@aw(Ji`3@^=z z;Dm}8k;A0W@Oe3j*%t3wvDKJebdnx&f=aG7pg?_KRd`EzbCZpIvzA*T@3T#xtE`Q$oO8pgPlTcCYjqvwfdB9%2zI2D61C3f?! z;kdb5`!{5Qt6%H05tbRW@KWm;-}gOJjAQ6NyM1n5{sIczK!9SeU)Jv3x1+L@i1dAU ze6!M7WcQMb&m#%}Infn!e{eS3(`D zo{~4a)@MchD*(g=i6A*TrcnY+Vf(a%DeP=FkYqdujnS}-CN*rzoNAGMoy}LiWA0%j z$hXqW$+b9v0|IMhhUE%ko=lPS*rX3~72vPA3E9wAx?ir8%y^v?=Ppd_izl2#eoUR> zwNvYa)=r(}tewVlu|p#(BqXp*bzX)CgiidOd5!+Dg8ZZHA_ao7H*c#*cQmZDfp|M= zg&?ok&`#N9eu-4wKimdW_s+BPufbo3ydB)?h zhZ9GluRuX*6&hvxt0y18u@8#qfpQvBa*Ib zMK|>EHcH9UIqP#(7iOi_dtvz|&1%lnJP9|nHe(A`>+39^U^XxjMD(K{)E%gkqD z%75u2(YQR*pxu5A?YQQEb49Ub{xx`^ej&r$^}F2??N1aO%QD!xP8Kf2(+F2ly3MbQ@&7A#EE*L3G(KB4-maJOPdFmco1*ESn|eK zsBSH&x$q~Ab85f(d@LPn62j4~6mEKdN*p0|^6sf=UoE|O7{}&N`bj(jn%HYA_{2IW zR(;%{Hh@t_FUBTY*m}~s*$e;q6w$!I4n?4>5kQnpo^q8rtO0=!b(X%Nm0_H6u+%Z_?X8ty<5J{3CWv-lJ{S~=Y+^2>+( z^i*97w_#EN0si92XE#`v-Q5^Wi>^$*4UP}ub2xK0p~I8yz_TiFeAX{~Mla@U4w{O% z8uH=#s1B8Z6#kiu54@@7IF#q*7$wxnd0S-F*wFs^D7EGqT8WxIm{<<<*%X&Aqgsj1 z%7eROZ;%A|tp#M)7YTdZWJz=0y}Zg=_W?$Tgh|W#*BiP9%~_a`+2r%>eEHy+A{DY_ z9v-_DY0otGwo;OhWM=p>fVLNbUY}!4m}h!Rkk?!=6T-7(ajiAbQXrqS;ANMqycWXh zp!*zjyMdJ%={zMxv&mOt zpvYZfougCR?LZjM?5HXjbHOb9PB*jbV(Q6UfQrCm)$sW_e@rNlD}irVRsfMv5AFP% zIie7MZ|58xem*_TA(x=Ek9`}Ijf!hKH{Ed==qZ^kneF)UuZ=T`5E$dWreHxz&4kz!++`@0j)$|F{@$`rW>t5YtotCR@EAQb3l$H(Bi{T%G0Kn;NUq<%M{t16r7#ekbBQB)W1U3Tl%m=*Mod!swx zC5;Zs5~!K;QBrZBdRu%VETzaQ!+nOG7R=jy(bANG;Z|#iV=d#PLNau`MILv6{)W3i z#?fG4sn(`A9qMK`pZjceRSKUibwv5~kzLy7k|$3-!NiC6)d*I0+5sta2g=H2PwDBZ zu-%TPt8{D(eND#^ae3~Y#2fNPi?<&PHWM2fY9)GQDUfcTwJH&@4i0YVt35h8dR*~t z9`-}6noQ6_2`Ny7GD+asDjwb|i`Wa3H|COFX=#p`oY}QUGh^>WTts^4c}ftKbeB?B*A&K*baB?ndn%^RPGr3mCLb%3tT91={{?! zv|q$^LhY}gK+Ygik~LYk+%@UwRBPmL{V;`Hj#Y0gLf}0_J|V+zlzj$3+3b65T-uG= ziX{D5o?GV)*#V*GPC{6L@}vHD`yJ>Tis;!{*Fn2&Q3~?4%Gwg=uM#nN`3Ci-c#Gcb z0Me|Q0Jky8tN8IMs;;dLoKAp4x^mgS&P)xjGkKzL{aK!$f4a_u>hVF~ym~Z~Lb=_Q zB#|$>chbb&o__9fF!Y6d62!}R^FUET*X@rB#a8DU+#OVD3)krguwnT9uNDsGHS~!OZvAiC2ojv4pBv%Brvt#v~&|{ zhB6$@-$xDm6n7Ir!nRy)>0Q-(>{{kf7Ry{9*+yS#-XZv!Qi~%1=NdZ+2$|&P#Z!wa zYXp;^<23><4CG^&uGEA=_AT9M_U`NRm?yl~u&-FAN0F$)5qn+la&&^oK)1XA@ z;^M_&<(rEo6=i`s&s^rqK(l-eGP00>P29rW@X+w$Eyk}LQE1D^OCD7D#~B6 z6aghpx9*z~JZN{M6Hs-+vh}l}ALr{s&9Aw$A?t#iRrSn*TKg2tK8(~C3XZ;vFbRY* z8lbX9@~GZu2rTRaFoF3>{$A|R<#U(=X3f1HG){ogIMVV`~(j~s7S?Lo>ZszV(s`$NEpo@VD zq@lr*JfFOYX97VLfAntso8Lq}R!5F7Rdi|=T1`Jf3)$kc@pLiz3%kX9?SNoM3!MzH zumV11iuTCHlQ%Ky_lfgdl1E_ds9|3Yq$3NYq4siB1Hj@MHC&!}0;=|&|$WlVG zGrmI?vgZ`|(%y$>rcWx&!b7F8u~{Fz1sO(0rY!r!uXno`7k+lbEtv1PZgT5gjz6hHu%yS=nDgIP5~8UpAVXmzRI(pU=0Oek7{e9it?QCK7xM126-L!1Kw$@R_XOvnotoeWkcWLK;; z&>2IWM>RdUCOn})>QrnhqJZz()qKfS=C<$4r=WX-vN9Q%izqwa5+nSklbGc7O7E4U z=-4y7rgfHqivN5dFzyQXaG$LU5lNy^AWz~76*Z8v@lh`xT2}9yTA9>i-)7buX^^#} z=MG+H4yGd>sX|Gwv3N9LiO}#z_168_4Y^a77DGz>;t)KYoRarBTK(9+Fa-(e1Ba6* zFdmeE?<|GSnExx7GBjG+GoxlseWzB9|Nmp}z2lnN-*nMnS3m^mAW@1SReF`!Xd+#D zlMY5eq)Uk+y$T3OjWiJ{0SO>A(xpZ^gc_>!gc1UTc$UAJJ>%YYX70ZC&YAN$bI)H1 zab=nHUF-Y2?RlTKbox`yr=Y&gR({T)o>u-KRfcs%rV&^FJ2_3omQIGO0%n`Gm33pn zjDndl{MnDAyF+gz7|J}odDY(%D-xO|K131=kBl{d~`+-43%ql&ik&kKov2$)(?HLFyb8-kj} z3D*|QA7m+kJpXtQe!coKwI_ZAmj8zRn_f@ne6_;y>*Yf6@h+D&oJ2(p!uG?1H!jOv z(!I>JX{EP1!)+_IZ36|2mB< zexfN5ghMqSo0{PWaW>h+bggPT+(uhSz1MzJQCV9wgQc4NY|@GGBh@XcmF0Zj$4 zjsk*{u*+jUxdAhw{;UMd(-%-#Q)uShIgG%{QcaBN_)P}gcvlT7g) zA(i%*CW#-QmY19s76kG6F;$uF?6$bxlNf<-XRq1cRi=J~ZyQqNB^F`{0B?5f3$H3~5I7oH(xYE%2w`o@grnZKRZ3T>45UQc}_?=LefS?%RSg5=(TI-g*1 zK?rkniM=sxMM)$`o&P!uGLHOILE|ytf^bE4oAy5FM6MOi8zziTaXT7^J`itzz-SR% zD3BE#tJRF6_GgiA%S6|;|mG}cTRP@Z%cGuqa zFGR8}#RSvHNxUh3cc7jlCeC;3wjKAgbHD0>G--`jesFMoSCk`u^vP{a;ow9KDM0ef zx_aOa6(Nno0nF;^+I`oe`Kf|7*T?O>cw^CRfv0MTujYT0D44tKb6gnkATGv!?~Zc5 zr&E)xh8-~_E>nXBjjS>%`r!dbCmf% zFri@|EuW;BQCD?MVtSt@SL;l~C7cb|xkmdbA}Hax6>|x03%eCYzD|^XNKC}@cOOBG z@sDU%!>cuWYM)9!0rW(Xx4phvg@2NW$4>3biwaq!V_~5fNT^4qW`C)Rq=ptSo~EF1 z`ewo=zKi~JSIGVQr^4rAbY+^Z(c3PcoK-BXwxmRAoH+=7fevxb?yoeus6X3K7Gs_G zJmb=5XS#MR8n)aZ{a05xWIvg&zIghbBs=VSeMlvw-swFhWqOL(q}qq+r)j6t?(N98 z@upuSbc-n$ulE_jb1D&FBPTLnIZiXlS-8B^{AseI5Mkpp()vzKZG5=Vo7wQul%!X* zMW>&`ZxDN)SV+_i*F&*FcLcAE_*-G8MAU{JLL6hTgW+nzxyt;X`MxB3L@(qrw~)TI zD5heCWJqW5bv~AQwQwEWXq<5CNNvN~dExmH%+1fH7Fw_B&H2w(xxXxBopiRW$Oi?; z{$X^DvN&J28p_?r`RHx~x>+-hk$&CYwV{xAE<#k#zxH`8pX{-4%aV}?)LDR<8^tX|Z9ItraX{%uY|n~v|b=y_Hu{@LfxwZ!KFU`<)qiHUQR zS`p&6$RXGo>jC8P83A7edFI7PnZ?x9t(ed1JOgc4hG_X5$2sLY4>D+1AQw>!xxWay zW$kcaTuK!GVj!#Ip<$?^4CztzkdaD>Px5Mu{^uQNJ@^@MaNpD)=@<^&$OxHJySceW-ep~*Z6Ucc;%b`YXY5; zSw)36#^EdMyqRM^B<6nYawKn@=LmAJ`wH$$Vk+x^;P z9a*w?rce@Ic-{lQ6LBs_k2p?@%G2QXQv{y`Z-weGhqrkvlB3O z%oI;!DmS%6uPiN_9>@Lu$+I1Aab>nimZw=%pI&l2xeUL!R@+xHw-uJ=5FxCwiNN+k zTTL{2yZMXEqSTG=U;c8NUSVYE20I@}{w9dZu%Tmqjq)dmr8)>n>Ex-len&2|$5g2_ z>%nI14O`x&tcO9@Zufk;*Vtw&v~BhR9Ijwm>?3Y8P$^3SI?m5mM=0DY6l%UlPOi1V zDBP5|Yv!wqcBXPE{_)U?yg7T0pcN9SNCyQZgMj)-9OY|w7`ASIaqF>z*d%6B)i+^# zDx*T4Zm0ErUPW;aND*iU{RiFQPsO~y@!3h5R?Hqe#9z9|2d;ezv%;7-{eY=3c3LXF zJd3DjA&#YajiCuAi53u}=3)v6%Tb<=Ty^zz$;Q$cEx9nqr!2-A#oEi=THLt|+I_~| zqI<3!K2M3=U#B^7FW=nv?3LT>@bj=lHuu}kmrWVXP$}3s&2534UhKA-OtrX_tr~1! z`D|=4@zJ!%_f`97WrDD zoG?rZl~xtnv{3_}^pETv8Z8-(B(bD=j^I;T*`8*~w;fKHuk_vFa8An4EQPaoXV}Z8 z4KKyiO1J>D2wu#`zpm2BK2pG(v$__eB0a9?23|T`y~EMiJHJ?TeuL>V!e?VR1vINr z!tTqqPs4j0oDM8#Lg^g>z$qG3zA~iM2wdr{2@;StGt~d! z2q>CaG7=M8m+yaB6PWN-8dZWUUcxKvHh*Z+2UGMw&y^NtYP4R|8K*8#?7`%xLu1e99 zbF?a)ZbO1m}1 z6B@VQY;nq5MH`vkh9PHvFrDeM)p#IuW@<|~=S}_l2Nh~|F7ezYsdwV-Y?f?ei#os; zJVeg=XzAg@n%5xfTJ(+v+cOirr(#jA3FB07hep9(!no6Hbf`bVK9RogGZMrx=0?y*^)c;NH4 z)<0<%-i&6q@F~!!|L(8u%#k`bL-nodMMH-1cCtoI#z#h-<~< zoev%CiRz2rSZwvgJFe6KZ2~?yy4}?-d)r%Jk|NHsMg0m*EBvAa(O-IwqIkvN4JCLI z=2$Ydpkrn_#fg7ja#FP|?_$66&fB0=#ACqovzpE1A5oWQ9Ini?UJ;b2JLN`OTOZGH zJ5>3U3abiRx4;{(_UI;~mj#r~mNb%ll^;p&_{+NHer@MsljnFuoepiEM=Ka;)2Rw7 z8fM)c_$DS6AJ6*oc`YmqxfPWqiZ{rfH;OI5{sdi|0VdRAZ1d829jziwi&AI%I`f?J z9N#`*(cwIOA{DHl(n?8z!+;rl1XPnPiOT*u*jrxPm;qzfM`aCz+4tkG3afMLwaK{h zT?y|O)vlXROvPbVg-*6l%iZVKe`fx-*5<_CeM|ch$Abt`Zh5x*P9l%+fpL zfDsPV*m2hc)2VyfmDNElT9rCHi+ddjG)WieD84REMdm%B{MPldY`0t1iyvbYB0%9Xrii7CRZlHtk?q+#;&Ji<;+CwqLEt%p^TomJB*_8@?cm-WX3_i+_k zsF?XYO(FiCe}CDlW#t5hMDPs{Wo5gwAF1+UDXu)(qo>$0Y-u>z`r>SdB?Aq;Pr`hw z;bBdPodg)?+_4F1+xkew75OzX9 z%~UhQE~zp`Psgd}nPkHk-%v_telR!Vu_A?+0c^rADLQZ+Y(6u6` zoCMr(46OANQVCpL*~q9tSZpInoN|16=MHF52OjWeS zx-NP_j2byco!EJgXAjD_Pj*PEH)v_3fy3lV7BJzs>KF+MEBEJ1z~qIA-o)D#s(}tZ zc{^6qtTZFL0sJExcRmb@AeB=q1Lq{vf=dZ{`gl*8)%ekHB8FFkhP-3F%i_ds#!wgGh zb{JIPDFDBmmUg9F_10TeWW_IzrHvXkSn*80m{4X^hj!wmqw;SP`kpbgn8p+lJl@k3 z{9EMKf9tdVfJHv~|39u;sbMSm49+3OraG?ekRcf>e!SPR}k9Geb(~)iXwa3 zhW_;csb)&4q(_VQ-sw<+G;hnc8PpRsaMju)6Rj!aO9R4^hHoB@iWz+!K(@pSh~721 zGmh?&ZCV(8-+!C_!yxh~Y9io)pHW^t(wLDYmA;d%JMJb|d=d@P1r$B&@c!c#GJxw* zR@buL=b%kzDOaRPeZOFgt5kbjPk0#**yi)5*_C4DXBv+Gq)2yKw|3BDZDoBE;cWdS zUF$9S={9HG9dTt`Di=3htWEs)&$Ao$7SVU=lt5O*uoezByh{cK5{WF*uy1-NTV4N3 zm~#^42}~HX^^*DaX_^J+5bV$7ECdhsFuldf7P`Y@s!9iz#K&wu1=Ycj(ruzxp(?z|8Z$|ra&&B=vZ-6xsHx4@nI z^Ud27zU;|f3Keu}+yLA37fq9*`X>N)RZDgw3OHudgx})(+s{V=-+HcN0 zc=Kc{j3kMjIlmH-Lo`*d?B6Ahr#i1=RXM4a1t+*bc5&Q*LJ`y@612014xXX(q;{@X zgJG#T=k4P1qB30S{8Xs?*pG~Hc7{aUEUX#=DEljTFA6rtx%?eseK z1G#1PawP>Vs&Cw}xAl`k_21&|oDHH@1JaCDDL7ZGRK&-L)kBrn$aMKS=^dek$S(Su zko{~5H#Yaq+tgmH`(X-@xe_T2w$?+tQba-59Y4Qsx^s*I-*~vG@?H-{0_0^`7?lrK z)AWT!g?%M_gE8RTBYpg}<6nJSsNK$69D@D}kYeHVkekO}HgBSTaGjDX_ zEA#h*$h{#pLsO%Woz^nR_O_je589H&h?CmlbWAd*5{HI?`zGfMR{b(8RM{QH$20Ua zHR=q0I|anG7L0fJhO_NXPLwV(bX#TtWB5X_3v24c zb=BciBq04v!x2XCB~my`cmuL5R3uUQ@|(7v{l!U1Zx=4zmp#Q1T8Fqy6lfw#2C=y^ z{YW zq2Bf1B~1~R8=x4B$(TutjSZlsW17=Juk-Q#{W~H0Ixp!dUY-T%1N{KYYr;O%e=&kf=Z|iHzfo=*!RRIWC zg*LUM!}#1&4bxY9xaN_3Q?Of9R;=jMw>SH4I~-y@3_m2aPjh`(S=MH)sSKYHRHkq% z_wT=a(}vmKWz6V=q0F~4xnKHsqZ*XaMs?*u&|`%MRiiX?$nSz*&);`ZUbskq2BW-A zwOhi<2l{0A!UYcFDXtzVH*0v?7*MAgxy0i2**8An;wWalF7zUEY%uUVoVYO?y2K(> zYGZ7b^gvFAs<`+jutCqwzBW^QjYDAJA(cua0J><1sH0+xxw0^Q%y}{3P^0lj^akXs zU(SA!oCcJGXCnNsU`1Fk9X4|ttZ3t)vpbyX-{{=fGo?S^yl$>{2iW+`gLwIY&Z)&6|yQX4-wiZ=?8WLFp}wN z^QVtrFMz(De92#xn7bl(Ci6^Ow(*?;BU9?~UqVSOtce+%S1;spXPy27^ZyVc{D1Ju z9C!`a5lS*Gkcyit;U6?Ev%l!PbyYpQT4c&O;S9@_STX+UP=6LI`1FLj1NI6dS~(P3 z-)S&#>UFQbjej~wcxFRNtHA^Lk4BO|I<5Zk`UoL01<|>ADc z?#+QNU>0EvKmXhY8c}?9jgX% zZ~MzSGKj1TdaB>pozld}cdNe@?f!-5_pO48i_bk)(N-1+{~kz3x{0gGXDr9gTyRSB z-1}k|>LbLdcFlmK@byH#hKh;r>!e%5$6Z7yt}BjuwNP)2`aR* zvdT9yb$Vfr^zt%r%e#5!f}ooc{SyXSVAttugFufdUTLq$DaOAaQ_26~b<=4?hsibw z!xPFQChC3d(&^KwQ5>?|iJiZ;%Wp}-6A3G|FP_jnS%K2csMFovr=Wk(uiVqg;_t#& zynZ+FrBVFQDD5097;?Ve`4h~Uh;A~uj{0~2rNMYoOJP=FLg)FjY|AU0Psg5lCKdqg z=X*~hjhksg2P;w!>zU)iUhN9e=tdaRt#&-Dtjg4HXFksvV{}4>OChph4iW;iz}8y^ zN?HN~ZPiCa?uPe_dOhpqQmzlS_vCrQ;gu0RJAc*CQwe0Rmg6{p!5DR0xe!LEFHhXJ zzWYW2yqYG-bxsd_!_0Xj(yIipXnRpBFBYR&V=|(nab~R*_j}pS-qmT+)gQV!A4I_( z+5cO|2+(nJ!_=uZAhUOBF~3;b%(AxAjgzye=SLJPGWv{Ju)emYs`lC20^rt0YFtWm z3APe%ALCH0Un!htW82g zI51?3p5j4(v}SExC91YE5}Ho8G;OHlbWKdSL$sVlx9gc(bSw=6M>cT16()*BHmBE& zG^?fB9h2`-{kTwEzIoQ>W$PK{tC!J$3ICTq#@}n6f9E~JS19U2ATtrJ@9f?hNH^fA zp{$X;Df~1$&G-b>RWA>o=2>Uz%rB>G7-sr!Q{PvOw~n~W31YbYi>IXb@z3l0kKU|b zn-%j2kYF)sbD->&Sbvk#{s9`n>Xdiw_UVQA@u}JhQJVxcz&AMyxA5m7ntp^j+K zF()S0qAp%&q2OaD6RU4Y?>_vq|I8wZ(I{}*XR11UODI?Kb^I(SE_NeQ;UZWzn(R;upKh z&&K=A?1UOGciNWY`wQBDC)%*lm+uOtH}rnR=$ph~`}Y31+v{hH_;_8kSY$r7$K?qZ zV-Cl@N~*jxm7ICo==;j(OL%9@#7Vj)_vi$-w7)#y7xFzntc(thPg!8v+@#<)vc?yb zyn^gwuwu%pJnS7@*>+8so(AUEK?|l^9NsRYB?dhWqyU8g31f0bOO`Hn#Y&TKo82k{Z;!c^T}v^a{;vhZ?Cw&;)S%Lh!jE7s@^r!{Uflk zjs-1FYuPpK~cNH!_7lxV#FtSU#Td zQzT$UXEfrxa~d^>4o5n$<~Dp|`s@r*83wT9kP2Hf$JM>#i+8t=(bvQ?-yOBEpeqCH zjO-7t3sC6mw6!E_Sd40E7%Cr$n<1oI=ACN4IJqmyKOL#tQ~#o)#IVTmM#n+}-`$Ft zdELmk7RZHLM2}0dX;t{@uB%BMq=+id2u5~2FIM5&L!co>%gd~E+5a4HdDB13uU*Eh@pG<;m~+vO6tCLKv17X&)O8g0gQ7I-J0P)@ z-~}hy$NlA0F!HsF0HqOfZ68>jcm+#h1d(O|J_RPL{#)Nh-BUb6OgF{HHBVh`4dkjh zvdHx&z@^M1IJShBj7Aq@^pzQI&jwb`x`9tvZ`V@NhQBp&!ASd!Hx_!`DvpYNRukf& z9BrFsI4r#4;aal>a6bHX_FgFk`jAWp_TnN@NME9jd(o|!V8;lGF}qknW6BST1`DTC zKcD|5)1eJ2eNAT0B^QX03Sj#h0n3MluW}!}#ebP`xY4E|tfsC~RS0Jzf41QtDrCQY zIqPO(l>e$+$*$v3Z6s+Q4_;5q^U=FskH1(qpY3d8!;+vP=k4;6f{yZV&=TkM3+wj}r7M_baz* zAZuqF(Cg&Oca@6G-gla3_$J@Kh8deFOUHZ zibkNYE=8ZR=}Qys+A%HL9#B+AEiGUrp$!K8h2W6#Qc5ao&9?S>--5tyvNY}%U<&eviAE@xI z_lnud-^z#IMzoIc`@jqP@;{ji7;gdz6_=vIy(D@aOBugxC5D9=_IiyHqCD26dJb~Y z>^M;ec_f`A+?wc`>Q3*x{jIowzx6r9WbNzY#c+fc4fVo+A!cJjFP$Ya=7M<1t(Xaxa(!e5frj)7}l zh#N5x*oKl>ZRWj?so)2^PCr5IK*{VcC!D*qWVy2q_D3CD`sT_jU_r%KeBt*`PdoSAE`*p#2W%lj2Z*?0= zr!l)@#}{)NX@N`O1GS=l3@<65L< zmL(^?LY0#qKS3)Cuy4Om22WJaUjtu20ZUIbkisBSM(DWmmYpJH-5m(|t^)2ayT(Av zGU(f?i0`X-nK&+^3uVFl?(J@RQI(~c8@UrUBi|~#qYGpouqQ-km^7@99PUj5#fpQS zi9ZQ1{IiU~f6Zrz#K7x;LzcMGde=f9X)-N*gJ%6Fh|SeP!ZoJ zzF|?SlVDxGXU@(oo0qiVf$aB$)+$U6O1I1we)^L7yshr(blt92`CbwO7JP~<1#6xo zj{{a$nxH~Z0p2XmZ*VwZ6wrqb98-;!jqVfda5hEYs7x|ou7RUa1QZ6cIKNJ|94Da; z9j5KXJD2NiLb})ZC3CN18Y8@+Mdp8Ne*?GVUw8QBw!d!izv2UQSA5nr zq7ubS66#k_nUjZCu1dq4cI?LDH@7$KVgikZA6Y%Rcy`U=EG_Q3h`{@D<_M(^8@b*M z>8mL=m~US?DBl@)IJT?h2k365C?913+7gCsV57f}06!quNXY8*pP(*pfCw@=4g2(t zRiaRfEr#$s(7B_+~5F`SUW>eYu}zl;i3?BG52ea;Kb1h=e-& z(psY&?G3cX8$5?QY`~CUp$(XwyrrMTQ6jm_ z-k*$Kwb(7_%W-gWg+mG% zv2-*&<%4kN?yz?BgB9INPos-Qy7YV14MM$N2p7bsV{h&?#8tyZ1n1^I874F&*Ufw$ z+We6di9?$mNRt`hF@UdF2*cXI?16cM@_>hK;Q`sbxD6TJ+JTJtqr7JdB=?e@!EqG| zfNeaFA_>4ZM2H_z%cGs?WS;|IdXJPZY?bWVurV6Bxi;B{A4?;*tleB(47Rz#TCcyp zH+?)3m%U)e{O}god7ZbnS622(X5NP9hmrL?z^v4ZkIHC8z;VH_X4*nna|d8Bz4swc z0F~&Upj(N671lXI@1+NB&9_4r(8OlQ&N0H$nDh{*1z_u`Cl=I4t^^n%Eb$2N#(*H{ zVZw25XU2NRG;?1-5pev!T2#qH>2ubfiB@^#KA5-pJ-e?HG>oBn7Ih_uTv!=6bxang z2z79Qfz-5`v|oc{!XTmtGJg)3@kwX^MDFi|{cr+$j86bu z+pGS|L+-AJOgcY7Uu^(p_gRt^_<%f$C?>N2cMw>@bIQQ2yf9BznwbemIrb(IVB*KH zEyxHk6C{iok@#n%8i7%fF8_*cn^;|62TuHQdzS%uqHe$*Rw$s{-RIfx`>}5DCkGZH zZFV~?N>-wk{N1Te($^CzMvY3mO!pl6aqVnuRQI!Z&uhMQuGiJvH*w-Eu#BJTS7JEm z+;;$0)DON)e#JnTs;bDhc;cNVSu*5qv9fk~@_tEjxK>^FOl(L{!dWx-@6mwtI}`lq ziP{-LT#s(u8*GZ4vRFa+l^8!%(@Zc10VGFgiH)X3JD-f98csqDJtgU(d=6sgD~6lg zl7Eb&&t45oYWwo!a&}muZhis#Eny=-o@qkMlBcAWl!`MGw?IM-gRlA24$H!wg^@ci zU@cjx^CA>JhI6N*1zvAe`ws0W9 z@MvFL*fWzKtyj@sTACZWuWf8Gef_CK4t1k_&mKQBo!orZ%lSF2g9bEURc}Y`i$q(V zue6Nh;KBu0RWGZ5?l;*qje3{PbiwSa&87L=gfQxUw=7KxT$E3J+=Q--G&0L1WO7+> zY(O;mR_7*ki>OY0mtK_N#!&rYSWz!{u|f^HR9UT+rYI6oG?Mh>Qic$N30z<+e>zJ` zi)gbY$CrbWRNVcyxFr8opZzaKivR2%AokGfdRX-#Qnn_WuUmP83ECi4?VYpA(s-H1 zbnM7?d%I6qF5g#R9NKUsq`4W%hoIVP{_;10s;>0%oI?-{Yfa?R#ko%WQyK?8a}el) z=C@7Rba}XJuZU}R-!LDyeP1q-srIEK)x2b&vBdS$l%v}cZF#iT(-!{~nf=Zhp}U z4Fv30i05pKbQ8+?S7L4l1zS8YC5>lU*3{r*h({eV=t^%-;-&+if`*}{i_Vgls}~~I z@5%lPaG@POtz}T3Ggm+$|FhSy2+!vWL>02DI|-Paz}dKocrDBcY)5CctJD1lk%rIN z#g3%(Gz&WbtQvv=a&K7%MYe*h?Lzc6DV-f2)DLY*Kutd`H4KjMeu*#@x&-%5 zL0zFd&3*C9ayCmdKILV-0RBY#74f#Jm;mUhtS9ZfX$&*bXE=E^T2HuT%=2E+@P+?P zu=88=-;}*a841h={De`HX%T(fsrsyB2k`6+fOFX6 zXpBQ}32Xa^ufj+)TA~}pEO@H)NDbt`+;`^1g{hhmkD^#+p}B2{v+C^hfbSC~_lyb~)pnzO-Wo*x+K(^uD-now#fpm-M#`f7MQs*n z^Q|ezLk?w~;Ff;EbhH2DU}AD1Msw7}*#)f)+=P}}?!m-V!0iB}t65V>g0KT7m^%Tn zH?uRYfZF>tebTskldNFh_%ATCw+C-Up8}kfHk@+)>`c{9ki$=qqb~%tO#))qPLavU z)V`hoZ1ghb*LQx;mRROK3l$IvWmRX2DSTjtR+Z%*POep2tDVVFyc~fW8=vFEjXm(R zab8IDt{XHKk|VW{Z^i-J846(7AGq}U`BZ+7g$@xL>?!2Xq@N(%4FJl{_64X&0rS#5 z!K5<~(wje^@eqxLldger!A>Y#DP+O?cz+x=3PU^WKt{oDR!Idc>(8_>&`2)x;T|k>_M-nC`yT-p*hq_-)zh%b6;u4dAdwzDs%rrDxj`>9r zzlmGRz+&8vIdX5?WCGC%Kiwj^rd;WBmno`5FO@K`rX^iKl3mFX!!d;fUNvg@;8WF) zD0!UNYt*`Pw+{cZ_5RA2mzv@kw)2I#O^qjTh?XN3B>KFGWK8k4ZwX75Q%c}Xqx0s3 z1W=!72GuUZn2sgIakk~yu_Z=PYq73+s+<7V&D{t6imO!i$3;b8SX25WVCQzvA)XW4aOwKQ>^aCeb!ebRoe)Fm z;+L9;@@M%yXGAZ`$?`p$Vq#gC1)r;q`gUB&_1aHu@uPX)^WGRP7MsberA8p*QZ{t$ z9<6SsdY{pY`hK#hL~+T<1 z@c+fH7x{h{vHYY$zK1C|E(7cy{{{&5bKt*mT-@s#kp*n)5YHpsRSU5rCPEkimUt8= z-Zitd(#bS~A-0o_2QzG@Q4bsux!7U?^My(t-4tZ>?CkJ?hgV}&K*?biBPp0a?50qtXP4f#bfQs$g>($MMJU1lWPkkxD8dPMnZASyg z6u`<+z%htp1R)MM29iR6iI7t}>%3DpqP@`^TS``YB&N{-)G*2um{xx<1NArELSW0Q zhw!jR1jU(sIJ&ED20P&lxx8bFg~y<<*5sjD*XYK!WB76Qa^QT=j+d{Ad)q$YgL?(U z!$xKwF_6YrLQJ_8s#qecm%l!Lj*lnMnPH=`Wu2D$h`y)3=NM6asS(SfZp zpdzyqR{IQoSpf(QDVdsBQ+|bpz{!7AM7n{q>v5yppF7Bl?uUv4q2ZsnTA0E0`~;nX{V)aM*KQQ~mL;?$L&xy+ zoq~OM8ojH0IgNaW zEm-g$k>XKpqXV(>ND}b3FZMB?Q)}Vr!`W}74M}fT%_o;i18WoaVOH3+EKepbX zLs~o%75YTG-Eows(L2@o?d+?s-TuIx`7NUOQ{p2)>-r&LII!W$R_{7b?@Dh2sssrh zsjw^#uvvYM86kA1j7(S_P6F3W`!8tzdsy)T)oeLlp+dBo0}~$zd-_U75pz9W?{iMe zkeVvfT@ENkb#E_&F1NeAKCzwAn0(#CEo(rpfRAT=@rs?Z%kB?BTX&#jM_Uv^a0{LX z8pebW`6-3F4S=mpgZuv-8r1)+V*;VTOxbFZ00NVui%%It{dyGPlfGHNqbRMZ+DIO4 zmtNk?nCIBrOZgi>q`0LN5bNCf#ms^(lLcF1a4DMpV7i~YkuYkhzLmW_?s6x*#oEC@ zCT^-rH!+NV+H$&KcW`7Hp!Po{jsxi;?;|!c_Y4MT7y>-EpP-buiUXJfP|I7%NU!$I z8td@KZ@acY*Z``*aiPZyOF+~=l_xq;o992}WgWl6ph-a2PxU=%miu9E=#49IlD-sk}N4v?Pj0P8X|`~R&x zU0pHwvX8@kd~Ad}7EL=5fY7WPUutWZpIWT_hLU(%e6x$@>~3lOnQQ$B%}+@gCf;gN z#akYkiPiS&eNXL98$_>E`h6&#LU>k=6b(Q12l8cUfe@PU) z%LYWfr+{SZ0vafO>H{b@V=DGvyg7D63}fWLQlhO7ysFY&98{|nTG zGshiL7ybkVIl;EA$a#R@`(70`i-D6Orv1f%q^1DklvCrcq$XhYGJhpC0lU|;ecX%w zm+E#gn2P9dALkX}j748G4M@~P&MYc5EKDrPtIkM7_dM^>;@vG-B~5_CJ&T4uEqY&{ z!l>2NP33za#^t%To_R3v-~5nNobcJZGI{g`%J7TfzhnUSJ^j&mC;{3J2tx!Db4CpU zILNm5;rpC(vi@ZZrY^*9kKf-|2m;N^qWf`yR-xGjNKw7TmtK{8JKnmesnH|5ICdYO=W639V-bt>z2PlA+1dXG!P0Ea9hHl^K3i&nr?9~j&I~N#)VdlJ?0-udSEdi{`$4yVn60^ zoKU^$V@%dkaF%l`&`Z!xNbqX>1Oys2|Ctq7zXpb|QJkeY0Vt;(kCKj-14{6&ji}J} z8NIe+y|E+MGz`MO1q*XVlC*kVLj$D%S2F5Icct-{D+v?|G67c-s1rD$0ap@;Vvm9O z55HW=oa#9<0`nnqQ8T3?`s(|n1EgV9mHEcuy9E{n+@n?Hd`(|XO$`jD&QiUx2|%|# zj+MU&ZhtO95W*tr#dCZ*?$_X!!#5g73(VJdtf%pt+f%ldX;J{${Os#P!$eq%I0-PH z*{(qGTNtSC?*$+(x+$44oPEF?p{Ft3m2ti5!T#ICmCBNW)NS9bi-T<=`n-3=KUYpJ zav;=|g*+$!KcQ3u@L&FpQw~vo0xIn9n+CcD`CBQzXRyBjuSEqG4mhG0`-SRdvE8np zokr40AFJ!r{X3l4mf}-mjUyJwNdUWxUcdeOqk{0H`9%NAiH~>D&!1_?$S-y+ z8OhXo)zl&ObqK*~Wa-4xr>B)AI*l2x-CFfgr5n5%FD>oOdzIttNulkT z-_ss{&h_poq_3j%4|ly$(q$>q1yTfZ)pwXlZ(#XeTwyoVbVAtD)ETzszPb9ZnbO}4 zmgFi>!h_lOg0XSoFf3oq@WP15T+9H?WR%bR&Y2PCx&^^DwUZejEtYmF%C6^!MydVn zerlQ-?^7z?dL~uY4Zbgr?;*1b^7|qx+-gQkam_vNkxXQUGT@Yrq0gwbj#Ec;YNOP z`wU|v*km;jGNQ=oFGdV0RLM3g2)8$FxTnzwkMhl}ZBD+cCMX*7vE_8+8w$X!4=0QH zFg2qt*9+mojB1i0bUy9!sUAO)86BsDX!BA$UN1i6&u}Ptkr^oI?@PS$N$u!4iT}Gw zD}sSLOT^r3TcF7w?leHJq|m6GMYiA!kDJ=K)~G2yjk2tVesuCxcPjV+lf<|*9NwDL0fEw!{Arps)_M~ATIpMFr%61Io8A2Y( zAfpl)j=G}Q5+x@>Vb>VGtI*8`kW=X?jwqD;PapC5-zDZ}`fNGiY=_j#apEbT_bHkvXn|EJj3^u3%_FNp~c? zMi>)ivkpq8Wo=EngRp5Qq}t2sO{_khF=OQ(rt zU*idsD=>{v`|Q+^;R2$0N@O0raT1?9Ul6oJEH5USHf1s4Z?*DCtY{HZdIh@~8j#yA zj3Yn7)y|2Tx>|J<1$HZP6EkPEFuZeG>kacjHE-HyyQO7>@kL`w;gkCft8Fq-J+W%c z71H8-2k&OH1zxs-X!;)rY)iMy&`l7Kc*Dqgt2x(LL-09&hU`^<^ms8b56XC`zc*G5 zs@xk&$m~gAWYcr9xDX_j4AVHygg8;Vw^YrrkWS;d+Gzdn#Ley%JsEkiSJApv=%hTg zd7`e}nANd4#J*GhT_Ja=>e&?`|Lh|ne^0!3`#$k)qX4oT39N#An>2<8K9-vd5yMQd zvAT*Gi(fW-tvR)=UmREBw$S_8g*}y-EnVFvA z(8iL_KLD7_uoVJhBvTS8TSk(lSz)ctCQB{a39GUejyYlLM;bM1I{I;G!XX6QHV%Xd375iY@`+*w-*-cQ4%j{t5v@N#S71)?2utcf9)J7cKpJD>-BkfC zniVx<9H^i9=tu32jNo8p0@eT*gY+40*(W@6UX*F%+1iEn$Zm!V$V)XI1Gw62eg z_bGgCeq=qnut3l#r8s1KYZVuo&i*i?)mGm>Wu<v-3>2D`8GJy-Ks--)Q~6L+3xrK z?0N}9JocDhgYwRxhnj1tr1aoMb~UuFn#S#!)$LH@cOX`k_GK9s?hJzl5JT8pfR=7Q zAa@C6l}bAt$+mG)AP_`r+xOj-SZrUJH|mSsyuge*---D469j^Ch#VI*a{1r>OpL=% z7!wWW%qdN41|{MsSEM^XFqdGjDhafU=2ot+F+QC8~2_wN3@q{ z$|LA~Ts21e5=+Lv72O)nV~iz?ik-tR_7&4W{3z7AU92f~(JkQ1#FX~H6aDo%^EA}G zSZW9D;7<2^Ext5$!^uYiGWvbDsl?vDe0KY4-CJ z2qGXL(vgzTQIXz5O(;rFL?B4W?|uGj_RRb4J+s&BleN~KHSYl@ImwXZx$keeuFsW( zJlFY>fCm)=eOy@= z-BVppf75ym^6+?h(Jon{E(gBNMByi~(D*$ZsBHls2i%qw#5j^dHN*@LLT?>3KH^n- zeXt#p#?-LnCkNt|XWZw;p-;-_3IM$w(+(4$S&{SFDi_M%nRpKyg{8@~h@p*!J1xEs`+->_ENI? z;D+r$=C`AM%}H@d2#KY4J79T{#xK?j&<9<8An_FrT$in<-|_BLYA9H`vbQS#adiiQ zgkLbL8bHdR;Qiwv0V=o)EkAB`Wv4zxbMUuYJg?|Ce|FOjM!|E}TmX0Vc+O_|Uf(jT zuPtN=yY-D+Wk$B1oXl6G%zsELTyM;-DD!c4uTM4$z3}AjWH@?1iW(s<@+b=x&$i~Sk0y;@4 ztVU^lnxPhTHQOfCZ6#Se80GjH;-c@={m5wW%Al~#1PFyaK8p&bglsb`!T6Cyy*62z zAr4vzDu0kRo!TxbmD?^cBX7+W-+#Ro=YN{}^_@ZV=c@1c>5#tWkj~v1#4;GTLNxQo zy7DhlnD)9$e9TuI9o2eESeT}nPLtWz8gPPujD&J*LS>&?u1cX_roQaR!6z<pxzh7T!;G&vk}Qp8IHA$Pz;o1Qs4GDa?*iWQ*sZqX8{W_}sB5RQMikNgf~n z&$d7OkM~0v`+t+|S z9`Gxo7N7|ML|z2g5aC~j47bSny`Wnh^6Y)GE1Q6XaMdSaNZ4LKECs{6gK>wk z)Bcsh5(Q4tyYtDb*=?m+KUZajr+%a8=13hb;B=b*S&5$#x7E5@VNv#w}cwizJWBC2R_F(G;8V`pHOn+ z5~Nef_88m)KeHpMT^~_srl_doV%(H*$@FBIMUnLd-N(m}XDWcl;X#JHbtxmP;nsumf6NPKq2b`OGYU}-1zN&ol_vV2t zB{r9QNVMJNSRQz2zlG9WR^i#;|9G;NI=P^2Z#x&$eprSBz&6y?V=L-Kl{B;_0E~s5 z8X~#p{Jb&Js{2%FOHqlj=GxD`v_Bc`e%cpJrIa!ZKqt^b9P3FQT9ES5HdKw+Re-`rHOk5BaJXv?}e+Ng@suXQiKYrZ)@t<#2dq_ z=f+2vTUd9S z8{=0la+E$CzwmV^tmS{kg#0gHh5vAK^S|?c|1Q+wziF;I|8OASh z>hR{1BBAQ~Qa--Nm*`X307n6nP@((@Cqlb9;#;91aeH*Z>>p6guI~*q{yFJ~*nUsa zeV#A6@t{{7Khd(Z>%;$m{1Ay)E}SHa`|_qFm~7L?65ne0sjY5u(x2nO`G)wx_@T?? zS9%voR%-9Q)5~A5cC3gdE-Wleo@1&o=kyu5bG?+ta)#k3``DR^HK&>!+tQe*Gt2%O zCDC=I`{4@xzla&Gjaj7mA_J0!`_F+n(CHTy2D%pL7s!Ba9cU8jYB9B{YYO5sT3>Fb z;www|w&pljkReAP+J{In#)49xGdjGSot?ko{3XRbrMbM+?aDq&;*k%~fWwZ}QBIV+ ze?Z;fcu78A*n{!z^Q0wjp*IYMF=v^cAuKsvZk<~BL`t*@!l!Dr45qlvCpPj1H^iF8aR2N>* z3w?zg7oF4RFyi5uWHwk_YQFltP0Tf$Gv$xl`C9y=-KYh1sxEo2S6>t;1vrI7rxN!i zY*;7lc6hgy&o|o5BxZqLilIu<6G6W*GxETSN(Ao-5S;6C)YddSHgVR zHCt2&NQmnAu(&Q-GUL(I_@?r_f}-J96W@;aZpS5J#!p1dO2`;e!%GU5&aR@8s;6B7 zxA%NWl{Bq8swxO~$2F#|LPv{I+8^?32VD=_l1pAf*9;*LPos2SXWct}iThoo)Vq$I zi{2pG&yxiFN+Wt4a-iKvPi}rs4FFu^s?-X%PF}K;(?ixe!WJx{OWF8IyX#p`@57mMQeTEgKi_@$| z>C7lCeb$jyqYpp`WAmRZ1vkOcMw^`vrbW`XTwe6BA1`TZkJ#HRLn1PcsfGBP z4y+E=O`hZZ!sqTkYhvDD*topVx=0%8Sd7Dnqb}+RP;||8aLJ9l^8jJfk^c8=hZqVh zMi!>L4NC<5Apjp?4RmD%?}j}V-o@wT1!KtDi_^38=UPB~Vgn!m{HC#=_Agv9FZJHi zycU2&ex`Uk_njB}rjMUnhKq29aV9?WvTlRZ^zXh;I*%gtPf>F}S{y5)rya7l!K4BG zjb$Rf+evcKVtil0Jfw-+4>OcYtgUQ0@{NyM~_i zM!0m)*(aAMLh(UU4LO9tb3WZRBO)UrYe62uv$UJW*4prirwqb+(RUnYX9?YB>k}TP zgk6-s2Za(8|;VvV!I-x@;Soz3W9zwLuACpke{Kc zyl}U7&SOog9q(pB4qU9y1hnxrs^f7Twj*7vQ!8}Uo|E_VF1Tl)OI_i&tf z#$c);rvCXeT8)g>sp(thM~teGFu*0yxMsU9eQZ5?T+9cjBj4%$ET5WcHIo-xtO~t z8!!TL#j*%$IGTj zwp)~$v06_+eAlZ%+Zunp5D0`1Pt>GLtj4d>ch+deM7hSs=J<{P z(E+g_Tg&hM8{jhx*~B59Uvwq%K{gPhbw8r&=47RFTV1Y?+#BEsGsXJExXmSV3!l>A z;d=abX?qEf?MkiPOA~%gt9l8&z6KZhNDw+A+;Z=`x^+Nu6;l}3I2*K7WK=BuCG7Tz^{{$(5n z3CoFu^5A4J-=$gsCz~Ji@6)go$&*s($_V-#X1Wp(?7d&ZVGASmADv`kRz}ehnCV*J z2`yCEOCv&Qg|@vA3zL69=2mz=pLV=A3`ET)rEUEKT84v=`lr8429=-Vs(Wf$tpI(U z618(Zm0`n^aTykiY3%EO$6{(5p9An;>+uSjPa!hEqCWb1HtdUI(z;aLjm8YN8)$RG zCjzb}i-*GgZy8qzKuCfhN)c$C9H|>%35WVvJjthYzrV%)%=;G_3q_u#85dTd*`l=< zR?P^Z%y?@w7xgw^(SfkGG7{muj=eefxIiVf-lbDDp&^0?>0R zFhV#{wrzQ6rhKU?4Rw`BBpvBXd`K%>@ND5sDfJrg1rr`xGbCm3hRC1Znlm}e*rn2q zr^-pTl*~BLu2Ao?B%lI!Yib%3S(U8id5lDvOe3#bpSBc=G7;xf zu)^J@#vOy{KQIC)w=(2EpguIOuesC2jO50Ww-z};^va}+{{72^%DkW>>%bw0w07v{ zsjUt&5ZCIFhT>g17~3C^G*qhW7)PxAo`3aDS0--+XCt+C7l@Td_hGm+`0iI z4_{LfO;`H0z(eghz5yfd*HI8Xx-<28Vuo!lVysG&9@Vi+v=5rEzT3X++;&NRq3=db zp@bitr&caST;E-WJ(nTFdIMwae%98E?raBHPwG2Yp7~t(zWX}^yiR*dj2ZuUL#UxK z!?h_x+UwC*lZzbkSD%C>53~HlM598EO;MUB@iZyYt^o2TLJXNd4B4*Kr0Om<#Q*-m z*z6JbvO>~k<;m?)R7aBr>LI18XR!v6s-!-EI8gtLA4O|hL_8)0;p6@sx9Q)|_v9}g zceEOpzm;JvX)qTosBzqHSzKQiVCFAwTGOAhbB6hG;I|Eb0|E84ZI!KDmF+w%LGlL1 z>o@&pUR`KzM5Bk~#J|AP{}s|8+cg;N3_F4!B{4x0I=)9`=KI_5OyVCIDNNO$ibDgOTX>W;IY zRg+iQ#OZ6}=NINZd@>Vm{5X}(t|cJBaMj-ck^}|qJj$lCIZ^mVV7nDAQ|I156QsZ6 zCM8m*?xuyUV>}m+j1W=KEdgTMssOrc!~l}JSU;FaQemLtm+;%LLN6`Rg_bZP~%q9Qwq#x=iMC0Yk>q+bmDRh8YV1=_U3W z47es_=sl2H+n!xfZwL?W9MY`b0%OgSl=Q3=ZUFh3a4zle$zc4Y_@Wq*UYNE*lju}Zu)y^kJq!IzStH$&LF1alHXgh{J>TpM%j{VJ75G8;^xhU zpWVPjRyM!pv-FKa&5_<39-S?BF6zCsNC?2;>cY?L67{j`_51(;oZ`u)?g`3R&J}o^?9O4D3JQUw+$YGw`19OEhQ=8zN3}-OC`!PbgK_`YXny zE^(XAV;3&Q_*RaFd@)UTe*qdi!|becD(tN#v)1!ooSsqev1MzlJ!g|5xbb$@tHI(*%tw#Zj~h^{zd&@e@s`);ygA z2HQDqXeo}AHePb#;dj}QZq$n^lC0kdIx-@+05eXbV)O#~@`z$!((_vKhm@jC(WeqO zPkz`gctc+gPB9fbb?M(PUfymLw`AGa7B~u`B$-fUNV_R$r7^nvhQ1QD<)fog^DHu? zOE4_j5$UajReAl`%`PXDZNb5KKK*+t(^EJrS`T^H0R?=Cy()Z42k$~kY1&g7797U8 zwpKwL$jiPGpTF47p=N`>NSE^odBU2T3cynr6@ELC;9JE+42b@%89>gMLwGj3NSDZ) z@C}FQiBE6#+2o#I?Y?|{RM?qgK{YrS}g{Ln(b=nq^{0~Ye zHMMLUQJv51Yp|CHNV#EK$7K~3YIq-A1P{aqZ!c_`jSrjnp9wJc`|&OsB=c-JJ(Mxi z8N*KddJ?aH1}Sun%+j+f>1A`*uyVcmFyW7#oL2VjG8gt6-ypAkzkc(K@vtJ+aQ+wL zesChK?j#P!_A3CzO2%)qRUwNkF+%@<+`SXrO1*^;N~viG90{B?llREgfU{~%((YyE z_MS<%F7$aYoK+(&0mA)>ehwJf3cjd$&_xkdaIz9(`}8O~JmdGxhIPb*zV$g1?W^rj z&sJX&Skkk(BvuI}M!ZAPy$ZWy9$^y@E>Lpdn{MUsOFHsR1NO)rtGd1@HLS1~zU+~- zkLT?>mZs)!5Z^TeZEc`Wc{TqTd(pKKnV^eJOd+bv!k@xaVjMhr@v{NT@E8=3OrJt@ z(uI-#jIL{D*Urq&_sis*-pgSz)07(wRsj9z$kjChq2Jji(XLaQ7r3cOl2UPHvO>bxPdP%D5)L7&GY9761ifr+DAzT(6LWAF_brT^BB5L$ecfn(Gv+O zbn+z64<2%P=2VssDa!xmn`K}J4Ou#t^M%8rs!MwC1V&O^@A12zUZNtBEW^2a!jByX z##gPboOk@d{;BPrDj?%I+y&F%j((553w~5RZ8e>_=k%ZuHKYfBx*8{x!i@+I7y<%u8f1 zgQ5Kw2x=O1;yu(ACRj|8Fi@=@-Tu)0`fUHvlXJlYHjyhVj2b_HS$z^P(?~)yjG(2c z_8t~;58{Rzh2*7P%aZUq<(F74hVjGmaJ<9(eRvMKHhl;qh(Ms6e8_%ul9MQ%v_n#8 zXbivbzSc3JawYl4@TXPhm2y|+`%Hbc!2IMYTAXUw39rS((3}cpzA=}R`i0%bi^Fph zvTPf(+>$G-zEXN3x6}CwuI@dH1b_Keh?mE)(wNDZzM*)y(xWY31qg zIR*#3-Fv&%YC;Omtw7>n(|w%{^t0pY8dHD16}B{g&MW-lnVRW78uPNY+0T+zvonXk z1PF6RjL@!_&T3%#>kgZR5Cv1nA9gLPhxKJ0HX8`NdHxapp`4@Z4TZ>(fv&F$4}-04 zBRd!w;;RHE>F0qyX^hfB@}1gde2=W{5d}XF3atZ!y_&f`e3%r(I=h=Ey?M~{GU!KI zlvcw%|6WZd(8;^1T#OGU`GG8zjSTL_oS{K{JwnU9z^>Is=nONz&b1wRNLY{b?I$&S zmzLjpNF~yKkKwdvkI3*XEQPV_Z%}WeN01LG90R7>+D<8bQmQ{as$m9-nJ-~qGj-v9 zxA&W`33_odZX0_njnbxS9bZHy^+GBSGUHHa3mZu9ozNAS6ax+DBzC zN!u-GD*SPd*qq$PFB4P_g^p|opbcrX96(xfd}l3FMK zlq;-f@YrZ%tewv6WLCT$)dvDL^7-`yplVMJvIyz@>#;RK*qS|Z{$UT(;{-WAVOe()I5IiDMh&n)5aha>cen2X3ElUpRtK8q z#N0~Z1zdJv zo5br~eTXQXt2Hj1>j*TkNNdWN`^}ylHdQvYDg+1(W{HBC$D!1r8`Y!Ad_$c%nrRhY zUxO45RpFtfbQj(Gf|((lM^CfU8s_{RA;NW-s0Yf_DbhTYlv+euOlmH!nXPP$L#cX{ zYBa_c&ip=#?Y);!@*;2bz3+^8=@L+oC;=KUa@=xX8jAlIOnXE*zvwL#me!RQU_D&t zRUag0*%+-kGRu(}(KpcOlX0{9T=4Vb^2Az-%7GObABl0D$B>HC;)943VPsI%zS3jw zsat=azwDGGiM;dTRN~lmyHgvMEYTBd%Cc!8u*cE~SbuULsrr&MP71v-6s-Yc^%T5J zi1%Ab6v^L%#Yk(Zrphe2df)6&(M+?w?)@6^9Ah`RohK(dF~M8X?{( zE-pcCjX&fCaxd(0l|)4_0ZP_8=xzwR9!CL{k#v}*UVh=-cHy#szd0`i}{|e#>#?1xE?V#fWZ*Qmsz+ z*GLig54~zRg=xlpG}VC^RL~8~o?qXa{ zbT)~Ny;Z6BT6$iY(<1V2N4A5^WE$7p?vu%ckA&9WooVNkKJyfC?lU!4OEa!7#>0gi zI~IA-5R~eu`nw8mx_0vyZch5m8Lxv&M~zyD=Ml4JH*V+}*1A0D(*Xhy@kAIO?Jfo0 zaiXnk+_JP8u8=#4obde)?^`okZF*$4yz6ERu0(=6?8Zoo+YA)L<%r?Ex)jbd#$Ois zoSI5&8&Au`d6mYPY#)4l{jCQD#6BN{QOyhKbMP~Pm0D&^sk0rL`I5NNd-V#3m+o7D z(r0FNL`s;x%&G=SfEa-@*P$X0<4{Xph$%zHS#tUYws?K|z^X>S&+60=Wqg~VeE^L> zL8$gb?Ckly-bE38%{lm{;E2!RUz51HX~d@r-?|nKW;XHotWjf&?&n$QSnado+|g2< zI;W6_L`)xaW|uUuy_-~t=31OHFM`iG%^3wx?z*XG`S@$JroYJ%n5z)ad>q*y2^?8` z)Q1$IF5nucF&5HH$TfpGx7)y2KV?gkOBOxoARA|H>p!wB->EKiwoUTSPD#F#T$&1L zn|aVw11@OSrCh8qH+QS1t+iIs?NOFVsv^td_)Lb~x113}224*N{6OZeEu*dqJqmR>`pjQVdX8cUFdX zLj;195H?Lo-J(68-X?sUckKDV`g7WcS`c>c69xsZ8;b{|7Deym7|rFxzI4M}I%WP^Dl>}A z5X9;KmNJ0WM7H!kP^Jjc_+KuoJV*I{H2dU@;chiv5TXyi&dS=6fLhmhuB0dEzE;c) z$S}Q=aocP^jZj4jGpBG65$$O_`oI^@YH^Spga^Vz!Z_a4b5JdXeJdLs==;uU?=JgF zNr^`{L_+V%SAECK48g0(CEL0#M3pEVp(J%|Oup=>GSd$B(~S28pOTG+cg1E)S%1D< zQ}pIZp(`h9jcO>kNaP_q`Co{gB>qE!@vx?UI$hFgi}n? zsE(pX;k>CcrQ9jQt9(=d2aL)HQxT98*h?eY#aMX_yzcjClJg9IMe?_{}lSwu}SnZlW3() zJk!?b3Njk-4*X5&SNB%>A#~&P^l^gH(Ami26{kRbQu7fddQCVBL)iHX?A=3fuVN^f+1TJ-FGb|)lF@+z?0f>M&YlOjr;HlW-= z19z{Fe3%%|yS~6VYw}q6-XTPD{*&3o?id39&L=<)`#q7CMa@3GfP}Y0Yy9vOtG;d> zGBsIGXXgOVE-u`t6_4Dit1AT0jA{9o#bYex++k9-viTB@&a#w{E z&P9CGNJ*K4H$O2r{Xq6AE+b>^o9|Y#^rS||D5f4teF&cNCrcZTw z?`3E=K??|aDlS6HtQp&Sxaa4$$a>t65g#8M=iW<9ML#-8`?FaUyC{m&p%48HE)Izx z!kC>rvmNxb0@}J9WXrwx`q}lKH?&`mygV+r(Vf4zlty!c6NE^JK6`w4F*PE)Lx(X~ z)lMs~>{eTR#(GqFDo|S#fNM*u1{UM6ye3;iW#|z}5Ji1kDLeWWPHosGL}_iE@YX1% z?#A1##8n$!`q-2uXu^7 zF+Rz4xzF9wJk_yqdK{4GH6Bz@@0OtW$h*lXFmX`~<24QFe-qB(qQ)#Q&4#`h0^u@D z)Qa*3jaioqf?0l5#hy-^y?gNH>{5zC7xb@R0k|3`vbmr|(Tb|u_91D-0?~nn{%a>C z=8j)?dvM?EUH$=?#-EP(B|u$o@mdl64JQ-6T=Xi5Oi0?sM`m%2yA8MMQ z+KU98NCFcbx(YR*>kkKeu3WK)y! zJC`fmfq+k-G$>CI(*af^`iwL)EGf83PRG6WaVOVxCzYYQN*`~zrKkMxM_6>112&C& zPYNTU#-3H*pwVc^3HCg-erbZLLaO4Sou~A?BG#vIQF9db^xJ&X7OzQ%ykx!A%~((J zAfO*gUQho!^M8ee^WT+g{x{9ws!9vmjb0QAgb+SLQ=y!zBYuGpzV!wjSiEn&nKIvo ze5D_w-U0fiTJ6!$UMtD<2S7EX|r-KS`N7n3AFF~UqA~eqcg4vD@^>E^~5e7vv}wZW$3S7iV0&Hspqz4vS>>dM(+mgTq5HJIDQf?4 zC{i7koCA%{ac-Diu!bb-x9 zew66TUm+N7noFto-$5+&%@Ykf<$Vn!I~kbD_untQ`cDlkzTJj50J+33aeU|@QG6`M z)eh>uD7JA)Y3r|kNwL%Xdd2S>(Uzfu)Z{XL&52h}zB8OnHPP71{JFJz?7fLO_46bb zL9^iXXDwb0?YH9k7`2$DpP8%$<2UYW-0b(^0D^nq(6vWf<`}(1C}2uK(nm; zF33YRY1!whZ^@cE-XmU1N zILov?Ol|l^jPUD7kS;@ucKg~FVBY{}hiy9`XQ(ysneWG4Vc6%u*I{0sr_${k!m*wc zGpKjam@xtW@$4(-ord)G2V09lz|UI^C3uqmo(zuEvoE49+UWpAC;I%{12{6tG6B`J z($r=sCetiwCI4bO84?i5dha(lx(4NY@)E~8Llq)<#7I+i-V_cZ_mbu&AH*M^Dv-Se*w^%-K1!>L*pUBbVg(xv9S*( zF#7KmG1T4od11R_3VXN#WH?b^P97CmpN0pPykIXZV9^GMftgOa@YwI zN5*y7T#~PsT^ojr)Y~h4NSWGJqzE+J2dlQ0PS<_>c&gW|Y79uRUBgPCN3^iuR#*`6 z3>CIq#h=9V9OLQPd`ruQKQvtt+nl-M;y%;j;>Ik@J9zm)I8fjUA(IYDs7}j6&qX1? z{D+^)e`lgD)N0|dCyV2u^Kadc&-a|ORu2POPF-%7#Z3eMk_P4H*6cEuM+cl;1993X zD@U$-@p;x~qmqSNH5MiJ-6xc)ls>Qv!p=X5L~yin60f-=I+%1+r7oU9%Oh?3fG5m% zEyk@F$!yy?8i?R-`%yhqc4@jX#A^P$xUrgxkvRW1$R$KQaL%RC908YkNhg%A_1ZEt ziY_y%X96tD_*o_XzyyZE_;2M)%R6+MYQ-d84{3Fk-_l{c-`KS~W8Q}rqbe-#vO6J( ztyh<;;vHK5(g@f%)U2bE=7FsVuS0?7{JSGfTyVCrR1`y1vu6Q#a95DVeo~Isfv_PB zms5*~91t$5%e~r(Pa0gj2-kn{=m$Ia29qSdyiny+dmH^d>dW1~QvW$pCkDz+8#$b` z6McNogwjWR+pIJgO>)xgGox@dl*x2D83|Ojd*ly&Gq%AEKk?l#C8mp>)12@~xou2hf$GqF zftK?Pxp6DH;+Lc%r7v2A4w>~*otfFpf_P!?O;)Ck*FEE<_({Jpu$Q(8XjE*8SAl=t z>tGt^bC@@-HB@^Q_(a6s&bgeiv6A&ImnXhDbxl0gXH!(3agTZLAmpSz2>&dAKF~;)CA(vZ;S!4L zS%l11+p5wBvyKIhp0YNuivrFcUz#TF2<^mbD{axw0bZwIDxokClopE@)E7j~9wbngiDGs^vjB$I>#q(nxbs`y zuHf&^9|XTV=^H85H;KJ7Vn7mTN#b!7VAfEHT;z17YLlZo7rCV#O8+&=;rMcGC@=hi z+Pryw32JbA$mM|sv|~33Cvfs{UnO}F>^9=pssdC1nnUpDI~hb3=VO*{YYwb0iuidz zEY4aUSH!2Ltu0wP8?a0OCZH-x{{qDoUz>RtyANff_E~1z>cHa(J3?a`9B)In5i%d~ zZx~+)7x_v>w#~Fbmm0VB??Em_*}EUwhqV>y&mEm@^|h}~>?{<2ZdRj3i-v}1onsUD|Cj#(A)i%(#R z?V;xfO+T~KfiAyvFxxKS7>w^DG%M<6y^z{_^kafVbC2W!WBBJ*>5D%VpFC06%#mj# zr5LtD8-yZ(RbLj>mIMP+`?@4$+OqymF81tsGjE!P()U|cz^|q2ZlTgM-UotEJ*}Je zCyA|2EL(R==>W}O>aky|*~bTcr4}K9()DAL?ro`S)(MYNZA>TbdV}8Po8DhSrvmq* zIyHM~Dq$Z>0HP_xmHaBEu9FXOcz;Pk-{{bD; z4%77vgWT1*E_aUodTKBvH_P>-QR$M~HId}L`bg=oV$iNspY<@{2FXnW?#1kbmImtI zMjMx|Vd^DnVEY4DLqx@5{UkXbp1u0m`SP0IeD7QUoP++=248OtVQV!~=`bJSa9X#e zOHJN+G?ld@^5^*3C7TSb23z}OL|c7lreDmn9yHz-49r!DfEiWr*GptTyMyy+-DXnR z?gN<&`)f^2zN5Hxr`a{Vlq*?DWrWOLQ+z_5`a3{p7GAsY-GC?j4p|j&_YD zEZ@mG;3!WJ>o;?ax9nS3vr+Y#l*SlfKQFUA zu(cGoCmP-%R@bilwJ`ak=I-m2fM>w`E;QmF&<#0K2%Qyh`~a6pH*l^PBI4*5sl}8j zB5Yg2$CWNQQYRs0-Pkg^ALA48KwBj}b7SB=Ah< zVOMinEG-26w{0PfB-;|VrSpD??Po*Q1U&4ffrl|iGSh=MQ2!AFh%`#=B)9b%5gWIl z*GREV`ULKVCe#dsmTxxSy;@!O#r1-pqQS*YfzM~pKH>;?d+H1lld2D*Ph*7i#X86= zB)>Ky2bTS^gGne!c=W*6o3t=^BhdeJ`ASmBv|it0bkowxTl!1jjV_`rsJ@gv3ra%* zg=0xBaoZ{IMBU34CqX%LXi@CcoASl)MMqO#$cXvYQKZO~IWeszNc(PrqzJWm8IuZd zVGwocq$<%4FVDae26xejwzixY8!tHy>Do*8`Y#=Yf?gQFt2_%%{bw*xaaX5sKszvu z@h+ubpu)c|NZ{|q;`_M|{$4}RHH#@$c!~nkyCSS72 zK_M`I)#@WmwWq*4r*LA_lrmiP6$;-1Kre26jsDZC6g%TWC)IK%UMf)a#0k;+oBJZ^ ztU%WhK-ue_6wOLmj(}NDx=)@}Vtj9FS@I_6Nyd2-&iE0whH^dcr&>JPC#OnFqBu&W z>zh;DW8LI`EveOf?yzBrVqJyi8HICph z5mfF{uXOeX$}vIKQqWA`?RLL1iDPO}A2gz`=1J$KDik9_|Fm4?Fdy;trQJc4N-rYx zosuuG6CR5(D!ywSx-3zhZ)7|256A^J6McJNq1slVEXdxlck?t|O6t4vO&t{7t$o6|Wqd-wIa3d> z@ZF+BkwXg|P>-Gk9+2BqvtvDPy2weo^GPlWQQF>E`|EFhW=P0~+HZVZ*M91f)ZX8k zwYk$Je>zHAd}`1ZTZ`W=ucUEc4}9?yt!Q-sfj_SXceq zcD2gZrc%NiMq%JcIu{ZP6hC$N9&mDV%*5Lqm6VD<;o5{iqZzj9Luw zX)v+Wu^k<^Wa;|b*9A|4aOn9S-vG3VDHS@~0Bv6VL_$yZq3hsbQs_f>U#;62=Z!zg zJXwC6^2+M=SWg5CS&WzrVbSM7Iug4dOJhxIuZ75K?^U_UI+g{D@4lJ;vgI=c6r}M$ z7qUB}bG$s&4P`{u-W)F&t56hqtOq+?u|dd>&$e8stl2W=yfM7Od{O`L&tyq}UqzK# zMkKzVs`S7)_I<9Z0ONUZklRGv%J{@0?2GR8*30qFf=)BUTAGS5GBv79oB`^+ys36P z0Z#(=B4Q!JIQdPqKx0yS;}kbR^OcIA{I*{q@4!Wvy5HXYE6XmAQ?4sd_G^UOM&sD1 zAM3Sp)8nqN5z`l^)Roi%NrjDzKryt}g{{WN zr)~KXu`_$!VXL38%Mi8&^N;IWpyge1Vmadl{rKIlXt5Ic#x;d!Cty#lRN$}qzm9Vf zo)Em(xgb5>s)J?GKj&j!a+}OCuWKCTqni>Z>B^@;5H#L+U;1_XK`7-mTmU@-0nsqz z@-9ivDKx};O5sCND9XiMebbAA%m%O4x6JE4wqr%Sj9Q=0(EnHYXimkGg?b4GElY?b z2wPB+5`+YKrj$RxT#)hY!-nSF0{cooRMi4d9i)CVm8QH^TS00qP4yzgXb*>Rm+2!R z&FX5%iyW_8p@qM={ADWcE|g^Z2L&j7FV~CFx4tK4!u8_X{DBkTQHZTwf)X~{^R+0( z<>a*P9HAtX1}W)fI9FU!34Sjy;r-yRO4sNg%;SIyLwO{|@{+kT_42_j%Fif_2ghP8 z;K6;S6-Prvlv|J(KdwLA-Xh^Lqs81}*k8KeyRIe)R;ewtov#CGBhji3Fu`ZpB$W=? zfJ^MqWZwV6f;HkbnijhX56m|=TU=flOx-MG6@J!c$mhSvZnF62&j0;m*#E|vqW?eA V^1nUO@_&cA{9j-HchQ$K{}1{cQaS(t literal 0 HcmV?d00001 diff --git a/docs/source/recipes/librispeech/index.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/index.rst similarity index 82% rename from docs/source/recipes/librispeech/index.rst rename to docs/source/recipes/Non-streaming-ASR/librispeech/index.rst index 568a8016f..aa97f325d 100644 --- a/docs/source/recipes/librispeech/index.rst +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/index.rst @@ -6,5 +6,6 @@ LibriSpeech tdnn_lstm_ctc conformer_ctc + pruned_transducer_stateless lstm_pruned_stateless_transducer zipformer_mmi diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst new file mode 100644 index 000000000..d8569bc5c --- /dev/null +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst @@ -0,0 +1,545 @@ +Pruned transducer statelessX +============================ + +This tutorial shows you how to run a conformer transducer model +with the `LibriSpeech `_ dataset. + +.. Note:: + + The tutorial is suitable for `pruned_transducer_stateless `_, + `pruned_transducer_stateless2 `_, + `pruned_transducer_stateless4 `_, + `pruned_transducer_stateless5 `_, + We will take pruned_transducer_stateless4 as an example in this tutorial. + +.. HINT:: + + We assume you have read the page :ref:`install icefall` and have setup + the environment for ``icefall``. + +.. HINT:: + + We recommend you to use a GPU or several GPUs to run this recipe. + +.. hint:: + + Please scroll down to the bottom of this page to find download links + for pretrained models if you don't want to train a model from scratch. + + +We use pruned RNN-T to compute the loss. + +.. note:: + + You can find the paper about pruned RNN-T at the following address: + + ``_ + +The transducer model consists of 3 parts: + + - Encoder, a.k.a, the transcription network. We use a Conformer model (the reworked version by Daniel Povey) + - Decoder, a.k.a, the prediction network. We use a stateless model consisting of + ``nn.Embedding`` and ``nn.Conv1d`` + - Joiner, a.k.a, the joint network. + +.. caution:: + + Contrary to the conventional RNN-T models, we use a stateless decoder. + That is, it has no recurrent connections. + + +Data preparation +---------------- + +.. hint:: + + The data preparation is the same as other recipes on LibriSpeech dataset, + if you have finished this step, you can skip to ``Training`` directly. + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./prepare.sh + +The script ``./prepare.sh`` handles the data preparation for you, **automagically**. +All you need to do is to run it. + +The data preparation contains several stages, you can use the following two +options: + + - ``--stage`` + - ``--stop-stage`` + +to control which stage(s) should be run. By default, all stages are executed. + + +For example, + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./prepare.sh --stage 0 --stop-stage 0 + +means to run only stage 0. + +To run stage 2 to stage 5, use: + +.. code-block:: bash + + $ ./prepare.sh --stage 2 --stop-stage 5 + +.. HINT:: + + If you have pre-downloaded the `LibriSpeech `_ + dataset and the `musan `_ dataset, say, + they are saved in ``/tmp/LibriSpeech`` and ``/tmp/musan``, you can modify + the ``dl_dir`` variable in ``./prepare.sh`` to point to ``/tmp`` so that + ``./prepare.sh`` won't re-download them. + +.. NOTE:: + + All generated files by ``./prepare.sh``, e.g., features, lexicon, etc, + are saved in ``./data`` directory. + +We provide the following YouTube video showing how to run ``./prepare.sh``. + +.. note:: + + To get the latest news of `next-gen Kaldi `_, please subscribe + the following YouTube channel by `Nadira Povey `_: + + ``_ + +.. youtube:: ofEIoJL-mGM + + +Training +-------- + +Configurable options +~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless4/train.py --help + + +shows you the training options that can be passed from the commandline. +The following options are used quite often: + + - ``--exp-dir`` + + The directory to save checkpoints, training logs and tensorboard. + + - ``--full-libri`` + + If it's True, the training part uses all the training data, i.e., + 960 hours. Otherwise, the training part uses only the subset + ``train-clean-100``, which has 100 hours of training data. + + .. CAUTION:: + The training set is perturbed by speed with two factors: 0.9 and 1.1. + If ``--full-libri`` is True, each epoch actually processes + ``3x960 == 2880`` hours of data. + + - ``--num-epochs`` + + It is the number of epochs to train. For instance, + ``./pruned_transducer_stateless4/train.py --num-epochs 30`` trains for 30 epochs + and generates ``epoch-1.pt``, ``epoch-2.pt``, ..., ``epoch-30.pt`` + in the folder ``./pruned_transducer_stateless4/exp``. + + - ``--start-epoch`` + + It's used to resume training. + ``./pruned_transducer_stateless4/train.py --start-epoch 10`` loads the + checkpoint ``./pruned_transducer_stateless4/exp/epoch-9.pt`` and starts + training from epoch 10, based on the state from epoch 9. + + - ``--world-size`` + + It is used for multi-GPU single-machine DDP training. + + - (a) If it is 1, then no DDP training is used. + + - (b) If it is 2, then GPU 0 and GPU 1 are used for DDP training. + + The following shows some use cases with it. + + **Use case 1**: You have 4 GPUs, but you only want to use GPU 0 and + GPU 2 for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ export CUDA_VISIBLE_DEVICES="0,2" + $ ./pruned_transducer_stateless4/train.py --world-size 2 + + **Use case 2**: You have 4 GPUs and you want to use all of them + for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless4/train.py --world-size 4 + + **Use case 3**: You have 4 GPUs but you only want to use GPU 3 + for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ export CUDA_VISIBLE_DEVICES="3" + $ ./pruned_transducer_stateless4/train.py --world-size 1 + + .. caution:: + + Only multi-GPU single-machine DDP training is implemented at present. + Multi-GPU multi-machine DDP training will be added later. + + - ``--max-duration`` + + It specifies the number of seconds over all utterances in a + batch, before **padding**. + If you encounter CUDA OOM, please reduce it. + + .. HINT:: + + Due to padding, the number of seconds of all utterances in a + batch will usually be larger than ``--max-duration``. + + A larger value for ``--max-duration`` may cause OOM during training, + while a smaller value may increase the training time. You have to + tune it. + + - ``--use-fp16`` + + If it is True, the model will train with half precision, from our experiment + results, by using half precision you can train with two times larger ``--max-duration`` + so as to get almost 2X speed up. + + +Pre-configured options +~~~~~~~~~~~~~~~~~~~~~~ + +There are some training options, e.g., number of encoder layers, +encoder dimension, decoder dimension, number of warmup steps etc, +that are not passed from the commandline. +They are pre-configured by the function ``get_params()`` in +`pruned_transducer_stateless4/train.py `_ + +You don't need to change these pre-configured parameters. If you really need to change +them, please modify ``./pruned_transducer_stateless4/train.py`` directly. + + +.. NOTE:: + + The options for `pruned_transducer_stateless5 `_ are a little different from + other recipes. It allows you to configure ``--num-encoder-layers``, ``--dim-feedforward``, ``--nhead``, ``--encoder-dim``, ``--decoder-dim``, ``--joiner-dim`` from commandline, so that you can train models with different size with pruned_transducer_stateless5. + + +Training logs +~~~~~~~~~~~~~ + +Training logs and checkpoints are saved in ``--exp-dir`` (e.g. ``pruned_transducer_stateless4/exp``. +You will find the following files in that directory: + + - ``epoch-1.pt``, ``epoch-2.pt``, ... + + These are checkpoint files saved at the end of each epoch, containing model + ``state_dict`` and optimizer ``state_dict``. + To resume training from some checkpoint, say ``epoch-10.pt``, you can use: + + .. code-block:: bash + + $ ./pruned_transducer_stateless4/train.py --start-epoch 11 + + - ``checkpoint-436000.pt``, ``checkpoint-438000.pt``, ... + + These are checkpoint files saved every ``--save-every-n`` batches, + containing model ``state_dict`` and optimizer ``state_dict``. + To resume training from some checkpoint, say ``checkpoint-436000``, you can use: + + .. code-block:: bash + + $ ./pruned_transducer_stateless4/train.py --start-batch 436000 + + - ``tensorboard/`` + + This folder contains tensorBoard logs. Training loss, validation loss, learning + rate, etc, are recorded in these logs. You can visualize them by: + + .. code-block:: bash + + $ cd pruned_transducer_stateless4/exp/tensorboard + $ tensorboard dev upload --logdir . --description "pruned transducer training for LibriSpeech with icefall" + + It will print something like below: + + .. code-block:: + + TensorFlow installation not found - running with reduced feature set. + Upload started and will continue reading any new data as it's added to the logdir. + + To stop uploading, press Ctrl-C. + + New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/QOGSPBgsR8KzcRMmie9JGw/ + + [2022-11-20T15:50:50] Started scanning logdir. + Uploading 4468 scalars... + [2022-11-20T15:53:02] Total uploaded: 210171 scalars, 0 tensors, 0 binary objects + Listening for new data in logdir... + + Note there is a URL in the above output. Click it and you will see + the following screenshot: + + .. figure:: images/librispeech-pruned-transducer-tensorboard-log.jpg + :width: 600 + :alt: TensorBoard screenshot + :align: center + :target: https://tensorboard.dev/experiment/QOGSPBgsR8KzcRMmie9JGw/ + + TensorBoard screenshot. + + .. hint:: + + If you don't have access to google, you can use the following command + to view the tensorboard log locally: + + .. code-block:: bash + + cd pruned_transducer_stateless4/exp/tensorboard + tensorboard --logdir . --port 6008 + + It will print the following message: + + .. code-block:: + + Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all + TensorBoard 2.8.0 at http://localhost:6008/ (Press CTRL+C to quit) + + Now start your browser and go to ``_ to view the tensorboard + logs. + + + - ``log/log-train-xxxx`` + + It is the detailed training log in text format, same as the one + you saw printed to the console during training. + +Usage example +~~~~~~~~~~~~~ + +You can use the following command to start the training using 6 GPUs: + +.. code-block:: bash + + export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5" + ./pruned_transducer_stateless4/train.py \ + --world-size 6 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless4/exp \ + --full-libri 1 \ + --max-duration 300 + + +Decoding +-------- + +The decoding part uses checkpoints saved by the training part, so you have +to run the training part first. + +.. hint:: + + There are two kinds of checkpoints: + + - (1) ``epoch-1.pt``, ``epoch-2.pt``, ..., which are saved at the end + of each epoch. You can pass ``--epoch`` to + ``pruned_transducer_stateless4/decode.py`` to use them. + + - (2) ``checkpoints-436000.pt``, ``epoch-438000.pt``, ..., which are saved + every ``--save-every-n`` batches. You can pass ``--iter`` to + ``pruned_transducer_stateless4/decode.py`` to use them. + + We suggest that you try both types of checkpoints and choose the one + that produces the lowest WERs. + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless4/decode.py --help + +shows the options for decoding. + +The following shows two examples (for two types of checkpoints): + +.. code-block:: bash + + for m in greedy_search fast_beam_search modified_beam_search; do + for epoch in 25 20; do + for avg in 7 5 3 1; do + ./pruned_transducer_stateless4/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir pruned_transducer_stateless4/exp \ + --max-duration 600 \ + --decoding-method $m + done + done + done + + +.. code-block:: bash + + for m in greedy_search fast_beam_search modified_beam_search; do + for iter in 474000; do + for avg in 8 10 12 14 16 18; do + ./pruned_transducer_stateless4/decode.py \ + --iter $iter \ + --avg $avg \ + --exp-dir pruned_transducer_stateless4/exp \ + --max-duration 600 \ + --decoding-method $m + done + done + done + + +.. Note:: + + Supporting decoding methods are as follows: + + - ``greedy_search`` : It takes the symbol with largest posterior probability + of each frame as the decoding result. + + - ``beam_search`` : It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf and + `espnet/nets/beam_search_transducer.py `_ + is used as a reference. Basicly, it keeps topk states for each frame, and expands the kept states with their own contexts to + next frame. + + - ``modified_beam_search`` : It implements the same algorithm as ``beam_search`` above, but it + runs in batch mode with ``--max-sym-per-frame=1`` being hardcoded. + + - ``fast_beam_search`` : It implements graph composition between the output ``log_probs`` and + given ``FSAs``. It is hard to describe the details in several lines of texts, you can read + our paper in https://arxiv.org/pdf/2211.00484.pdf or our `rnnt decode code in k2 `_. ``fast_beam_search`` can decode with ``FSAs`` on GPU efficiently. + + - ``fast_beam_search_LG`` : The same as ``fast_beam_search`` above, ``fast_beam_search`` uses + an trivial graph that has only one state, while ``fast_beam_search_LG`` uses an LG graph + (with N-gram LM). + + - ``fast_beam_search_nbest`` : It produces the decoding results as follows: + + - (1) Use ``fast_beam_search`` to get a lattice + - (2) Select ``num_paths`` paths from the lattice using ``k2.random_paths()`` + - (3) Unique the selected paths + - (4) Intersect the selected paths with the lattice and compute the + shortest path from the intersection result + - (5) The path with the largest score is used as the decoding output. + + - ``fast_beam_search_nbest_LG`` : It implements same logic as ``fast_beam_search_nbest``, the + only difference is that it uses ``fast_beam_search_LG`` to generate the lattice. + + +Export Model +------------ + +`pruned_transducer_stateless4/export.py `_ supports exporting checkpoints from ``pruned_transducer_stateless4/exp`` in the following ways. + +Export ``model.state_dict()`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Checkpoints saved by ``pruned_transducer_stateless4/train.py`` also include +``optimizer.state_dict()``. It is useful for resuming training. But after training, +we are interested only in ``model.state_dict()``. You can use the following +command to extract ``model.state_dict()``. + +.. code-block:: bash + + # Assume that --epoch 25 --avg 3 produces the smallest WER + # (You can get such information after running ./pruned_transducer_stateless4/decode.py) + + epoch=25 + avg=3 + + ./pruned_transducer_stateless4/export.py \ + --exp-dir ./pruned_transducer_stateless4/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch $epoch \ + --avg $avg + +It will generate a file ``./pruned_transducer_stateless4/exp/pretrained.pt``. + +.. hint:: + + To use the generated ``pretrained.pt`` for ``pruned_transducer_stateless4/decode.py``, + you can run: + + .. code-block:: bash + + cd pruned_transducer_stateless4/exp + ln -s pretrained.pt epoch-999.pt + + And then pass ``--epoch 999 --avg 1 --use-averaged-model 0`` to + ``./pruned_transducer_stateless4/decode.py``. + +To use the exported model with ``./pruned_transducer_stateless4/pretrained.py``, you +can run: + +.. code-block:: bash + + ./pruned_transducer_stateless4/pretrained.py \ + --checkpoint ./pruned_transducer_stateless4/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + + +Export model using ``torch.jit.script()`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + ./pruned_transducer_stateless4/export.py \ + --exp-dir ./pruned_transducer_stateless4/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 25 \ + --avg 3 \ + --jit 1 + +It will generate a file ``cpu_jit.pt`` in the given ``exp_dir``. You can later +load it by ``torch.jit.load("cpu_jit.pt")``. + +Note ``cpu`` in the name ``cpu_jit.pt`` means the parameters when loaded into Python +are on CPU. You can use ``to("cuda")`` to move them to a CUDA device. + +.. NOTE:: + + You will need this ``cpu_jit.pt`` when deploying with Sherpa framework. + + +Download pretrained models +-------------------------- + +If you don't want to train from scratch, you can download the pretrained models +by visiting the following links: + + - `pruned_transducer_stateless `_ + + - `pruned_transducer_stateless2 `_ + + - `pruned_transducer_stateless4 `_ + + - `pruned_transducer_stateless5 `_ + + See ``_ + for the details of the above pretrained models + + +Deploy with Sherpa +------------------ + +Please see ``_ +for how to deploy the models in ``sherpa``. diff --git a/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/tdnn_lstm_ctc.rst similarity index 100% rename from docs/source/recipes/librispeech/tdnn_lstm_ctc.rst rename to docs/source/recipes/Non-streaming-ASR/librispeech/tdnn_lstm_ctc.rst diff --git a/docs/source/recipes/librispeech/zipformer_mmi.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_mmi.rst similarity index 100% rename from docs/source/recipes/librispeech/zipformer_mmi.rst rename to docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_mmi.rst diff --git a/docs/source/recipes/timit/index.rst b/docs/source/recipes/Non-streaming-ASR/timit/index.rst similarity index 100% rename from docs/source/recipes/timit/index.rst rename to docs/source/recipes/Non-streaming-ASR/timit/index.rst diff --git a/docs/source/recipes/timit/tdnn_ligru_ctc.rst b/docs/source/recipes/Non-streaming-ASR/timit/tdnn_ligru_ctc.rst similarity index 100% rename from docs/source/recipes/timit/tdnn_ligru_ctc.rst rename to docs/source/recipes/Non-streaming-ASR/timit/tdnn_ligru_ctc.rst diff --git a/docs/source/recipes/timit/tdnn_lstm_ctc.rst b/docs/source/recipes/Non-streaming-ASR/timit/tdnn_lstm_ctc.rst similarity index 100% rename from docs/source/recipes/timit/tdnn_lstm_ctc.rst rename to docs/source/recipes/Non-streaming-ASR/timit/tdnn_lstm_ctc.rst diff --git a/docs/source/recipes/yesno/images/tdnn-tensorboard-log.png b/docs/source/recipes/Non-streaming-ASR/yesno/images/tdnn-tensorboard-log.png similarity index 100% rename from docs/source/recipes/yesno/images/tdnn-tensorboard-log.png rename to docs/source/recipes/Non-streaming-ASR/yesno/images/tdnn-tensorboard-log.png diff --git a/docs/source/recipes/yesno/index.rst b/docs/source/recipes/Non-streaming-ASR/yesno/index.rst similarity index 100% rename from docs/source/recipes/yesno/index.rst rename to docs/source/recipes/Non-streaming-ASR/yesno/index.rst diff --git a/docs/source/recipes/yesno/tdnn.rst b/docs/source/recipes/Non-streaming-ASR/yesno/tdnn.rst similarity index 100% rename from docs/source/recipes/yesno/tdnn.rst rename to docs/source/recipes/Non-streaming-ASR/yesno/tdnn.rst diff --git a/docs/source/recipes/Streaming-ASR/index.rst b/docs/source/recipes/Streaming-ASR/index.rst new file mode 100644 index 000000000..8c0ffe447 --- /dev/null +++ b/docs/source/recipes/Streaming-ASR/index.rst @@ -0,0 +1,12 @@ +Streaming ASR +============= + +.. toctree:: + :maxdepth: 1 + + introduction + +.. toctree:: + :maxdepth: 2 + + librispeech/index diff --git a/docs/source/recipes/Streaming-ASR/introduction.rst b/docs/source/recipes/Streaming-ASR/introduction.rst new file mode 100644 index 000000000..d81156659 --- /dev/null +++ b/docs/source/recipes/Streaming-ASR/introduction.rst @@ -0,0 +1,52 @@ +Introduction +============ + +This page shows you how we implement streaming **X-former transducer** models for ASR. + +.. HINT:: + X-former transducer here means the encoder of the transducer model uses Multi-Head Attention, + like `Conformer `_, `EmFormer `_ etc. + +Currently we have implemented two types of streaming models, one uses Conformer as encoder, the other uses Emformer as encoder. + +Streaming Conformer +------------------- + +The main idea of training a streaming model is to make the model see limited contexts +in training time, we can achieve this by applying a mask to the output of self-attention. +In icefall, we implement the streaming conformer the way just like what `WeNet `_ did. + +.. NOTE:: + The conformer-transducer recipes in LibriSpeech datasets, like, `pruned_transducer_stateless `_, + `pruned_transducer_stateless2 `_, + `pruned_transducer_stateless3 `_, + `pruned_transducer_stateless4 `_, + `pruned_transducer_stateless5 `_ + all support streaming. + +.. NOTE:: + Training a streaming conformer model in ``icefall`` is almost the same as training a + non-streaming model, all you need to do is passing several extra arguments. + See :doc:`Pruned transducer statelessX ` for more details. + +.. HINT:: + If you want to adapt a non-streaming conformer model to be streaming, please refer + to `this pull request `_. + + +Streaming Emformer +------------------ + +The Emformer model proposed `here `_ uses more +complicated techniques. It has a memory bank component to memorize history information, +what' more, it also introduces right context in training time by hard-copying part of +the input features. + +We have three variants of Emformer models in ``icefall``. + + - ``pruned_stateless_emformer_rnnt2`` using Emformer from torchaudio, see `LibriSpeech recipe `_. + - ``conv_emformer_transducer_stateless`` using ConvEmformer implemented by ourself. Different from the Emformer in torchaudio, + ConvEmformer has a convolution in each layer and uses the mechanisms in our reworked conformer model. + See `LibriSpeech recipe `_. + - ``conv_emformer_transducer_stateless2`` using ConvEmformer implemented by ourself. The only difference from the above one is that + it uses a simplified memory bank. See `LibriSpeech recipe `_. diff --git a/docs/source/recipes/librispeech/images/librispeech-lstm-transducer-tensorboard-log.png b/docs/source/recipes/Streaming-ASR/librispeech/images/librispeech-lstm-transducer-tensorboard-log.png similarity index 100% rename from docs/source/recipes/librispeech/images/librispeech-lstm-transducer-tensorboard-log.png rename to docs/source/recipes/Streaming-ASR/librispeech/images/librispeech-lstm-transducer-tensorboard-log.png diff --git a/docs/source/recipes/Streaming-ASR/librispeech/images/streaming-librispeech-pruned-transducer-tensorboard-log.jpg b/docs/source/recipes/Streaming-ASR/librispeech/images/streaming-librispeech-pruned-transducer-tensorboard-log.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9c77b8bae243ec78c33ba7b06d28bfba408d1724 GIT binary patch literal 560358 zcmeFYbzD^6*FSoOp}Ua=NlB#}5di_|Zt3oZK~Myg6cA8EP!K6$q&uWT>5y(|m_cIb z-@#A*zR&&Rckk=@=id8l*mK^q&f06QwRW7f_dav=<7xpQR#8+|1RxLqcnSUiSBrp; zLV%+k0H~<}TmS%I1LzPM00TrI3m^lb{TnMnI05KyIvM~(IRfZ^@~DH?>jX;sUFMG~ z+IzIWIPj$2L;u2w{hVyzGb?@=j5;3;2yu5{`wuYkeLxtal5TPtRJY3L-0l?MG$4gu39+RPwF%#BL zP#Xe(9JETr(%RcYPE+&YwamY+zvch^aWwv0cVL9~TGrp>|1&^pW9w}V8gB}+TiJM9 zyMnj{0MIk6J-mDX0P~v8;_u^ejl)1p?gc6c;#b$${tujgjjjH`4ZnGGwdFycL9j{i zEiJw50f2A@q%-+h+k~$g1-BmuKNbQgVzIX z@JbH)@Q1Uz>NUm(@wS_lx(0|zL5y+)t@{I`;X3#}&;v2Zk49(hEf4w?pak)48_WBu zAf^McqMM8Q!$0X(c8-e5AO?F2T7aXEvMz{kfcUwsw}Q@}Jny~Swg0ROt;)_zUhYqN zx1|?|{-|rt*IDnHel36B!B+WC`B1_j2VDb@jtc-#RzFAm>(N07Vm@zYoojtS-JnuI z4)WLZ-#lKv+SmHCf^=G$Yxs*Y;k^L3#Mx zyRHi?i;nBz@_U?sbo86HE=t$*n;;hQ_S5;ZEV{mrm+mzk)Dzv#!$kpQRl?x+9XIe( z2M7Xo05_1w`|oo1e&4kP{6PA}-_$?K@&b0h@A~~N;SBr)Yj6cr0T=MQFNkfyGJn;k z53GO`$G^Y-Rl6l9(;C#*1e(6{R-`YHbAQZCTKIX2Ko)!c#Z$8U-5VA+W%df_HTWh!FF-{o98ckuUq`O&g=f+ zeY@!P#_jVzr1;nd_<;Qxkazb8^m4R!@L`ezpANQ6%5K(tJWK+&1q1-#`dM?$0|37$ ze?LPYl)eAZ!oe{v?RIr_)$k8Z^%MXqn8D${@DGh+7@Q`+dAd8%+SkkP4|(X<7YuNg zAOc4kHNXI{0Gyyb0)Pl03ETw~02M$3&;^VDbFjZUfc@MH@CRVP6Ce_Z0iVVxKst~G z`^2M-Y368^jj^gM>q3AW4vP$OlLPq!dyM`3mWQ3_&I#i;xY-KI9Y)iiVFy zjz*8hj>eBBh9-xmil&QZhGviEi57(R3@sk*HCi^>XS6D`uV{T}V`vL#TWE(+0E!2t zgfc^Up<+;Zs0P#+Y6ta#K88j^Q=!?=5@;RR&LhwT=r;5O9TS}loe7;6?7N5|aZ{6da32m`<3%n6a4eFyWXjm_wM$n1@(cSkzeDSkhP;Se95mSdmz7us&lo zV+~`iVxh3{uo zz*)sP!zIDxz?H_;#dX4cg8Le`1h)fs2KN^p9^OqnaXc+NN4zk+*LYv>y73nAPVh-KoOD=@)0T%+7N~kz9p<8 z946c)!XaWKk|Q!F3L;7+DkmBs+9t*!W+#>>wj_Q`{Fb_Svc88vJSFUa!hg#a%FNS@)+`B@_upz z1u2CPg+4_9#aoI-iUmq2B|D`ur3+;OWf|o--mtw9 zeFJ`D?8YfI6SV@hGj$?$HT5hFIt>?%7L7kmCQS#;HZ3`=B&{`VG;JB}6df8J7o9d; zFkKGaH@ZW526`oW5BfLsZS>m=lni$noETm)G%;*2k}=9KIx@auY-ZeKqF|C^a$$PS z)Xs!prel7*t*z`**Vw^*rV91+1EL4a42&GaTIaPa1wCJa(Z%RbB^A^yd`$a`PREzgIs7_ z!dwnqZ@C7z(YQsp9l76e5Ak5|i1WDdWb;h$;`7S!`tcU=F7i?GJ>+}JSIxJ}&&qGi zpUB_Ee|cN@w#)60x2Nxr-ch;p^iJKKeE}{3Yk{`{qk;s2ih`knHG)VXZXsKtOramb zWWs8~QNpdl7b2n}ULqwT8=|bDk3=&>C&b9bG{s`Yy2UZX?}>+sH;A7|h)DQIlu7JK z@=7{O7D}#5aY)%peUw_3zA0@f{Xu#`hDqj;%zK#yS!P*F*=*URyDWEY?tZ$vCU;BD zN$#`U?!DXhyzW)rJC>J_50P)V552E+Kl*;Z0;z(YLb}47;!Q<+#Ue$7lCV;+Qp*F3 z2dWPeA519ID_biUC?ixvRYFwSRq<7IRMS{L3dh@Q_okgRi8-TRKGy~ z*xdU_CU z5MK9;{#n?wugHlg$*8R7=+7OW_rDN)@#e*4v`utx%FtH`J%tBkIKR0UQcsvW8qYK&@zYE^1m z>m=(c>-p;Q8kig2Hc~VuG~qTqYr1Lux+$mtG&NNwWGUJ zv9qm9uB*9Qy1SuAqNlc3thc&Pq_66m@VCnE!rv?VMf$4-Lgs4;}2e6 z?pEyG-TQ_xM64pckXQS$zo>uZ9S9w?9%>%W9l0Kz9Y>?6QTZogCtaror)y{a=UC^j zFK%7bU8-KrT)AFdd02W`{*DE%*WGA#;3|DT4*&>^0Dx=|j2}$?SPTBgw#>)z!sEa6K^z04<(>tU<4tF^K`-z4SF2yI~V{ef@X( zV|@xz{<85uub}0Ef+tNFY!W$W=SQ2lBx1LJzb0TKe<2hd3{NSW`*VUlTCV%_v47kHBR0h{Gs z=~oKvAq1E2}@();Bhh`@aqjkB(6%r`L8t0O%iP{cYKQ*hK=`g@%p} zMaR0f3xehkW+(|d2J;vg~ie{;_Ks47<^OCny>k6axx{Vqju`2@@NP(l9Zxaj}0V+&>fk??iZ= zi2hDjpb!Ws105Y53;ZX>!@(o|zn!jT!40a4t0@2<3IRJ4lmw6k4tqJXnxjHkQe^e^)4~K#ov06=nByD zvAmtaJsQX5v~vYmu+}v_|Jkv$F__CWW|rLyZ!X*PdSy^~5=yu;kW3Zzv0h<8m3QX( z&aR$}7D}sr@d~)WZ9b8f_Lm|gk(9EPtt&b&s61xiTEb1xizv7PUR?pkX;RwJM~(Af zUF&*RK+nedEaD2djIcP*%uFbF1X4w^W(}@@6s)5Qi&ON1>b&D{SlFgggJEe%!x|I!z_I40pU33q8>cbsq!kQ^^XE zy)+ESrPvAnm-heI#Hjrhw8>Sr4t@o25$&8dAK>Whm|SLH664KAN>g&bS~CzeZK8}0tBz2d)Y3Q@(n0^}+#H=ED%Wve&A8rd|OvaSH0 z|7nzey43#<8|4?`90B7|4360-<4#A~KRcLq6_!$6%<~(WY}b*$h53M$g(Y!J!}?M_ z(3A~FgTP!^TyH8tz;Dk{QB<`U?m8Fw%KKnb@tu=)-uG;;)$$NHwaWK{MtId)#*hNV z3Ij?@fgJd6&w9IV*_CShS?_n9(`(1v>{*%jM+SY*)q=!$N{U0Qvv*+}l&Gk*mCIE} zZ{177yA`sD+$3j^P%O~{NPL%2%WeHsgYx-aD1@c&?YH-tOU%Yo6thHD@(bubS6S!ddH6HeveYu8fIE9m9^ zE_{-{SWYKxx2*Wl<~%(#Q6-wQMDPn|b~Nz6=Xzee#l1L~n)OBQdxvAH@%re%L_Pzr zWfsJD_o&vf3#1CfIxnrRWG?CTMx6kk7Dw7)+d*>=d z;O4+l6QJ)gzWwd3Y1xxBUGWnQr=55+v%!E|s_u92;fwfNc0v@s5niS{=CiIdOQ3s~ zB=1FYzDN6?{Qo{5G+MVv1*mq!N-VY8|LC2BbHEJ|ROit$JXb(9(FM)ej*UeJ7$cd3 zPvwn*EZO$?-J)}!=7SYnA1wvgtwC^J0RB5NMUyxxbm?W_*HV@rCu2`cwypryE1<*T z2ps%5TN^d17sxAMqOSSW$v*L2C@$mSe%fWmlciT$3RWQje^23NU13s?l5>R##TVHa zYB4|6gx-*J+z^}j(ZiY?hQ)ejlVFRi-tK~R+*T0Am z@4})N2!Q)T+r-Y*84L@_-#+sX^4YEq<<6v(eDURm%YBc^&0KMp%B|V$CO;Xg>}c** z>A}C3W4TXpO6*jF1E!2@nDvM$?Ks{8UTQ&cQ7u6F=Ee=qfhFhdKv~;8Y{wpQq(P#-(^{$Nv|YTLT+y}lA&?E|g{ zkJ@BjcfpVghi&PrdI-HaMiN?O*zqr=#|G(xor8w4QgEvMY4r~ZsA$Zd>_d0Ci6*$-1@@=KU4@>AK zZgxJh$%+2B{P;{(vBKU@FLMD|>eV`?UF|wiI7_jf9lft%c>P@BY8>k|Znq6>2}&pawpFrR~SPrAW4p2JKo64H7rWMA<6$uwS&PK@^zs06OmyzINx zBbzI)KyixCl>N#IyzeQNo>Hh%Y$VFQJ3ad)JDMx@_oQ51QGB<6vYdfvmZQ4z(a=bu zY0WoQXSr6TA&+L#C4tYTo`(KaTOZSOuYh8Py!TqV6=zu$)7;hb=Cqq!JCD7Mlb_A7 z)ydZFH*}^75@}E<%n2&Pv$keH3%>2)a%P9gUJd_eL;xL=B10y0W&*>kLsC)*-ildVK+y1`<;R^D5~&0@~GjFXI)axyGe zK*sI?RFC@Y+mRGn=~S2&R6LvpaTD)AOa$FUZg%N$mYs}}s6TH9ep z^0gU{^0=;6l9-sP(%1p&qh6TmAs7KuQ>f~tsOSRq`6VQuliHC%zKUc3k^QpW{%o$Y zBUQ6w&9GpwV>~%eeLuu3H8pzol>W^-7`tjj1dQ!QVG1wBkl9wCOkcq-%$zf6nEuG; zDP%=Tjjr~$muhrLfb1kes!U(9xluMMR3C|1i4|2*BlZ2>>NiI+66zTUE9nX(mZ9=+ zuWeTT=y@6P-K)7<6^(VnB|k}Kd0J(8N#)} zaR&N0>G6k(a!3x%YL9#N%zROzc={<)>I}ln3<0fL>B>`T+#S;c0~TFLQPfridIeGK zDjBOiw>mXli#kLdxlaNexAMMoEabHh?ChZwClu`(4fPkF_L0P~_}OX?yZyl7@(@gp z$52?>3vVWtBD1_GDsgOHSGnu>!8e6C^X*Ya{s1)#^|J5*e1{$>-!Vrz?sEkYpq%#3 zWW;wAzLhItrLVplQ}@hCOlAcKLP@+7EzK zG%?F^K7Bg@%U1YKWm&Ia!H(5iwf+5dxNFJzT&5DsGX7L1ZoIuQ`NgYrH9j|2Gx%g6 zFKi%8$d$|hN%(;w<-?S6u)pYRJR8r9FUKI^3PN%M-g71nGm4+sVVv|Vha{8Ev|4~` z>**7kL--`S8ihqWNvT|sXp*jo=v&WvNy%kDQR|D0qtCP2dq=SdgP7(;rZN3v{gK5i zpTucKk#R#&+psvu{#SDdC+;C!kbBXEv~R=OQ|05ybTeBq`a3u;?JsGmfoMq)}(BHLZ@uWXmqnJarb)>s@F){{t4bLmP-du7_>R zkQ(b!@h#D^31?q0IM1VRuQ@>ckkEbQNulGaJaLgjtT0#9V4? z4$4PwZJ@-y^)VrmC|+1(=v>%^uVk40^@l{36Y?;1(g0p zCZt)Eyy#hY*)rxX9W|CVBhJ0I3C}#-c_f?>YUMJSwVf+)00GuHWt% zS%w0){7wAh7MI5qlN!%F3}0A!-Y`=AF_y-nJ~Q;Vs;X+|B*Ry%o;E6;n}U3 zCwL#*`s@`z*5)(tgpsV$WUKFnP$I3NY#MhDk2Ax?hpHmClNy%Y6~YyRv8?iM$s5hs zd#k26#cu95>=SEAG=5REz#LIq%kKE7sb8&gHSD&`!~Ba`mv$-n7{z_ zfEYvgnS_)mBMzK()PzLU*U!$ftGBm4fngG!zi3Km5!ufc+0v`-89SQsIW^9Be5ZK+ zJ0pLG;dUEYau-Q_tgbDmSpICIPUGIZDu1#}NYzEY47US(0${GNFSkfc+cF)%ERHg%q!b}r=V zkfQ01?9Y&-`PIh}1_9!B0m4t)>PUaqdnMd-J~8k+IyuRrSP#N1Uouma=8!8-I^BJk zNN3FfE#P<$trbISbXq&|L6y2YA0_6pe#C!0ymWJ!VLD9qu388d;_-OfV%Jy6z2^+~ z`a=EjCwMgUNy=I8Bwx(LW`qWu84O6s?m&_3ZS7iHqNh1?)6|gkv6TC2)KZ1{P{*!& z;{Af@CloJj10=}V-t@u_4Mgjm75VLUE3bg4lMC;GUk`1AkbNwKDjI2XjjL>R`qP)R ztIMK(Y5G(GlNR@mj)_YR;uW|y+w!>rG3%SL&bg5D@s^^G8AfPK?mqanVjIm9|6VY$ z9voiuTkBWAR8E@&Y|1ZTxsMYmw99Mjc?Iy24EB(?dsaqD@-pgn1KyRlDt;djj z6`Yt5vhx`C^PwK`&Hfm=g(2uBl z*3g<>p}r#5x+JjT;OD}NZ>&akD#q{guCl^eDDag9(10}nGF==bunAq5-@XED^Vf&+ zN8KJhz1t7D0#q6xZ(jTVz=2*25^#Ea($OK&e#e!(^8VwdsH@9?*16l%#-~+{=yIok zAYU$>C~12NHSP`HqO?VuEHghT`^CLhS;hd`CHp90`d~F~J2&cwNQpk0WHM{NVU<&~ z7PJ8JfvM(86~ebyz@wVx1K61JLUpGGqf;oc&;_39=v(cwG3udVThkmD9%^hLmZDSQ zQKsNiu1j#2bzjqG{tN>Q&U1bim#Si1mW&55xEW|Ua+^u=O$k{>1k4{sM|XMs;&L~f z`)+nTCioL~yMYH=@5CGtwtT$P^Mp1CSg!&Q4_2D>5d6!v3Z zJabj?D7Y$Z`)If8tKGU1&iPRq%gtyoeZYDHBus+NIR2B~qWP}{}EU@cdb0Ua_-GK^VTR#zsQM-WE;8yn#x zV#P$UnRGE7&5Y;Li<7@r=1fd@Y8FmZS>a6RK5TLwA@3fFP*%xt1^Nr;8gZ*ksC_K< z%`su5)*y?{$?E)M-4`(AhHOk%tDoG=yaGHInut*f%X@Uj7tQxZaLQSqNPR0+Y^S5o zea;&`;UH|#$%xnBysU-p+MDI-!VYrT+QGNw0!fxGG(({ zQQmhYNKcjs;a1+aydzy`L2FU(Rk>5SGZNBc&SXtN@mh8AjQtYt5~+C9bkv+uCBsyM zGFXnWz=6lJe$mMhQWni9zu#r^j8l>#=8%Pjnr`ieaJNI9NbVA^{>1&E#WRJ~<>yNe zBl;M|h2o90o}tdH5im9a%cdeKj{JG;Dro6I#c5QZ%#@12f7|LyOK)GMf%d?!&P zg*%bLP2$kVC2cF{W2Nwx+#$1Ht(RhKD>GN|pFcHRWrK8Tq<)W=>kQ$33D-l7EoGrp zAG2&uAZ5;+4bE8UyV9`w61TtCHq-1EqGp!ja03coA->nJ54-!j@{f3Lmj#V&S~`xl zfBZD;%A0anZT2*p6E1WBe*V{ulBO33cq>1SF-l~*tC`%lz^7PNQ>^moq&IhX(#LmY zAqzMFKQu{6?pI}q?4s;*3nG~1f-`>^Cx1Sy$?N80JC>k#{(>PF5P%_Y_A$8*DmK_%`?pD=*)ZW%yijqI8ysumha-Iv5Gg)la*k>?{u8rM=dcH_*wL;-b3? zaR>%uGfD)UbCAj+7G)DrR%MHPB@_K!wB@Y*-dn#o&m%%akE9vhC2BH6S@P7rcgT9V z9mS}L?rs%l$^UGw;ts~vZnFv;Q;Eh^3r6U>DSt|9{FG+)aGLSi6@ViZkY?=P#!osc zQtD2f$lUahwXnhStfsrC^FhBv3jB?`%dXE>&%xP`Z=#gLLfXumB~i}MUF;5v!O9Ek zHi-=5{&1gh|1Xg^Hy^#EqmUw2VmdXjTyUki0+#ICgK4GJzPG2x^SOFu>O(Iwn_~X z6b*0UpAK*{UFy@XHtg9P5-J?3UV27z%E#X0@g;?(=D`$6T6tPV&QbHu>ItJ{oA@}F zXR#_3cfZzA;(lz?z(w88aWc#X!H|O3>gZd%UvqQI~P?WJy*^fSz zJj%HOlF}G20_4EsJ!K1Vz%W89DF86|=BZElj;^zjMrBCbk!1sAK%re_cCXSPQ0&v> zfx%aer&Ipz??ko7g|~`gBuxS-@w_2-%}w24d-~m)26?6v;ad`Va}h;z>(yITOP!9) zFGDM7Da4M%mYSN78K(oSSr`Fit?3k>%jTxl(^HOL#ZLN2ri(DqMfzDDXqgi2@Gf>2 z!?Y}yd8M(-k@R&v8P>-`h2e?QoSGNj78+R|hq&TFabi=&3>FS&zre%9m6D;8EWwvZ z0;kDvU$!8<+Q#&qZZI-P=I&c@o<3(qspbX{$swcSd>a+p&a5~GcAwS-mb~o^`%(`H z3%c}uQG8B~_*AsWy?5iB7YW@GIW;byG3kHHhMVI}^70vRt=O$!kgo$+6Nk;{dGRth zC=QJf^9dzNq#XYnziF*zF9Qd0x?ky?$q$E{?cq^D75)ZYhGSA@vjyh$Y zp4t=J8xFJcZ{@pWXki?~=gNy)dvvG~Jqg(;9>zX~9$6tY$gaQQCFh10qMH7D9d6*C?#+@~t z;_MA+Ks9B=3!SG_KU-|^ai7Zd%VEVQg@z~4#|^~DUu3FOoIDmg?v$Rx6!S0MM;CfLVyc2BXRGdA7I#*}+b((lp zB;omKSikD1aW4LBn_!^;JTs=74v;mpqI7<~ETHtHPMiq87K@spk%(g5c;cxH*}6n* zx=cZd`Sffj`AeOkxE7M6Qj)>bann=&=R&eC<`^oG>>n<2eOXiYdEUN3MCwDl78tNO zkR@wfj08l{4c5XdLEnpuX0S)^b_o*XW6gH=duHYOEeh0Y7j1vp)Yd7}UN5fNGONol zY#2Ut3^tXIHRYQLsyLq9XqNA8@Lga>BuzO{8aU@W`Y+%924H=al~NGuSzF-Wu3#N|xm8~-K`CN)ys>tprF(adl4=_AM? zzVglpbuq!s;kDK)0Nb5^XUVS_N%nlgDyfw$g}Tf2L$~Y>?rNYNqw)TkFlKV757)FV$}g5Xvp* zmrQ@1LkEZG!d4aqD(UFwu%uDC=?>g*+%6(IYIYRNH@p?T_TD-v zg74pH|2f8hma!^hbSu9H&S8IPFKG`3w%{un(fFz6L`eMN$Aw1LwXtD8u7EJ+(`B)( z=oSQ381K2llA2Q6>6w}{+cOu#`?~a-i`Z=8XyQ9&)z8k!@&d?=W7_9fkQgsj*}M{4 zSKq`Gs7IkAGyqbk`PN1=VQ?|VG=5W_&R?@WJM#mq9V&y`G7}Q$6)cX@+u-cicI=R* zbFDYQ}UN1?slbTa*NOD_FtUHr3oU7U9qA0`W6%y)!KB-Zu%)3|)QaBTJcG7T)rh zI$Z7bjImOs)yr7$mkGV?kE~gy8Vbd^@}k;x^Ov?$idlCGw+HmAd^W9aA8ZtGR|j_# zt)Jzf@keEnB90Or^HJd2FX>ZDC;naTZnckY6XHyru&_u0MAX7@OQC>cRTHBHRe+MF ztC9+qOVgui9-DVLQiLne7MAdCDPPB>Q0V!+VV;V48a4K`LL>gvH11>CFadRJ>dps` zZU-?qry<>+p35!r((cgMeV$5)DtXKHKJi(g%cD9>F7GfXS|HtW^j4>($8KzH)W=*k zr6rAd_30i9=e^wG(Uq@DrMMN=&kL8e`WBsMoGZ+#tET4Gz%v;N#M;bXN1P0A;Q`za zhnpBT76ZslfvKkY*_^JJSFYW^jy_)qEc5D4Yp97*I8K3>K9YlH0^in|E^F`9PI4u z5ZIKs5GSaKO7L9RuzTamoa9583;U+{F*d1YXMjSbuf=PNL81{6pLXkkEBh;h5FYmWCVT^yA`Sdq;ubrzu$&dnE)qSxSYOqB*E41AhfQoX}7GgT0PKDBlf@Z7~l zx8K+^E&FQ8g56(O5%M`XbXI?G({RQ#k5kJ(tt>|Mtp7B94Sd@MCri@|ji1;o)PZWh zq4j;+PkUPL)h81dpMxiE!)QOCygR5Q7!%xWDytsAy4kC>VYA=RGF-ecMeXCUJ41qv zV9X0dw6oI3;g#_13q?PHkW|Zrj9(F0_~L7;JXlWq{1j6GF&g8nUyxE|-viyyB2yKU zZ}n(@Za>cxXIxIt7|w8%(>o$@Iu)z$C>b6IuVrNuKqpqXyCT9xjIDOJ1;zA7@0LpLG%~?`65u`^;m8 zpt=5G(hqv)Mnvd&qB|*gpx2&tiY~l1kAE}Tond|BHpczj$k2x8-;8mB+axPB zIAPAZ$tJP(^fqIPHv>Kbv%1h^ax0yo*4f&h|IKP%dXGoFDvQVa%)1=xP6dxPMpAgB zIJx4rw3D~suG;z|qv{Z3V&;qT!dMs`(;uM98zPdtt+dVC+n~R)__2Pmr7zLk=TIsRNZZA zodAb39*KnQo^CR1DyBug~zFpm>YSQ;H|>nrQcnR;Vh) zZSqW*V**zEdQ)UgCyF+LeheNkqV?SAsNCfJSiAAUnpPQX>`3Q4?w)*kvgaRO7-$*G zOgM5JOJeB*B$y$%C5mDjmvrEC#-st9;^bJgXD zR2uTRz5<(6@$Ow1P}{<~obDbah0AlvJ!7a|Q#wo!7nt|&WfU?!SV+wJP?FUVG=hp* zG|ww0^)s_bz9F4sbdTu@XwtQxE9>(z_Gq(WxG8h0y}G|Td$Ggvuur@|N}b+vbyt+Z zUB)$S=B*eb;koQ!R}>4}C$;EmerXo#S@8UAjXmXMq*^O=2~asHx%$}=+>70=b< z&ld+fMU_VyJw;?vdI1wj#d@u#b2}3?V+RIRJArxL9S+&iDDU0h$E4u41d$EG|B2MC zLX^Z>$79yU+KTvbySQ(XDiO3Eo~3Wa-NuBn2w>t%c+|U>jLu4$nUns)6EUNgKR74WW zg^311zs#yiDytV@QiV(4IqY*i$C2c+k<=az``jylG_9=}#|Z4s+744WeP~K}-LMAk zNJXvi8bIyFcUQenU7{RQWoipkHY`)gF5&x^P{gXnkoq zTz_&a_|W8b9|za4Pf8$%t8U7tcgM!YvkdD;&oAlE`IFm*19F99mOr&m(>%nF(TVqP`jieH8r_iWcqag?>c#adAnqq@L-+(?xYNB?pcgL52@7#_$3^?e< zLJv>uW`Vzln<2Z}*5(8`_7cpn-!QfBP_tGq9TD1#O7}Za zkJbIoT{83KAt3YNg>;>@@I=OH#iDvcp;j!p+53^P`t7lKH9viE0hiv7o0_^sDsC#j z)ErW()%!}`mS>_Oz=*b0KZv(8OK7+%m#pEb_D8f2*wd9t#|fqruEpTf-F`9ubANNu zyovNm@E|}(XD#c;G9tan9~ZihkI1_S?gL*R?l4PhFBbh;y^&oVT^T6U`M{)?qO z%WZo{zf-I@H+3+=)PPOwtngj!yQI1a6Zn8P=TE&nRlEUh^Iz{y^rp*4RAu|Dd!&pI z#~LFrS*jQqe?h3Dqxx6FvK^nEDI%9&U(wNIQ$mSRU!;IyYnxpJ!)u6GsRjKd<$iS!Oa=b#e~x> zs>_(9+Bo~MPsDEZ~G*^lHIlC#g<-GrK`(vFf3X2Fa@b%O` zPPho%ph-2fj-2VmXOnj$EDEo=-AVIlup%K6?#zYC(yy;*h<7ArePHBwt2N{Kbn<-q zMiW|S_gDz{=9C(^@y`OM9t~Y#_Pzqlu8)|7S1xxW3FAWTK29h3H>;Z1RY|+g8FcD@ z*rC(bavI+X-pyQV++rlWq*@p8GdTTJ8bVIXVLNwpX9nRey6WM@mFmgd2?xCp?$>wb|-VSSDdcI9D>(}OgydX*zj z@_eb@OaiUwxaGOik5lum!3-h{xxT5R9d*NA71fm$t{hdT{$eJKxO2`j+-YqR8sGY) z5~ChRCzfLTc=CftT?4IABoqVOMcKUqzEoO}n1(vu`mq~^op?di-4eWt+Es8pXNwkG zhLf8<5-NRj?{x7K_&yi2h1dX3HJP&1wz2Y!aSW;>tBDl!Sol7$H%%-=4Xsqgmbq&sNOk=%T(@>G36e5SZd zSiYqRJFU|=qFv2{m*?{5WFh?PFTWzPhOJ8HE&xGZcVUgK57*duoJ(5zW{r2 zezKFkPw)JoanM)w9iygZHDz(-7cU|YxdW!Bpyk%6kj0S4^7-c?NUW4V#y!$GeeRoy z7vm>3n$jNg#z-?;Ux!20=au6r@MYJeH?Qkg6~(R81L(X*@+iK(TRd^>)c%+~J*?>; zV-Y;DJMv*R$x3`vf7q}Yi-r{~j$9Dy6b!uGDQUue1m$!MZ1tz& z7OS{@=Ovl01XLqwIaj(X8fAb4_ ze@Nd`lB|iCZ?$i(4Nf5IPpT*}yU*j%5wCRG)LDJWDVH}56GWWTJp8%&ojq288=uO8 z7xAFwQg3Uhhy0cmhgQdwgt2lZxQxG9;|gL_cIAte=H`fmmCCFARLHpO3s`9IG%Xc5jiWce>Y3nTliJ@)|1Tktc>lZA$0i zJR~KzNSj58g0`j}O{5;#&5dKRFGh#I)_wGd!I|-1^-vSJ!)fJQPSN!IrfBK&rRr+e znS8N<$9YM20m7A~#@-V>Hl3v08=Yxfu;X%#)QiK3+yv5iXT1^cJY(&mZ^IEK=|8Jo zg|@tVWf%hFq-^1#_Op^5j`(Uk{jc$X#k)M%Jyw!F-wlv^5lZ3bN+QW*QmL!@56qCY ziBT{>?2*!MTe15WGp5NKIw+H`>F#Z%vhjy% zdi*=v%QjjLxnJ#D_a~1##i+$u4GO}07rAt4KTu)O_Nz@6?;fu=Bydxws|!2l8ml8XbYrWe;X+a@ac%ao`0R@wK$o6V825qK6+=53@ zUE{_Kwz`J2T#4xY@=ylUg?v|)wBpV#o7VXqHvzGMk+2Zw*##uq_pXkZcU*>a8v-pZ z;_|DzcTG5|d23}N8$QSIw0&3i)m!H1p_**2+#lJKHa!+>mun?|mrF6gHp^fBYhqW% zbeNZ=dCcWA8=Lk|hh?kAyMDNyndcLq&s9;D)(bl#6S~9f=>g{QtzA{OLP)KZ@?qo<`=q|>o(d^NNs`;wk3j0=CKBa1@5P!=0*y>&Y8m;y`jie z`GfD~YRMx7QE-*dp@ENe&m=|N-|Dl7_jBpAYZj?k)K?3djSV;VNM`W{$hHkEWsyue zyf5pL{HbVOL=)Gq7|@SJy%)VL!-m3BNBI7H>zLCg{I$uT`&EHB#$Bir+;n`Rt!5xX zihqfS&a|XwS;yYxhWF1TL2#qQ0lwn2*ZG32N`@-s#tW-z9@?hWAKoXUBYm80-0xHp zFP6)j9cqn2j%;fh0Bld!c^ab9njHX;u%-BBE8`nB#O>M(c^XilYzf~F-#hjMbXf7 zv(}Z!IQkAF=rW$yXc{^x(@76{v;||+7h%qe{A2<5I`NEzshS4!#3OS2_WkL;U`nN+ z@3&E7{yYUhU$P+V*wfhDI!6M2Smt)P`<<|@LS}go*npv;o8=mawt>;nMdn%h59`6vffZIst%8%>p z5{fOtWMN;b`J{pg0If*>gHJlXqNEOxZOWg6bK9Xx#77jtOZ zM)&-lp#29K?W1FgO#@09+Gdo-;yO$#U*?_D=h(8FRP1VHKn?Ef3u#G%I8=@jf+8}b zJiy*0O{-pj>-C(-UE=W69))$!`9|NLH4a+2jP~m9YRSgVCY<%tL}x$rrC8Y;|1aj= zE2ybI`W{6Q5do1V9aKO{=z?@ml#(D-Aaq1Jk={ZHO^S5sN>z%8lt>AoBfUv)lF&o1 z2{k~#oA2+=f9~A>+?V@sXYRv!J(D?S=X};)d#$z0TZ%!M8xYm~Sw;Gww&3sGXQ9X0 zZ8y3xqEtEx2Nd0N66tQD>`Kpm=BF(k5%?tY7_gj~{>x_nNSN{8{VU$T?F&^gG}UkB zYA%<3FL=3*2)<+G%xBdwrY~CSg#1+1V02kp`}P@4rEl`0(o`9)4$>l_$ldu&43@1H zD?GNR3}__OEALbJ>@-P9VxbNZ*W>PEB4Cd98`Z&C^XHp}ukCr?{UfO{I13suwNUqM z@vc6c#CQvfE?e}Z-$Gr({UO2V=DtFefi`h;_~ZA4dPPV*j;3=Y@|rP^4d!LsicIkl z2_-q#^~lfQwo(AM5{QuV({MTwcMcK9kTW8r^uVQ`Q|Wi4?-|qvENm@WiDbxC;hn2` z^ov=KKtV&>wTP^`JeVuA75O0bcA3Y6$ftA0fZ%Q5ZH`RRA01~S)+PpeR($$F|LK1Je{r28-3aL`Tp77^Ckl5fEOjiamPF4Pf=vFAxN==cm7)6D z;@(`h5QR7sZ=VgK{mU)#t0 zLVunglPviufTb-&BNVj!mj034M?u4fZo%L8dW!${eruSiSWQD-wcb=cVttYoVtUOb zv0g_&&R}K5uhD;Xh>i)4+!mLTmB5!!>7~z8*&Tc)Ef!j~Vpm`H4x5*I{_LU?=7O`! z{v5ZFn+)RU7Tb}TaX#njQH}k`Prp-e79HUC7Ej=*Spn@~b7iQ}gyHn#~dOFj#3G(BCX;Q|#3CKr&mH?&wmb2e9M8 zfbYSAQapBnj3xKqGAfw!YtpER`51t?oA-ul%=c%^HW9mOTJc_;1JxCVBeTc@fZFNi zx#0;K{0)I~ShiKV^(-qfjt9wV*JER!R;J)Zl|%PUjoBOXLw3*j`Tv7uXQgO%XAd%oQ;0mR23xtA;Vg~m6P%>s@yyzqljBr$yYyV zq^dt^q)LD>5yY$FB z<%K%W6~J?XI3Mc}Wq2*QNC7QaQB=|9)+{eCz&#jAO3Ln^5!-^ZVb1VJylQF4}9Xj{!aHshwvgo*nT=af&TA&+uK`hcc1xR5e z+5)!CkJ$aga6mhAUK1;~`vh~&r^f|s_#jfbw4hfyN2$~w($yy9=y$$s< z$XA;}w1C)}-rf3md)HOBGadO5s~)h_8%3ZfBM!rWM1n2ypt7y%}ggbM*p z)Iv_CgjO2j>^ee8_v6Kv=tooJPT{{sJQp1jA}Op()lz>J1hW2&xq zdW&MSabdTm=|va&OGgWNS4gk)q}ml&hBz8`!(1QdGcQo+Yb(Y)u3DkGSrSM5y_-Gv ze#Wl)A--4dmhSGnqTA>ft50;4j5cmD81@*G-Px89=pJ(*8$Al%FRGmXgwr z&cSuWCmr2sq1nTlu-~v!Sh5JRk!b(*tb*B1Q@bgm4_bOw#J zWQ?R+Ss>62J4z3+H&Ve`RDvI`6oRzS0rE{md?a=rmK--R`~mf?a<>qUGV`2 zju?Hfovr>T!QSnk~*8KC5Q@d+rT27U54GFCE%}N%UPYu_EAFr{Pe9wiu z(Cb)8IVfWnf49=%n4+bU?&E*1Z=*N_GM~)B=Kv`h{HyuJuCoL=ycLF2BzMW`KC!WP zE`N{f_Wi7mK3My9f#LD#b`Ih|j_?Rdx%m0zM?BHIC&I<1Xkt*F9GJoEt#$U5+P5P}iJmLJ#M#iF7gwYqx?MD@!6}5B~ zW&4_%$hX{6n300t#h0p8sPan z>~qW)sYL07ItxoRZfeL9cD`WQsEstaDN#bv0LqY+;rTEY>Cf02U5@|I)2Y&?7N_6R zY@0(nATKK+zH{cU90_~6qS;38l+Pe}Ys?!mE2(%-5+Tgk45tOA2}OvR-!DVrOwaeJ z`o)H(yD$Q>Z}bhuYZ^TR`qAZ&;oBpSm?O~6>}aQfUTzM$H6wN`0HnR}5x*K_4xw@un)BJr| zi^k>zcIh0hHdfK~$K8B~ zGoI`p3GY3ci?jT513TtFO4gs31q?b>q_gzb8uhg**j;B}M*m12*9b`kJs_vmXIulQ zB8KcIfIGxbug%AC0?QBL>2U-#d#S(1mrU9ASf+B!UB6-=K>c_DRn8Zehp{vdu~qbABp*`a^uJU zNW6FZbUZ0~RDN6AW zng8SS_6a$j9O0I`W>DT)6TOlp?Iz;0t_y5I>?u54nbI2n1#Vth4CpN0`k3Ugwj^I0 z5fM+H;{c(9-6rSbZW9`G#sE60Us&p+FAhy}Sn}*N$yK0P`(ug3S#$O%&J;BoBwq5riF*FBrOIscl51mS`1$XGX%mMSxnZKgJfF zx(5v#&8xfAxojqVD(R`|MyiW+_D`9WRvNPHuC-nkAsl`Z9$_t#+RZehag|6rmmkou zZ?7VA_;momj20i~wEvN0laho=KtLUx&+!z?S0UhOg%?kL`hcIdQ@7t(b!T(_xfWNG zaXv>PcH`jr4vF-<+Twk;3%jlJT>i%IJ(KiJeca7N&KxIevFTGk z19Ym3mn&)(*-~dgqXdsWsbb^>GsU#k*8E7**U7Za+Y*L1Z4-GpVn$%j|Il+n_qDtv zeWy2VGmP%U7V>^cBq5#5ifu7lfe$lypVgjPjrNZnzKA2@WM~RPyy;Cf< z%;@L3t7%bQDbF#Gdsq2p@E8RCy7d9R#KFbO8ul?^B;)4I-Xx``>45wd5|?dtkL{2L z%}>8W<(D2*E$^|pSJydaJ>0ARk#)t3&Rf=e>yUFJMMr8w8B08G1c+kFeKHgHdS zH!rJZ3Xg1Y&93`L@+>I&qUN&>mr4aYSJn3t7NB5qns*4DsYx-cRpUI7sI2a!nd|sZ zTHP9XmEX#22E|zBo6LNr+icWeo2oY{u}*)v{Zq+)~ zj{W0A3t8*brfL^A=9A2R*i75Y;QoIk_KOGyfUsfCZjJSyL!5S&Gu71stkN7~;R?*2vbn>X?4NTr;Sa-0%}6%?{U zLwI<9j(4rAh8S-_Vjzw3Jr9hwIigC0B?N~;ZI+Tc^QpJ=whkRnFLExQXRPe=epz+h zm5!p=Wnw;4m_b=%{iA$X>zT((R8=f5K{F|hnwonv2MsuxczdgqfeblAXCtqq`s|j1 zUGy~i__C2xcv*)vh!3BNJ=4UBChjLhT=0nHj64i8M)PH`%HHU-qaf4yrt-Z7~zkdv}FF%o>fNPnz_W#nat}kE{a+!GFV*KHA~!X zo%HzUC&*iAi7v&H-+>RsjQTe3U_OMQ3FO$k^d(VBATF>gpaEIe7?$>`;c!~?soSu4 zrswv}$gP7sd2G5keXl}|Sm#8hkc;feswU844^o+VT-6>|5*nDxjJ>p_I-UOeigBMj1NjwkmeN)(K?+|8)R)h;uK`s` z7VF2t4#MZXL|2k})@(7`$Sbdmft7xpr7F!%V^{O_dV7FP`z?^;H(Z zZv|pq1|8{&8o#hwxHG4?8at0{NPg~irfD+#~R=w8(eO7%8KGC4F}HSrl4^ z8FuAQTc1@?pXuoIG+MwO57H2-q#u7e)1v;K(@4>ZWBELLPNv88(((;sDC6T z?)5)L=)Wu!!5#IrdIS~s(`#pa4(6e4+tUC_erwOcHwkdL3IX6CBD=MzH8)><6IH&* ztn4sv^cVOedrxpFpZ}@kbP9$ntjxeH*LP9O}QEe;NtW$s=u!xnk4^` zkWZd3n0}6@I&uc8)2is3j~OMXavaP{;VYwuc^G+U!{znC#e{cb1*1xQ&}i*Ym!v%< z&#+O-NRwZRQ3P-nc-5G|^Xp_=Hv{Ja8}8lt%We)CjfC#;Y8=xnVElbozX0^&8FC5U|R z^73~mEpN|TDr8J~TtZUtV^x~usJ%itlP(QeYKa^6S^+75u5?US)6y zK3}B)+xziE`QH4-(1fcUk|&|5xlWLcGyb-E75tkLfIyr=cPdArq(6RdY0k8o9BOu> za3n+|7fNV_{Zg9hA(!oEfxo~{VEddyYL%l*?-=e9jIcm4KEXx1hvrm7`c}wLr$K5c z_-`H`Yf{4TB4x+#?n2W8cN$=&wgz_{t!`3TFzUkbrU)Ki>&ydP8A7$*tG_ zezy(&a#fH>K5HClwb1F;D5w`wT|aI#U6m}1+q^mp;kJMUDc^siN$WEudU!=c%msgr)5=vq-?gdyf7gW#-dj8(Hf zub*+P_%3Q?Ine*#jr_16g=s_*1L1Bs7nx;6Yuk4;rGbo^fN%qwZZuV!!dnuJ-l074 z-i`FD)y9OA;cqgOypu@)vgMcZ4`hb5C18*etV2?e-V(>FZ6&rQF{a-cPtL!k-3uGV zdi=$zzXKFT)@LwN*#RfpcwepT6{WzyS+eydQ6)+36Or;>6V9NEZp7I6(^;2h*yU?W z=eEDUq;So1_rFan>lWqMeePfcUln*PN>!E21!?Le<~eqs*$zCEl4@iR!Uaa>g! ze3U@%&g;?F$IX~4y=Q`%6{Zs3x(IrcoFvIp!nyo9!eJ+Q&+@4gj;0F>9jD>4)Ry8f zB|joHzo|-J>gd)X?RYO`xQe{wN%74d|~veJV4^6$4k+OeYZ zfl5{QvO)Qfp=Lkozy|!~BPkqn8$yBC#AG9KjB>0JhhwJcy%>mtV$0l+JfcHmm?c>^ zRR0xoN8-HmFLuL-l6=EIJc~8yd*2(&-m57`R{2PeCL>NJ)L}y-kz5o$`D|NY zB|hwQbE9kF_t$+@KkU-Pp7*s+MU9?6gHn9tJNO+_gzPxed_6l;OpbO+0e!ksp5*3g`1Ij%8g0-wz zfxy^$EZM>k0I0QN@gYNBhx8-NC$fG)PfvAic~&CR)wLin6J%2ctiZ=^Rqn6e@Vif9 z$iZmvRxs(g;zbH%teF~z?hU?GW$|R#7(salx7O)fz9POEUtUgLr@;|zx>ed;&u;kI zg6fqy|FqNhj4EdoVgu5)l{1Og4S(J(9e6?;;HRR)Bi|1NA8e z8B@`^lGnd0O0uzhilIRPdas2h*2ZnboSxN$SHsecHMRpveeEA|UV7vaz4;>wH!yoP z%fodSr!iFYb3n@HM$0Ctkm@QKwU;?R4!|#C#hBG2JFIhZ1-&-XvTEE63wBelkx!76 z7#xH#uItB7#G`)Z67P&8@I2z`y-!;$4I%yDen^YUKa!#`FfpI8X3|ZxJq}->zkg&& zrYivyuOuN_LYU$e41Yk;U4D6dQ#P@Hvhm32`-9AfV>ZlGe**Z2f$J@lze6|_%&93$SGbp7o6T5+ zPalxen>R7U(@uV0p5R{NE0J23xSk53&AnTZ@r6aX>(@~7`5*NyOJ$7m#|2T`@s?$cMh5%_Mgg;qu9<9ju*FE z2CPsuz2hdIZ6=QK`**m9Xa_v@8T!J^7zo8I9!8K{Ok87R)}c2YcFtGdO+_JLbiopdb}^wAaew zZ7AF8iE9>gKg5CUTmn%Pn8$1w{9KGex3}zMSau^nRBA887^hVF3<0-LTw7}&>?3N{ zNZd(wNl1JdcNgwo6$be)YY7Uoj4v7>k-Euj>2BhTYK(q`BG1omoel%0ihE_8@$qKm zYd`}%y)>URgMgq5^u(XTpV()7G|?S`f>IpP0w9So;2?ak_$nx=u+70Sn7iMhY`axH z>q0~1HNX)$u3@#hU~31IHI&cHp~9f*X~LG^y=`m)OJP4kww%E^s+U>>P#ieh(5-5a zz3SuiR6$5h?}yxh2hY(-o`0Z~@}2)kn3`3W;jGTBaue@&Tj$G#;v`*DnANj(xEb{x zzSR`*x0#&I-*L5|DOI7*NoT2>%czjwIooYw`A728B>uWdpo7+@AVn|WNyYbhB&;`x zm|u_rUlG07Z@Ul%Y`6I7)Z>H<5A#JCKjRn1?RS~hsEm(^>X|a)yrMgVDM;EMtn!5yk+UBN3ALnKTp(c- z{EK-Qe>QwuDNtbLwUmOZDEj;6%(_;^6x!*UZ*+Z5I;J$dLylT-Qr5I}xc^BTX)~UF zNtp@RmGIOcA5@R1o#0r_Gg*Nq2di0p{l)T5<6K?wf@*qZ&$PJpkyC6T&sH)CdC20B z32WP)G>%@su{LW~_`NkE-A?>fQcT@ZwUKCj={`ec;7m63U`4wPM`~4zds4>!0GXl3 zBYu@`!Ek53$f&a^ch+W+L#Cptq0&TCur$+}Pe57F_k&*ECVnn65_iI@$%&wWmu_+&|=joq_CJ&X;DrtS&R%ZlR498_^64 z7U*6V&&QS)t6{I-Jmgw1M{?l8Cc(r``M#1Ev}Q#$A%hqV)9h{{{rT3N{?^wn3+OKm z?TTWu3I-SS+qSemH{)JssqLW zRckKAfARzU<%u(&A@?;u3m^kJB1bv)a+{=zE z*q=H$Yq3f`-RKp5voalX%Hg*>KbiEMxz1em7OlC263|4+))(~~!yT3s9OTuxz2 zNXx-5q{LkN6H3kY&vXXPzin5wk+!93n4FUc6%?^nM7EI&FVdm-gyqrXZo%(F`ECct z$&kBp)xS7HSMPOH;@eK+a7146_}u$PJgZjuwvkSC83l)qp~#;u((7Sm!DAf3B)m&& zu1ABlMW&1NT`LIOA( z^Cpwq%3rwBIJ6|^xilltgP1rZ`(9OACI9kGsCbI&qPM+b<=ObtkzI|WKW zuev|)mm%@81XFvY3q846!8+qZ@0z{4niX)ZK3@CnqSbr$~cVxvfCvWlGW9a+$L1LD+(oH{6mn&g+wUmVE1>HnGz48 z(!TwvV@gUArhwt!L2t?OrQ9&$ez=tXc@qw~axZ25bZSsg+9SfARg2;C&2BjfSk!bR z_TwO5E)2dDe?3JS?#=n?R zH{dCol~7!;h=GGTJ)fTRgNE#&N|7_mrlqyf~elHMGHS%FOUX@Cj z2(E%+MFc#1HOt%}TMZ|oMEcW-hV{kDkUqM@?-TVCig=@q0ibkwh=6;qEY%C^T~#Lt zr|d*4fWXjd#jx$`d(Atx&kEHCa))hDjO>qo(%zg+BikO3pXp~XXnoc*Y#3a{Cu5+m zE!+e+-E)l;p=ei;Y=tvYOn5j00Zo1f!6g%I^oXmw|45dS!o9PpL8$eMiTP21_3eNaxK_Amh$*XzFzi`(R&Hf}sbf9t8KfcQ{@+_Y-*Uowhu8wRX+Y3$hE)=20Rx;2C(Zhq31Ik@K9YQ2qrbHjIwI&exe~o<1T)e`)cwL^KMQ=W1jwx z1OCmkg&`tE*ShoIV83-9<`MWy{WOMCt8C0K#a%JVxtIam2hwV)*}*WR6I*hIsQ(!K z0C*=C^VR}q-kBPXgby^8Y8!nkLI#jRE^kbCPj<^}j39@Nto-(%m@8h0h^NO@lV8|M zEV1t`UtD6adzLJ!fzY63LA_$rtgNI?-wE(E$~B3eHWg`P=s4-_i7!vpp7)u}+gZ_J zatt`@wxBnt%Ao{2LB_17uD+}`@_kk7Zq6ajruiv;iUKbY8SE`r8%jCmcKYOCJ_-3IQ+RBb=4&aJ; zFGXo))>Om5l&sJPXRYcv)ql~!9f<|W#VXI*8OrZ@3qBTGL%t1j4z4%jd#yAOQ-zsd!aH-%{sv7RvEPnMH(yPp zdEBfNYqi5`lv!&&u9xO|=4^2#M{{LXRCJZS7MEscSM9O4GYOczrf|-(PhVl%LTL8A z#{}yCoY1H$b=h{Oq#XwI@WTPZy$VuLVxVWt_PQ}+%r`kYUwyvDN0oMVvSfe5l@NPM zveVv8NhDBh)bKuKcJ`}}gLfO79vO!|8l_WLxDV;vEbCa8D8Bnk6JvBdy}DA3pUM?L zV&yNkY!*%|0KYkvqegT`PuP_2S@HxvK_t^L@Kvfk3-{lCnw*QSbITR$3*7%f08`Okx;nNcx5mhG*AoK9A#E-AK!nsui`ODO1R05zHPb0 zlYtOLC1 z&1y$DBKuNkGDYEOIP>8!c@kblr51&BW-k@@?rA>7#sE|8SZc(iMz&K@f2b~!ByJs z%L)uIq&Y}A0j;eHN3ft^T1Buc;W+$p7QbdUhS{G9N zeAg1SbT!F~e47^S)iUq@#*L}cH4^+`TwDUpUf$sI7V zS0oE{f>aNunx7QRYM$VGy1)U^{yfN9Q>4Xj8EnYo zz~AjI_X!~1)7bjO3p>&Iz{u~{pr6ay7w$w=C+^1*N_OPj&Jt}s_w+&*S>7=V(}m~>%SX8DBrojowr@@MRa2Q+w}V_6jjT$>J(CxG zmR96kU5aI?a&<&C>D|{(R_~+t7J`G4#DnJiXNUjn_sDLIH><@q60}TJn&r ziaHg9(;Ss(O$84%9d0cU)3)k~MDtO}*#C+h;8KCIYEkOpt$#ovy3;P+{rmgONlFGr zB>pbB9>7Ab8w9ZRM_> z1PwnyVQA*!9*8g9hmpe$mGC#2knG*!7eF=i}sU8jA{cUR{nIwBp#lcsMg7|_r+QIit@ksTC#$@X# zz>hEUeTIfq5em^RLXna2DdS%yi`FP_b#{w4&HKz%S~JR&&(e z*YFn)V^uvP%8jXOb2w?NKzq#|#dxb8+?Jyr44ee5u@2JaX==!p#`l77=vAYpI3)4Quy2A$v7N{Ia&+ zqU&w)*mDv(Pp{Xb2rFuQ?{8fwpZVk)V9Uhn;;r!|wcihVfWJ4pHTo3qjDWDQ0MV1! z;Bi7+YG!JtndimAS({l_NZIJA??w<|j)1bt9-l@u8efP$C-RME%c_X$@{38C8{ zbgE}vVqD!(f@UP2B(!DNiGPFPJ?@%@L2;gZeWzORJ?k%ljvYA<3-2uSTO2gG?EBfB zyH%xaYrDQlubj_7^JMha5pDIEsUB*NE^`B!umH|Awsf}ii z7w35uIO*;S`}$K>nrN@9^r9zy=vx$QH#~1_+`WB6(Gyl?gApOtLH=K>+00M#=VCG; zPgGN#M9HZE);&&dG|0@2j`t#&8uO|WTScbL z91gC!<$h~HhtCQdT))KJ9Q5H zam>`ZhUHCyaWSqe0OF;fZNZvS{N$QE#6l?0uS?hr)Pkjn*xfyNl(&Lnj;QK=T+{PP z@Q&TB;-M2!Qq z)2xgRw)bId7`K>Is69;i1y*>D&ID!lxHhKRj>r8cv#sro)Fe|Ya>6SmQREas9#e*f zK59fh7O=7vq?zaDBSAij9-SGM1Xo-mR$AI>-R7?~FY6E$Hbi3BP^`6cmiHh;F;nUw zaPCZ*+F8yrA6wSdC^dA_xS-he?eWnY?@*za{7*A(eF&G4dbz9~Q;RlgBMQ0KG>$9u zpPd)K{g~p?KFE1{?O>Y>I?p(yqf~k~&4C&2;fZb}2B{4H{6Z}|Q9%K`+&|krMzmrU zo5`;}EZO5kYCK9&=5wZicKb5P_Y?eVo&2{_>v!3`A13xRgy9N%LpS!sxh~)3FI(!V z(yiRZm9QH7E;GWir}r2DMLD zw#5aR;vj$zk2knUJ^_2g1eMvKje|P);u2__W@F<1!UBJWT4YzOZU&u*L2_n$?Tj$hMZ!SAk1~(7u{u)yKa0DD&CGOt!>a|K}j5%s#YJ~+u6#N$5J8P?=K$K z`7J~3-ED*9oD818*pP0W`;JG>mMkq*S!jdAerZj*Mo=_Wv`tv@5s4Z#l$W8(T-B)T z=bZk=+PUYR^)l2qofjw|-9L1+zOc~(v1B7c2iM0_1gQ#>&F>nMKqtZSMsnYsc@}xL2KeWO+HWORgx-)L&BCJi%kueHbJEi6r0l0o zcZk3=F+*jt^k~X6o@3m1v3S06_rF#Icl>`Y6kyOLFCjdd6p|ca>Gml#?}pro(zOH| z&+xcQblt@?8%j!Z@=c2f&)dbC4MwT!vzVBwJtU^BB}V;ntUSk{JJqRO=v}XW_TCbs zcxM;WTEXk4?zQ#OjmVw{NsiW15MJ7SA92F(5)`aC4e!j-0#xHjeNGrdAw zR?w8QZ&e1V-bbz~eW0Vy?`W=$J@zme5*)J}5BCHD7AzPuGOISMJ)LW9uX240(7x5% z`zMGISRGN96TL{o$?^DtvLeLzcUrq(-)OjeIv*RuQb;Y3UEfh|Iss!vqP46;XZ6D2 zhg1DCFLRBmsNeER_S_}>otH#Y(eKTlNwiv0jKp5b-#@^t?4%X>#9aP%PELNU=lc3i zaNIb`*h@f--y>aym@s$7&U)IGfU3(Zb4I<-vRIXl;b^)o)>lm?-$+-qG!vs|X-5@U zqEd0IHn-KN?L03x=w_*9uJ1|YwIA}Fh|tW*|F4}L|6iw11Znl}{bH|5G{^@J?cG<^^BO9$Z1k zHWTGR#sk)QYOTmFu(YSBuHZl_Qz;m`{d;9Cy5cStSYLYrHuY0i5~XIHECy-KmN^xk zKi$%}X2vu6)Z%#a!LDxwqfS?VR;J**jHz{S{7&G|If(X`nhqhP!IV za?cM({~V{&0n=YU+(w}-@@D$-9(V)opcf1(FW!13hxKUeSv}GNhhogtI9mnF-(|sv ze8yz04A(1X3i+5iCL#Kr6RCd1gANPRW9RbRn@JHhj?$LoIO%8~nEebeRp0IvB!*V| zu8q5q-5mLJ9I*oszzz*K=RL;3TC1qY{bnoDbv~;QZZ3x z%uCbvE<6tx*C6>ic;s(RN}T$~gA*KP{@Gv4)_)|TK|07#@um@$tu0G8o1)Fud-yNk z+TAM6oPG6l+yL`EeO00(efIYTo?%ie!@2NeG0)a}kW7L%%)~XQVATa3AQB533c?)_G~^rm79;g1*;Mw`4ZFn`k}^ z(hpcTP}_r-d^7O{T1GGmykq}bIFKo1lur(nUxln@&p5mb{F+=3R1F@_vXOKs{eeRx;p&j3p=Ft1A z*W4YG7@$hdoVo4V?F#6T8|Sdn>kHPw<;+k8Y5u7oyOlCCv_`A7i}S2MObNq8w1(S; z{EoCH@J=bne-C;cI930E)p9V5-csYEN`50t3}s`pQEZ2Tj#_#Eqx)ya_BWZ&_UZ_$ zCrH}_yHAruuZ)7Kzt)iFZi;5Vu;_XKZ+?Og4K+cf8K)J|<|yhX{^|C)L(D2~CU-n- zpJL*jEeo#hCa$}S=It0|ibXV#^CzjfX?f`eF)3nS{6oF#wTO&c1+<&7t3Haj0jR?; zNzzWl4o2)*0xDQ{wtfDQyla0au;Mf@9=irBYEsr3{g24#t^SXoXKZad)FBuWjjYNS zO1}WGkkx#A+4Wq=F$^Y65z?xBaYkGts`dKtDSY|mbt_S(IPrmGn5d9K;2`s$qq@f-SVJz(IiE1>lV;1N3c3=ui=B#_;iHkE?qsxB| zQO1q*onUZ0^13;%WFfegmG^unuWu? zvru54`92J=v(fXbyZEJ8XZN@nzX+IfU;dy;M=c1ZbT^N-@!a<@sBC`dNWjcaxNcVJ z=fkOE7x448RL$uM-s+~f19ic4>rMB74}TMBod3W_;SrXSRrnc*d!FRzZlx3rO|3|( z1Rz@Cpi@vxUSsyiEB=!^KvQQ@K^b?@iMmcI+oW@{nst2}4hfa+!MjveAV3ndMZkT@=}BNj0-_7(Lp?hndmn8&p9df%1Y0K*@$WC)bGu zbY;l~2ZNr~=lqiYp`m<~?%>#<>fMbxF^rlOksTV_a(_QEq=- zxi%aLaUT;8RD^*lIy2-g!I7WeVr84J8(z|~8uH@QTnDbqtqh18mwGoEf}e7vz1;S} z9+qh}~)aev9yjgW{yWxR$4DAiF~^F&mvD zx13V8LMUn(>4CkuJeCF~y;O7O)wG*`FHYe}@PkSp`JDt1?t_jGx8@^51G2f*Pd`c6 z^FnV}iRP|-ek~Zm4>sh3OJz1TnVcvU^8CR;OJrxhZgpy#wfUblZW<7Gu5FBb{U77` ze-?&_Ai(kgsTR)nT0lxsZOWBhoq1t;>)zVim`_ykZQ?xGA{+LNm7f<(r@{DqmnB#L zjL9;8V9e*eoBf(~t{US=!LM_(>zg_^TRX$5J_pFg^TvRNx1?+B#;y3yEU##YJ_pXn zbrv$q*^M<$Zu}=nL$Fiin-PKu?!jeO+x?Jb<)~Sq{l|F}-{R@ogNKlVbAkD3wzyru zF=;0kTPr(06_li#YK)xVLs8Y8d6tW^`RO(P2zl2qQnk|D(HH!uE~^m_>R3SQHqO`` z`uzmD|2%p|Yh?K<=u_ERav}wHMbHoq{#(wk6V5`-rs2doure?5?#TElzpoGA6FEs~ zkAJWE2^`Uhc}kTb z)M?!)Ai5GBcw4bv#dUJBWbPR~llZb;PZE*hQq1t=IS;>m|0WOLffD^$D>$Mym;RAD9t{AFCUer-N+uS|$|7tDGP1s_kGouch zy!RFS?L%QqqU^eV)Z-k3LjO*`zUkze1G|4Dw>R?RE21*U?7jqZ64kSgI(EZ()s3i` z5Ufy;4>1_Qh`4s>d&P2v(#D?esYE33p0st^v+17OYs1r+A5z_uXJBb2q{Mi@n_^i?5_t1@4MFE2gJ9;raAt%WX07ncdkMlPcZ5 zKPABAt7-SvC}cS%^YHNUrMclmE@sR}-x7ThI#^oc)v1zBd^PHCJZlDrWwES4x1Yz> znK`anyd4h3Q&bQXiSlq)!b?n9F*0LX^1-8T8P@Bv@^g2_umdC{Ik0`%)c`A^+Gww` zZ$Tdaf3Wx7;c&M9zUT--bfULOl!zcAdY$M*)I@YcM2jFKq75TN?*u^@gh--A?}liJ z5}hE*4AILB$z+%@&hx&%wbt3cz1G_MtZQHAT>IL`A9L|APrL8?+rRhcq=^ zh+gZZs8f(DlB_X8MN**&d!?wjt?dHa8;5Vp7wgD+47NugQmVG{VHM)WxakLTrBp-n zi)I^dm!MgiS+P>ToT5m*gK2(On$2u2+OxA^RY_3g3vU~lua@lPuv@Wp4HNNSU$G5wH%E+7EBJb8DL zsr}VQRq3Orws!uj=P8Xfd!NiO)m>DPAd39jg<5s{Wr=?vj;ta04?bLLg85p@ET7zM zA?%UR+(1`-0Vz-ody*k3p<=Ua%ubY=vaGvK`(C8j!p-SPd*8H4XoY_v2x!-QvGk3^wr1gwSANdTv{$9p5FHrN_ z<-LO&9PIwnY$11@%T$WHoPDJ;|$Hxzu)>})x=c>l|08)!)rnhjxw-GMo;R6OmO zZ|o286V!U5yLWFl#^LZYC7XXYm`x>Qd4b`uJCu7WXV3`g>R!aGMaa>j?oobwGp&H} zRkG=b?E0iD*H+JHiT;%weA_}A@&{+R^#|%+*k7s4Tm?fzK3e-I(uT+}+{;+1WC3FksN`UNeRvW%Y}lUPX4=1 z2Sosg&e`YxpJVXf(c7;6-a^nj#~sYrV!r(bY9Lqh>HEnR7vgvLSNYX4b}>ZTY6jXs zjbe-33`L<1&*EeRI`M{=o!~XkqpOqg-mB@(LDxLugX68NX006I-=HsOCQOa3qY^sXYN1Em!Bd@?DyoaurBxzLPyg_s-vxsJw3h~Lvn z{d3(HpO9?n+nV{Jy$3ozq$I zAHr`CCH#~h2TrFIf6dp~ z_x%b*(Yn?CIN+uQjYV;o66Cs>x_$uk{;`JsZBU*}#6?xx2~5b-wx*4anTY_~{#5 zZ-4}=1%pSxEt!TeXg+56_2l-iCMY@Njj+yrn(zo|0)i#8!2Q(?p!_aPYH zR7LN@$87h4t)jEZ!FHvT`g`}ngL970_2Zf+1z%n-LA@@`YSn%=CV6r&_63gJ!tcc$ z->Hf}6z5gV=dU&0eQPi#vwNWCErX(K@7tyOx%W1XjUuw7%zX8tH2jdqa`N|W+`u4*8)im13>)b!N$4_9JZ z=C!vOtvl`PU-kxS#QABiD_Gx_-);RJ{0VrKuc{7zy-K;A{m=l?J;6Bz);$UTUG^w$ z^Ya5|ouUu?(UCoE5M3F7z6I+3__%be${0_D7_*r)oa1b1GqTG#S*j}4jv97aq(*}T*at9z~n_HGiva#$koaUUY z$5*Q&JR6_GohM2;wKZtIx`oGGEw%*u+{jL+&eZ>zx%)SBU)nF&vn_tdbTqsifI2Fq zn$Fgi*?phW5e=v61@q7MeCjWgGGFFM67lZM)`!G?K%ySQfWmx173o&QDJZ5&si?gP ztM)J>^I9p1O7c)e|4Eph_YLCryvY9fV#JV#5AZ~f{y|jolC%u=pNHk5UmrM z>Vy)0KmjId2TrlVD-;nYlt6&q`IF>GD*K7}MQeNt0wo{u$N%$p!T)?Ei#$`^mA~gv6UThI8p3*fotDJev1c5@zhb@-?rZe`xqQHMdg44eT#6AFBR+2qShwFgK6$F&@ z`@fHo|8`u4UpcWH0=7J}my*aN4{O95`(m*fl4eS=55SYY>3~up0b`1wy)( z{Le?fo&Hkr)(AWW^2^zoXCrjB63%4hOeg$rp*0)Y<%PeDR3dhHP4Ce+ZK!N2X-MU#IufHsFBg^(@gm@WVD@i;8kj)G=fo-yHNleJu zED|s%2UiM#UEqK#%@hHcK*TH35&Yot*~|aGI`c;&J^8flfCm!~M1?PdMo7ZYM~EFn z8yLI41{T`~K;xQ|+efNoS6tP2T>=#PY&*t*HCA$^J!CRiQ9kC#<)&G(d~M@lQ!m6g zFh5Y+4_KM_f6S&3oM;aIrz2pFkE{QRA|;H8=b)=0RL3Kk@KJLJkpZ&iJ85zEuZ0!+ z4>{2N$4dC;V^X;b{1+|oHBi>k*}4BbHvj$DosHGmx&pNKnH&KYK2=ys&z$YEu;MI$ItvxgBJs1F!&!FaERXbm zdd4icYLArhTjU>!pudwL01Vpb3GpvzpR>=|82lfy2FOpx{)08$cGwGNC$ZraW;ZK%@s+}lZR6R$#~ZV?=z;TXg6UK!tB9I2 zD>GFuO!f=ye(t8shb))1gRT|ggk;!Ea3+~&)Myz$P7+b zH-%$F)f}`^ZmK*i6Sd$^VDOcyOJOX2|DSNGUacC2%snO1$f8a`FP4~o;*FP+X$)43 znP+zlK3`WGp}2UvFpYJVw@T)<;1e>9wDSfJ85=?`5Y?wh*I?oRIZY3Juts{fQdBMN zSGhbfY1LMtWd3nGK`)KHpH$GkFjb@f5c5N9#19z{ry)wCgQ1iG$$i{9S!kXKmabAe zXczsDch6^;^z6E;Q=4oVUt^?i@kjm(KHT4g2Y>(0b|}B=NYDo8;BUZgM5A?y0G^0QB<~SZ@^cg5e_ZJ_b0$(XBPKfLi z14+GXX^M7MFPyJY^QP#U-r|q``+Z#>9d~3DxcdjKwFWXEN-ebxf7t%etJb754IdFo z>bdqYOkXIB%6h$x3(L_7W-i~8WsD{XO~4F1+%LtsKbvo=99N|N!8Uirp(5f9zt1hc z3|7-ylx!H0UhnyI5?k`FCTeN6k;GKBE4fP;7MLL1B^Z1?K-N3!HHPHwjI<6eya{~! zTq>XShXDLFbA8)No0fK?5FCB$EyTp>T>;`AUs>DW9C6d^2`(m1FY~F z6@m%Xra4=zBifj8E1y52wu{?=zwFTwdLMrc zy3>k+GZ_?G$twsR7qF5XaOu52ovuRP1<>(y+>!1w)h+@fiA$QiGTKJ_{)d%EO8*)F zII!g-(LT|4-l~Y^Kqwa;BbfF!O=>tqvD|~(IQmy>`%}BNXyb`;C-=(h*ZCgAT^)7K zTZo}4@Dmz6pTUNy1i%F-FP`KQb+Pg~cpDULg#mO@b{;RYaeenO&MLOlhg1Bp&G^&B z!E5rol^#3$EcmQjnCB58=w#L{rp+Lu?2Cp^{hpu%;$*k1l2@tBo}4#kC2Q6Wb=v}C z%#pJ}gVrI}7a5tLHiB#%qu;e8zQY5=Yy0KYs>yeu0Yn)*Vi}si&s2w<>8`L=b<%17 z@j=G%Raq%LZ_Kii(x3aFU#iqHWQ$|=U~CyK0k29>+f3Plu)@^&b0y|#ZL9_A{`m6* zsM$(|e;=Xz#>8JP#)$JjDJAidCW5>g%ZSPilMW4|UV;tT_0&!;TL&V$5_H91CNTJW zs!fvpa{fzj&p!VjjKSzYPvoTx1%(~=$$xJQERUmOp3z?+bsJ3iDSId{m}W~{%6XV% zlX=I>GBw6ktm)xZg)hG={$nQdzuC=;?^~D*Xp(;sEG#KF_qSybl^1;wme`%9ygYht zOk3t=<~3>|C&8v+kW!pzVMYyWMBVaAbpe&Bwyc$nu%3d#%A^VTzZv#-nI;-(WL^<7 ztXu!Wo}O~E`V!M^Oh_U~;(xMUz^CK?;Jgsj8@Njql?}eEcl1Ua z+GAm#88INOVycy<0r3M59kdmy)5)Ry`Qgi_pxxhCS3~?yM*@U2NxD#v`PNuwtM2Ok z)UlnypTm+JZ8>K(t(E=yYMLU*EQmQx0h$D_?SC06MjFqFPC5lSC&$luMPMJzx#$1U zOGB+>NQ%#VX$mNbwOe_7uBOlQRa#FwPh>Yl$8 zE?bzle~KbK^^o4n(XiGLVxeJB+o+m=#>4ApcEvDIoWm$dJm61PkWu_fh;9#ZR@%K$ z%E&WwxMJ^$zPZ^($AKenHg#X9{p%w!5>3n|$1W5zXJsjHCAl=MXBi|b_9o3^3xO9a zCFVv)A$J}Fn2Y1ITOc$*QP*`#O@;Ur|6v(UzmX3LU;3Oi-sM^O$aHpLs7x^-J4f`s zCM$C^iXw8JQ#^1^u0!q=)TN;&FoF#|T=w7ju>2k7SiYGy=>61R#q39Ux_YWYI48e~ zkd|YK|D2wJU6af4I_`2~+!D?J|2Wg9Sol}LDd<`lZCg@r4ChyIUt=>5j15i&v%gFK z9_E2%jPAIA@7Kq8CuL5rDQC-O+z)=0INBBdNNd7@KOON?_wEUq1a7}TXh$4)_wjqxfx z*hq-$$n5@SW<0Vl^~)z)+3(N8Y*ikZ(n)DyYhP(*^X0KU(Nz$F9V4p|yLOeNap0K{ zSZ}m^#BV6;W>zM5at7Z}=ARikCG#Hd%Snq8YLw{?mkpSc%&o%ulpt~NdPcw*i-*={ z%Ks{^v#}dC$TO&9z*=bz`?8?1t`g=mMw&)6X7w=6 z_TwVPI(Xbe!4vgx`VkMuz~|PY^QO=816dw6I_C*8m_q!Usec51K_eBB?vhkQ6O8<> zKj1iP;~Q6AV%KAy7Txhqqij5Q5WGH1rrm!{{fQn$G_Moe9<OOBA!YuI?m=6+aY}z-Axf|+TpNiRDhF?{#yV%eB=z#C7*D?=nTfP_oNEhDu z5daw~$6RzuE&oP_lN+}GfIWY;82E?v+}i?Yeu}YNiiciI^EV(Sfybj%Zv*^smGNl! zjw>2E$EY>tnQk`zY~hyOrvHmd3Ce{cKCGG?*X7jP4L2F4=Zw^=xKA?SkLck&OEs>& zl5s))FIALLxrlT0Ivx@b-I`!SHGPSl`x`RV;>Bs-p7dJj^Al&jL25iMo`NU_@aPL) zID)?OFJUEWWesouu+gg!(&#DZcRzq-I&s^C$_{f#qLBo=IvY%7p)+OzC0iP+V(#?B zJ~&>?qjaL;@toAp=;V>>6RojQgejhTj+X0S&CnijBP15-hG5%<(vq&aJ#UeZd)DgO z8f74qdG%cq7rSUt8dr`^(mhwIeNW2J9ngRfpth@5R6Gx47W(16^SBrSnXmX*1o;;+qA#*2LZL;uHK^aqi4!>jz#z=&s&#!2KeGEj2-2r54pg<#1fH%T^$+7YTA zSk{)P3EVLCY2s}wNKIL|_hboUJXQLH$XyZ9`~9_W#ha>!(Y-TDwf_^QO>hpK4U*rfeY@G> zmG`a^xxyD>w^N>t(`41t$ljwN5)Am>mxZt1_|d9g)#!nd$KeBhA?69F%qld_wBTbv zklDBo_iXP6{q-o3gmb(Ir(sYKIdEew_@kwpFN#ur%GtJ!e>{r+!-u}sPpP9&)ubjI`u%xzXn!NT zp6WBTL{*7#^~bqK*M>eBO8mG#(XG@iAcz0+#<3G3M)!eh>7?546htn+Yatwww3 zS?el_tkRG{v_gmB4KvAbx2gy=waGdNomFGc{`tAsik(Y`^ce)BO3wkoFG99$kCXD< zcQj$&c_5fOF)O(^&2e;w=G!gqqMFnc_1EFld-O1MKt6*-NbHh5vLjTt>lBWS+{k|# zgYYZ3@Mzq~Br⁣M#-*{=F={nfo*S)J1A9dr=mYIqbLNYN02xU7`Ojv>I6_vfC9x zP=sB<*v-TfRoVJ`CD}GV*%`bQles@_wgat`DSehCsIRYetEbsaKH(yimIy4XSCs)} zq|YSbQ_$GmNjSzt%Ck6h5zgR;;JyLG_s_Zu3;2ixJo`+ z0ti?nX$C&JlSjI$Se zn6`X-yy|&|rDT4j82q4p|I-Ql;eg#<7qf;`m5sNvo+n+9C~9~&U`|#K73ie_)KyCwLT9weje^La4Vx06aY*}&XcaX>lnOr+HvDEjrwzqm6Z#L!% zWPsSyAKqH_sq<9>6o0t4&3re38&9M6z;f~VRKCVXGJo#8@$(Oq@;;t7cA&Z?cG$Ws zNi6a2h2S`_lBC%DFy(DVyeOhyG7`yx_>8Fg)BRc|Seh%NVn<@F5gE&8{A%O+b$-Rv z*A6Kb_!vdIdMd!tTP{M@BfITvZetZFK;iOn9p<*GYbxgs6!f?5W(HU)NNVjk6$PErf~8vIFn?ePMx3IYyZJ%m-{HuVKmv7eMn)Ca;H zlV>YQa!&8M>3i<_^}oS>h)`|P$^w@JgS(O8h{n=hC1DRtzTC82c*)g-7!49OJwlnyi*Y8GbB?$gf^BRfxb{7%R`$-vt`hF zUsJXGCRhp`!*a&~*39cS!uCs-$P%=;2$@YgMk{1E=l00A{XGJCsSYkUHbesxUC43LK53+T?akBpQam zJwI$Vn%Z(oq%G>JL~!>%ZQ_U#$FrOvvXRM|3f5EYVj=R2qFK z<1r|`AX|wLKTv2nq+`C^6#q!9TAPhv9|!4<6Or+p0C1Ze~I^L=kZ0H#7UDk#O=h}3h#|2SC8-RkQ=)c zoN3^F!^T#dHC!|-$!vr*Rk`5mx7n|k5}4HP;CZkfxS^9eq9Gc9SC6v4Io`n$;G~sC znMUhz)!C2Zye3VQ+oC(~HNG8t8IisBU&+5B9SjWb8lH5(?v`#V6>VEb*5Jd%a(|h` z2i~+L6GB;p2|3u3b&jx-?E^4;U3w&WU2K2f5_KIsy21Wh(}ZNbowRoSgg`ULQ(LPa zbw2t}?u?Uj_aq(cBIK6yL3|pPMSiwzA{nozjJ%s z^TdGu{>C}uN)R{_6w*zAhR(O_1CYzS8wg4gOTZ&ZR{S4ZOFpypV4Kfg#Rt1+X00y= zgQ8Sn8?2s+p5OT|xB@ajc~aOx3)3BkaS-f@D%GQPC$@s|6)H*wbS!EU3O#?M?o=b% z;s!c{-jGj_W})W*uCOQgr4s#RW5E542$k7HI5!xopZSgU^anLU=3;8#0tR`y?hr|LSE~VBQ$uU4H0JWqu0l}% z(~;AYrFO}7$Nr;A!ylnrW6pX*(&`P-TfV%B+ zWzU`^r4_54&b{=}=p&!)TilROKu+NWdVfy&^P(~q(o3rS!){tjxj%C@p)wnbBHxBu0>W~s>X|EE zam7xF!+6)t4<)`ueW*4fTrNf?Sh=uyUG?x5|1N=@G|wD5;rob0kM&B(=fifq;FuF> zeAyh1{-Kb#n)quO&@Utb>`3yj zm&?F2FJVnK_C?cG{ri`bXx1x2r9(e=fKQ5un&>4~cSvn=;xFiV5A--ktDzP1yOsuB z2Q8kRy>EiPDxGp6tKW8D1>mw%P)+jgY!%T0pAn1^TS<;w!T2Xj5ch>2SUEPFmA8;AHT1cKNf48K098*iUbL|n5)J2t`CtY~%=FThHKY4gbqi9SE6p~|l>#aK zUM}8>8CfbV#SV?y^KT+Lq17Y1QXIRUC}aYJtqqRibM2Y7V4ur>OJ6aR`sm{KTkKQC zHf}o(8gdVAtX>rfp9Rv+9Nyjl06L`|R(?*v3fAu0la%{PZpPBJo9>rjbphj)@hmwS z*_)hzpz%_HV$*LK({Go3_+iukC^L&)bLUH{iucbSt<7Jb^VEie_R|6p^Xa%mY}QUM z2BC*t50@d(>kv$)mnZC1wS@y-n$6n|S!vzWns3x7VCiwlS7OF_hW@n<>>#T@RAf&< zPR^yI6sQSR%oY&WEzOhP#|R^Ojn5MW>x`)^=llHljDnTF?!L=bvxyaWn61*pe&@Z& z(0p?C@q2;=3c8!&j*Z4yV;lSS={HBe<8!-|?BA-c1_WD2f@ z?Ako@6$gK=j8E;#ek?2SBqSp$v?xq@n1@7@OVGl1tA~IAEyE~tV$<7)dQ+c~w>#YS zT%7~Fv8g;T&asero%tSPMy$fIQy;hy+>4|oP3BYJ8J{0)^E;N%%8izI@dqyIlWTA}2dzi)KV!JH_Qysyhb;AqY^hqL_lvCkGK2H2=)zO_)IAENUdy_A~gc>JQD1X`nw!ABK_ZPoAc*$~h z>TSLhFQe2v*HP_3;_4fV4QOZbT<&K$(DCLPjsH{G*8#2@f6zku9#fqv;atO0^8GvI zR7zPvU#LS{eL!Pxz)BYBLR!2)G{SoGl?WjkO(0eqtL`TY=Tt8USYR>d{?Pu^MZR@fS z*yV+QB3XC>6Q`#wKS;*U8-I1be;=?{7iA^sLI+Mk44!QUZ{3BLLVD|mr-pl&^WM1} zxWr!8h4>iRlqw2){9FnrsZDi=0J32qBgyM=%-g{|CCg`|-8$jAux9;3{aqVN-+YFp ztV^tJv$DbR3bjWuM<=PbkWFN(f0xS zI7@u{yA23?#k=hNXR$zpWOgtNi;s@8%vhA`6ccym(ihjK(JNrs{;#LiAxZ#g^M57T z{V9~6G=;1~bXCm>b}yYndeBQ6&#X>YidQ^9>Uol8Sr+ww+SuyKe!7>YdI}1(qj+;M zrJHI5x%@sUHFOH7f0ty03%DHziW`B*&6>Yn2~#6Xo`O7*`G{O6xe%0B=s1bVj;HsJ zQbJ2eX6yqx4YWqdvT$*xZ;Uxt-XA#-O8*i#$r}jt3ZM<|QkMzf#Gum8?N`9j95n*{ zzK+d2V+c1j|2cyzhX=)p44-|%3`3BF0MhuMrCQg|Cy=|Cy=;#U0nt4iF5`o{AJ%N< zV;3s}7y-3hv%m6-byYo#@k`Z1P^JzlyqT)5Ce|gd%nG096`_AU z*_)T1uUoJ@^_-KN%pk_Q5N{}Ybe``b@s(8{UlI@9~7heM-x-lpL(8h#`=MQCWj!=q>?!#W*1o3S_pSl-d z6DoOv`6+0?q=WAi^xf$**3qW z?%n{~Z|}aQQ_McsQ?u+vJU#(Z7C)~-8BRevQYi3=SX|p59KIfF*tomvtorVQN6~`q zDag^F1p4R_kJudwtsGD6K!=&FVC|ZmEqR8@DAoA=CBTHw!BHGBGFFgBl94`cB<>l` zS2@$;YXd*;CiX$p+*9XnbiU5@SU%{!H1i!^-@jW1v`P%3PC?Nsny6fbHq)|zX*q;0 z^hQNt!#!PjGbLiLlz8wxhYlmICx$R}>;%s2q=h7kzdb&((%6fnwAP4XRmfykp2tlsKM&!WHVkPx3u02S~O?~gu1(kAK zBU8h7I_HO$c|}e%l(6#}#mnL9pr3wA^CP%n{Iz3n9NY~Y$?c4Y^!vkaJUeaU{za_y z!51$tEnmu`218D#h*4j);gdqRC*nMe8_$EPieCg*dovlXye>-Ds(XGwUY>fvTK^ko zQBOJfd&+pGW-f7v3M^#BdwvEdI3eRsy0Vo6O>U#P?o`t#A1qmTUC>33*stf!Jft^o z;ANV0!Z^89H+%91;=Vt-k0Mp9EOE@iq$Xu^jdw#R_i0iuDqZ_=>+8V0cOjTaESpHKUj#&TfG?@5o+@+WB|+v%)L~Dl!ucyY_b;U$q+MXc_pi{Nf`HKx z#&~6a*tWkk4hzIgf8E{n>{+xgd{muEO6EE=6RR_ec(mgi-$}K%J_C6K=m;>InhNO) zgwcz19xrW{ZitWbvaxO)_q@`YHuYHOJ8o$M`$k)$^J!b;^$!V-uT7{vFV)S&wRsb) zj8~TCZ2n|I8HvoQ?-Gl8F9{A2v<3=qZ3sPdQt3}Cd26IYeUbq9BfxFG=HaDDty{?B zC;3Qzk|4|kKVL{1c)kRZtqT9sL5s;q$gGm9dg=rxH}I{A<<@_gFPjV-54F990acAk z;)SDuOkYe(;|nLe)yDuG8%Zj+=UG|Lxi_Qy;ttRr=!cQcj;qz0f(e&}Jl6{g7-zT- z;R*cASjJ=^qk!n*XPdInnv&hVP+r*XEPm;5#;Jid>f7f|CUvS*`&2zW6RRZ9mK*uu z1P*!<*})t7jl_VZTVi`K26OsgKX5Ske$lEheBLNY_+;_A|4Xu$$%1s}sHq~WM)vU{ zM%nCz1do6TYam)q#S0A3J9#iYr&kLu5P$2``Qu3f**bN9b(Psk`6ftFUv-GEJB?Y1cQA@Qmj{?u8qk7cnF-IIN zd9zl(?XVepbIu_(tm8qp@o+*)S$=pn2H(95#)}zMcvO9nu(&r>Z!G%rh|JSyK)<4M#BNPn z`?rt(c%cYDKV_AX5A%}MDg=;1ym zX?6v*pCC?zo{ucB+B7wKToSI&^{!)Hd|dQC=pv^!7Zp#)_5ftLc!?hAx*vH8s&z>% z+)(44kWqy}=hcf&(iX2)&MBx9^!;hq^ud|%0LwRDzF1}obl|X!y~&iM!6mXfBr{Q~ zMM}45vmw)D{<_*&pg${38O1f9tOup4ug(oeQq2662U0N+K>HG8xe&~;dQzmOqJx7o z%6x4rqgOQ$w-BXGYj}Rtb-;QG>hD>wt0_BVaRMDfmqog~W1+0@>6LHaPWg6V-s>7u*gf{kXrxFMBkAWw#-Dd2=l_iwo>mx z4l~YyVKQO3dF5UL9p$Agq2$+bgs~-q(Ps^5I?+ZG&o=ahGUBK&MFt9C-p9QC4pJ4v zAXk?9;4uh{2{@jK^bJ~z!^6lOZ)BNFTHfsS{q$*gp=(QG1+pbT7+2#=?9~9N17S4YPV+0q<-P~=ro`F9XQT;aVZl%ewQF1aB!s*`Gx5T4N{#pV&c>qo1eG+2euv_F>P z79;lYeOQpF3|SVz)=i0OXOjRUDr2GDHdosql*gM&c|Gi=Gp}hx`X189~T1NeumUeBi_SZHsx>he%z{I(ej$OGddr8`@oEu=Zgge3(yzF zxq@6S$k$CmdhJK$(~W_5I*c@hjW7w=Hu_mqbJgCqFt&m1q76s$33{ z_?eLU0sFk9#AZb7hbU-)%Kc&UiwiuT`?+^X%Yq`m1aY-^o~pn6O6ocUbpT^;LyKBg zu)p_*L0{9#%k2%!(i$I~9AAlVUsAlZdlxDOJddFoCyWiVzcb@l?l#AsQS@k9DWjNM z^ZR`Fl?%Z+a#TP+<&~k2lwF?FP=mZ!7GqhrSQ!hmqho_1G@z&GV$g1 zJVz0g6^DC?w}hic3sz}B%b}M^!%HBb_wRdhPbe5$nOb58{qB}to>SeB=zYCD+w>(P zUCsFv#W(L0ypHZU?M2w@0ofA*sJC#|%pLWvwDEBC_0EhS{8EzegVSxBoGw>UzOko;Uh z4qWlvZ^@T%hVrPq(ypRKcbut+IKP{nfA!Y5aR9<TfZs3PUj!APv zHQD7CdqwK+1;k<3j=XlZ6wIuifI2~>sJ~mJ{TXygSSG3 z@MGBU^<{YUZ-c~(-TZtFN$WlG_Y+vqsCAk#>po%zrQX8!VK~Dm<9Au?pPex-2urA( z?UK*?#+X@NaGS)`8UfO=P`ol1b|X}P^o@OqUPccVOi-R%vI=O_%cr`b?I*MSUTd5( zO)<^laufra3>2m#9j@BlzXXOcVp=@UkMs+y4XoQYWTMOkS*T&J=)Q^!@%V^5prpuL ziQ8?nz;b-OmQ8!UiHcu__7it);+3<9Hax$6u|q{)j_voQ1ch*5MOLW7N$e%B>;*D| z6E)2qG-R@y^285@E4_;K>HICk(4Er<6qA;^mRRarmuRi6A1Sf7JA(7BexUw0Tg!~6>v z-;;QWPcd0NAufmA?6{z$h$<^Bk#wV}UBi;GmnkZ}c0)9)35WSXPB|m^7X%HwU?;z@ z@%fN+Gb}iU=EB2IV-=w@DpIK_m78c&Csv)ej6el}aD-)fSL_P58 zrE9@`mz4>wjjQCTHQ2QXY%T5`Ms_9|wqA{?ObrQCk!O~XwpqAse?G}_KI7RH!5`1M zWdGy=flJcjlKshh(sg_b^_1X*VD)&$E3)PR$=o!hJRFh4wvXVH%`@&^qG_|MbQ>}| z?!VoZ$-!qjedTh0OUVA`FqQ2dWZiyPo$7N!_ncgIIlg9HU)UR=RU`jo?#7mIO3hV4 zDrH|X3VwCAZ8awR3lx+WuLI@{@cL!zF0Hfo@KTxX&-I{|eDf@-ODDP%Kr}I%+?}Hv z9ZD5o`YgkC>%PyDtzKJQ7Kb1#GTFCyXT88g*<*HbWjmqBw$BgK&xOJc8;4!r%!uC-TDB-)qe)7RX`s`lRDF^Bz~9SaTaTI_j~Cf#dS_u+CM#3M2W< zxvAJg;iMbAv%wCOLMLKtRr5Ne=XhOjl-R|V!(Qw9Bm}F|zT(nalB&GU+bwO7J4^{1 z^#WtDY|*W3%&hAp-h!{K4_AK1qUYtuGy}fjVtZ9wlXsOxp~@Yfg$4QwXkfxVD-hNT zIeoL}go>8vjGX(_h347n+vy{>Jl2g9dyr1bkN!-zb1|t*CI7D%sopfn3Lat-P;tn1<)O3CJa4wQAGRB36 zQ7DZUN6ue`W}J0et^$M@AQs%V);e6;x~uP_9KT&mFO=9EW1a%K82|z)Tj*lEQ9=;9 zXJzy4_kV5Mi_Vk~?pJ(JI4XU8SgjnBE_C(y4c)R2vd%S0HrG!14&Go{cO<^;`P%ro z(p($JQ9$!!-@?0>_>qCpqPO_L7(o=|MSWFb3$lli%u3%C$Oco@T3RaG8YmzqpJ(Je z8LxFwn2ND|;#;oxrD`;2g-E`f$Ag-Q_=Tk05N#MSc&WpgAd@_``WSw(T0CU5@`%JOu=jya31=)lCeq=b5v4fT< z3t?8FeHu~{FnFiA-p20M;)rnX>bTy=O@UfgMZ+PDkI7k_ed>hhzI<8CQujWC9lokJ zE8e-(s?sVq(*@$k98_D4y65WwZ2q7K(Cmfs!0bwajChILoJ&MWbyjx^UBsI0$^2mT zqGp%&Zr3v7YjT^F{=_y{N0c#D0LQH+(g4;qd2hE4*OY)NAN`p7;V$n_TC&tH6e;SU znIxhoUJ@9ZFeD=iUxAr3XovLjgInvKx25g5+vD{G`UQ;2=cz+x9X=K$Bsiek<jM69A!-1?yH;Do7>NVSf^^XzX#la`q0&h0s z%)S_cp2Im*WU-8BqY4z9dgJB~cYeB#!|Sr!+wO1UP;brY*=n5I#@fPpKc+@;Ei=J* z(2#C#S|Tkz0yA%lm&YZFSk2VGF83<$ot)4Oap{GvWw`jGd~f=K7Ft1Jy3=shj%xu= zQIP1Y`_tsMi+1v$E*N{;h}8B6w6B?t>+J%8zl ziLtKH^N-cWyxq>$6H*RNP0tiMx8Dn=2q2(c6Mq>~2R3rMb~h~t=iNt72cUBfmZg~I zFpE#r!-h0=6;<}g`9;u-Abjw0bia30CYXxu_+h?CQ$F8aLk?+UI-C!sgmii3Dx29V zi%-UvE-Hv2(Hh_U7A8LKutdJqLi)IU~dh2>odK znfM(dTe|_wxU-gY^(?}O6 zxg}9@vdsxAcZ@sidwvX?`C~oQSR0r*5^}ZS>(XtJL-1%Av zGsG8pRsIbB5jesfICD{F<~)!8rPQ=9CbiOZ>q4tDg32C<`p}wBx2}wnnBDzBZ7(_3 zXWm7&TvZTzT(0z6t6;xso=0Fi}Jf=s!m-^i$= zd9t-DqGx+2zC!k1`aE8=HvB#xxw>8&MEt+ld-HH8|Mzcrq>??dZ$lIkvR9U&vZbO@ zLX1Sno-i0Qmh8(|LMYQhlI&adWf(glWUMozWEp1An9Pjs>+`*r=lMRz?~mVe-1i^% zecb&q2Zv)^*Y$p1=lfh==j$AyFXH-yS#H<{C^ z$7CL|yg(V_#&VLwLkCPhV_N=wb!iA4L)6Q924ZjDT4>zd8@9fnuDn5}sk zIk2N|9@tLcNy`lq~f;$a86|qu+G(kW( zVbM#AxTk3THQe{s=vmH^_YRj9EPu5<86c&JJDuTb^9?&cR!3Q&3?m+pj^P#Se=e&5DWJ4OY z{RVzP?H`rrU|HX9=tlAWTs*D`iJkVJt1qkQ)bD9FWp@{pK9WC_2~qrg*?F5Dc%-n&Nv}XxI3T?W9XsO#Mx9-qUlUA`yjsMWMTo=}%e?(eXAN=)?o9 zuX=7sZ*RhN`)iPyY4>SvO)CpFDf=$*$4%SkOiQpKuX)_?D;u+9@U_X>MJbq=hnD=r z_Z{z*Br6@OkVMFgb6N6=z$mH!Oe+p3qUCf?9b7Ts7K!xD{JV-iJ`yz-f zjPA7V`;KvVx@<|FBvzYUuf{T&T%+h@@ZpmJ>uYAtZu2YEC7T%T^_B)3vU&K~4E|AB zhk>1=&LQ%iEO~niFY`Lc?-PLtq78-rMm}Lt(@-a}8ZT=N9;if|@mwS)F*GSQ@t3{B zxYalBjWRXw$n99N8AEa%uziEEIeyg( zBX~bq=_yeDD!cvu>(-rzd|el2xD3<9mbe(bk&%(YY@766KSw-N6 zU{L)yvCYXPIJN6}5QO=dd76Ss+(oj&ptmVOt*TiMiVP7&uf&1Qw zTTUAm>-Tz9KBpn^@S}PF*HbM7e|aYt)Lhl&s%fMFGfVmB?Y#rCEcqGW$}8_ zN=+{afWEJztZWPkLut_zWbenQ8McdS2RA2(MO`_(^!^S$#!lx-JC;xzK&Q_eA7}U68=zib=ip)A+G5%43 z6<;xXf_myp`(w=8^_!kscr>58f4{HfXgc+STSL&|_kJl_-+19LuJu|z44G9b%j z5I8D;7`|xJ!FwEjBm2kJ4GJ%Hk6Mcho0EJyW#L+PuVCYN=?@>?A3@m8m#7I21=c}> zV+_qdpqLBfcNm}>wzjZ?2We__Y6?mXFDmWUFB%57?I@~S?>u(P6d-K>NI37{v=@E1GtSL6xpf3f8EoV;*IKQ zo~Cdo6%CVqym`;8ouBVHy_;@|Bf5W8xPu#&IX|-LXYh=+O6LKmKx%bj5&B>;ARUhe zbi*T{v%i_HLy`)n24e0bE9lq4;KQ@X$@+zs-HJAV#w88u#XDnpAlBT~#2bsm1L&G( z_H_PVTY8IC;myN~qNqoU2Hc-V0y1~YE1=AR_6CUS=k`Yxq(30}!7pCBmLzEP=I9K+ z=1$x24^->-6}NmWg3=Q%Zw!LVlk=$n9y`W?u1;H_1^`zcq+1@GF!wD2x(Y|O5yR9$ zcW(+6SqsUn>7PHS3wY_D`~tyo(2XVo1r)c%c=RD8UgKa2_{U$WC+Mbh-9EZJj&!6P zj;ihd;f>w;l~C|9u%hJyu!?pdojr5=K-9D=Z!Z?|uVoa|eSWF_vy8Zb8IgaM5k7+r zNMB7{YU4k^pvk4}9*g|{^s|ux+u;H9N*M)~&H=rPz@Gg!u%s3aYcJM@7wv_FRiPDp z)?-$o;WrJziS2c9ok09RuptS4RLP|TP2+>R!(gP%I@LM_`r&CV(50H%%CkB~hiVfg zHs?MG1InrZ%6}kCyMp*)z(?~S=U^uoKzRu$)P2=TyNvPwaNW=3+p!Q00juAxF9(Ra zsY^%McEOY!sw(qU+r_nQT28NlC};W58>m%Po`prip=4J1>?=>sfk5N3$3ddKf^JLb zjbJkC($PX3;Ohc#8=}37JN8|e(!9gTbC_c@Bj2jo4`V;xli)W7g+rh!2VbGpalrR* zIT)cl%fv@H7E_whvL3ismE5NqbHvBSD&7d-3kvvf#H9-xZ8*wc*Y7ITAq(w<_TBR~JV8J>J34L~iOEyoKnM z^lv>*W0i*)1c>_Dy|>Qn;@#fUo79TtHu-~rf)kXujZvMA!KJWHJh!mxTl*1domd&~ zr#aV#V$G6KN-}djopU{ntI#{ZxR9YkozVVe5p<%E2+V`%>Zcb86_ulG@B-lJcwk9pK3H8vX0U90KYQl`9NQUo4o50iZ9RO&{=W(M-LSy%D|4 zd}_Py#`LYZ%^k!9&-}upIj{5BCjpUb;JU>PAUx;3P2RE~^D^Yn?+rdLdwp$H6H6UB%}xsy6fvx{B!PYqO{Slhxo7=&f!N>s-H`i*!bqlw*$Ej^j5rBtL-9sX`4)#KDQyze=i8qrx@ zVI<^&^3=fqxTb=k!Q1s4A&%@a`U7HD&Hpfpq*hdiqohJj^7Y!aeF_`{nqTe`+Frbk z;9wjuD%v zO$<)*8m0@$uBT7Vxsw%`uPlDKc1%8RZWlhTc?&=CF&SFW71c0}2`vaoM2gYq| zQASjg3CF-i`>AX~5T+3F~{9s`72>qIU@b6~-N}_202-~8y6&nWVY1&X-g5p%6)f6iF z(;Q6|TYpkZtSm32?J;J^(aWBwOFTbw@?xa~zoA2$Fdjg(-_Jv*>~4wv10E9ygd`*t z32b|9-nEiE={+b>GwQGV0!1;2)L?Gb?N@Q=9#fjzJ{#^?p{Y( z&P|^^S>N30X4i44&Ep(gq51*_sI5C(F!cwym82H&o0fo+$y@Zv&H)(?4jfi zRE-9}S#93U!78k#54hl0TFIx;A2VtIBkC6R%xw7rjF_z(f3Ju6ZBFNnm(6;(0L!-g zMj2fKuE)&QoLzmYn;W>u|9C*QWzy%k35g?Zkn5)61Bnop_(2VAPBoTG3?^mbfE|tJ zjW5t|qP-#mDJTeU;OuH)(`Pd$joG@PIOVgtr((@S`%hI4HR_$&K~D?-8HNe!LuRBV zGSX4kH%UsJk^(R{Iaam!-5^-FKj1Gh=TywL$xVF8O{#U%)p-z$`v0xVj{FDcahtV3 zSbMK!1dEItK`3qx(Z?-Xy$+?h;OF# zkS#h3RLFbX>svglj`G#9TQ|px>OGs0wMU&?lFyP$M#f?Sh7uv=B8$6^@#m(fAM>`a`0wpu zJJVl~{N@fB-TfWJBmtJ|{2G#MZavM=J>>ZZB($~FVj)qIzn*loF;bcDm=L97??LCB z(SoZi2$a+hF=YfZ0lzH>E>3$u&OA8byWvZ-5wfynHUQ2m*lXL{xqk6`_uDUK#GlJ<5?XpMCX`7>FnUsAdqsa0Kv=VL-LN zxLcdnPM+>*Xyt2s%C$@ew_qJ`it|5;KZ1YaCn(jy+<|^q4RV5|R`F}#| z9*hNksM{lt#@C)-Zb)6}qg!jtHv4Wbqu?%rURka_I@*{0QH_9^8TkoWLhs)ILY3kU zbOJ)7nR2wqqr~FlJx|XQa`gdvp8VLV!Lw~HMn+skIevOxP9iYvTwvecef4R9_~R?Y z;Sayo>0caIJeuht_^zy|2;)~%%kN_SbNaqYCTA~ zuWwR$GI z{iGJpdiNHEzoIl#4S@4egMXAlZm_Bi7hCMAJS!$S`DO?g`u6Mtn~;-NT#b3kZ)~up zb)4IO4_rP1Lbj+7^(=jWM}gsiQ(6G{_Lh?UYQdHElON6;w$PgU($&GwdMqD>UlH7Y zvjn67-0(TknkW$M;`D}{cAQT=bADCC)u))b;(}GKoI5Fr(3jOpN;`mgHp;nSA`zo$= zuHnbx2Er-#{b!DoJEt(NT?#xF5>76Jw$GoSe6RpRN>D-wzFlB-YQH|~gM5y^XQ0?H zSIJBKSSMTBF=CX)R0Ko~a37>3D1xX?ReoBjy{&P1>nCuPW)CV~@Y7Y)$rF?CuKsj} z8Kw@A48Kjb2F|R3I1G(v))(UR+GHonYk##|QLk&qJ#=SuL=u%7E?|F+UQ1ks>*JXx zy7DCBV&OWjLit4=X+^DdwP*OxzfJbOW2FmHa*luq>~7KIq)H2~POL)2-dCqjduocw zn)C#`GLAE(ypCq84p#=bln-raRYAfF9OF7rpgVV3%6#Y2Vy;XdRIQME^ zrA!nd-T>}3H)68AvOTihqxDuVR?YJS&7z0OSIBxZO2#kqvgp~Qv%17VS+#F#YaM>U zpf~WiB{bv2Wg{}A+k!U{EUf1pNa4IQUYJUl5`EzQz&>6?q`)o1qws}^;hEe!BkCf4 zNGsr2MkBo--wKepUOV!_2e{vYa2#f8r%U=^RA*53#v0b>AQN;s21E|%)w3C>Zzr3h z1@Fxk73K7X&B9KKm=VY2uH5`7@Ys-p`4Tu725=2Jk7p?FMn4fbgQF8%u|<$%-?z6Cgx?+Eq}NY)tE;60+z za8nd}Lt`CU2e@G?!W)b9Be2XRaF4EPf>=&ZQoE3m+mY_mU7qaHrW#l)wV>7)+>N7i z7k&xqkN{I~m_n|=E=mD_(8I8|D;)^s&u8{-8?hINCp@%DP`HyVoOu29Q5P>jp*ILI zg@H%cZGVKM0X>ZeM_}s*LKVsI>;A-c;hxsAa`|Vo7+<-Mx6HLQYxJtKoZrloS>#kU z_;%SyYIIeG0rP_%m@Q7v#kKmY=2Xbj`KyaZ!_4JCkby-dEZrtl>hobiCy{Vjz-R)N z>45!WH7)v@Z$i7k7TgkFCK(ZJX<-wIJ8qNTcG~0nX4hvKG=t)-hyfCni~1O&?-gMiSYLnsz_)nuPG$@Qk^8TTFJj@N!M z_g6e^e$pjdR7RcOBEr@ja6BDj=w}W(9RVsmty#2`=f!5)C^jJO-Y@p0<3%3_(dRlv zLo9Qp<@A!j$M=a}KIhx|^S@DqDg6EU9}nJ;2hzcJ&WY30~uHb5xa0URhGMTvRl#j zNTK)3q_C$0Z!Soug`S8|4;Z6C@HPxt@QD@=iu$_s!Iwsg@61h^ZM}->p2zvxz6%SF zmV{rGrnK_Feyu66V%yF#q#9`1`v!II*LG;W-3yXgMH6)ZI7HK)EGeaKZI;Y-#*C4~ znzgj}Y(cw*VvJ4(pnUr~JvvQK%(CmhbmRp-e{{1h7av)rwzE(xRzh zl>6Z0foLg)5@mUTC$oE&rAyB2dv7V(K8CxLCtYf;^?AX3@ZvvyEi*QhuRg>~!B z5!)>|87q)UO8+`gc@#@H95>K zT>6YV6*Cak0b&j4kZ@sz8jTrBnUWXr1vN{|%Qq?DBt6@3!hEo`86hJiqP{*=;kR|{ zxVBoc+{^8d^H}@f>|VMqY!CyWYh7s%DEup3KtJ}8!3778{)gOyrkdz=?v`rs{N~SN z_dkFnzlrQOC4wltQII3v2(EkURKK3Yq=nocD3?-ED}Ej4O&8bS37OyLEaxSe znO2q6lU^!MTdxW>*0!YJ^W;U%Wz(-#S8g7YsQHF zDpUsf9t!06Vbo6HQ#UI8c!+&i{fjGY6x7Ba(3^G0kS3P9`DWQf+&pD~u8b&!ey*j6%MbhjZY321AStB>Ch~SBwa4Qs;usD92dOz8DFAc*uANiD zF{fxAD=DRY<`^+L(#X-cm<(7lioZ`obJ4Y#69qkP~uOMV)JxEIx!$spB4A099C z5jvhcn<|);cF|^1(~2z(=8Lza3ipf(GHetdU zrP1pa@&FkP+&hUVxKQ#V$w7KGh-zS>SM8$Wy3_D-C4jq`4cdz3t+;8(ccZ$6|5Qtq zqOF;!(q|>U&e`;rQZql=RDpUuh63$udS6-&3nNgXZ%r-l_^K3s@!pirZC%a_c4|vJ zc5@Z_7?94JOV=&Yv^7^Z`Q!xz*NpWca*eQ)F>?PUIF=xopSxpRI;+yJUlzHtKu zS{$J_zRs7_YJbH?Q;Wh6I19b+7V%g&>>LI02L+Pa|*T;UG%2xL%2|slI zkj0I=wppU7hT>^NR>5FPngiuv;Q&YY10r5_!1bKbtVHq7U|m!030=qcJ7AKRh-KlY zj6pybGwPWtg=g`FS9b2kdK2 zz9}aNmfh_FgG5z+yRI*`y7s*uo~{znyx)yPWy1~`A=G(g^`^f7@4N3nnfbCMoWPry z0~UN*kpP)#=n(nzFw-=4_g;sIw(-g}YpKRSV`uYUT8p&H_!VrKcy~}JINZKW#G}fi()PDFVpL3I29@8kHP`>r>pqZM zr5xzPrKS!ve+oH@RZ?!#e7Y3v>_1EaOqIdOA&-f}N#8HAYE&%f=9ebCi zdr9ge+`2bthfjG$ODCq3iNoN9)0!Y|O~3u#eOP?P-?%H zX}TP;_*e)Zm-TA3Mg}z;+xw;h3MA@xP)UeGXaW*55ADd_M^kkl4~mJS1i9d02mP=kj+^`%SD~>*l4m zOGg$$=3w)QRKP3gJvhdYhej80Wp))?UmV&K&RiIM`hb2R!@M5V*mliHL3DmN?8rZ~ z;Q#f*_UG%r!3O@0fcQHGtbahK+v#_}4VfCJ3?(B*E)$0($0m`%jiXpiC(9V}wES-`oG){J*02 z_u&0iJAak#Uz79K;QnBZkLpBL3w!)iG2tZ6XZasJX=_1orctIu~I z_#O|r3QVv%A@yq+U5_Sl2hcAtL>nnmv3v~kPeYodv*)?0hEHnfZ4_@EOSb^QR4b*} zD*$4FM^KtP3Eh)&6r8eHJ5}`I@qSKXM|&)W#FL??5z();{0Fpd(Z}=siYxAOzg(6T z@0Tr>gKv-;)9$jX%HS>oNM+vs6CnkaXZ8G7yO+*-2BqaVhn`uc+l&YP&Y~a>*pX!2 z?ggnP-)H^)OQcYRnX(24%V5{D5q_Cv7z>5FdOjutE4udLp1|9T3PXI6(r^5STFG@W z`H6*UJY2d)T(tat3qjZXum(;S-Ts|lOg1CPM-*3T2rkvUX$%su>923EZI9Oqx4An3 zS0k@mP34$2X5c)$-*df5&oz?1;@|8cZDWI9OfE2(3Y_JZx1v4xxEeGaAos-K>hDy+ zhfyT@n~t>*0hV2T!If-;)|3iD4oVZVY0|hp+3r+j~gRJjg!xIYS6wda7o|spl55P1cKxiw_68K?#|5O(acH(k~ z*BfzGgQ7rGo>gyFiaH0e+$2Qao$brN_U@nk`|lr{7&r?JUBS3RZYwE+#V+6$!X&TS zYYKh%8GD|X(&@Ch5_Eq#itUEdy!(GbmF)~eu}#=~3a1I#wLP|c>nZaF`>C8eD2Q9V z%ltFs^W<0WF1XSMBYG1-yB8IJzKQ?LM=1>$XQZcyzYaD25=h^wUWCQoK9Y>Eu2?bq zlxXur>CE=};PDGot?4iSwJkw#XEXwF5>`qujc=2OsRuUz6Or)}pS87x?~XO(6?on^ zIhCecU89=)>70O@UMMJ>+ZcxI0>|fJ+XFtq*m`A7Qi5Za`Q}ABteaxmC+Dy&PUmF( zSXreh&F*^~|NavG6H^%ePP+J?n_}_b_wUcxzZ;K=5miANz=GII%7?@-Zu|Jw4yGlN zbEWj+cdO1sj8#`+BIfwEzE1xnA&`GS!&D@CpY;pd7chAc378F_6w^ODdZ2I0P+#mo z?+YQfgLeLaP&QwHUVC6t^1puxss_7_u~Uuw1M>Q&r;7Isc=+C)&H9Ko%V?jdc2iH_A_u-Tj6a|+QV|eCg}3R^f-*OF&g=WRvwnvrLhIu?c%K!><#gs9 zG^o6*nEeB)s2V2~wr2A~VlRt)D_bjS;92r_Sm{LJRPn;Sv%1Zw*gV4GeGeM_&dagZ z);aOdM+j+$tP6C?ZS>sqY*wHfcPN+fiyNz@v{}3=>Bt$_WrFcSo&L~%!TW;ycF*{a z95W_zztJh}EBYqOHdmPOjsDPRlXeiG0p} z+FbC7VEkfTf7c(7zaPsKO`f=!PK2f$oTtfn7Fl|SFTuRu=aJ^G-A;OL$|3;Ai2oif z3j2l7T?LqGUg43m?Zluga&%V$0mZ+5cmJuscu7l&!GJ;;31_}pcvySk(nrN3FUYXY zeM>x=Bu?oyq0x&EM$xr7vAf}rTNLG&(|OwO-soYcd}SO>Q>NU4=vANl6^yJn((oTx zx9pUFmPNpB6{qa_Cf)_t>~-*R&jRQD$Z%y_OW1EQecfwy=Vdmv)ob{Vo#*<@^#8CH%1>L7hS_&3>5hIfpDkkDuJ=6m`R$+PXX9_k z+xDR%tHJ;nvU8bPe%64+ldA23-{SjeS>mA+U(H$E^bD;nD$&cO*zeK|GQl;sH%skfDu!V6ZV?=xM!Q z&+CoF+XlAU*k%|N&(uZ;#vgMKec{G-_|*)Hfr-|RN&{grX+fIVh&T%HceE&ZM0*)atpP zxcW%B;fH=scY_ngmuh0WbmU67YrSMJR%|$(g_GoioDVbZ`vt|G9>@|Qq57R`>BD|D zfKL*{B`2yyBXB81gW`w4FTBap! z9bYj+q@JPqQ#iZ*TGv%$pNiWnDa6~Vs-CV&?&YXDtk6}X;e@{ve;9AcQo?x5Y-WJ* zwgzWm?f8b+Vsh&zs8{pY#-@smU0wCf)|r`C1eJu04nwKgNXkv90NjcBF;G&L z3Y}>8@QoI_+cfp+Y<-4<5u)5q$+l2?N9K56^P|KkALgo)b#6vGph&88V-zqVdV=TAUaEwbmrScXYJo3B=ln`=zZcs3xwcC3IG&p@Wv{+l*m)}+GA&3N~f*Pp042hvM?yv zf6x#82q5e@-&K;bX+nT)fFQ<77^FKY)y`&vLY^WL?$*4wES-2JJ2sCGZRUe&0 zWNn3KQN)3{pLGZcm@_oFX$u48#leJz-({~@gf$jUE!AdC-YH7If2CXpH2cu`Sw??^ zKbRrF))NXvL7eTO;y+{BYZyk_DM;Uz2i-R3vsihCSV5)_4Z($C;kY#+q3NGRLO`o?@b564g%s8M4`H`KE{)%eX%a#O#`?%93w5W6oAj8V)m(JuGvu?l>|VQI*6 z_6STk-ryo)ww*%{;4R{`AT(e`xffxVX1#DJ*(v_#C!f;q(no%n%f$?Z{H{MLX?~{B z6~q*O+oI8p2DgCU!MVXwwJ?=BrVwF7yxAg5fiajL`}s_I_i;B*Ht%Y;%?CAm{!oQ= zhy=u^m91lf;qmoA3TDU4VQY_cpL74Q_H$j)sq{y4uFmf>{ZbZ(@?&`(_Vs`eX=6}* zAm}rPaX@Bp5ZsNe;W*6J`=l*t(hg(FSHYg=QZCT^+9a{maG?i2jR&VL0wk*mExNvx zP}LZmzye20yklpfX`uT0__yN@E(y6q*GN(vB`LPFE-qmD3$!M#cOFbdkq;k0#xzTI z!(gf_mFMOzHfKlPOmaQnxd1t{*`^G?Qcb%~UhD?*Yl)kyl3E~-r*9dZI1(agCT`sT z+F57=NUfvnfE&+L|E@}=B{Bf}f8Trkr|+!J!F5l94KwN#;0k0%T$E?NX*p5EI8?=a zt9Mfpn}eTi=Cs&Lmib(u)#?Hek|IMI`brC4^|R(td3n}xZiy|B_pokxB)Wki!C}7u*t;6FZ)(`Op~;|=dOoxk+6pS zY-Re*v^1ehrb8E4_52;=No@3+u!^P84H6mX_?JpAoN;ud6|9&gZsM+Fy-ihGSM+h+ z7IZP0<8nIvLsl|iD#GA<1rPJh!TD8a%_znKonSYTt6kyk?SG^(YJO7gbP$dQn#}WS zJGdK+W1~)zePdJ*8vW!6!)mf}V%b-m!{pb|@?Pg2%UnfMJ`?;6+e;dzBg?wiA0g%s zbjI6GAu!tLXToqqgGInSHoPs4S#-PPt2FZ6R{@;MyT zk1>U-TI;`lD*GiazoWeQ&itWIetL(i-hQ;9(1-uYFo6*SNLh8ACuF`ROQ2#>jezy-I2Ad5_6DZZ1|03V$AX zU#Axpbs$P(k_Ye5ya`bc5_j2^%dH++*ni%3e#L?fDiD35qv+Z<;z`--Lh>QKRKKxr z>ah~g96HHx_QNAnnMd+0jClg`?fWKM(pECLO47__jO5aI7p^iaXx@RILtCFGbYWM=<%ur+t_?s`>5IhqQWGjOI{?GBhsoUqm>tgbhq!OA z39HFQ6&8(V9oA<40l_fe*F?>Roe_ft=YheGT9g`F+HK0tWg&z#C|JgX~q_Y zVAcb`!RKq?T?!8)oN7-Be4*K>@+;0g+u114<~jd4-fPp-1|qa#@*Xxc$wNwL+WVJM zPfC&66?3OVL*+h+L#8WGWyJ8oaWXogZZv+IhBB_6bmVrO(r&I#d=k-87^AM(p-+K2 zx-T3fL*sk^q-xwQH{@EQtAcA?U7{!lKf6QU>!q~w)YrU{Ul`$p(nZ|uZeR|vFRQz@ z*e!?4zTZta28`S6hd5jQfIj(fK14sb4~~&-&Zm~sa=4#$Ev)Xk-&MRC%Py1nMM+Rk zLP8?A6d?-q4);@hnGx>}?mmYJAKcg@2s5Kd?Yj=IKp|WPVzefthj_{5o*TiDW0$;B-r1!fq-#q@IsQiAx z?@ILM%8->MSS9RH#6_^c{J|KS{V9}YSy#b718cCXaq+%{*+}rM){!w!ykhJukUlE1 zW}Te)(?Fkbe;R&)+>CM5rG*49BfTAQBxQax$Az#E*Kc=?{R*@0FS$saNo}X=gE#E4 z6iAOHpQQZY=@nA&Ov85b@BESZc1NXn<~xS|L%d~V0Xum!2j;u65om9Qf0dXihq&_H zUTHf$8hgG66xM)o&?P>{?Y7C$vOdoA7A@FUC7IPH`_!e+*s_{`kGiVz-S<_jVpIwveFUu~D!(m7r(?rxYyf8laWEVT?)GSLAT)01SqQei zO07$|d%{Cp{D)2W_j}1b&1{>yqjg~`DZeUG3}h%6nk;^@nihPmCf>rJF~r9;^{!L# zU<22uWSbx!Q#mO^2@0ed&kCiz5rte_FE`s%-=;FSPrvy{t0Aml(Z}KEqq`drLSNu& zTweXf=cBe)XMKL`_AfmK9>BUW9EnC%&%FJVQ>H`> zC5NMX_+Ihe9p9Zb8G2%-hw&E6XlzS`Ff^6<&Qwla@k~c) z#J9dvPcGczh7XrBS_^=1}xzj!!wlg%ad*rc=5kl|U|NIl*cNa0e(6&i|I!8;F7B&j|V zy14vvS%s_S(m11di_(oNv z2T6*P6uO5w=b87o*AwX~6WLwT3Qw2rU64McF`4t#itPcxmk<7gCRHzWmQv8}O$2jW z_@zPL;)-gg2)Fl$Y9IO}kKWLTtaXzeYtw<-f7RoA`?Cod_4@kdZpX8^-n)|zp0m{d zxbOc`*zom|Yc=e!6k89prZxX(jeJU*;`(TzhKZb53|aeo5_kTy;zz*`MYj@rY5lIg zOX&FaqrV(3J<2p0e|fKU=mqXm)G1quR2JDEkWKVe&;;c;v)DkAW^s#xctdai*eQKI z&*lid#huTaRzQYv#(#xmAe_t3b?`-9_pqSoVoT{ts03#QFJA};Fg9S@8A}Y68O>=J zq`wSbMAWUL#eF|*v>m54c{YxVe4kKpHVFc042oX4xn>&=3O?QD;qHQ_O3fUfGk)#c zsm0xW01^Z36Y)RFkid{FjNOk?`Bz~dR&eRjQ3^@ijRB**kKd(8P#k%KE z>GH1l)wFOmM?t#moKdhZaj?{lAEVtM`>v$V)yYhHpU;qonOevP*QLl6+17goE60dF zjF&#o+PEwDt{9BFOmN8w9j-3|V_vD9a;d8+==V{Xzb)bcpIp)ojnZX&Ufx-(jxB7` z>GBov^!`x#Opk^)Z$DBWmX41`Ogfje^J0qKT>!kEaqs{MC(%J%mq=a6V1w z8ZmlSPXMW?GNakVWx$I-`h{uTVJX?=|JafaqY-& zP+W^v+a(WVJodHUsg!wmANDW$x)c7^R=X)yS(S2pMwD7!O!fK$%7~%F9bn3Ut89K? zT~p$>ZcoXr zBt1*T)UVfvtW0>s+*r#sdVNG*k~;VDXg1WSFE}DKf~@u9&8K-<5aw33r@}9B({ue? zYZHFeT;6l8#y*I335ublYGNG(cVcNF2b{{dOwhFyai?yZXo;(NO4gWUR9+vucqW5I z+J$&H$~u|YM#B&mC07eL#}Nt)^%lBG&f_Xa#~Dqv%EqeaN?Tk5F4sRIybB+g`)Kd< z(K4x_$?m~!J?(;ehEM&_t6=Z8;kF^vuc*rnqP>p4)hczapKi zbUmzI=@lJ;S`_poj-pzSU4Kg;A)jgl*a^nFe>l1Pd>e2(pW-!d?+5Qz<)%6$XV zy*I6O;igyFryuDvr~%A_4ymG45NnAv2kl*FX!v*32DHLF=u;_7+KxKr=6a5Ed zRFTrg&d_S5@H7xrx!P0e&|HuAZxktJk9U19>mJLnzGlt$dD*D|rnIKcvdbIQ5>7eM zg*@?wW=?TjTf7@?H(&mA_{OnkryU})Q?FQk}Iex(jBfxk_tVn26qcwz+lko4~ zdE+O?yl?&3-Y3XtR{S_N>t%+I1!aY?zI@f;B>J=>GE_T>!im|RYNoK>o(gQ$YoP|= zQpSg6vG4c~+nl2A?0A>7+{A-pkRtH#54fNEt($Q~6!!=7O)rbt8v^i;=B|f~??*P8 zotO5(9sym)>w6M|quWl}(&T$*D39VkVuyLEJRdpJH9nlTWz-#eorM(#+>Uu;BZ_!0 z5H@I`67Mdu8Hmo$y60S)AT`*Q1_@OfgR3JH|*$>+4oALc~X0(T&65D z+D?xVE!YsOO=m%G3ts-_r6mWiG&4)I%5=iu*d21J;NAFq`zI~5&(m>hK4zbNCp0LsopT;#+G&X5I zcgHqU_Ci=@%kJjqZI4ud{DAv!&*a3e3wd!Njf^oTaf7hc6pe zeNA()pVbvl5?WAaNg>;O+2R@Olhw0gtA$GegX$O_c$RRYH@`QhZCx&DCGnn1jH?0* zKejrcD*9oGcw}c!=+VZHm^bcCu?>DHa(3Y@4LqfT>RmB?Qp-F_y%4v>guKP>_G1on zv#;23l;Veh@8nrh>8(a@3JrAH~E(vc=00g)yjU783H=|n)9L{GLZJ zB}BO5)M-)L@L(pfJ{nNOX0lq8jNZ5!JIv)bC@bu_qICa&=gDMd@%fS*FBRsJ_QDJz3ihs(q5_XZnn6>^0(LwX1WnoNbG_F`V~5zN<&_>CX5ulx#@#V zdwTb2=SBS(()&)qvp?P6_w!F9L;M|%w}Ft?|au8C9B=-s!Ozg0caK0 z;trRwQA~eLNi;P!?w2u0d)7<0cC?JCXfPyg&N)DiMVC4?IyVwz*hoKr%AR4$m4Sg! zwCs1_93-GoP3k1BlWe7k)!Cw+*>NBut7_Q5ZY4|hT_uFnWiefd6yhMg|w1a$P zT)a9Lw`b4CN-AICR9j#C!jNor+0HHNW53!cdCKXeOqYGZaIG!UwU!u=%vO_ne68Lr zglFJtY}y+Up)-tQ^l5K}Dlps%Rh7R%u)xF+m!bE4q2IDDWtAaPLz9}W%mbmW6w43v zg1h5X9Z=PhaQHxvt1TcfU1?Zcaw)1mXalv_5hwq45Us7FW8q(`0E=%6HDzv&{)A9novMMS`>( zmu;W(x{46cfQ~!5cZ9b4X6>w8U7`BqebZ{sScopUV)yo&8vYx0T^{~`erX*Blva=* zN0ld6&tLbA7}+>jpk=0K-U!yy0h9o78GUWCB%rXGBl8M1JF-em_X833GLD#f=8$r5 zhCqrfd-r#(`hJqKKoFXA-zzSFz_eG>q)Q5^J1RJW^&&Mtx3f`Xv;Y}H^rIfFHhrBlJIrCUpoo6nOayt_a;$(Z$o+4pk&j^I~e7tav$iO%;8h%RLcSfBqb>A za19mlLy;X3JaL#A$BS@oj~f(F8c^_`E>dvUroY}_m*(Jmg^ulK^`*baBB;u54GVow zi6OnG)x(1-C$1O7<$45V>5sh(DpDeQm%(jGZba8~thpz}z*Bz6kMZ`z5Oa#-wIN{z zrC5pn@i;vZU2aI-UXPdrI7y&gYU+n9jNtuYa=p$;(noMZXHZ=Tmye?l#;KY+bKSdjap+*OvGm;Bzh2A^fiNLRzLmEY4BV=g&GJ=aW@ zCFPM(_zE2$?l`3y>>I8Zn*W9Sg)Vj&e6laQ_5BwoDZ%y)!ilW?(m8!f;oU5IW@}Bq zS}pI7t{RH$*3ciM%Kyu<`#&L1{;S8|lm7;(0@Ry-wT}U}T`%GU^bfd|d_w{Rz^$a{ z*xtJvI($3rKyc-+Q>ryQSv3{P!&$9#d*kCwUw8O{U!kbU0QeLk%>H%YMcM=qQJNBD zs0o3T9H2V9rb_dJ;a}h_vjDo~Ux6b32CnopD`r3d+=?|078mvER-(rfBjUR8%O_8A zv+WW@`PVrE0M!Izn5@h6;;$9@mpqB0DNZqK zl+5Z%Wa>7D-MI5J2ku9iH@vrPCIBbBr+@VPtlg-EifQ#2c6R#_0fQY&u0}lY4SGv{ z-TTw$P87_wMWeAo1SDkLP+G0mx6?udm~*rB#Ayhpsg{6;jR8P6@`}Gf*E)@N06?kC z#<5QEQTS)bj~ls%#Sf8mWP{FPK^-?D+N`xb{}!>%p0Lubnypq2yZa`^%`n-;l=)J# z(O#PMnqoASi-_sPFj71}qZLbMS~(uQe5_>3voN5*pWC|O^K_9WL%sRU&(Te?R{1t_+J+lW$Xx+_zKHWP#mMe?V%;19&O!K{Bg008^Al@TROE{O+ z1?7je%hvcSFnqC@=$G>{6pz_Rw`@r{dbSq6Jw;etYc244iw|u0rh8_BxJ8b^Z{1K{_^vC}b{v}bE*KK5 z{)UMj+kkJ<4tk^~?s*)tSWNGSItB5maH}KJ^<)_!w>qhxrcKP!`={2Kujq8|LDe}E z9ahCL-l&Y?I`elB(q?k~=By4xNhntiq0}pX-@EXOvd*nA%_E2J`MsG5Knr4+aB!3C zi$75LrnPP+IubBcdp3RM27LbWv`Vuq?TxgITyNZJcQHFcnH*JyxI|bFI8smcX$VS9 zM@zk-S=rCEvwv_(j%;;G2E+lVavYF?h`E@ZE5~-Ol#pbDZ$OD3nfqFgX(`k8v-`2( z0F+d^tfTlttfqMCI{_69sYL_%*gU4Y;B^8LT)#B4&~~soBK;$K_`H{)VfszYHSwDa zzx3&WCg7Jgd^AKI40jM4>6jcv%-v~JJAxn)V|hI5xDj|;f!)8}-?W)>r0SZAO=g=u8!sN!@YC%=RLw7TC? zp#N@|b=WIYgr245;#TJ)|kE-{HA1*CBTUblH*}*$HKAy#RoJ=86vC7!g6KP(U!z&et=#0diLweHk-U)9S02y2IMFw$$69cV>ZA(S6QB?4piFa1=HdP? z=AGwy)dua0Om@l~bIZ@ac{yO^7;j~%2W*U=q`m@nIuan0spF29;U-7S+P5|2c;f{X zyws_-+CN;12nm!2#Ngb6q2b z;5cJ+HQ^VxI-p$Jy34O_z0YQ87%VBIIi;#d;Zd;QJQHgdW265|AA>DiWvMq8kK*hr zQYMC{RM`u}cB;wSd6}z<(=GJf+3!hdvT}0LlVRpy2K{PT`{8bWQh@P5i}($V;4qw) zxCdc>UGOluIm);~W&+6tH1*r}+t8f25yO)a8YOTU2clP7&}kBXZrZE^PafmB2FKx= zfq}F6ue3H(zC6P2#mJBi-3SIQAG9{(D?^7}WQN+#mgli*PHH|~P@8)#0{+|@v?g!c*^q$-Cn|&mFWjCB<{fsNekdr&V!#&l9dr;lSBA?4 zfT5V%h0YxX`*$0OK{M_(Ds1!_&}HK@zd`qLLY<@$c>ne?esOGpv$q)3&ij3$R;IIs z-QdRT*`Tgv-vbdbb@ee3JngyuGVt}2CljGAGTjA+c`+pmhHkIEjRQhjt$zz?rC%+* zm;nOac1WdMJb8{315)+b4BfMEvmNudWYX4-L&ZH}ot1Kqsgr!JD1&>@-r&n&V2Ux``!`-x3F~F<%4yV2X$Y3mwB5}_WLrFbdMZXNkIT%hR-mjo|Dg~!Q7M2hnT9tAf09| z<)!k)&h&b$6bRJn;Z>)rj$k4r_7_6Bj{;No6f5zBl?o1Yz+I0N^#@pcJ13pFmkiAA zS+vEZcu+930ZIXC2IymYbIMFR`$-9A<8%=n;E5rezd>iA@1VT8;IFHIY-2S-i^w)B za1?;KJ5xL8H}L44u8MV`nEqOG_KvIR8h9QCBvI+@biR<Cne+UlE`7?6(Z{cq|iitZpyy5V8V(MwU00;87j!F z9WsAPY-`stGlCk2!^G-p0g_O1zOQ;qSLL3K4a+>EGpfn7l~ghm8>DSOgb->0*aA=M zDUEKF=(mcBuOU6zs=+y;g3+fWv~W3weFh55y7T<|{%8vCCa%2J_Xw4$0|V?IKN)Ho z8tKAx<8$k>7DO=~DlchJ#{sfwy=Mw&wLsXY3h-KzFAyKxg@^YCIac+3_yYeix6Air zO*o!1!uEk?P+vrd&m&-{+6Yz+D4%-cr@=Hj)`7~Z8I)Dn`8m_0Tb?4ucZmQScknO|(%`OSa z5$3P_3i2sD4&5JoQ%3V_Q&V>a*;gS~)@UPPxh3jOIuT9Na?{oJIDs%1moN?=B|o_D0f(;{;W;Ay(w62hD+?oq^Bs74NiBmYBbHZDvE~RQ;m3=;+?}f zFK&7#gfB?;tctqsE$kus`40?P2Hj9ThpHgM|$KJ+P z@ZIuAK&(x%xA>Q#eCA1NyXbGwR1hO zWcKGm8d4}>XsG9wv7E&5(EGg0?&pji_I|9pF8Iq)Qj;$BBE|gVX)6=a8J8H3YTk2-&@}Bp(|hE2*DZ1m0Kjl6`Nt^Z#dY;ENaDoUO-dLU*l0Hsm=M+M{P)d zira6{V|%#b$?~P&pzE#}7u_>(+3VzN00`3|yk4vSnw&5cDUvMz)ZQ}%$D*8G=wbkZ zOPh>Y!Fmv;$1;FspK%AfA}MHEP0Sv`?khpV-(#z%_-uRE!)0__)mZ8vDN#&W4_u64 z4c!w&0+rz(P%+BAmbjh>&fYEANZ2aR1Mg3hhXNu!?6Se{-9ab49#)-t538bfj_rIg zAVd)E3M&rPFYi?MhT`AFo!1jAw6QxTrkvOZfb5 z*jRj#N&~D-V1HSDpCuXZ6TSZ>os{+A1TJmwq`1uXAxeLP{-BKSOHJ51-++DF%(?w^)~R zkv@B>CjoQL-%Z)}D`huaUW8}=RAMO1(EXXGrv{+dPrMd;A~>GmMbtbRKg+*pPQT@U z?@>t@o3^O9A;R1pa5z-5uR_`F6%Ar>YO8;#uZmn=fqMPBlQPSztFEZm| zgwQQ_QZ(I7-1J3*qTjt!;^P2;LTRJH|ERa|Pr4ibQ%9&WavJjCCjgxAEoKk{JF>ne zCt5oA%i7)els5^+BKKII2mS>NlYP9+9;FSySTvEnC;_S_LXcel8^jhuL>aVAzlgp( zwVD1`Ul%{a!&r;V+gL>AZ;;$vd_Mub9c6dSp1+8yYaatIP%tv1b2?*++6{qy_cxz= z3w^1wVtR&vkX-=lop@SnVs<_wokCT%EW&Ba=l3m{w>aw>|O#$OmQC;CzBda%fH*o z`xdBU@zwvcq~kQ7RZM4(+4H9F@Nbz$fCP^v-nawJMlqVLY_(H=fhye+u>K_*_IgMn zyNo{T)WxOre)`*3?7F+y>tv)j9NOo~o7uT?zR~~skd&cAE(U=X1{Bld=uOksGn&q% z@yLDz(qXqak^>=9M_{2pDQp>PQh(8Xsu~^c3~J+xH`Inx#a^$pLGVmnmaK?tgq3GF zbJN#TKTgi!Wm7|IlzZ#C(uL27iQZ-Dk7m*ziHPC@Spy#(NDK&ownPINz{+ul`rFz} zRtE&w?0%VCxR*D}Q?aLeXGf}{!A#*^a8Sx4Jb3#AyPBj$LCV3NEpy&7w1c&2S@RC? zxrJ{yXKAu21x1^JXm%f4F^$YXKdg2?V_g<;XlR~Js3=>J|Jnw7_+VGs-rianKF4np zM|1U1Cnw}4RNNUSc9M~aL#6DJvi)i`Yg-LvNxod{pS(J6FrDG*hAUM^>NL13YjU*2s+kA2|F(e|y{pWKbK?k6QUo|vff9mGH!o?J#O0g> zwvP(i$qi&-71kLTx7QnowAOab6>+3TZx9L-R4}}2MA3ie$D*0j8uA{Y=Z%wNu%Oa_ z=Vh^UN5WoLaRQZR8d8tY#sTe(iRXW1(@E@Z)n{DaJbCU=GR`zSa(oePbR*TgnW&hg zpz3HM``Q%2{+^C}e_IgxinA?BtjuQtnpjui!fV)IW+rle=KLM~{Zd8!?@@#6MV4x> z4M8DPS;9_!YOC5?vb!}9`9;~Rn1yQ&LX}O}u1xP=>4k@HZ&WYO`8WV*B@09`H(cEz zrYwMMCaw2HfzHW#=K`}Rd(&}Odv^8x%`T)^3sJf_b+aaIOI=w<*@Y&V-iA$uN0+dU z65(@Uk9=b6B=HlH&xAxx3(T0%n_l%@{YS1Fec#F(j;r9GdT5gp&U}BACnw!P%T0x* zszR;-4=w1T`oi<)xo54EZKtaDNVTq2gjt+%0+QePdtQ-X4zA>L0CMu1b{X@i63>-4 zPpmXS>&LUN0ClNA!Xhp_vHi49;{r#V+Q{dI7h%#%#W0WT#(|WyvO8xMABu8Wy21vl ze}jPG&BVR8%P!x9UsaXPLzR8E?^fGu-`}gFO0s4>z=d`@u{Uc_OX zcmK=5sP+jZxE+HT16_tl__5uN>si9NI6Rns2>}U(u}k2W{jV8q_?{HEYL9BX)op=`e7rwsTW9+*v^pYdZT>noUajMY zU;9=Yzdd@ut4Zg#r+NFH<^WAi-#vy%J0K2kry=v=z_eSI_`IgJ+4>oq*kjvgV|i+K zVfql7 zOl}gK6RY@8haC+T7z~d>Z?#;n+IjHUlV6G7-GDf8oF^@C=BcpT<#Ua@ZTl^pan` z_=AT(mg4A&;PNeEM(p=*Hg>bqIg!)L5EjiJn_O%SVs^0UVhk-(v4+yqFM?4KbKMzB zVPj>$RtO=U#z7zIo_Eyu#lymRof?i|75DddZ{N{Br_s3;lcg_8mlh2h3kA=kI~VS^ z!0=#hzvkB|yP^B+@4h$cBCN81nz)on-M(kB0+hJ^_S+C%UBRz_k{|2}XOG5db%K-2 zOT1b|>E#^NP*A?H$i2-opokEchN@ScmERl+45w7Oqu8=8U!`Nny`EDZdnv9@zlsHr z!ffRyOm!of3D*&jS{^@P6JA;Q-NBoNg{fNY(}t2?Bt;?|vX_v<;#{=O@~3-lk}}Je zv~>!is&?fw)}AwX1VC1nqPD&*oTqzCfW{&;2*-Mf-LCv~0(n+BJs!0!jVs~q!^-`( zb08%U|CA;DX)9Cz5aD6AA?t|_Pfxd400wJJFPsPH>y;jmTnO-jb$E#D3`N^LV<) z_pSRo&%&*^b8E;sOn*7Nhcn?e4Ixa+pN z(s){Vf`dBF@W@X7KrxeH-A8$^D}9;~_W8}9Vx0%A#HK#V8E9My9{DwEtOw?yxD!S* z*T4(m9aQ02+@(3!Z?=_5LRplG(C-Ep0HJ4Q$&YB`nv;AAU-QVT3cS3#BT?>VBx7U! z^zK8#Ic@&AQ@{oRtd)Uc!SpFzQ0d}p%}$Qh!q8Bf zrqv2EMf43!FFWY=P8D5Cz%<9y&jmV)%Nj zcxz|DcK0NBb;ql{v^4s1sv^@@VTX9m*dLK@ZZBEG8ys=icP7ktPK`RQ=k!YzS->=K zv(c)739bThT{qcgdFK~fb7Yr3=#&rjW?YwG^C0feh8Th#B=&+IQiaJiFBiTY{dlx6 z;-*;WAL7?}E_v*h0GrZ173DcDeg3KvGLzXPf~TF~bAW=5!JWa2e;qwa-gxR6ysmxQ zgMH2wOEYQ&YXU+?SJb?17k9Fp>KG4f0hxa3l5QdUSAqVZ(ean3`Za$wf+Kzwvy(4% z>+%wokUZm+l8uvPMh2%ux)YDTT%A{NhA3Yi3FEV3`eX@2JRI7K36qCW-hWN*i4Lki zgFX20c@PpV!s_2IC}*L??V3$Ez;VV=`CYe~Gvm_;z16CUaw^SrO7fVeb%yb5r++4D z&u3ahY<`A8>?gOw+7u3qrq>QC_i#C<9!Ujc*j?2BFu#Nd zHl@JO?UiBL1g?MijyVO{|s0RYhz6T#dHFWQL-$zms;SZh2vrBv>U#w)T) za-Gto&$`qb0T}nco00!&SRROu*&2anWPCBD@b6~i=N}CJ&RGsYSoA|EhpyPNc@*}9 zN1OHOtowCW)AG3o7sb8qgR)d;Kx+j*{)Hvxe}HD<{QuH(s!_c+@3&R>Xxv04sAX=8 zS-2G0?afba4J~{hIeyD*zf^+0a;d1a8E{;it7HeLIOHoDwLv@f8KBu?c_d;U8b{#~1wZ1%G_OA7Ajt7yR)Be|*6oU+~8l z{P6{Ue8K$l&DqHLY5DK}!s1EbMor1DwW;X)tM2Av)tGQcKfnw$n z>_Aca_w;{TW9(f`1as`Zs39gaip780YP)kvYVxiQ;{}!$kcp-gCu(qPXp>=bSx{^uWmf9{n>k^M@7FHq*bJmY4-fkEs|AT?rR+ckD@k_VJU` zk=JZYJmxgzcsz~H@TlmW@h+fP)z-F52jJenkKa`Y5OdGaoUU!D>9C{AG5sq0Z8f@x zp9qP8_?$3x7s=G-o3|Ub97v$>Mzv2L$DU#!vWxp$Ub&Sj{o&}N*n5#(UjMDvin zkIeugP|l9a7$0B`G@VngE`=1CR4$H~=)JR>5yYjPJZC~1t(#H$aW!#8hp-9I2Qz#V z*Cjkg>tulw>CZ@|Xd4TP6W?RrOeZ82~A8E+8VV8z{kO*C+vdiB{qbWnZ`1{bq2(%_Xk6T5T|dag~o zd!Yw;c3gD?Dfr3yyZ!jmJLN%Z6(&&jk~bes%|9ch+!d2rb&nR|o?+Uyl~S&ePNyF2 zNGLEK2!!vfBj$SV{v%s3pVj{jrh>C$`j;ivNU2RHg+{rW)Bbfql_EP^>w?1t`ThCdcq(Ana^VrxtT7q0GyEH5 zf<@1)2W*?-sxQCOG0-4f1yJ9T%Id}bza}rk9eZKsOvg$(* z{@X@Q!FAb%{aqAf>oYRw(D`e?ojRrI^!^-ot6w0f^$z#93ZYvS(EQII4|5kMC3n&K6? zP3uozfcZWAWk2ElP=BeYvOiCEj;`S(?>7io6>sV)>PJQd=p-5tQJ_0;pq|&y_x%Pf-&cu;#C&PQV4vjYaU;jVR|F)tk7nYk zR}J1eLxxpt)w^C!@2iyg=)N>TdV@a;t@Q`M7! zxv#@ZVPfAmw*dhQ=~`0pZ&1xB^yETY38x;J{n(IVKZ8_(p=!r=7;wwlhD2oH9W`z3 zD(Q8Y?C`C3g%y$}9h7;daLwCx*AJL_iWeLJXc7@X@M2}Q`*BqeS0ZB!%Se@9b5G8> zr?2xtmDjEg?NXxe&9i6mH`KU_0JB|>i+u$5u!+Ax6Z!m5l1e0s6FE^_R~#4Cw~sIm z!z=Ennj)QK-siy*Qw;xMq%A{PVl7TjJehC+5N4MFaNSe@>w#Vf>As0dJY40zHtADP zohjR%kDSOXko)V?w{u2}&eBOXWx;uyR4wXf)aiCGb;2G>t1CoSnC+d77){F0Qx<@~ zI`+u9-ty}4KuuMvFL(G*++xJ0WWh;8HGW1{i#n>yHtuMS_n}ydmvUduVby#UJt)aB zaP962Uj2uf_f6n)6n%w|q#)F^k0qRDMB|;)L@+kIlU0$Q@_=xX+uw7=+sYsE*MzEu zyV%lx@zNQu2T$BkGfMKKAo}-ABdwOl`p7v9gDMP_j1fgJ;1-MNe$}uWiLu~A-hRur z=;SX~86cY2J)x_;XeshoG((%pqWsD)7=DZuZ)n~kL(Lmz75s}Gxso55H-C&YipAv) zT(!EktHB*z83p@5+}kcBK7tZlt6oyBoxI$WBJ_^5L}7Lc3@hzd^nNAh7~Fet!~OC= ze-#KB?{n1t49_mn?>bQ}uxB}4lJaH6W5XlClT!p;dx73z6jIl?Bdp6f?<35VP-jeZ z$Y}KrBVFM&HQw*}K>Ok)FpZR;s!(=P8WGPnybWg25yfxr z&K%2FQ5M#;8UZ-5-=^o4ezxdD!UAf_5amapr2%aGP%L!L2N3( z^NT->TZmQw_?r?>s*tqv^W;rZ62T}2g?-=!@WW=Zn_q63qW{D#e51(NK_m&RIAlr( zm-3#4_@Ndc_GD#3cP9W^))hB!tLnxV23o#1Rk)T(H^1sUrde=_w=ejb-;ErjQ){nO ze}nu-BCn1^Jkd!iTl*5mk+18Yvur+9s1}sESoZezx`r5S4xE0rbG5GMS)~v3y-ddq zB2S&IZvi@Syyc09)9=nYE@nKD6{S%0o=crj1r?>K=#h8a;f@X(#zB%$|QOK+9L(C7eZ)!`Oeo^ zeJZ~l@`V@bik#KTDE2;9LR@cZY041v8b?bQhqF{37}==VSe&|l`yG>w9{+v~J;1;# zuenLRpm_3DPpo*mErgg8IPcg?WE?YUBz$t^PV!Ol-P3k8Wtx{7Y~>M)A#L}KUTRtR z2%VSdb6DksU<)0PT2uZdtc0-{iPrn9a1Wwy)8V`OC0486Xqj##^39B zi@QBS?#ZHZ5B&FklJk5hew;;=I{x6cjKJ~hRGB$biLf{$l-4UU@vY7{+07}nNX9$F%5T~2GFPlq-uG)?&(0HcKj4gM zr(N92TgdDkn}e?L_a}s!N9_A!Dw?8i2-|25b4*mv- zBVa^dL?AXZc}5_wU5Om@_BY7$JMXvN^uh`GiAYdTk}MNY+B!fw9l4ltq)h@iARk4LU;GB$X=wZnD!=6LWfvF~ zIiH+KZLLE6pggA8-dx+;L9am)2^435O>|BF17T=RdVlfz!Czsr`d!qp7YGTvw1PsvKK@C>Uqvo8g{(*K|cYus0Uh0N2A;n=Rp4MC2&0dL2 z1zUoQrY=vUtj}nQtY_z*mSQV15mF^In#bmZ+n@B zOHdCuz%42E8NpOWt}tmz@(I@pJ$-NFBXizdkEpVH2TjE0tIH*vqs0rgYv9E8E2vIg zHOu67;pCmF!^?FWc?$+jJG&B?Wa;E2kopW0drZ^4IiX(HI`Y`DuK@wz;*Gcr#PpABpKF#ej+^h}prdgo&Of!wFzorj!VQ6W38?adkVw4WzyxdpPj zK%L|)13jf5TFON{yl}zsWr_sSEq3vqiKGlaslJi&Yzaz|MN_lJnPjnJ3ufJ}vM0~w zF+Sa|mgVCdhlA9M@4ou7eoJ3cU%FL~VdbwOJHdE*u700U-4KW{KoUCQ+!;)9n?Q<@ z?{s$!-d%VYVYt&fwPowLb3n&vY?4c<$*O#}bkcQCb`;{Zml_VGgJ7R!60dfR1n#N& zD#o{}cn@D%`hIRj;`qTF`iz&55GY#zay5xF9wVgVOz>=1KvpMLr$p%OhLi%cjzM zHmyvhqQWg>w;KymlRih&IX?1yC{&xQv*PCIC@(#1(vUX2WjePxb%%SRui}Zeq!jr4hqqSn*NJadmEasP+@>x0qpx5Bp{FS#?0?^-!D2xoUc25^zi6$x_IF$XHox4 z5AIffGPuy;7Co5W`(7C#2BZ=&0;G406Crx?Mn|pu+W~w)%co!kP!Dmt(xVp;YqBYS zCfLa+_@f-=g8fpfL#KUWH-=!9nZBUjy5d@f_H7RP@_Z8UD$)_32v3DX2H06ZKCUin z9!vJOYm?1zW2ed5UF|w~^R0$>*4O5Yr{q?4-GcB_Zs| zL}EJiqA9_Q6iK)x+}HAY;pNK93Yfo~u=zU^uF?x>6@QVl&FJr@VP)M?@CXIqzA5@* zM5{2~feSzt%k>#v%&X7r1Z*>2o?98n!+$^Y!ewv!MbSj5xZJ)^{O&{W#Qs%5q&9L& zSAon9Z1z4Vp}O@lR>aH~9Z=<8`fM?}AJnr*C+J5-=MDp z6$?l~M8GVrT3`ybA=8XL zNoEC@OeM;&U`p}k4#RJdYXp><91D;-3mrQmCouvD=e4oKVipQhxvO6h@67Htre3kU zVAZVgLxMw!B7{dZXUSy1VvZDD@?3nS_E?8Hf#O9t5TU(b05#d4$?5nNsX^E&^;0%~ z*;1do@l0^~!ztmb%m-F+t?L2|Rxg-FA-5sF0K?-0J{H?U{qUG}L=9PnDWd2Vsc9+E zT2N8s6tHhRiO-zLGRN{{@9Q%Pt!MaP1$~WDv0W?Pzd?ED-#4pooXy(*>M zPRN$updY+?TXNB;`a*{Q1$J)%f#l{MinF zjKd#$_C$<5a5)z6#w(d znrRwB0*jl9367kPrV>JvIb7 z0?X60E1VOO@0$wmtG-+>Fi*^R@x}Z0DY`RU96;h}@^`_$o63%$C4_fu6$`k)E%vE9 z6!*7-?^amnC9;fl;*%W_Z_2t5{WQQl8dkCEJ+Yy0~<9MP?Lx11JqVC>st4oDn|Q6GTXpor#QiZHyS%z=X&~z)Cg8r=3(iwtfoFc?eF=>^X%F!L7(Oo5DnrTo!U0q-bw#MY{d}(?}f=;bN~9Z zq%zHo+=vSM>FybUBiBN|7Pp4_#`JU{kqzh}(p`N=?(TD_c~X3H zcDH>IUXvYL)9{&v8T(wIf4Ie0qQ&eVFT?krFGK9#xs3hU25UW0nV#Z=A{KHJj&+rm zqBq5BH?bGJ%9~SJ-CCL5oHg+M0b4n^BO5C1i0Xt;MI;t{#ifauG=wEC>dG1{ipge% zt>8}OHLi3Xqo?{Y(J$}G?y;Pwtx8YN%Xbfa<=a49a9S|R>Snvnqq&yGaLw=*gB=LO z1PXG65#zA^-K!bxT#q`T+;w}6j`=*@63Q9)g)n-#@qx?6Kuey>E&;hCBUM#D-kVn z+3WIV1R^rtNj1J_i{~+s;HKv>eBPgBCF;o@Ws2L2ZD$$RWj7#5&qP^vbNhbqwNP08 zYk$$CvZ`&(v2R&U;;Q0+qhkU8;iZknWQ6IMsDoUnn|s87BSgAlXVoZ|em4kSC70d=h9n}77UU1VAwUFGqy!0{%~$I~{k zo8iwq)3o>`Nz_C5=v)eDf0c;>CU`&8<#j>`>_eRhBZaBHBjPVUI20}K3wh2Zm1p&r zru_1UV2i=*RT3Y^>jK5#Q`ZUUQzrZHh*_r#uc(9^Teq|t3zM@?pZxSYUI4BdNEHEn zEz53BVFVA*FF>ZOpg7^Ch95Jo$d{$~P{%HgGCwGLED(OHMWt3*>e0t9ZM6oI@^$RI zmg9HlS5_?=Qx?>-VmHp}r8q{@NO9Z)k*cnrr^D@_NNi-^AoTZ}d)<5d@M2sS%;t34 z2f~+^B<%~~Q51uXhUJy$Z86%IXx$jHJmC^h~n}5a}K9El6?omH|GpxJA&dV_X@8eP?SNWo= z=Hm72qLK{Ro@$IICh;e#_Q}bU@@X}FoM(z=`=oT-vBBN&J3o*IgZs5#*ucwCD_RFo zQ^4Gh#^78a-9;Qz2yS;NwEN3BgV)#g`Bk2jBwelSId8LJWXi%A$&uSZIYkEJAoF`2 z%81JZ&*h!nXZ$_066RM;%8L;|o>`h9-O&$Lh#HBM=KFgvdBGv6E_IpCr7q*xV%jZd z<;#jEr|A2kmo7q<^h#3PP|jkq%viTz>Bn)^RbhC)CBbv5Wf0dA<(Yq78!#nueUANu z19Fas@&(7EDIfwW2~wNi9hd}W80#>gUlAu1R5kYsJa$ho$y=~{dR4V%2y?0W)`y&V zsltA>@781o%goMUY>9B zG;dsdD#XuMLzlzH>9_r+gfmxH9XTj|iafCvrHHpCGByJAAbj5doSD0DLfrj^N2Y;I zPb6!M&ZK{$yTHW6{t?79o!5Y$nGvztIUmoj(KOsJkyf*78Dcx32fmE(Kgrh>LWE9- z?s0GJpk9#kO>FX-N7I8ySFURQ(sjy*C0IFT37OhhOlqh<$c%^9tak4utY~K9@7Gqt z+I5Iwt^Q&SWq8@uhC#n-mc|_QH}tVa&Ot)jCWGjz^Cg@`?dsG|s~oyoF5e>dlB6&b zI;N#T?o~cNA@0-uQ&QVb_utFrNzsWgEGAR>Sr0K~nLcJ8%ezo2-Tn6eT<3E#lw6iU~PzFW+i$68=y@db<4S zedRV^w>xuJt(aRYX(NhYJHSwC50=b{V?|~B1lL`3g@U=4z8RRnLZqN~xESr1c!v%& zHjH(>uDl^60{#}R2r$~EA=P}t^C!2E>KQAtb6z>;>eKym?GTS0bEmW7vSDI$Z|~yz!*!6iGGyq}r}w{tjOiq~rF`ByC(QWoqH1id8*`NF zUVa*Ke{57Xyu!kp-V~xd-TJ`w!855Dr~1!k{r@s=r8h3Sr3S#9jTumw)ur;ry2he~R)= zq#z+PUWfN(buxBYU|4@DH#HT;3g2;Ki)a|gb&wH#ZpCEq&yci#!kYc(NBT=qQ;tVK zRWv1bIub*T{#>K6td{!fH)x<2coX1E92bthBJp1=4$=NO|G!jB``>lFIAnWY^Uk3b zp(lBzaVL&%bpY)RumG4e3p5Dd09MrgJl*fvf1Tq1S-}6&dHVl}|3@p-{;kqQgK8ve zNd^NuqW(YJ5waVZi>KERU^-o)5bACaP9ez zlD}nMGXNREu?25jUd!?2uK+urV*2gYkA#Nwa&$8KoVRyF`^R|h)`Ldr6_<`Ro!M88 z$f;G(Arw4MQJ6x&+{5sxheolZ5araTyr1k2Xn(m4g}i?sXk7sMnIgNnmkh4A+&&;= zI-)ocSBWDD`R}gr4f?R<^Y&wQe`PP0r#%&=bOenVxcuw}B!Jl46@B($tL!?4-gpr~ z=y2|i$I`@QAyx@dEuPCX(F5yNOfM#&7m#B|-qCemK=<_f zn5!551rPimfmr{%|G)nl$%;UUjM=#Ki~g`%ADx{|Kueo4fC3sD0F4KswaNZOP8;G( zBi=;eZhCxC!3@ST$hzYW5QzpesB?C|3t+*IXm+~6}KLMt?Ix}@m* z?)v!=hIcJAiEhy@%?x9ro3A)q{t1Zv10&MP%bA`SC%i^|TEAkPRl7+mfx5a(vo@XS z_OFaMhfMD(k?01XiBP1Q8JvfA%8B6=gyTP0zcRr3m2mzZ|3BB(wHu!_*DY%ulWvd` zgCb%4Jq-~sQ@Vka9p0y00ax`e2K5=M9rnU8JN$$_J8~^SyRVod^nzrn&1~rW?S@M{ zHNTcr-R9_`0ZMKrA)p3>gQf_pXTAiIX<2;pl-%vQg5xe-;XE8J+1HRFW);N;(w&DlN*sn1U5u29Z73(i+itZw0_N+n)E;eg&A>@=$!iz9gq$IEgZUr|vont-P^#aV z1Fd88lug{7g4f#|YONmnfAUKdcoJnM8{IDUx+g*r?T9`BRD<2Uh=TrJWPxBSH1@h# zG?=vt>QcKh7^+<}arw0m<$x|9e|pba!%ASlO7xJc;he`_yZ5KYD{Z~?ltCmNL_CP5 zYq9Re8_Ya-hnu1ftVF5nAU-KLeB47bM@|59fwwpPJgMf?gM97w-g`s_xD0N;y=rWHp@fBVBjJ6^RXh^8 z%FSp91LV60hDmZS!yro~o#fiIn@jZkuOBu4WWIm*9*TUbuWvRt`E=XKfc4GTMdp@R zzC8{812(GPpuSrKpm}VkJVZV-zBAtA3Sp;o;mqPmCU~?4g@dlaOJ-npm7K4L4E3IG zzVhSJzP6x3FO!W8jQUDL%=yE50b{U3%-pWjfS|u*)8*azo}j}HC^pYk^ePcHRjsKM zoDX|RDT{XWT-zSKf@xL<$*h?rfw?cs!EXNs#du(6{RhdTyq@ysG^0tkdgiOY9e?9= z$uA1Cv`tM}f-w>hHfS7c=KbvBne?xf8X-@GI4`YLj~%%FCH~}W3Z; zgKZ{7KM;62P$CmMntJ6{;@$XM-?RCY$X)}+VA~_jM9_BW-}o9~;D2A>OPlc^db{Ju z01|Kdks5S75%hcb|J)?_=UDqMDx&|Y$JYNpulpaU3jfZa3uvz+F}*uUSdA5z#JCiQ z$fl)bZfRrc?AO>|>+|r^BrbvT{n#lvSjYAL~w^JR|oY;Z2^3R8{C&MC{isw$ta4auRh-nBhM zkIC=Ji@W8knNKypJ`lbaRjK`@KZ@#Tjq3p2X*P4boYQ5-r-R-54Kkh7WXFU~eQgK{ zOz=A_F%5{8ya?a>@I;Yr$>Ai<-{heQ9dD@h>%HVe-a;lNDhBP)R3Dms2n^^4Cq{aMrMVOnwcP5bbf1(76t zW$e7B#>%CoG3+5d7CW3ayKr>9bJyMD^}5_AJ7RTuvgJ5RwS}OzSfAFZrm5rg9vttp zX`SRcbT;dwi6du^_mk${p?X{SGpp(ONp&>a>N2ZnehR=yJ`HDD*=L7e9IL2L+UE=k zt~g_5WtG>opTEbho^&-OAuRo)BF|}_cgCVx;C8c_yycu;EtfmkchTI11%y099l43L z3TtvI`TSCFzWv7Sf|}^qFJHcZk|V-I^|^4zX;6=}7(m22ynN}kr~!^GZA#f^0_z~1&G^tKWHqjw> z#t+nT$6&6oROI;7869aIsJc725#7GcVt6<0>UWYUCgrSN)-r0+FGXtJjeDPGao&4f z-aKTF@ouV)NVGQC^On5@M2Q4#CK!j}n)Sw@4e$$2CAM46p_f`01g^SdnG3$*kng#j zQWwkn7;uydlEiNSl^D2-TUZDypM_!Zv>CWkxtc#~#^&|t`>fV*W4?~$FT550eZiB5 z=M-6=P~kl!3<5UtKUMWF#x6D(#{WdlH1=4_7r!B(5R3~pZqM94E4iJrb@T0iCYw3m z#Y%7pkR?vP@VM_i7-s11i}go8E^js}7z0qsP9-tX#ut3adE}QHRaAk$K@|&z0xB!W z)w+FW`wf6gH z(~OQlPKA#k0UWxkjh%T54}Mn2L^oX&J?gIp84P{=@LR_ZzOh*8PH8~We+9U8RQmAf zm8w~rWy5H=e<@sJRQWHq{=t%Q_v=1=%D2)iwX&;RyF0hwO!#u)e9Z-mQ@+XH>A%-m&?jjcmk?At zW`>e}J`lC~5n#zP8Xw`tY%yF{d4_|^G! z5uCk)br(;axuHYz7Kh$dTKa-I-HCH=yZ8ttJ9@b*m&w>qL!vo%<^gBi*H7`guGtdT zuRJ(tpaT#FCC<45uJ)Ma38}ZYzwPM0f1jLTyaZCu?ZGZ4YAUVBod=%@(zC3cRBn#y z&}bHj-}9I&gD%2OFCV7;9i7hmC5VZFX#+fs9I zG+S?zz@fJM*;Xv$ykgnSLID9{Jq7{%ixXft2oPg{lLnLBCSVYC*+Z+IDQP>5-c0pf z51*Hl)Uo`ap!eRLPt?;r3?TMrY*UFcU&Gnq_E^=5;YW!SANR zs@>Ay&)ezwMj}x~Wv58ED#|<%YmO@1Ek^5v*t$^iqje0zLWWk~k7|R9XIc<&t=<=} z6avhp4Moy{9bT^QZ4V+OEe$0Kv)AAzBmIr)?KsFLhU!jF?k;DIM?CZwW&S`#Us{zd zK}Jy+RO%5fg_{fKCY80M9cee>m3J@f)rAJwdb+$82=AuX$W{a#UQdhBY@=C&nmfWT zHGXc!rq(-E>Nq{Y=Wtlfb^AP^-r@#BVbWjyeMz(Q886`-m8T##P6b6izja5)HD~cv08tBdW|0 z$djpN9-!B%sJ(XMv>1C7!zq5yzAI3S%B-@@EaD?@Ic+8|Dzhb`22&62Ml)-UKc2Am zOiYnEJE=cg7W&4cbEc&Pv){9yiUfwzu1v`361+327E4+Xs<&xHz?x$`X`7xD>G2z- zzNEe?c){66TdMSHxGA|(84Z~5H7+v26>!bAqb)9ljR_T=vX`4p{A-%dUbyS-9)CvK zrzAu%dN|Ut4FvvlWcdYA@3PB)fj+$%1Z#sa`%TrS#UsziQ@V?EAn18v5ElL`ar&4YvkV-muJeV5yX-v5GG1jKP}xJBSBEj3l=a2rppu2~omz z@h?UzXlej6&qDnY{ZkG4Pd)hDp9(*07G-`Rw)zUNi8xN-IvFaS51&JTe`>*C!~U)) zx*+?JwTd0ts-vRjae0f2e9kXD&r-pj0PWt>(p(?!eFn}FE-T1x(@>1zUmf$b)GD{I z6)<&tiRrceOU!{28M~b=kGRaM#U3-hkE)6)D3hf+HU;`lJ2+e>Mj;J=`KN-}NO#^d z4KlSiKG83pjX#&%mL~STt&d*@rJIHtG~5;X>y;5u*R;rdYV5mG*&`_5;t*XqRTQ6T zmzykOfoA--7p!xI7dAGvg}~jg8JZq@I&ukx2hKI}ct_$p$%CaUr`AlcB9JfMxUtVa z>z3gcrOzt+_8TjRWEd+%kdW7bV#DQEk&)$unY9Y#c*u|@+U+&Jm<&p4UbORiw4!2a z2-X`rxqB8sQIUyI#AL2yfjlAE-S=>}xpTWx7hEUu+un=G@#EC<@n+{1bYSBTG z2%=af3g==n1PH0bKq8F7E#D|#yGqX40Z<}`Vp*gL%Hw=b%{k8wKVl1xrx!6jdV|mx zRuI&Ab#kwTZw1NFjLNFCdfO*vBKhmfyM)Aqw-1odK5iKV)Ru$bKLPI5IU5hvWn`2l ztV9aN`C~zLG|woDRr1i*Xpo=%@Ik``kwEjJak=FQ^iuIJb97=uA8W!=$>(9w%&fC5 zI~nR1R>CLp@^5g4Fo*EOpq;$t`5coBd2TJbhP%x_S? zKIAlFiXI?Kj)eB_!R3fDK?8un&|8uL`7K4%7$?8o@%7>G4&7_kya(>y=}9@U+%jT? zpfZvBk$r8w?+CR^3M6rGon;gR0xHq3IWhF6gk{t_iLsVLKHg(y8#uleZC8~LPSuAon>pKt6n=-J&ytKf z+HU6lNR-&G&-?T)xo?BPFsE_w+oD`;@evsdKTvV~ zOlu#Sj<&!Cb~P`iEb?f2K)cL@ZD+`_3wwYwLd6FUvxU@!1r?{;!2+?vv0gc{T@hja zu)KU;V7#`mWSOYx1|*j0YLXUCsXhPV_`7jW<=G^3?W2o!MkdCbs*kSA(mm1a5%WHX z{?xcB1eft9io*@D(D=*qhIk)By)S^wHNXjtN);)A8q%3$`~*S6kQ%pj#U zfsY`N@a6aLYglL!;^9ylM1)|zd<9OsayoZUfzvG+5yrKXR+>TsPrK;!i48$ammF5`x7 zV$l(M51;qG8+xY`=Ii=u{k?GD=0$?3@_^jOntAn)*x?YojM%w4P!_&>3gmioFN#U(tnYf z`aPW%6EOd(BHTqV!7pPciF!m!3j3^b9nmstHF#F(FgSE|_SMl`cpViEFt>u880GXT z5UGBkf0??Olpy-A|1=|>;Q@>oF;xb*oSnWEKAPM#b_(o$#;9dB?|T|>MGPlP+7qkDq2Vl^!93sKnG z;-K-QCK$a9!H;c@=CiTEX({dGqLw>EU4&;aAQhLf$^FX`?=>B8B5asR9I6X035O$_ z*D969R-UoFhAzEqnDC(t{Yr~3?)tn}UA4urRjaC`)&Cr1)GK+;=&|sQna_1tZOrT9 zwB076p^g)^O%I}aRfa#t?x4GMmWD=l=U$qRv1abNEJeWe0uJ1pJGs+M)M>{a2n=3Z z3d+>a{WfzT6MEzznR|L~0V>O_saR(J=Hbj6?zkO>8Bv0Azyh+r(pJ;ne{Jd?*b6_o^G_|i()NjP~E>l40L(HB_~A9A<5k* zbr9#lfo9L0-8rS?hyImjOR$c8!tF(dmMd5^(TwOqVJ8TK>rbFAhmHLQF8zxVbDlk_ z`7aCDfC0(Ui_h{|??`8H`nxGLdRjw?3YhdU_@E%W?)lq@&ylJ0I+N8;ttO;-y_+hDuly}j zm=6Df;D87`3Mp>f*~Eo!x7l+|c=E5?@voRPjz6f+FhfnGX(&9`prujkx%W|K6C49O z?GY{z#J$lzligfOs9Cw2jaj*c9+wPuYRcbFxK`#+eNOPyb*izo1P7{v;~k~nAPMoe zzd@6gtw3FpV;)Fhs4AJdN9kJ-`drtT%2!o&`GThqr;N$@d1GD|(d={Vch62yeYZ5j z(%|poCff#~wpb~`vtJvX#60&Yzd;EP6{)qF+yiZjuWg)yM0*RTz^$b0j(nL3gfl<} zUBnKdPah#;5VO$P#c3ZqWBgDpIp-_!`o4;0m7=Z3Vj9Xf88;FERS)M_c=U^BXPRi; zyX@nEG{e!ylOlNrR(n^Ure}ivKA5x%UAzPfqV)VSj5|u3Ml5%u;~Y~wD*Zx3Up7$1 ze*OegTx~25{<^d=DpYNOqNfb0AnBI~(JqK`$@pD(Sf%~eLAQ88p|({;uLFITR+ ztRR{Q+VBL+J~TDj>4ZW}lEjsCvM~`l0`j^hXJUS4zmwB7E4e&h7l4CT$Ry;SJrlP2 z?vNxi(K@+|h=J=brBS&qaN%4MByr&tC089kjeFi1$lXk;rE{7i%$&ZJ?Vx=s!^f4T z^`B~E0sNEC4OqnPEHIg05JiYu;LbGy7-rQqrP&LLHC<*bPhSQJE(;JV05@RlZ4XP+ zdk~f%v0Il)jWI9exfh=tUcXS!Z`v9VdtYf(lIth57dfrBoVH69_KpzOr3iBk7ayK< zdRJX^U9RdDm27pb;-`le;;{y6V!IErhtdHH>*l5UUVM0ArF*{E1YF0l)Mks`sui7z z`E|!5BYS5=_lw9}AeTAo4aE)W2u&J17TY-oXl~K|tZ?bdjf?ivy3JQtUEy|+7yQ8! z-g7S7J59FnRb+zb>YQ(c(N6`w@U9)r3-6&_Y%?-FyUW!y0>44Ka<=v9jT8m7>kIeH z?B*+p^qwXgPQb*@?E1^6@rt*tByQ>sGAG3vb6#o}(gAvM(Fd1HNM~!8 zlQX8Lf(m;U6mY7set12)gwjXNJKDOZb*p>g?)OqLffBwP~%Rrz17U!YHWAWPQ2pR^Nq~9 zX?Id(hEXDbuVq@7SxiV=`|Nm*F?N?+WTB(HFV7Qo`r31=z{0mNQy|*Yy^j_Mp-Ud3 z%Z-tK3UM=nwYF|3aBXZ4eUnpgzSQ0GFBzpTIMA!z)GDPpgyDm+=q~WANykiDk68kO zt406KwC$FvN_H+=tU_#_hYP2;k#S-!cW?YhbDzd=AvQ^jLVO>{qV*9`GTB(3imo@CK|@4uTKF&hE$73L*Jgfj+NcB8m( zyZOacc0oxU$GtaICj+bJOiN9lODQmQ1d4@Tj8agDeVlPan?8p!3_UmF!b#HjW#r!gQt+d6Bt9njS>$Z%{NGFhK_L8h^d}0KtFgRIhBn&2}*p zH<6rx{VV(u0njTw(>(RJ$#0-QK&PS%Q-nwBV7!BYCG)P6bXy&i4*kTf8C{S^@%UKz zgFtHaQjsXC0=5FFDouhV_7^k;MLXKU#@$#J6z@iHBL-AJXA~rqu8klSugmLZ&lCzu z(Y$hQW&Ky!iYp3y1SJF9w-kx=^4_iCYlRyzL`CE}VEZP7U0ed=g=^)jFo!~m(_-gG z#81-kM_KJ;&f{&E*)45a&VT=#yy_)k^*rz-wa75}M;uX=ac%|`qbu{G+*V{HIYW{U{>*Wwn{*o`0tt9G;t#kfgrFBuCfuwV>IH}fCDFR@cv}#c2Rhe zy~QEYR(-4fE>(oOrT4$Y{`_}r5Y;l`48j=^f#$l5Bp9`XExm@CkF(AnUr*s2dCYrD zwCl?VNR6s*B{gd+oEwXb)nt2PO>k-dW!I2We6MoKqr$=Jt@1LEN^{>~9d^Gw2s%;-44&*Ij}Ql1z#F%#CRG5(3hSvM`2O zTwfr+mV?iK`-=XYoMrc37dOebZKx-9{z=(05Flhb_5-9^Fq0nDv z3Ro$Bm_Yyb>iu83KCDuJa7yT1U;+;QMQiuokfJwJLjDNsAn9EY;_uym{s#f|?*r_A z`tKi70x3FoeuLV84!rfx^h)64Nh1hsl&4ZO=zstE-@M=YY#;N^ZrgfoRXyPKgR?{4 zzTW3MnJ*bIX=WI$O23aVcF?BD)aLo|p_RvG7;ao_y^Rt=H;yaMu5!%ITlZ(SE4vgL ziajj9u%mOn{$J&v|Mcbg_m|-Rz2?%bko>&2ZjB+|Y6Dqfkq?-2cvMw6WN!jOclO3& z>0j<%y=ihu1jHc8MHg53BV2OC!yiE1ZmIlgco~xakjElZ414`$&yUsQnvY+$7z$Xd zr7C^ccnJpEUB{5q4U@9YIIdUJTVa>w9o(lXrb;3-wr$b z?BlI$HK8lVfb-n{S66IY9^xe!cJk!fBQbV9y~dv>Hn0rVlcKqi zI5gt#A!9!(=ikbR7c#toN|0!5YDjmEbuAJ|UrA@Mu%JH|DUzD?<{uhT{Ud{{|C!ea z-uexCh6j?(W9SiM^W8L3T)i3b!EaE3`3&gqHADY-_5SyLI{uL$QuPX=Oc{f*hM(TH zc6sV=V|S+uB|RbPR^ew#{dI9mfpL|OS?8H_2CM2tx#)<1S7Bh-Yj6#w^8;lHSvM5# z|6{nGC@?S4Go6={X2#zC{!D(rmycamtXp+dM@WEs=d#?`Ycg-CgT*}JS4dD6sX)AS zx2SL|vnjk!?`nMCQjFR>b;MxD)Ak7P?@$X;U5UwwuObO>&D|n6T7DY#L{0Mdn2wiP z_zF+L-wED-1%8xDTwd&OCajgfuYU^ipq#6Gj=%VR{>0bpLDS1d8#-ZE3&{;>?+B%X zWcXh|kBt+!7A{p{c@o1~ouy^i$99WwlYSzcov^Vq#7{E8p*t#sWx6NeAi31HBVFY* zVP(t-tDk&qBEOkG33J~-&>`COAarouagu5QLBSM9S5X?2x+kN@%pZcr2jzNGn}&4U zPVchX>#x0N_M%Eu6Sf23iX|(EW{4&LShT5h7mYPBRqq(ZThz7IH>AoW?OKyJ((lq| zrFUHLx$;;F%~1XoZ_TTVd8dNIa~1tu50Il(RC?Hzc|!Z{Rx=nTL6 z&iATB_7lCNa4pUR{}OOpiR*+2Pqy3=egk=f=}z7Zk;^J;8n1l${mRt_3&Dc>g4v9` zF@0h*2m3qGC0547d*Nfbb=kQ3T~t+xj^LV3`uP|0%JCIrGIg$2tPJKuI`(Fm*~L^I z1ol=H{~Sqnq)yAr<~v-;Nf^l!FOe%%@G(~-Am7AGCSCHiOc``OKnNf{jbnZBr37zW zURNN6zPfh2MGiBs!tSk-fDsTaSv;Yq`D!(wUK3^TWSXy!cl}cpdpLkVqixZ^RGlTT z?D|)jbWwOwc3O3<5P%1mQ}gWWpqrqli9uv|U|0l4VTN{tZa74PNhl}-u}AN8v)|lm{`Q#0!R6rSl2H^ zVw%|)Y_Fe7>YC)2O?cbnGR1pCu&)tPy)k?F0Nj(8R0qJ+T;M{ZBz)tDzs4Zrthtkh z;gr$K*}#ukqR^W%fLvZaAJ+Cah|#}_;n73`F6ge?*_i=WP?UyPL?v*9J@X{={+XZx z`k+osRp&l6TnmXQNvLlQ^>Vwlt_=pYAaFGc53$O0LrY^z?22jO#LBj1wchtacw*k;F~_q*L?&6&p3p zE7+ADl&Bf2soc-6A(wxyMmr9;2VaO%tQ=UET#T2_&`mBq(XQOUAiK~>U{=yu`dGBb zV$5e~GDRRS$t3i3Xy-BW@YSt0J>?0--RB%reP(jL2c=gSwIOSSN!)TT=9HAo(!xH;e5@ZyntzuB>oHGmQx`iHFQIm6DRyK5Z|g* zPd%Jc+H`nde!;4ruaXA{P>?qnE%NCg4N(|mRwaFVSQ6TwfGKD(1Ui>8nJhPz~+G)IX3?eSIdn}!kC_xf&IuG zFSpvV_%TH_!(qkM2g-7+OK%p)fvSWLB(c^K9RTx`bCZPzlRI91c>>O!_F9^T(xtk= zv*$b6?`yoDoKRh4rJPmmmD9jDb#{8;So}(=V=VRuazk0Z4t}1xc&J+qy>cg+3ia}oZ1DMjbcjAcn{X~* z9c@QqZw^Hw&(tGWen6N;;X37yghiVZzjx9Bl(u|{SK68OADAf@c3{!yENRAV&V6odfD$Y5aHw9hG>^l z-exkJwjRMx%&1i=)Vm+&>)#Z%Fg#uSW$j}DReppyY?}VY$4q<;x$LqVV9Y!l2D=e; zr9b+-{-w`u>^7yAuHQcja+K#sML^yXiiy_&(Neu-B%Aw!RN77CO`JoIUn_V2*xc6S z^&f(7ZDd$)USu_)W!7EIpc*1vA>M_n5e6|DfX3z2^wt75&NJ`yr@X5;mh{nmQgr0H ztkdgzI-+WilUI{{0Uefx1k}=!A zYP@l6%k0=L&ey&O$$)l83x^M*nNf;x!(q70Tm7Y+&r$1k!UxMGpK4lk(*zRk?lELb zHPOgV7qB_vcy@=dHk|+)6dJKGX+73rww$y!I%_hz1DOmCy}sf@?+$;@Yi8dpx~qwI0!Gd^-mDycbN`<*uLbGf&JzV}S_B!Q6CtMySr1Kvt5*ET zSNM#bC{{^#FLFyOaxmrFxNGc$($=%-`Wm?J?3J@#HQrC!AS0oxWT+e4ujMz0O-0Ke z-5_rSJ3R0$=^|FX3CGmsrnH$4uFp7F^BB{7rxt7DP%9<*odZ&jTl)P-BN5x;ljEC|1whYLe(Gj=}bHCr~y7U>$H+&szab+MLbXhAG<>^r^92T~ZpqZrQh z2~fayg8v$dUaI4H8_LNhmhQ#=taR<-)`q(bm$VN3DRz4pwh{lGV1~==xfrC^#?3!b zP-B%heOc7u?2X85b|ss>E1BewsJrjW?3N6us%}6?6>m%Mbt#?|uQOXVRfNflF*4Uqi`A&mTLkmH zOuOBc*;l}tw-nb?bsCN!AUtrQZ8*E`V!J`*(~Z?1?s!GwRm_EYlt$M-KNz4VL5-VKJbK;aN&E1*R}6VHLMK15o1FZz&kBvsbZsNnh$(Wo=P;yy=Bt z55#4M_}9%*3B0?~MtlGvsUs~3LQfc9GGU8fWZOs07CxeCbGaj!Zp=3;=2PgC`t&~V z6?Hc%9AG&_J#m}h1|*RqOKKZZ8N6oiSFTB_X2(?vOt z#V07uE(RrC8?oNjZ_H4NoA+%sW+JcMF7hdgVhIE85R>KdB1tBVV1u3o%9XgVv|YfX z*(NO$4-ebRt3~EGt9qVzHz$&|$jG7Hmk+Qs?ICQPzd>|C)cYMg{BqR|MAP}H3=GqY z=;pcer2NOl)9-kD4mw(7R@gca2}o+h{V=6vSXbHZa+H_AtlXEHrr?4kqb3FKx>{N_ z9vU@G7nWDw>4#welfx7J6m}+@94Mu=rnvmB%vbJ4BAqY(1-x^b;zub?Z?^qmlzx%8K5iLrX^14@IM0lc4Mw zHWDY%EU6<=6mlM`VN&_QaLr9L_ZDYHWDiN7IvEd`;fDMMb$QLYEL%pv-OJ#Rkt44& zgbkbT5$Z+u{G+D7Hc1bY;uH-EO26$EH$2mO5J@uyIRg(SIN>_mZxasNLJBf()CMO0 zl)pPW!DJg>`l8u-jM#FEYwZy=~uM@Xezi3)^dgCX}H3aLX5z2PW-+^(?{Ikt5N9| zOVtaS&c5DR$r#Vcep<2HfCb|{54~NR5mq{)EFaLBasLlu9M%&A&uf>T&K|~!{5%!K0t)T^Use9WmvJTd zw*+6Tq4lVy==+(id~uvbwScMnTUTxXC4DQ%Qsj#RbL8D;-c6B@1!?M5e*7Z+Am?xk zx!}9V@8Eh>1}Rlj3!5q`U*}pvf>O6S=$cSBiEw-f$WV3k*?3FP0JaTUP_{)Mu{2eQN$ccq?Md>- zf0|JAtkA7}fp1>BxN>TVcsKALEeIV~7YUNkf#v;VZ&gbo8U!T@BOg`?!(Fx9O5Yyh z6@S@7uPe~on)m)FaFDC*8ja_0pWt|&dcHTn{ACopJD3k(gSu3mXnDA+NGO*NYO0*N6SJ| zCe{)^fKx~SY_)|6UN*y;Fd$@H^Ef|gGzvEA%2oPSnlS*6h;2-GvA^(jQDM+NW?9Sx zMv$gYv6#Q=aM*TnK!r}T*CtQd!s|9Ch^<%y4+pBRBXSXRJ2nRB81!j_3%QWxk3M_G z6fd}k42}g@ha-XZ_t~%QqKUlEYn!W=q)WvzzgxrvH}-&;2);c)4u;=Z8sgqEDct1W zLuw@-r=0JS>+U~sm)JPx!xu{hUEX_V+-UNBBT68d_yL%fbZ^0TTd4 ztmo*wG%uMzBA~g*HnZJ@2K_`t@FNn!W&29qh3UukkG7NU`Cagq9Wj0>H*cj*Bfz|o zzm)cZ8WLoTaATV;U#iQeBg8C)(XY_AlLbsN7TgDgS7+=~1k2`*xeuahcH=YP1>vw4 zeeEF~K(6f&LuAK_=k_nS-PlC(*tTKmOW`R zUdyl@a%B)qhmh&*4@7i^+Y18jFP0DABwM{#I{D&R#c5c8`-LfzEfPROAht^BPf4m1W4x@U{}g`7CV z^Ov@!nFS<8eipF3sq>oMU#dS&@~dJZL|?;!pB79OvAfN@Ybw5+`a0n26O`?TwW!R< zru4fHJw<$NC3k@uO&o!D0QSe$`!%S>XQX&F!1oUT! z%lmD1=R0rVJ~XQG3Cq&kwa@2%ahl+Hy70N(ImHhh zkArpDg;&mHin`ivR9(+K$?SwS?zgwPcZC?nk&NFE^4i&hqV$tY$24_{$%!4ibGv5p zP8rwfH0kbh2A)T`6P&Q0!iNDC(v7OPa9V&^>407M*mxV2M)x@LMG8olJ=Hn2L)DH_D#pH@Ms^hn3rDz^>#g~+=p80?ZEFsx6 zHxst8$L-Mi$&NfTwxIFVPXKZIP8TL6_Ug0jj+Dx8K9OE&l>y{ z8i0N9Hc~uQ+_>n_)z+{~K+u^ph5c$6!b>VS1N4))n zrXlVk+>>x=B`+;@K_hqkan+t^T_R=ZR%0?K$=CgYxywyX`o1>v6CWzD{Vi><9Gst= z=cP%%(rU88*0snx55G!?DVt$|60czvG-N_e&a@AIAcETCONeA7jY9or9{qmIg$NQKWI>xG0aOiSgy;>szAoUvE%HDILc%7Vd{g(X1e+ zSMbJQP>VO9%h{NsplPX)T3COWLIT-|VdIoc#W1?!ihZINg{>d++BPTw!rwcs*s~PA z#^uW;9s9eB1S8adzppPIT|>?coB;x;e@;*!mk00OHW#odRp3L zJ86bD)wPWcrr_(*yEKoJ@e@N00n#A48&n&oVq`{0U;+T>V=Tm^b;A+>?TZ;Cw$`gc zA$$N26WxJ19a7Ac1`(=(fUbz6mtH<3`L9nh3IQQp!pGDbcg4xA>xf4|!G7Q!bj`}Z ztg_rZ0R=fm zgM*6)$>}{jPwe_l6dN_%cN#O)`elwlsKP|<&dSg(?XrRZ`~smITZZCvLy3H^di-YmTfJKa zg|Z2cKsz_Gt1|AK$WzTWiac}QN!0ECG^gFyHdFzE@4lrBQqBT)jZa2_Vs%mQI$W>N zdb4_S)K(VZ9L#L!W6HmmIpFt%FWkAErWAt;{6tPcUD*ofGkZF-u}fQK4_9VSa+9UY z@(?bTi(j%v6$pO0Jh=zZKF+)U1+;s&0hN?F|LbPL!%4BUG zPVjKIohbG6YP{n1=&phJas&#EA5H)W4w4%o-N|DqEoD(yp~Ys2S3J?U?u&9#C{x~A zOjmpkk9V^MsQ=2jgtiCu5u!&bfca$F3UX(#kOT~ZpW4`AzqJM~neh;K(RhjKlG5~w z!gO6@rjBk4tCyTC#j#iTX1-Reqxi|0kaO-x_7)`^4?clV?v3LjaY1K}JRm}fEYc$a z^*yr*SDE3xwB`|Qoop?Z(&6iG|CtDXp>mU7IpgA0ZDTP&T{gp=r!!vyyO``}gJ1`` zt96sgd4m-FwPf7ftjSta3f|Y)?+U zOiu5unUfw0NIoM-@8!D;?{ge4AEv)xlT%@qx8bX*2KH7y)M-Bgi5)~UR6uwTo&bJA zA8%8hhphigNeK0w*l=S1cxrZ3Y?W3VPNolz&n+rD=RS4o!T_i_U0$3H?8!#f#Z2x} zVQVp>1XH+ytRs%~#VLyq?Tyibf$*&iwT?ZeCSzxd4UbzbghS%}_CQS)WUatl zyFg8^0!c#sQE)sfSHlY*#VhTbFi#+Ny@K*VB;dig#+^f)?+qX%`J;r~J_b49t-Nlu zy1hRcO?iLrpknx*PjOzb+KKmA;u4B6d=M?vVnVzUBy$7H*6X3yDa?UWiE+F9El)nj zD^#m-W|>Xh+I~OLwb0qIo=Sk;KuhLmyBqL9JWkQPQW|8^9gu)1dW5a7A8dL0etJ?)fwe=r~D>$QrlN2xb5t8ZwL1CLJn> z8ZwaFI+yZ}hy7+}RT?SZbINGgbd;5(M$aA5s`Uo5etD!Av+oy+{V4dQU@3vcd7IohP zj50V1zkE9En-kEyLL>YWU(88`og^vj&q`8K{DGTUGYTaPhBQQCY3mR(?p@%z=qeRa#}>eUwu^oQZ2>#wH|yc_$4!jA>XV4(20wE~^i>uvyZ{5R^}JRIu({~MhW zvWM);D5C7imSu*@mP9K160$Xo31MUyJ6XdAMIlQn%VaOx*mv2zs3c&bhw7>pFk9#&sF4@qR7O?Xd)mB!X)&U5Tl&FM1Qc98DG3-2AjY zGl4Ooy-X|>giP>G(4sPpX(O(X{qjAB~PqBO<;E$FL{>4Rez9~-@iGLui&HkGed2Jr_A zydhq`XB3r^-C@Zr04{b)F~ya2`iMzz*b*uI9qYvxu4#}Z;fw# zu$A&)7&MDw%j$t)PqJH2qdV1U(zR5_6JRpJA~5Sy=@T|{;9TuTlCodR6kmIqGi^2ugT>x<@Ue$Q*|p->37#f0yjch7B%e z`S*?i4#D4Fc(!qE?%?9APKvP&%_CdpT#8XA*#{26As(^L+FKCsk)!gbyKxTNB<7KT ze;g-6vv!Sy4ALUsa88{e$q2Rm);u$@#z52_Nrl$43;1_>A75=5;%<`n1F{vj52dfe z3HIJ`l9F%5Gsmd@_W{{|^W6d&R=^}f}_vmROstD*# zqVzLX&SmQv)tF4*uMrs$`F0@M2I9vYY2N8Pl;tVb%h4cAw3(^zM}N^Aq@bP=g{`hQ zd<{lq?|OnxwcYV`Dy8@z=#y8wxMuM;PTrv|l{dufr*e&WDzpq`0hoh8{-&3UtxvB9 z3k_cLVUQyDwjHwqY4!*a)5b-$#R`V;(*RNxeHg$P!gTim#jIl>?ZDT|o2F_Eo6C7- z&$zAXRt&UF`E`vk4hU=R6g$Rt68ki5M| z`Ehlrg5KP}g7VI{m>HR-8wRL1_-49i2C78rTZBDL(XwFfX zk)$@9!hoZ6vUmpKLmjF2UF0?rd>H_1S>Ubrwx6sSzbu> z#IHx~O|+tkm!gDUp;xO%M;61+M}Fl4eh@MG2%1lx7E(eppw3c)?op<;qchmpL`;qM z9kpcMpz70Y;LZXZHJbRg{-uN?VgiWj+X|u2o{D##4gSFZWG>NzcUA!w+jJ4M7m71& zgJ@9iQJe`GD9P&S{Z9%Nzi*b?!J0~5$H*Oo__Byt|LlLpb!TQ4xB^Bvjx~|m_9t)m zF|)nBe)SE--bZG6IfwTO>NE5D&l~xllQzH7De*|9?YVY@1Z^kySQPb)A|Aoq@i$16 zhMufTjitmBbIr|U7x4flBq_<>>?4b?l><9Jd@hdBH{pccThh=Vkab-2vJ7Ji%YQ#O ztB3?(D6IroErMB8@zM8;WMVgUv<~*xvW5MF{k^l*6RcQV=@g4DWw)&U>m0FQ=o4VN z)6o4l=w+e{0g;LoMd)}h!6C#ec#|mOh6(o;<@#m&+}H2kh3kkHpDaE6Vzl<~{HZ;z z-qhI-BgE|r!iXqfYtRsOnlcl8D{VbzI2xvreWvs2D^s6tE8{*HhG#v-%RF59ROo>P z<>z~fb|;<>?S)9-r5R1?h>x2T&UTmTm)P2kNU(Z5k4Sy}H%Rsi;aS_`FM^ncGG<$f zb*zqdz=v66lugZqqS~A8iliWG?v6uij=luOKqeC|cZd8ZpvC_^ptV;6Z&yG%I=I{g>E{@-#`{{QEL3i^+2 zy#Fhr8)gU{2Np$pB|9GsP?TcSFsGgUdu71(Zn- zvZB_|br}U?%v2M)z~-J4zg} zGJ=928Q?o~<=V8~e3_^XVH?!F2ur->Or@t2e_fY)Yu5ho4kIJF`Pvfx?0#^oWEe0( zt?fxBA06yCW$kt6v}_emtY9z-ZuZ&{QsJL#Om*(58;;t(Jf%+zhcq;vd_~!p0-!h& ze;L`-xsLi`dc@zQOr*>z{kh~Jmu_K(TJtrr0@FA>K@GX#P`Of?(~MettKbu1nMllU zF(o2j&fb;O#@SFHC2g(W~PP^WZLoUu*TtvSPT)$x&HD4 zZdkT>2F{qIAQ8hd$ajV}_1>?`3^O$MGGyEh@-V%EV8%i-E3=n?g{o`ER|u~(fF77# z=l*a?M&7qaYo-la<=KNzYP~W_zu-W-NnSbmnO(f`VzcOB7 zj1g!09R1=1O%cn#O4X<=?Z{VVz_d&i&L;+D57?(}jxPrNWpak6Gvx^wC}OlQ5df(< zPfICjUk<#bYkFfeplk!?z`ysiK5zS0QcUN$p@MVGX%la)d!2b`0R(eldS7E)SE_i} zC@!yo1m^X&PmSF~lvh+RF1Ptq!`08;nY$qKG->k`L$HK$AM;vVZ@vIcKBEiGfT5WX z4k-FFG*hqVY<-o*4W@E$itv214<&~ki^Wouk$vSr=1-q~k@tp@(TzTh;wF@Lr7H;2R-%y{ z(Kc!u|Eh+TX=pF-F(cItkkIZ1g@8AYk8j$ee4TmMaM^K+G%#_5aiQ0<+*=@BYQ$$x zPW5hj0qYvUo|z$2hdUB=3M0EXO7Gp`0TzShbB3{@bS5Fvzvz6%xwh%J zZCuB|fQh~f;vg~VrJ{P$AFtwr?>8A&)c{yy`?PaSZ#+b2mKyeZm7G_QbGPEXL-m}~ zEdHcWJNQD{myKN^hSSWFfW@2>!+|#i9}Vbh-=^-7cZeUpc@h@FNIiCO@$W%-!eTk>Y}W#`~S zd3Jb6aWg_dA3)Gv^C}@HS1S;0l3?&tKQ=h-jDZ|<2wjFk`HZghJU@i zEZDBLX9Vca0{a%90`Ut5Ip6!Z6VS&p*3cv0^l()j__52zA`>MwhAy-~lb;Fz3sXxwDH~V4)isnyq%;)d zdJfW!&r8^J`RqUf!y(&dfpqngk^m*9AINK+FPC7$*2L^&c`{ocPr8Rq)UK9Ht%~#?lEEwa)E+Qtpmq4%DhC$ zoV+Q4)gut+T@M!51{IPr93+cfrkaP|x+i?pNz((nq9jiWmj1p>!|DD>NbA;0PW;bFg z%iI&X7P&T$bIS3JU5EW($m&^apt;s;&b4O$q~#z?zsRw6(+*X#qS4N9Le-+UTp(RC zwM+X=YeGTB8e>=|@@BWV2qi9V^EaOLNmgUeI~B#jM3Z#nOB|=h1xZPrl)`W|HOwCd zJbH!&z0csBhpj5@8aay+2k?O?q$yD;#fRQo?9KgTbZMidEw`w(}hx8OKZyxB?ZEmJSS zC7wUHyIe2!F-iHgcd(qzp6^j|WTj|?PACa`Zq5YF^aiG*?rY%Ad)|9nvGk9mW!q=6 zpZ<4<<08xPV_=A4OA$HnnxqMk4;`y0u~-V*rj?y^Zt2}@4eM)z{^t`GZ;P$3#K?Dd zY&|+pNV$?Pw2z76plG*=Gf)l_>Qdvt+(1ouWn9Mv=FgE_GH_wny~g{4hl%4`0p(`5 z3txSknLvlYzYSyrDftwrCB-^g$DS}Za*Ep9e)@`9s4U(1C%3+eKnk+*;-5=NANwCM zsc1#AkPlB9)%s})zLXFHDzxGklm%XrX5Zx<^dzJCT%T(?LG>rZH(n+Gb#)R~u4shD z<_@KyrxBaddYL9ix!VG#AdJZh$t?>F@v3DumZl?+i^DFmH}D?mzL8~B;g{@#nJ%f^f6)VRwmJz z73Oq$^3nL5+4wZcD8r#Ti$j>j{bA@WxW8!AUs+eX-*DaWXbq>NZmAKYDL=8~!=>I% zF;k-2VRZCOo3AOm_&YTX=?@hzq}b>G-i=|FO>8G!|E@FA*|?ZUKvOTHQx%#`(1?;j zq7th&+lPwDPb|&8%88G)`xaN)Ey}W_L^0!F_w?Zf7zTdiRcACS)%jmqcn zmwDF@`PmwY0&K5_PB8}Vw*U3$s;tiGjROM3=iL$;Hpvdtt`Igu*s#|fjxG18urnD= zo!2=?xHC;23XD(`13j9@sO=hc=!zra@qix3*3}{uq@`)5q?zJ@3Gr6_1KSw1nf|_H zA|YEoK62hHt(M{t(j#JuEZIQ8`aTl(b{)F&7Fi~hOY`one zIB-j^is$2Q3&#tMcIceQJ?a$9lh_1Ar4PwV?zPQdC!22u`<|<G`Uq9)XXPVO6C?H7j}u==+rDm6 z#%YmSP5?wH43Q#73>C2?ui&F-wo6T4H-9Hz+`aW?Yth}~<@keBj-qmpmi$U1mCaL7 zV4%TDIVnJI3J_9HB>6JQsrKWvtH)0Csl1ER9`=sfw`<9D3quu$Zm}GSCQ|$Da9K>KV&dIG;D$ z_vj0~S6I3^h9z0oq;4Q*frSp;S?X|Qt9${PAw2X8?ZX5M~S%~ z>mZ{@!})PLPH2{2EXBWYe1i&gTkn8`P_iLiNQ4u6;gzz4pQh&1m;*G9m*kk;29B+v z8DYnZr!PeQ8lRl&z8v$iDbtnXTV~u~lV-0Pd1uRvlz)cidy?3yge)TPN9fpv=3#s? z@F(e)-KKJSj^;-~e8#S+l4JV!-d?$2pZeDX8ljX5@VaQe$Evg5dAyXs+_EPS$w}2W zTTi!q-110+Cezj`Mk3b?&s|}60Ew&|&s#8$|2GqXC%r&p4R(jBR)iF|Psxr~8wMX+ z_9$rG;*ifnr3pG{c0DN*F&@u|VGxx5h7}3>@!~L|*^PjX4MMZ2IU4n}ULn(74B43_ z4_)i$7`#nsDNbiR%YRj*Y2;l2S5t=V(SOP*|8tHBdI3$wdOMo7*Nd~0iAdqShtv|1 zJ)%~hAP4)ZGv}VncK1*tIo>T+h4oLg%WG88sf1s&f0|$h9p73r|3IHR;)+K(leUZq zl`o0TEC7_m+ioj&u3NCBp;2YfO={-sx?lj}nv9U^XD%u0Qj|C~l-!6X{lpN5)7S$M%%o_UuRYS{G9)Esk>ZG&{sCg24_Z<+Uiw3W!{NzKqu-1kcUtYsUhV1 z_A}45&u&Xil^(L_j%Z(H?nZG{})M$X+k$&jFylC{ky z6`q@O<7^%Cx%ERC-yYpm+q|-@BSKpjH(X#7Y>-X&yS?Wav?uNf?T-oO0X6g31Zz?yada|BR-xO%LE+nC<+ zRNLutls&pE`luHk92BgoIiiXDp*iAFz|~ThimT~{a8m(?4S*=eRtIJSD4Ox|wf1r% zgMDEW@*17XjR6KO{L_72xeC2@zjsow!sT@4wN{a7#Pe@z?i7BkOTlNO%kf=@$CN9V z>@TSMHS^}x?EwJ!2d3U+R=^%aAZw%(5N`-z>lDtqI@f*O>M7U3J0p^l+woJ{6RIgm zmzSUsixMGjU2~q?LHWLSE*1`*47#J>>S>ip`|v2S(SRY4jmP(`ItUU)jX8ce0(d(o z3!qQys4a<^chjYX9|*rI4piwDzH`a1yt-fg#}+Axy|(g))aQtFxk4(}v5`hMo*P+4 zI830)1TUG3AXVq;-vi`avC&08(+X#V{uZW_U;DWiIUyP71LOnOc$lP z+7eqSd=*>8Y!bCUiKcS4##z5K+n`#nI5~j;DbtIsLv-R3K;36MBlshgP)|`E2fz?y1)I z-o9)dWyfV@uGT|04@vHf`uD#hPh2oVKnS=7;3?1|VPp83{5TyUqCH30(O{RQql04p zvusm#%dRq^G@|@xhp0M}eg@#`KXTCnGt=?lT{+qXK#(PqFtXz7^on!PdioA-G3R~&`Sg65oghxB}o52 z4pwT|m(R_k^Bc00cIeW0gCZN$4f1x5D%TEIRf= zCR&JJq9_6VgXkP=vJMM@tt{8f$T9X{dd7ct@7Z*HdWJ%W#HyswK412GSdK>b$lN$s zfOfsn%>E`B9DQk`&e3u#%igG^`I{2ot2ZV_?v{_WP1@PHBofFiP=Hlby5^(-h$m>% zT$-tBpJx|;3D{wxQ8y0!?9g7F`O50=uJub968^f3blmGGTnPkm)zVy!ZE4U^2oF#$ zW8$7-P-egcnRajh)J{(lvV*c4uCp;9wB0TbSQ$#iNHX|JUI%IutoMSaEw#j*>`6TA zoZE?|%T8B#w2Y_2215;}kK_WBWJ|j1&)<$`J1?=skj^=}$U#CRfZ1D>!-nx^Q4f!A zuLG($;2w=pGw*V5vjgLO1xxjSh87JK2%MCupRVqncH?cs0*7^&2+EpL8j1D>pQUkD zQ{TZTr5*GuU$EaolC#F2*h$pJ5Ax1C+KHw2YOu10UzJHhHMK#Fz5Y$|XmWdCLT@h^ zCux0B-E2xQgSFWu3EO)&+Rc0I!m=#ei}d9`V?HKVx@G&N_h{B!SWlW+?+y#41_4@l zNApog-ew!wB|XFCM}u6UV@w^X>J&K>~Hdu1c}Ou`id( zPK->cF&+&K)?fOC;tWg_ioT|9NiEgy9-k6ixvTQ5l#~zjQqXN9=fXO)a{lG8YqeoR z9(oa4|2msHcdAjj*l-@OI%7)y>eg_|aC562Q{Ku(#kbSNsV3y0PT}num`fQowX@OH zyT0c3C$%|!H#Xx7H~hNw&yXfxB!GScso#^u3!}}~&}GQ`H_*IKru&5PW+_BD!*B35 zKHj^xyu(cfMSWB1$XuJYlm?Pg!gFZ6HS`?qI^bZsuC>^7dCI5GG$gg!1ir#IZgsB5 z@*-n=rcve^Y(E{V27Ro6FsK4pI}>dU*#~55%0xF^F30iuihf|#`6L#dw>wwU&vG)& zM_vPV1md5Ta(h6AQ4a)SKy-->)BTOkvGIvqIt1F-3GiyTxKOg2MLJe(c(=l%f94A* zC+;7Va(tATe_nB^oI1KdSp_N>HYp0Y4e)#7L54P40QnCOPVDrq+$V+cfuCN7-GL8{ z_pH5{-IByru+MYrZv zd+@OR!7Mo-dy_)sN>vX7w*=EMPf^;^WUJqULl?7{Zk|K zw~q@rxc+gy=7{~L)!3gNK&t(U1Q~mhQI>zA+3%$E|DxEkVE&PR7Ffmi6y~nQi&S3{ zl+BSU(wisHkaDnL3@{7cUn<515^+CITiP6+E7ub-VZeeR)R2Q6if2lzP-+5}D~uKP zVmr=7tSbk&7~u5{38VH+J2vc5LrrqR<;g2XF$=Zt27Y`ahY$`sHqfa#*hE!KQcix? zj>x)`NczCKM@Yy=iHX%*zv68EFXb*>pKaZ|HI_0OL?`pVcySz`p+4t+1y`d(x9Gd! z*TKL!_={3sbS$k;F<%PI{b;wACs|wF1i2V_#CCl@7WU(QA5!DbxK45-|5cxdQp=e? zTMww-L?uCLQAZ#So6@Wx6IEyJVf7_zckxe$(n7}fdXJja@_f_w8f)*$tnp zwA$hGSL7I%<4F|GpNPVbQb(tpzbu}k7{$_D2*4j1p492PEPT{G1c@sS%SLB=0u{u%J11Krt#5y^&#hWg;xU1R@O;1Idcg8z+o!_G z>dZ$Y*RJh_xq&q@iTIyE06t{7hJP2hr^ljEqs03+#n!gdAcN%4-)ul zTUPl)a}i14WgV9=a_0a_iLC%k-$N>&NdNc#_m7*BUH&6|HYv}L%M9It_!>OG@!8U+ z1F0)|AFlvEP4x%baH>}tyr^Gk;&(n&gZ>7CO^Ph{ne81AVr^) zl*fGFddnG?eUL9(!25Ykm-$5N4*Cs-IsDXn`a(E6yJP;yoxE_}O7pppZj+@@DCsQq z3XzG5P-_;@crk*jxHW1d;jC`FX|h)msd86Zo9{1#;QbuT@jYsk5ydkR?fy6DJThzr zq1xiyV?N7^~lc zk}Q~Ucz~kgHHc?N=_L}lI#gA=M1e-{K-w$3X1axm-~Cu_wyIbf7Xx?NvmZLX&iRd{v3K7b z!%#L8&EAyz52+6Wh(A={-85*b&zi<@Sbb+hUn1`L6uB?>kX#ES>IIjUgpUvh)?{cW z2bjp!l`4o7A;JM+ongd1IAp@U`J3_}C=*@BpT+bq#UPS#&DyKO`^4d3!Cz z%;_AN8VU6;xx_DzG8cg7z}>+W)SgB%rb9UN5{rjz?@w9KJ z0X7rE&pz=>OS&S2!zk>dP(?1BcIQ!f*m;UigTiRScioz`FTVs3mon$l*TxKA+iGyE z9wx9VyIyAamCp|JaHZ?0aRh8!zDTRUFDS<A%CeD@tFALYgvY;Cuk*| zHrD*}H^`J)N<>FP1tCIIu6AcZ7wRklEr!n(nAlzf>QSqSii22zj-%1obX6gF4-3{Z z_?hJoaS0PkBLJEJr8lV~M7j2W%1hBUfqB(=HsyR5gW9BA?vLEr4DuOUFszE$I~}7x zcv;>AbS$@I-m}9_Ly-Y~NS%SVLUvYc1eo=D(5d4y!Yg^$RS|mUR3kOEwH|jPd z0gC9W2LPo4d~();8T~&#rGrs?m@uFk;ejotmUW8@0YDAdjpykd2fxOzRKlG!_PF0? zPFC~zetU*Qy!Ias1mLD`k@-tI9nQiVS#)ViYUDoL-@2CdVpGe+*0ziu_Bo4`;nOS2 z?Brn4=~UWCT)gPdM0B_bINgPho84(mdlbNBuGswo6X} za3~xof#UVn^Hmupb*;JWR#iA}4Qpj1h2cUexl_u0gN^FX^+ewbG33cL13W>31l5r| zJ=1MXzg|zS%od2Rv$0#;v9Q-PnCiLDzT0mu!B?#vS>XQjdk zc-Hm6M)gb&p+~jV5ei#G_U_kG_hJkfU&??Ur^CW)sT|}RKqts@O1>403=?T_z3$SS zsoG?5$bFY}C>^Zy)ArPbN%EgRHBMCqAU_52aho~p{SR#Heos9aKt^S7*F=AKOW~-h zE8Og2WM*~4=`C5se zxV-$+{U(vnYjS73?aKKs4j!Z6}_L&(~^A@ zJ6T@P;q4keP?@La*r;H;`fD9g7@_nyugK)4CFUzK{@E7eNwtd>aIN z`APu_)It8SF6zv91xCIWXfKI2G)&$Wc|BoL%c7A*yv5n-bk^Epfa_GaBvT8gIV;uG zna6;1<#~w4%u>C&?=J;}P_wBA{3qXA%xg*cpz|CJIwZ9g1JI*>f-}1*(YV=NG1cRU z#^3W@*Pw!=@$%_#b4N|)-Wm54VnqGwTaP!y)U(w4Jx8~4n3MKlx*>P|iVd+pMs7dy zaEk7gtzHh)?5m7T04?iW02U>mYrl64#YR}r?wsa%+bkSfZZ-C;M`eTOqg%?VG{y_U zLoSVZewKEH5XwJjGp-f||EntSk&4EdbJAI45$j&BuS z&q4VlmLj2+aGn99AjR?iGc0u_QcK@fL7Y?j3ji|OJOb(wqI8=2d9WKwu0{nAFL+#H z>a?SrnK?ocZSpQkHw3T}Zf-m5KPs&Y8HsCFAmpE2dxrGv$T7%EEW*W|o3*v?g~J8Y z5?lL<9vaMV4*d{KH(AbE0KopeooL1|E~JKYMz<3W8S=7lLxqWHcjRtE!`IFvb*W?q z4t#gi$cwdRmnG;Kl>P)$GHHbC2Qs1E{zdyoyY2wt)~nyF9zGIb{`j*I;M-)Avm!@^12`Z8!nj>HNL_+Am-63si6NPInGHgNQkg^fa5Qt6&*7Wt%HA z{_w31i61hIwYKbg*t613XmuF@&~)f1EDH)&X;01Vfv}C$4LglqPS`#)Z}|G|^vS4H zn$VZhSJ&E?{hd+`K)z6N?0TNq1nQQ0gAPC=BK*JO6_@!GtGDNA5Wuv-2-43Tbh$|0j5Av1cKrFDZxOx z{55bh_A|h&VJ@Pnk)1L=P%V4Rq^FM%jCQ@=#S0GYq7CZtugcGKN-D0+j)DB5A9Err zcM~o%L`M;VI_;hLoT)$OIvO3EE>5|?LR$2n5(KQdHBY76-q?QDuJ@>G8N{XUC=f@x zPS|1bI#?6~?o(Kix0Op4o9Shq#YKrB+eib!2NpKz&zMZ4>j55-%>Vi&>3AOnA@kza z0+k3YS1X#8DgDJ=PYG8_b#6X|do^7>rQh-T;FTud%so9p^)Nni&pXT$aoLIe5N$Zj z$|n8bNDevvYehoD)%L?fE(rs|P2SpPR~@XATiJnN_wyyOLg4IMx5SN}N8^E8?5v>*Q8SA8|gVS_T_b&;}Atg{|K_u|HME;it`3 zzNRr?vLpE>RifG+oo&as6oxB5CSzhRQKXzyvW;p&!`R$B#S5naVf}iKMHkzFKancE zzIt>1->==u<0E&qN+tGKt{osS?N6VQChZolA+v_&hUBh0D;wHuUl#ER?tuXf($(gB z1TZTSZ=9h=$xHUcc&EP_fuRv%L+pd2uUD)bQhJ5>j5@$v%`Ct~=1!S_KT1cRs1)1ikuOrXzQ{G1d#gHaVX(GZ|1{j&L{l-ijiK)G zI<)#1^=?QqKzGg{VYwF!yo1l%h1C$RVB}>&5`!rWTp zk5>b*Srkw=1Be@r;@`yrkIpxIo#a_mYswM)Zh2jX?Wf9>wUT(T7{on2aZe&N63V8d zfa@IDy3L#gX@8U6P7yEqjJlbia6{J;=PcmJ;R7@MdjB!v=^@K;;oD4E6}GF)8BWPW zfQ*1tA~I#3#E~kPdb_j&w}^O)`jlT1OXmdTKes8LA;d}6{cC?+BS0Cc&^EgpRCL?X zrR=!|M1V^P~HsKIKM=i|I@e~V|Ok@%+YHHkJg2M85phN)1mP$r0wrHC+gvSJJI z`LT2jDN(M>mTznM>4FIyF3N@!e|EOBf#UM|WaUI~j4n)7GMPYKag)w~g=$1rxYdsP zR6a_v&%I@Ab=oB@`}xzg@as2Jn02g!4M2Py`s5};1ttMQ^M?~06_vQ3e5q?$EU15$ zSQh^l)Gz-&Nb^yFva276xu9}80n`KBd2u?r1O@TA&Q`wht>^C{ZcchyvZU_)(+s+#k8HRfgT zWgu~a@ZY(EXH4xk)*f~AWH|H`8nH$VioERLDFV$D+t|#m2dW==ykiCWiiWQ`dDvx* z1X({9eF1@Ffh@BfIYDslREpn$$#j8P)*+Wdlzwx3vege9d%PU2Qp9&AMQ`b_^7GuS zvGVW2(39)vA(=QI*TwyG`)=kmin>HFXIrt}WH&Ps!P|161HdHrQB(nRJHhJ&K9S+*4VP+p!UMC9T|sZFR-7XA>7JdT%;d+vP0Dz|n zH_Lx1n=C^!1peaZ(k_O{yq>FRUwZrb3)Fum5oLym64#n~+RB8w&3_bIgy}d8%t7@_ zq)lJSQ0Y_&mkf9icC*Y)H&sHMfyDI#-{B-cZl+6~yc<4@rTLc1&i^oHD!NuOU-hoz z$~4h1zC22jYAxj{_bSxinR-}?x=4gtu*}XZX&}={yVOgomI869p9W@lc4bqAoSWZ%21x)=KZ0NK3tdA)i zv2mF7qt#f20PEe&-@0z*>UOBrDqF{0fzqJ?(1;ZKl%x5<4Ar!F-oQWL12UW%=-vF+ zOi#AVfvMPBkbHMk;~}HP@^hwdfZT9EGjG`U2~$~Sq?)NrQrbY(VsWZ@X8e)xZz#_5 z=C}gdeMiVE@ZghJOl%VWS-n4!{7vi^jNmZ>V0}Gtd?m88Q(Od??j_e%+!ldzPUE-C z1MRN{OPww+_6*!ie(z`Viphj;Pa$rrg8Z2X(@u6arf^2Q|D_^tg_|z=psMw_o4#(_ zS%rJC{IP9$>G8~Wtbpxr=q&aCF;@2tO3a@^(6kYsvNWv|H`^Hfd#pWn{X!V#)DDAs z6=XV(;@Y9Y>y4ZDlCP@5rXRI2ruN6=so!g36uPPOp5go{9I~h%KH~R{vUFYoa#a)}}f{J(42ma6*mj1dtt)Y>Wo;#H)bJ>BlU!69-RYVZ-0M87WMKh{BmNREUbFlUs6B2oZ?MXR`Iz) zb2nw$dYhE=AnNu6Fa~>EI=;2#mM}p-*DG#W@ZeAF?rTTzK(67CWlDYza7hDT^L=w5ymXi;bbok5RFa+k?s01RjLm4=w;)u^ zLrH&*UX0z|iN?VQ)jWg<7;L4$tQo`n}>_T zD?hnOnJs|*6<78(rNUHC?PI!~6Z4p<*9iI4+o*sS@QOTrcJ5oPj}Lh@E@~EmkxP$v zL89q$?UhhCKKjTp{wqZ>jR@vi7@(B%+oj|lgd|3%)&F@F3y29iMB*WO5=(n2Zlqfl z8{m58wk^3X36tVp#3#oa>EMOjAB3QrQ9VnPE~8dPCE9uZX`FM^3x&xqMZ?VTPjC4= z;MZHf3@`oo{d%t*=(zIVQK$TqzghkN%_6#ISLf#6Am}uf%KA5m7WNIp`GhVtBRyS| z6=EFw27M{V6?endEk%9aA8KiH!UU;3g}-Mjeiwj-K(p{vKy8i|mMR0lQ}PE9G$u41 zswcc}Wt4#d)Qwjf=nT-mLExf(0N1s;I|sN}=l=~#9c60a+q3S;PjphDItP^4ulKvW zP&pDNr~HH%jmj(69QOv#~7D`F5Y2f``^VskbL<`3~sI$=_>V+i8ow6fvLTi z{qf2<&nO!ar+}18Cuii9BpTxaVmFx|ys3OOP)F}YuSZVkjBvW@6(&rW&=IJQn0oVsxshS1teL z-+=&Op3{o-C}=NgG-RdpjPPN-Qp%H_+r#K=drO|2bEq_ZTXbLJ9D>F?5oQ&we#T;_ z_SI|u@;m=0zsVPkb$oQ@(JA%oiCh<2veuLMb-X|?^3P9=&M$yVW4JPvvaXBeRI@ssHsqKmrtoDifJ z?-idzFeFJF2yl-{HGXAI& z$!49^c2-fFr~NERzx=jE%Dh;SBR4op_-l-E$oY_u$>~O%X1idsjS2tGRf|r%QTPw{np!op%k-keKHn-KnH7mB6*29> zGw&EY`F3p&^_5;sG1c~2Qj#3IpbXMa-(b2#A}@!XafpJ+cUwJ3A{bll{IV60+s~}Z zRWhZ@G|47*>SViEr|4ZW3`Jjkq==%Ap7kmduYbi~)uhkgWotomp2$0AM3p<(Bym55 zNZD2%Twa}#qI*Z3iG8G8@b z^xYVA-o~A)_J$TbS`hUirRho|K3P5D!yF=FvSzV8xj86)P1AzIGb1zX7k;|IYkM|# zp{iS_8fATxd}hk{i}VXBK;Ij$7NSXA+!?guA{-67fLB2nmcf|E-gTTA?ol2PF zX?5E(Q`DY2Yy=8#tHu|R%2?ri+2F-{XEmO(zB~G4IQgm%^1_n8ZcWG7lj5UIJ%+`a|8N zkilxz=bwFTti(1WY$2%OpSmd}-P!z*O<`6k4!Zv7$xhcaQ=YQ$6*HOtA+K>+LpbN* z$NQxQuf<>`d^$5dA#wR=epn99>UE{P8E%iLXWAo{^=f?hrrQVCE;V4k)`CBSi+9&d z>xd~y34|Um9Y5YHDuSAPWdWP7jW|@c@D?7N98O)*M|S#;*W3Z!j@~eo2#s;i0~0$V z!Fq;zl4hdXuBnF`_Q4wa3sUaAFA}u+(Chy49eur<8B5MT6`_2^U)lAUX+7p^<;?@m zooRn_XM{rMrDvUW-jLDMrVvo>zU|QS#Aw_==kt+l?40na5rC)IT&$-)fmH z^kd8dmZ%5M%s1)Cl;DL!hvd*ij1M46%rLEc zU9byDBVYBP;NDMl7DK5O@Sj|}myjbUy{Dt7-{vE^^aJI765+`o>PPSLN?8@3G?h&P zQA4wH6!)|*Dqq^Hcd)j^uhnG3l)0QqI3T59{m=M6!l<5L+g{n@5Suc>3?oR{}P_bO#8}+%dGfrAWA=dw0hZt5qwrj3j4?S z?Cm7L$ma+7tzORCi1+*U%GO+2X+W8!|3~~e~O&n zQtyv)Sjh4F?`zfDk^EurZpoM32sZ6edKIhbuf+3Tg|y#2&Kt60Q0{~#D7ROKOj%3% zrajthERQ1obLkzFHzw7{**IASgDHGP$y0zNM_rb#YC~U1_LXrHZay6`SS@@}9mII6 zcmt);sn=!}Mf!ijU4M=s7rD=q@9<~fn#}ja5I)@AdU|DEIvGWv!w?%WIIz(w$3*goyj?Z+YX~D*=uR~9#MN3E%TdvBKX!I9)2!v z_kT7Uz_pBWHGNjb^(tjJ+)=;NyD&W4I5%jY@3(JM0EbO^~!v zXjf(9s5(WbC~rlT{mQESWG}sd)Zy$m%}?{I^?9dMvPwJ;-~@{8Dvm3`KA`;%J>^LQ z>5z3@9c5(X{3T{@8O=|2E7wjzDP3PghsPPknSDtkN)qpAQSDgT1vDE7)e&y#vj=Bw z%ZJ@-dz{*IDFJOp-K{vyydgs-*Rl=_9>I|8MBWL*ej@XpXbSdv9@})`{)fkh_k@&l zpXuCuGD5${&j0P_owjQ#1C89vP&|sKW14BGPO-v+tqz|zrnp=AnU6IawifWMP*v*u zcswsq|82EqQ1}aSJwg*}NBfXBgmzqyHJy?aZrs9yy^?QzWvXjq7V8S}!$n6{ ziwdtHtjuCQaTllMHq_>#xpSo296`e{fm0(ETo!MtGwDs~4+Ie?= zw{*U0;(KLM4^*4p$RJBOFxG)SwH&sY!Y4Chg*8mUt?;TOl7n#YF!dB`2^b?b?WxR5 z3)dKX6&jFr0nqH#e74p}cMMsv$;BJU0(1)G!MG5o(q6gF zF}}0Z^*H5fVnUNj3>c8wF!lEE_2zr&XY#J=ySRNy>bbG5OT-2hpR;2O@8xP4UD^Ri z!{)c>h2iomhex6>7;kf8UC*{!X?6FlB;+UjGIutHbRGV}7}|4JO7`JAsU~hc^Z{fBE(QWLE9pJtN`Z7%Rh-mJH`Ppy=EAsV63g2&f25 zWB4X#18T&VB9TXbew3wCs;LV;Vk7{j-GPaT9GGep66)rQnnRF0iLj5}40#HKxer1? z-C`)X622Eq70Y|dV<>Vcm`#T|bXnPJB_R+LLq#mw9-y`I!F`f);GNtmJ?MLL2-$a( zRmXaAD&@oo`(&Nokl&a2<2hn=4se*)#!U-fF`e*9 zb;qEYj@sbbt!=JGSjII@m`bEZPwwg=ZwQkb0lDOFmDd(_QOD5vwJTKxN1#$5);7aAm@sRDl89ng%RS_Aem zP6b#0!QurxjArf>vyW$})nrMyMbd1?o6KYDI=qdG(l5{8dIB0ZPf`iO4U^B&7no6s ziZ^H0dz6efG#JDgXmEWKv?H;8=zP=}%D};37HN_!L_cNeQzK+dW3tG9Qr;>;kvod4 zBN3kW>TL79Y4Sb}HKA;GZ~O$;*xvll%I2pZK@aH;=Y>T;?w$*L{O$}tIRVKIT#+9{ zI?M=gkM*JPF=jmAT_++yhB6dSW1W1R2nNPDegtg)9dih(LTDa(d;(ZbIF*~oz-@vi zdr+|VSx=Z>YA?K>2-3#y^S_-o;1N2c=ogfsxw%j8qw))F=6Vc9?@U}-lCe+p7m1Zz zr}cP+O(|z#)nu3sxE_{9kMTC{9$BQ$SjK%*=Hwirsv%okxs$olt`-)}QM0mc&+9!y ztlL^!EvzG%)#a~Kcdl3hM=5G17>!DMW5I#?C>M=Q2Rhc8{$fAGwS$5+xF zL~ort^X&!~hr35q+Q;;)BU7W<=i0~A&;gPp28b-Wpq=?+WQKDo$QM$%dSqi>>^P(K zEViC2O9Ce9gX>D~5sNWc18+}=Gz3j;44qj!G<5da{C$)X4=6~67Zph#6@%{>y4lVr z#~!rZnL`+Mg9YqhA^a5k<*9uPClG%aG|%X6x(jumuLbh`6uc$idK*;qYz{G5P7J=} zClPFPcNY4p1W~Q*CyUj9@Yj*2kXWXGq1*S4a_%pVL$Ip; z=u3#c*D^6lqX;{`lk5hLxBC7eAw9PO`dWLL`F&LQaobe4;$Be1X&Pm^r;l@!NLxB^ zG?)YpO%fkOvgQ0NRI#b#N&Crunh{`|I|NVB1*GcF)b$9^Q zkyN#TUuaKjk4Lguf5whWe1$T}biSmR4C|FpG2s$C1odCd3`2b1t*xMSp$#IXro-M! zmJLVI<0#NFNU*eKvN;1-^Y;dhrGl{xJhbvIa2Bmau&-A(N7m>k6lJ2ba&|)t)c?e? z`o}#Tgs1L%pf3hAOAyO|ojO*8)x?w$dC}cT^CNKxz;Z+zMmSDht!F zt^^u=qCH744JLM)i?EHS57u{L-{Tm5bR5%N6BxaAg2f~F2=Ic2+OVg=g>`_jN*9YS zpnxk5o?+q=W1wP@Q*#I#B9-%e7>03|il%qt>3oUV((qw^@7aJkL}XZg4Y3;)1x%wT zkd$&aRz`+e;T)puBr(bJ!x#RNhoD{tM!;H7`6qz1Jpd#gk_|9pq{#Yq`qQhBX%*P> zZ;HtefKEK@9EG#M>Z^tbNk_=A+I=E@95;thNy_<6C;&BoYc_ed8tN}d%E>?qf&0h_ z4lor<-)r`}6eE&Ua*04vW>;%4Jb^4aVt^3MAug?f7RwHLT%Za)l!d3SoPU!ytm4ed zba(`!@ZWwe(P<7LYzLlK7S_2MZ}htcAQq|sNz{H;P95_0>AWd%b6 ziQPImbRYxKJeeu8_o>_^3$!F&`vdXtA-3Oq(u1)g24qQ`q?5m^xfi-?7SPTJ4e*gn z`VYZT5L9U1|!!_nBo40Znb8v zC#nC`j-&<35y=+wQ_>H_aSZ%~;wFWbqBP)6)IR*4drHDGwpP`VhVoYUnTn9 z-W1W^pI5lHNo`!2w7;M(F^|?*7u?#Lb3$(meSM8S!-|qcku3_8)O;Kmb=MzNk!?J6 zi?aP=Cm>vfSlyb^AbI9SJz!QDi{8OX(Lz)ob?r_nt+GfLpqJ|FMy%t&c7yhYs(9*7Jm(~u7qC=d&T4bQv_QXk9ceazDl?Sa z8`3A!L722-V(9fQ(2VeI@R<@clP$r~l4&r9(1BBp7N}vu9%7!F5t>#rFHGqNa)ETP zsT4uu60LKH>`k!0-q6qk3Yp!?D*1n_kAGPFwIn>H z56!f*?m2I?0oV&;ZLndcbBH4{9~bDVXFGX@Xa5|6-V0Gh4hhNprfv72aV22$E&{Jz zoeJE=0*^)gp|jh%fJ#{B0FwpmOIn7M5L6uRMg=D5m?Xm9e0o#AcwHUPQm(yABRQbrvxk3}^+*!%&%Wqx@kn4*s=?t{5`MItzxYin&!ut~(S6aa=#ty=#D%V@ zCyRUD_Ib}H0^=DrX!hN`CHxD;oG+>UxtW~}K%-)pb0F$O(N`M2QG->QVwv+|vU3;# z&DAZAng4~a`~LI)H~jtoP(YBw0J`$fArCvXE(MIqxk*^gOhR{W>mbs5fX9g1=0eV6 z!66+lIuC^01%@poM;6Zr7=*mU1}2H!VNtL+{vcz84#ee1-NVXLE9~Z4m?SY}k)@CF zd9O8iejT3|>&I=k0C**Kn^8uY&Xqu)__KzlT)~fyYtzR=fy|$vl}RUmmSF&pR)eeC zTYV6Ur07c-r((XGv$cB`kHK!qzos2U zkk_3<9JU8tAU@DB!q>famRfxiTKej;-P4)^FyBw)pXA)z>zE_L8*Te$>a7RsjVtm= z-oK;&QX}Zg2{51C3{5pTY8WRq6qnE$u^lr^7uxOTU_p``OoAvKhW0je_KH|}p?V^9 zi=n5LCr!(sh=vyhWXPM(Sf4V0#ykY8&LPtINw36Ou#p65+8e0zdGlZiBzA(DKkvIM zwP`^!LaD^N$fp&Kq0FP`QB;q91CcU>oW0M8XZ-{Jx1|F7e;#VV_8dj~ZmEY?D4jQi ziW0~pa^nfwz^)Y~5J$nP^&PO3A-UNAMm%R8L>ZC-b%LtX=+2|yS?)-kJ2M8nqcxlh zrm=6C$MnEFV1aG;dIZ}oK}fpm^IzcqzX<%Z-b{2>Jdrubl*T3osr%t?yyDYn7SfjrPK6ncn3|Y^GXxdeO)$OeLpX~NGszx zmyeNsfT;C0`tjQb1V}xif}H+UtWl>8SbcN0nnwsXTpY^|_Ut7RS`J+g)v67WeQI@= zYvPsDORv=r-tLT1$2hInKAh+a2BNa6Q34vGUlv%|zp_+N3!Q3}g|~3mt&<(ysT~|4 zzx9#jwGc7KFS@6igKr~7`hg$+6{4e__~99QvHqxWj57ZZH&^|RRI4)~RfwFCz;E~w zu6^xi8F({eJQV}<>KGZ$*$kf{)lE(yhY$I+tuI|49I{*~un$#Cx>K;0wl>GR(b0YQ z*(Ff1@~>=E10j(bOxzB<_zlccM?!uLmf>8Rm@*r;f+zHb{_#6c&lLv=wG$BIHW8Ce z?j`)g57X-y`&gKXdK$Sl<4)JM_}e|}6Nyp4+_T(dd5KceNvYtFw6?RshgSB1y8!Yp z&)Al2y>qb0=s`jf={iJ6y2-O3+o+lMeS#N^ih;fk9GfquF z!Txex$N-(^qI|uv80>;%t_30E2pK*C5C{n@HwyDYm!|aeOV}vt&sDM4H2ndkmS-T>&PNZR#dnL;j}RLRLJQ`-ZSawW=w%-3TOR^2YaGk~ znUAqiGaLF*lWp3}bFg+qstO~%5UeH6=ME4GKM@a`N*;>BOahxY1bsJ-yujcNJAgC^ zNQ(RpIR7?zfEb#W^?3v8*Tp=FAHnw+kUfaq7=V-PiBxIBZ9gGe#sY+PTOsUcEVFM5 z5DjJ@(?uWZw`<70q5=%`Y$bha-nI56I)DG(`S0HG;@`3&2&B#n5DcxO#i~#ud|7>5 z586$IOLz>f(FJ#Oe0}*)L%`FqJD@H)`gZZBc*J3DoE=8#GQgwv(X>oBN6f7i45|t$yV+fa`Y4LgVbWg1AWBo?W*F`M)Ehj%c zzF(0N(XsU*Wn8pS)^7auE*nhhsIRQE(MKuI6eZMYigL0gmz_bgc^sQ$bhz?`FqgQ* z;cC3O1Yc;GNN;n>-9y4SRU0D?N?UN}0D;<{IC8UGUw!ZWR%JSRXYaB-#+(mc1s(5h zY8LfvligXu`+kp5B7(cm9RW@Z&jWqE4j|0YJR@Hthjt`whGNR}8bMP}?qAmb_@`vJ zmQwD*C;IZR^I@;70xf(|8gv-H#1Q4TDFZ;nKvt)^9Q({MfWdD>7)RTTsWJQ5 zVeXlTSqhuO;ldC8gBndtH?tuo>! z8QWOa+3W2#{nyX))3vAcOhE2Mqd+6#sE-YK)pkNimd!zZhAL%ciCcfFe!5_LkEwA( zP8c&7Bzcx~Azz~CD^e?SCqSa+N><;DbTV!LDwG$8ZsB{0Wz*T5uoAU$dTZi3?93V< zfQA?V7!#O(^?wO(t_+t5E$p%JzF}kQJkiefTLS1r6>E_0D{$AT?Ia=S_LV zyg45~X0mMt;RDN(fN_zr`YjIw^h*BkjnyA!@4q?D|GgY1iywMixE?+14d2070Ppl* zDh^!EhNiClcnCFD|NpD7k;NY|0 z%@^-o2NxFA;q|66mw=|o4c0W(Z6I^n7Cnc6ylJOF$?qkRDo_WFjscS5(CbJxHJDu; z>X&;xhah*rTu`pK6m!)iYL+!z_R6M58dlP$#^|X%1$|GsKbMVWCGjcTP#~qb z%8yY)`BC127NPH)!tCQOZTn=<9QUc~wgS>(9KX1xmE|h6B}=KpHdbU-w#r?%6f1z7 z5vZF(tknbXTlfozsxH0zH3UCNS! z^p(M7V|iCX_eRGr+j8#Vp)V{bBZ+KgIMJtH$^x5 z>kQ6I``)tN4Z4S^9Ag?8lyIasS&n8Zm->$}X1FaJNDtAOWsf?yKWiXv%o4G<9`odj z)n(T&0(MD5-YD@R-a(}yuc{aNxcix|>}BohN%=2M*=*lXwd@W&?Mf#>6c5UNkRqn( z!q^O}OC}O-9G%CMl!JR=jx5id9sNxnCQ%CISj2kNo7_ON9O2a?{_H8xGqud|m+&zR*s zq*MEd{EWnoB^YlH*F}f#%hXg&Z}{9nUav0Jv^N2pBzyA-5^K@BJ4>RVJY5`{JNg$}!f=UUz&Hk9WsYQR3qbml>}9 zw%bu{!u0?iE)c$`u|5U_(r@T8jJbY?GRq@rSiE@AJ8+968)JB?+-Bv6RSv6>C!b+7 zOqXl+~mSv@;2Y~&xo6?z$f`=L|)=pF#F)_$VBK0G$Ixtjw` z-A~Y(PlFp!!PT}>51X+xE*N^+YqVu7|6$%5g@RcV+=V;C)Wmq}wIx4#k0uQiqXJF=uY zob=o*VKD`dVd7((J$-k+bY_wlZ3Gx$`g;u`DJPSS6IjrDXR!1QgXn1*AHA;*$3LB15~;0lHBI^&oys>2|J*kO*jP~|;U5$T}CcPtD`)7z@D_7gB$D6 zU7xjFJMfdw6)>Qzxo}y>0rqhYhd&R2bQnmT{=!Z!-elVd*|C(k4wqi=3j)iP*XN;z zHZ^5mU7Y0It0=y^8j!%U)j)q;X1pTCiQOZ9(K_@YvEXkdN&b(lGY(qivPi)4g{w#)nRD=L*S|ns@jIBW+ zK(PjEA?7h8k^tsfyvTFXv@_VP+O+2uC=kJxUK*e#)%>2eKE)ae#@^Quy{_rcnw~O< zzphDdsRfDe*B0rv=gl2LtHqv7jeXflnD!Ij?0z7gdT-Ur^H+pd?_bt~TE6S@ zTbfk6PZuQ9SN$#NYG@W|)9GCkWo?POk}8TaE?aot*=SDYhD8T>vh-TiUs-snwd#k* z8(Z&Il9NohrWVgLSoltHi&Q`kXa6ZzCF|9L!rN>2pPsVXI(;&vnDwoeWYqKBBax0h zUg%b7*Xj?4W`psihWHI;?s8*&-ijt&F85bUxmtC~hMu~wseZ^y&ZGX)ySJS0bT+Rh zl|AY?yDMlr@uPlJ7WDA8u4arD15%VFvY6G!dCL$1IgB4x_+MJ4zm#x&nsT%+XLzf z3mP=rqt>sk9xDv(#B;hHMsv8HeuefPP8n9PlG1n^asJguai{%f4J&V0wH-g}ly>KG zfT%6^g=ZeK8(F>%0a)jGw*xsdkLilf0J_O`nNE58- z^F;T7wC)yl;QcTaWNufD7g*QRG&eU(yBqJ9s@Gz%7WB@XB}Gz%3zi`k z!3Bs;qYO{6IRv}W77Fq7Sl@luy-vo~uepMFU8(un6IlnJo_($77XU21RGpLiyCf4T9lYM1i4J^qMj(bL-6hqC)2qFS8z(R$Z_knG z%4;b&!4__sH=#v+;#&P~Tzj${>H(o>FK3#L3d0s1yXPh${{T_phH?NiJWJ)JSsOQH zmqKt+(0K8yV};l52p~IFW#dm!z~z$#;Co}-ij7&PZExAZUJ~lwhGg1fK`xFzo552?V~`x_g5HaFT5a)5LpSawC$XTz zUAT2PenX8Z^5e$H%1|*-o*y9riCA>ZrxymxBJ!czC=8S=-jd7wZ!~1Ipm_%gZ{<^ zdyz>80s7JpJsKkhD(-$Xq)M>KgwLQI#7-ioECCu>P>wjnx_Ep3v|CG;1VWsJ3nT%( zQzy{6aw`-1Po`(AIrwNc1Dyd6+b>c$;rTq#%4zrrNXq&|FM5>( zS~$DW9g1VI3=nm7i=%0N1gEabwF2s&as~tD5W8Lign8?u^P&W1Z_Y+arDyb5tB2Ca zY88CzmqDm8>CvL^jsH_p^G}KmbYYPMV6b3p(IRtTYXcSw#s`O&N;SJ<-|hgigOz4B zO=fnB0NEY1o@FJ#0g)_=e)yNmTfEpI&nr=zhe(F^D0--=8B`@AQh%hnL7r{7SG>bE zuTUp94kw)T`XCTvJH7~F?ed$jzQeQkyz8d1qc%h^fM5zEt8hV{`AegLphm>hnt~MRSFLqO!M8=TG1g&n zt5LSeGZxFz+mG@)6hOWTssx^N4s*Sz)6=@WgI}9Vahj!;0_=kYnePZH4T_SlEW;W+ zlLnWyb>$_O+nGh)_Arlhxq6GY+o#7~(EqwjpAml*!@8rA&9Ss_nrEXCtvTOncN9-p z(4FcXqXMQ?3<+7Y&+?CP=AB|;8K?VngYXwjWg97D>Rs}+Sht&y4?gw<);F(;a+$*MP! zD~;WR3ZAXu{&-OTeq5)o;=t6(4R4WqNt|Ln4D}bTx)e`h9u|qdpJ<&VY0@BAF68g* z9xe3j^y-k#%SU^LYgU;>AMSpZEWgi+zim%=A(or^)s@|<1B7hI0aB>$dkZ;wLWtix z{bsF4SvSNQ!?)Tzu6rpcUk>C3z_va_SvfoSz`&efJ$3(r zX`oDmfY0^qo==>7Kpjyk@0y9>oo6iQj*xe8mJlaH^EHip658PSMpL1%O>&!G*!j>W;2&!MqntSzswfUc)o7x()&BB*ZW3C_AA3tO|A1Fz%=Au9z8d&33z{^N@kE!`*$(h@C z-m!6dTbiJ;6CP1{?6_s8qfuve(&C0lGZxH0wP_FGowU%SIkf08I=!LKRW?+#q*>4k z^ifBh($=batZYf~=rFKhok4Ias*(k_U%6Iv|rxnyyIT#v=@$7G!hQQWQx=oCWq> z*o}rq9Moi^ZVft{=;?|OWJa;c1mA?42T|;LZ3a&3Y`TMFd0LhqSGa9s#NJhd49q&F z##})_Q|&N0a2atz8p^N&$;V0~RyOs7CfQM`FL-QWmA1DS)Z^X9*>+PKAS8YB{5H4z z`@oW=GghTYx|()GL7d?giuLHh?tzcYCXlD(qm|}6vj(`LUJW>uTbbSr@w&41rR+dj zx+v$gV`ok~7Iz*@d@^r)^#45Xqif z6A!*>rSj}dl;6Df)2H&Y<`VBkHi=zvf&v&e9Zs$?H+@|~woA>JiZQ6zDa(W<@M|Y7 z{~#@qAti3IYx~|V;fk1#pKO;+zI++}wny`2xPF9Pf~!i^=0IuF*B68AW@8oI#kJ~O z-8w42wN_iKG}+?BUW~jpk(P-BT_dl(nO%%Gn&kzX;22^MAZ44chHL!-8n!n9{1VokWe*~( z#OWF|lQ(4iXYw;?J=*TAH&sEws2v#A&~xliLKhk^Z% zss!}WzZmpyv#=xd%pWyBGnToAzpL_5z%^CiV$8dWDh?fqXKLvB@n|WRc~1WX`xRcv zoe}=hMS2qxyGY2p9`U|m{`#q8Xh7IYKlt+AV4+?61z!e#`v{$!%9Q9s(ab^1BN*c4 z1+>UL3NY~Mpge6FE9@$Ro}EMj<4@ScbS#1Tw~D31(N{qeB2uJ(@PG2T=pP^3^50L* z@8W-VW1&p_pWTgdYhU!Hn<#IKkf3AMlj@Mk#6+`@eIQ+Wc%Wsvl4gPAb&G|jbi5lUruXh$coH<74>4?UH2pTv zQ;4@D&UT_<`9s2Szxt+AtZz;mf##QWE_v9AVbBTZss{RNfxq|6?jX!e0hpHqQY>AM zmkmC~fXoI<2MOR3bceO0CF}+o$wqy^w2>af|3{wbBa;)8MoKBRx0Jc* z<<@-k_Ntllam6bFqZ@9^n1R+aiEgL~-=!Alx=hL=hnFGh*O=u*5jkZkux6AiJNxlQ zPq|H6kwR_f?Kf{cCMPR!y>_y|-Nn^W3S0-(JbREC4J+l8Cw;4Iolq{_ZZkTKK7lBv z#0Eg-xJGTv7{Omz!jWFviM^fLs<>hiJaDt~80y5iF!RL; zojmrw(rU$op_=W~n--^vC%*a~8CNR~CEYF9!_Z~o;2jKwdQloK5fC{)xy#qTs7lqq zZmifHo0^e1hrlRuvX8%>3BSm)Z0+;rJExGf&P&-*OJ5h0pl)t7gB5gYaT69@Fn0SO zXqfs7npKiTUSR}Vy^Jro(B3j_lXGG^z4`d{7eSrq=}z6Ya#uG9_*5J9Li`Zl)aTy!5N^7+eo%HE4r8??a zlSbz$jr!y%gLLsYzmEI0=C%q!4&Oqxr5&a&f~-RksN4)ahUkL9<+QDx>#PIaMMHX3 zeOa6LRUE-h#iGX1)dGCksCM zvOUL}rowl((<_Qyes5z*2A#LS@xr>Ceb0zl#Ir*OauhSH%;pfa8SY-%MK91DqSbCG zK4|K-)NVbYx)A<0yxZzk8ZJIB4X7mtL=Mm~iLOAX9)J`Plq3X~sGrg=fIYb}baSseyG!EA3GIM*IkpBvaAQmcHK>oA4m>!&-S(kprEqgFo#6NKP%%>dR|#nQM%$`R><2Gx);}RFc3(Cu*~y5+8g*E z1lJBtT-6!VF*lAqmY^4;m-JDDu%KbY=t4W0i{^5U!rki|OIA~wd}-_SHtyh~(ub#$ zi>_|Q#F31azckUJI4FMWy*JR%^GxrSem{2^?J+#KY^1(s#9+8mIk=1#$}~Xf!of*m zv}?VeAbv3s%0XDwC&w2++E>nD^E~45N?-E$E>)^FbAv?Pe5df4SCrCjb6mAC-#o2#vd13Li3tGl-LK+5~{_ zy#SG6kAz03cs)dtgHthiSKPu*_XGaTGxr_N>}ZeJIPHb;ZI9m3ASoZ2mAadS4!>X4 zHM(|EuisN6?W3TM+4$?64plB)uPp}NfWsvh36Y%8R6y8X)YS1iQIwm2?y7+3z2uU& zS>QRg7J`ws24L=h<*mQ%$WXBOX)@$jFKU`dI)a`N22HGQK^)!R^jkCZLOQmg)uLu= z#P@FNvvEHHyr10GKj_%=G~D%2(YzaBhnQ%X5di}=B zj%Tkkk6}o%VHA-!EreRPi4&dJhr_zxP~&i9mlbECz7-zyG4} zqDJDsFlI)+Kb~w!V=>+0PxV(f83R;)zUkPp#PnzS+JLwD4AL6~WcxmbpK~bF1|+l& z$&?0SqGh@()x?9}ox@Lu_PXlK?f~6&N9&+(+;T9^BTHpbzjE^&Vo|3t#>a6MRwDp3 z^7RDBv8u_qZ@CXlgAS6SK}E-pkIw9Phn*Zjg)1j8L!N8jV_vA|xLLTKuDwAU=IAG4 znc{d)O1*^5s%_5hdwpM7DXx3-ne%v5lMWbxW7`%y5Ip7$9-+I#RUGsc?}9*Hv2}9b z$`de8W#_xp3=`|Kd7W8Nm!iDl${}-Mt2ENZ7h&C=qPMv3zrY$8y^QI&drgFwcgUzY zOj9jm{Nc+MgIPY%%{Q(HP!|bUaUsB51}9xi0btDt^AD#`YxY^NYNDQ?f z@VzbH0tQwb(1pf1gd0SaNf@jQ*?sJ*v%$nErG0)D-v&*JSpg} zSP_SWG=FLs-$rnJfAGT>>1IRr-`+3y51Rh_V0oIUA?292scB@BOiQ)ljQ^XG4~G8I z6YMRd!;AO<)C?eI>M#Yi?=p2fPI@t*-d`Elp;+j6@@LYR?@s8p(K(S*lfWkvFj>%h3)&5or zT4Cj8L&)^?ljE+kqhX(CO2-UGSGcA zZzzQkuAcSKK1tnLOgl_|_gl;ZXUZt{@UHp}%H$cn6EC#kWt0?}^MMvTTlcq)W17z7 zB!4EdPDae5|H6uuSB$b}%56ceQuzJ%{JgQzE!-K5Q55ThjE9g1?<~g}IzVXwZt&J< zTXTHi#n-PiWg3c3|DKTVA)vcoLCq&5WAIv%JnYm1K;@qHH{mO9VCvgygz^%kNJ)1c_Az$O5|voz_Eh!6)wywDDje> z`3a1|+q#x0tyTe3&6iNIEHVndni4N{!bEeonDOm5>h$FnJ*lnPotx6#24h)0kUP8rB~W&0a~)*rS*;*aSS3x`)U4FBE07sSoF z-V_gDjdBCywi50bMrs>{)=s&i*ua3mYHjR|YWk zZ}hA`xL4f^>-O{ST7oBj1Vq0vw9eFFP=nwdpf_t8R8`57VN0UH@(dRby`s>}q`uzt5O5<#DE+oDPI@Yw@ zKUrK1fl!Kn`Qm-z5$;Wr$78bU%gs7>o^`G3t{uMx`f@Ni%YfguJPE?f9T3eA-C+c> zR-TEVxPSnu(N564lR-@0L!KB!+bR%K_hKieurB|jkKw+R$ua1-mb52}34EJpwLm;SqZ|Em%~sI>Nt zSeG@Xc3*1$m&+375-DjoTyTeRFoI5KRqm?grrUtIU}I@_I-6NyXkK5WHdmaJTkjon z+KW&7!`2)P{{ve?jJ;1qJ{H?_b2v8a7W{DX8LDvd{Wq28xBAwY)sHZB=&)&zj59J% zYL$^AxggE?+G(B?-v<5IqLK%3U%pIUvl8yB;-`~fzW9#5CRZ@C4RqJp^*DV=9>7aE zz)D6%9hK|MggM^wTYD!2Rjlvv7_V?P`P%bjHpp_~+5KEsOoNExwIB}%^MmOIV!yx8~>-W}<(3E{dLm2P_dXqX+h0jkT*u3O$++b*M2!gh7TiQZtrG!!aam{Sg0mK7>WH zUh{?Av<4f|jWM@gT5uoJ(n53*WR!l~2!JUjrfPr22|3cWS1`0`$EE=2dCrin@Qg^9}RO9XV|Jw+Jx3)Ut$JiG3=#IcDKH+VZSF!`(&%w@L-3pZ4Ec#!Lnx8 z-4q>CpgX_g*+NWpf}CD&u!)PJ<+O2_qfYl>c!&eNw$Ix;1zPKfFTnCBx;R)gH}NK1 zzba^+dF4ia4EI&kmY|{X;1cF=ae`Jso-0Z@U@ZDD&2G%)Y=+aMJhi(v@yfYvS08Dl z=sx`lrhfLE$x~8)qhMX&CYJ|`k+MW6YG+fcthZ!4?T?cZlH8J0`|_w`(;Cs8L0KpD zK%)a4pYOB-Q9lkUUZ)*A?LT?Qy&e<)DMvD+bsb2lYIY_yM>-tv-PW)S0;&)U;)1 z@!N&oJ zC5k$FNSLGwh#7Zj4zQE@&v&CcSjR>gs@I7TqJ`2X?P>x;Ib?0g zHox|?KGXg$xGYU4=<}?;3*Ikt<*VFKziaV`0`2l^twlCf3;w>x$9JsXR&rHU)|rS+Ez-&iBk zZwU_WDw9*YJ{8BK=Io!l0(YHcQmChWM7kye!ey=N#_|-^Yk!u=vdhk| zv!J7HT+eQZ2L#|TIPC7i58uv2a(z+4 z{gxjLKGA=S3LYudJ-=6aX)) z^7gF3tccBcS=rma_i&QD_!3Mfj6ii(1Wiik>y zic|$e1f)g5f)MGQNN>_h1f};*s0k_FiQAI6&))l-`#twQ z-|zXIKh8dem9o}cbG`GO;~no9(pv`;pmrnK<_ru`iPq)QrS>Uv@}643@bFYXw^G;UlDsaVi)_}~5-qN}IHW+?Rg1C%-DgL6vc-i1Ol9ST<130oJ-98D z-LF3r??`x`V%qIkJxpxwDiD!k@11-?etVjoWFbhzYl6+IfIA0;8{{#^@QR@w%2F*O*evM^xzBF>1iDq_>AARz##h)Hdl%b?_2(~N zk8`Ghq0d=(#jf{H-wYd%PgbC9hBkrwcn~o?4~nM@ci!!^-e2*+>hfn0ULMf{a)Mvb z;D*GtrAP_%VA-pB^H^?0SIK^bi5M@)%rlK$(FQYT9l{G5eui0ERwG0){M;EnTdqom zfw={eLwE^pu3C*u6(lySQlOVyb=d35+JNYlc#aTK!Clj!Wh|30y5rD)7^ZDo=|o%2 z7=pZbwf*NJL%4tOQeA-3Q*h?bu`9H(E;ng|Z!MD9^WduQ`FAQcKs9EF@(6}Yd$lRo z8#Eu%uNe^CEx*N_%+MMM)+D-dNmeT{WNqMI(qrSqi1ag&@$>xiOymF4(WBGK%w3*U z+hQ+VGWvGp`i&qq>$BAHuQ0oD1N3iqd`vii;TAM9oF=%&pWTt#A`FM(O~Kea+zJ;` z%Y_EhBQd(LEIKJp4MWYI?ltzVY8yZ0$?aFp0?*(y+;J47>Gb3E&ugr0D?=|pD#PtN zp1c1^+(^gNtIT!P`!X{`a@v=rh~!dWNuc)~KmJx_`R(5<5yanm+#I^y5dFhtm_?p! z0MQ-6zxhMb2YHMFAUis}&agj8-kyQ0Xz%X-p!RU#5LCYdG0$w^Vmw}reQkcAhYMwP zpXKRU*{e5<9IXv6j(@l^h+TgVH7B#48a#B=Suy9LXk^CN@ zXC4$V!Qw0_11<_=A*d2j-n~%W4BD}?vBvxY4L0#M1(KI>o*iRD+?4oe%Ve@&V>*A1 zl(sLe8Z*mI$f8<<-C4sO;SeHVzk>(iKKNx{4DH!YvU48aA(;#E(V7cc-TPhDBTSBz z6SVgGbcP5yI%qsObcy=v5K=w7X3@gdqECc!G^Xmx*$J`x$e3W#-I&TWXWR5( zd-guXPo~LiJ>tv02KsCg`yQc5e(c-c9=NbJFl(yX4sWGXG^7%KwBzX5S)d*VC1S>K zZSQ~%xReGhmwmt->#0=sVR9DIC9l7*R0R($4!df00Y`cYDoyk?7W{5Borp`T?_mey zw!^^@-fYCx;p{K&B|mOSyO2|+{3I}F!UhlrZ{6S;g18U^dWdN(%`;;(nEsWgF!A=p zN^$m^Qr;iXpO$=@i?okbI(N<1P4Xxu6?uu9oTI1egL%Z){7|=czG%4jZmS#OMeGkX zyf~4^50(NR_u_~PAyk0J3WfXQf30Vq5s2cW=&Y1Z^?2`t3CzMMSv9a?Ji5Dp znU+@twz=BKKWcsSC7U$OpVd~kBaG$N*kn_NzV}@7ZyO|Gx_V9SGDFYSFR^Nd66-g? z>p>}7fKvX!ug&;|lD=a3?q|nTN#6;@5DkA+oap6!!H?2|rN0!0x9CIhLDXhS;!8RN zG#Wgsf06)hVA1&gw6A4QyvF~gG5W`N{j2)wpJD+p)MN-3Fvor0KL`IUJ`@~=7G|M* zD8eC*K_qUx32YCLg5(Mkeyyaj(<%8`@#`7ecdm`$c`Imi?0(BrFRq!84|=F4h(~tY zvaF;y4K@r->51bXs-?+K1`A$Sm`1 zH)CQM+GXTToI3dMpjL6)%TzVm`ejqER-wKl?VdZs?+kL{6Pg#KFlCt8R8&}RIvFx=xqZdvvXiq{ecKCwy zq`q=C+ld{pif#7fHc*;OM-h*pehQS02z@8i+JcD&Yy3uZC=DXbz=ZU6fH%#xg$ z(w1V#ciMqNw`n%TXA`WRd>N;ReT9i4X4uE|{odT8$OROL>_z~$E#xZNo>~Gc+fB1~ zS~U|YtK3qzcjlj(isCLqnJ=|Hr6@-3h*caw-O|*2v8ptSnonz-l;wkjO^ zr$SDC+MiPBf|c|*WXreoSSI z?6+F~{~!NOmeeFZo&m@TE=0MFq7rMi^xk4_S9ZaJgBsSJ29;CXCW-a$$)7e#!shF`2MQ zLMr2L?`gHU@X2<=!6k>gpDR;K$yXCCZ1$;gZ??4G<=ya(Dy-O)F1#*1X=^o~Of_4V z{x-rp_5DMKKC8pjpYILdq_jW%z|U#^pK5c*w{c(A4pl!N2;1St^g{GvWE@|D@&Q=< z9g(D|z9RNXv`Q%9z<%nDq>loa`MYl1gj8!W=^eIm&T2 z;|c7KZq|39G%Lt&G^nVWe61mbPjhA58e2t0>7!n{3+Q$%O)6u0c@#~F%zxRp=aH*O zw*yji*_f?pN^?CT!nxMv*hevte$?j`JGD=p7JY|M48?G&I_n)Af3XYf-frYi6|nT7 zDZVOzcm+Vnk1o)-M0I%C3$n4NgT=qh1tLjYprtZC$vxTl{w3DRT*p> zYfcYYA<#JQjRMFkHm)DgACGw5!j%sf*_*jT>Tfmj`@tr%ofe13q{1n+n8n_&u#EA# zLK)@V#4^rOV|kI6aE1K^EQ{xZ&vyyAQHS;R!{x-7DP<0ElU#y91oa^S$=x0)iBky#X!-Kn7f~x9A+@y0<|OZvN3uWZFb5+B+<~^jZ~K%xL`X7HmTYrJov+e4iL_ z_wW~ow$!K62p0}WnaYcO+oObBOVuHLg*|yc_^}U{>&FauYEQ$vLG5Jqm0W*{JJ<*= zA(yY5Ze-DT37-IyX-{kmG{6@J@I5d)3P#b0@wc-_Izkp2iilxj&L;!h=Y+}?7Z1On z$rtFXz$VUvg!=M({G83VMW%NJ-fR)LZ??Af&z z9;L;a2^wmeD6D~>ehHIS>x`p$F{rV^PXW|Cme)`FAecW45qyg$F z^dc$mXO@#9MVGfF5?|dd=}8*9>!?QY(FT<1&45*{{Z2c`OpqodBK!H$x0h}6t_K~% zUtk|#Ua9LmHAhL$Ki_j~qXQ`o&joIbRiS65PqVhdq0#I-#Y@%QHbXgZur>chYyoA^ zvEgOxLIQ0!81tX1A3|noVM1~@4qB^xSYhiml&%!J=p;Ne86=4Hvi`z8;y6?D&yAM? zJ&tA#GC(@=%lCYSUW<4f6-YF>U%5WC+tHjB|y}T zIFjIYBHmOEW*_(9ICiU&ZJ#?j*0) zw_a+dI@-*FYrra5iS^BSnDGGqBW+JWEMaY&70wD*v zq?7L&rs|%y>F)km9}V4nX`wZSFKZkP^ZIng^9ye3xdD;GXxgSiS~#5Oi;kD=$E^L_ ziKNEf^K3{YDZgY^JwsApSpOS5OFIYE z#Fue~6<)Iir)#03bB4!?$}_i#Mk6$8#i;^i8*X|~&gMI~`=`oF!X1-(I|WK77yU(% z?Mq^`Cp+qqjuRC5@x}8OeY>umoc7t7^ph4yR-o9p8o-wnsJ2p57a%S@`-MqugEcW3 zkG7Eyx$Hc5`ew$wuW)I7!k63l#c%||#AkSjFYU4+FGYD-?<`nxg`*mBS_K8jNnnz( zs#>%tEge4I^rmCXgKi==Qhoby%iOrx%LoFH37VGDcuIE^|B|BGM7h9!pFC+-d%32f z-3ZyQCTlotc%7h41G#P~G-9St#P+6TiPDZk$)2;fx`aRB2$_h?C(bI12_yj05>#BAvRAIVr6}1BSp&sP&=OX3*o@09^7OvszOl<#HF#`-6RKuU>y3?=|JcIF2#|r$51W1N4WF(?u%OH*qY+!LYU(~kUTOV8LmvDHh{LxwJih>y35(xY4JN{o>#daC+RPRwE#^WA>$1d6xqdT0>6Ge>4 z7rUWR%A<|C<`O98|ALH<{}&MHtyS><=#mr|$8<$EEwJ&>!UYu&86#gW0#nR9wPz_7 zCH0s*ym%X~#->t&`LT4KJKfQAbb6QoC*AYMkL%E=>$bXa9&E5(0dML6@Q){`jraDy>n2_dKYBx}G?-CL{ z9Bn0);F*L?e40gVQC(XpzR=E?2I6D$pf993=ry~nujAFHwT(v#~|+d(QGL6rCQR!QDCD7M5Cnh<&FA$_6%V85h@Af2R~>{8_jTQqA@p-kw!No`C_okjysY=!C9+(x z42BGHI1SV3RGG6y6-d|M2G^QnRhuyVJNDdohe&&`fGm)<0{+8?xhbYickb{;#FmBn z3$9X|N!HeF-_Hj*eTCf#=;W1<;4<1rt$_wl@>Kw{ZX&+4dus+_qsqJ0nJH#IDvP($ z$cYbqtzmpGTkx$AW2htM^_-7I`{aK7c*I!GbC*-+dDQ%>LYzB+Bhm9&>=Y@Kl=ayD zw9V0mIOT(f6(_qYg1^Fy?J=~i(l}5IlUOROfm7z@Yx(yyNvYyY=fpw934)KL_}DBZ zM~Bv7Nxd2|AaAdf+A?S-B&4bVJ%}R?>hku1=K+R(sCumI^lG)wltePS_gd)#YfZz4& z)-A$;Z$0d4aaYh7nw^9%u@pEWzB4%VsU$XQ#Zx2!cp&RQnX}q}qmEq>95W;m8PMk* zwHCSgx^q_lL;!iI_R+y>!^RUH|n&F9*fKQJ2yI@ONxq-$cI4AD^YkxOAE!!|K zsLF-@GN&S}Ob}dLSoXCWB0FuQK9|Y!I$7)J#Te#tEEvfg=76c8dRx0lt>JH0)mPyg z!vIf%VL^z*McYTCPH#qUHV*A(nE+lRE~wUBxmh`@KlCmxlvy+G`3?%Yc=k1?{nMfo zEAV%DN}uR)G#53{nfQdrLo*PVUv7gxJPvVYc}}p}yWcusDO0thLa?hq=x1BkZ!o%> zFkP(=M1f!=<7B6Q>%Usp zbC5h{{0EU6y0;V7E|faJjeVdRxh1R)qx0^TSl_;aD4NjW5-vLX6c>!Z=E3PaO`cTgBBYQgG zWg@*Vw?^*j$kSt<2ncts=*mSio`DKI;{!ApGLE}HD2uv#DJo zJD^Yoan)Bddxz^g$HT@^^ybAd)c11Gx31kE2uB&PH#WyCmH=9KVg^V7f7wC4V*lFs zzoUThpA<>{Q#9z)!4xj8DAfDE^B6Zp#R_R=sRfG=@{FrT-{r2Lw9=K9}wu^NFnmG4Es<1U~r(|AWtFm)G zO91Q)+{hTQWw8?bZcoC6x;WZX|2GJeX%ZW;f|di)$JbSc*nXl<*IL3VzoXYJyO%y1 ze}yT}!rQ%vI@r)`)U9E%7_;1K5Zx~j zi|VKx%^3{v9K*MrgsPAcjw&K( zXbiIr>%AAq9ErM48J|Etlv&jv#T7HQO~ahf_40A)ecDxT-&bn{oJ`{kWr6|pXZ}(b{3nUbt;$5x1n8totZZGNQ-)`fS3M_;)S1k7Prx+FSi6-M61UmalK)1c|y_ ztOree;u2IRu!@hQ7(yADQx$w-pT)cF&2BH?cqKk7*IVq$w;J+!d&N^;m1pGWh?RcI z<>N;v*OL`_=ZZHHi1@00e8$_;<2MJFq0r?Mm#7{)uPWplv6lnpmxW9!^_7%dtZZ*K zy#urnp!86i-Ti zZp!D+vwz7U^r8wGi@|p}IIn-;!JCQ^?Z|V@${KM&lZ6ai{KiX{B<&_V@jm^D+2e** zQfVy9A&1T1z2q(CV+Emw>FLl&VG6yZ5jo_A2C zG89ic7e|L|r)IcZP(AHaYH(g&fq;Fr^wyS5P4$Gx94!~a#rGC6 zU=N4Pn_r?ga!{+FKz#dYZOQ}i)Q$T<%j>X@A#;Os8)9cZo@k2p4yeyVtsPjG(j2FO zynd@|SL43scYS&zvZRBvA(!mcrw!%R#i?-sILtx_j1ZrP(?8-U z5<+pSpx%YXuSW9zYAMY!N*MJz&pc`AKA2tcqL8-W+uNz8qK6l;+e~dA8wVyaDU>Vq z#}I4K8gQkPgiUYMEY9LyV0E$En#DyV$xfFM=1q+%k!!c&&RjaW92-i_v+hPcfg`~7Uwzs@Y%JHX{44CD z*K5GHd_p^Zowj|fgojYkxWf?i>}Vrep(JfOQxKglzSei!#uRG@&a!g&AGLa_P=F}i z8T-OEY4wwAMUE;=m*|pA`i%dVvoFe9b-MmKZsl&;O67{QwK%Zx<)<^CbOnCeUAyV3 z3Hq8qha9+yRb@~9iv3#PkT41vVW`L83l@Oty1jY)s^O;A<7G3oZW2%32Qj2uxGunc z?O^!v)(ewPDj6A04|6@*ttws=RPv`sPb61A)DKhac$0G5cJjnn>tX&e&YSi8Hx~Cg zw&Er&L}}1&fN6m`9JR^`#C73-R z>(_n}kVsnj)dr^q7MFjxh`&EPOaUDoKcun^Tr3R*T920`+&XViY-MwoTt=I2%blhl zYbICk*Yt~@KjXol1XxWZ-2w59%gG@FO0-aRV0YB$B9qP(LZ1u)m}5cZsPoZlYy-O{ z(ck9BLJMfg`Wcp9ZJ5_&r~AI+bOZQm17nfn9o?~aWZ^_qk|Xi+#XyhDjJk3*|47&D zS8V@UTImnDC8_BTaKc<)J(&%dB)>p43fEDX0AWey5YSKQte`#iEZF)-p@kIE&q4{V z_rxvCL-V(7=UacTQ!Jam0yRK?9o!f6phnB~@5a7wFWh)d5a_9g0Wi*v0H+p0Bdyr5 zg%BNEr_s%Wh7O7{Wxh^z?tRkRi@SG9DZCxWBTO>8ZC);>QEK-*J>nKwkztM5L}5Y` zHqGf|I`!YotmnNSy!vIf;mtVaqpN|9U-jSOp^hT)dh+; zRo|UEC1ewlHt6Di5g4!HWsHSC$^dptC64G1WQCf?>0e>EC@1a>c9;7vG|aDTn7Dqr zKO5{^+2sEc4R}Q3*~m@^C_n)S{{2^2Hm>>!mbmbeT28y7yFw5{Oo`6E8K+6#1Qk(S zu-T$2lr?vRasq0MK`PXbUZMQbDAB}z@4ra!_lS9XI3wDW>E)#mSftvjv=W>ZPk+dC zj8Pgm^SeGP6}{g{uZ*HYf2hPK;C~cHwakAxx?Wnl-5naZ+62^@{Vu4QLiFgP<9;EjJyE9Vcw;hi(zQ ztPn%l6j%d%gGzrF6w>#4j2Jlh3V?)-;8brhW4~wmt9~d8F$xRWg!~i?VAKnMjotd* z`S^GB(_iBHKlv^{b{4l4I&=g=$OwbJ*j=A^@n}?Zr}E z;8Y$yM}2S$27u_-SJ>&EirfCC2f7qq3;sk%#bm^ZHw;+HIgt5@d(0=6SQFMAtXp?# zOlb&rv6*(Z1`txbXzoaE%Dox+sN{~g$CEdGuQiE;SNi5Tv9*&d*Y!Atk^D{+<$jW@ zPHw{u6X8z1hBx20KhX$|5`m8rKM|)<79xIYnWR%ui7$B6w(Gu^^sWJE?+f{g@vBEm zRYRrPdvAI`(=i~mQ67~t*lI=-zSZjY>2k|3Zuv!(IEn4RYzd~;6*mAW-w>?TXt(ph zEq2d|Qb4vE@D?}T)%_JF=RG$zDYL7Fx9d<|;Py?PlDD}|@t)n%qHBk8!|Tx!_LYiQbQimJ`$czL20=8gNYu*kK^z%)DGTWEy;6g@9*m2(sU5F z@gY*&lXWs4<3YsUBcQY3c??a{lgCmmtE+z8FjyBe<9neRNwx{??Tv@Z)gR6I)7n?t zz|+z_Xa~BhE38lsS26lu0BR|NJB7H3Asr9{b2$haA)@B*6xO(P&Zv;m=GeDK-EI_} ziVE0^*?N~cO%04DNKo>i4`3Il_55lNf$6zM23HPZW+(K>M5(a$kg4_i8|;Hf1;Hs5 z2MotL97=`c>N6y>4?M9+?BuuK7*Gi-Y|VW4zO#r=XH5O~vs1Z+TPI>Awgp_2-yo{o z@vwm8?03<4XAFC&$9Dx4ZDE0F>#mGOqFK( z%a?a{jV5i9(Ta={?R~r1p<#IVf(t~L{p)Y0DMjg3+hDSH}t^YG_KRq@06Ut$x228QQStp5ro z{42Jt`A5I7gf9IjlBPMr+51n9w)7=wX=Q7%aa=8 zoDJzR_TvrUUGfj1K0SCVE;PyUU^PL-!~Q^>xp5Avjag=Ip^JCYM}q}6D$ncY`GsY+em-a5c6#AG$L0K zaeGM^=aS(;v8$zn;-LZ71zgu3!ld1>hQF<2uu-U&Yi5|r43eb6l{&a~31)ti7mb(% z>YdEAEyzt!!)Z7fAlq`zfFE{d1e*5l25JYK)P(d}#6{f9?C9d?oxPx+yU3?EXa{x` z(R8CWe6zFU>Jk1^qOY_xDtBe~;2SX+*>zF9?-9ekK;bZW4qbvJ9gyJ$8Fx$w&0m=k zydYU(f&ysix|%X z?`%qa&#bozZCOT0@Fis+0Y1A%1n_C#q{>0^*6H$V1QxUq?_Khy$!cvs;o8t zDQ5D?JyHu6Zu-nL+)kXN%8xk#2PgvN4J6wye87l57iS7q1{EIMj%}^U&XRU!tmh8T zXW!g?+N4UnUmQ=&D}7H4B6FVIAWIF2)w`P6w`s`z-c_TNT5GQQFl|j#b+jRfKY>Ge4we)N6Uq5v@~!&3f^*EP#CJ}c0S5|%aO5S!Fn zu*P;r@a;XYEfksJXxvQKgF_RF?;maT#XNIt2~KhWDf$V-6mAZid>JQ3%R-#A9JDmR zRg!oX!<%m(AElf;bWKsJa=0_y=;T!^OkH8`D?|;dMI6-{zRc#l2!dy?2FEyiCHc_uiCzjD z8I*^c{9)8r(>z6(8`(1xmpantxSWgqcAUS~Kg`0re4~Ivw#mjtL9TgW{kt1;hb{Vo z3)Nd2j?!$0;2EaNz7n)T04vL}E_AP7S?Qx7d@zpoL^nZmXijJbeLgtgq>spf0sGS_ z-k0xa-)mitsP*!clIqw$7gwxj96q@u{&95qu)Bqg|2`i{cD85rJBA84)W8C|ow7=8 z2<3&1N!tHx=Y;d&t*ol=kaq)S7N5NC9Ls%@@o<+*fg6*3xttT6vZoNBXs>J#OJxq| zhq!^=@zAHUpk;GuXZPCAd8Yf#)@@%&>$NLSk`IB~m%~9gp?C))BS})KZFN|=H3bqVI^3SxnEhY08?~zJtbtyxX@^jD3 zj_a;7klDP`L2Pz`By3uOAyEGCsXcm`2n-M$S*UhFAUegk6Inlgg?Znin@(U-uETmE zM)=sN1ig7}^FJ03Arp3 z_X={Y{xmVMVbY(Rw?JhDEP$EecER!p_Alt(#dio6BC@2uH*u|n%vNRodtYJcsJwiD z2!KGjq7%s=6Y2g4^DCOBE?_ng9geq$+7Qq3Kwi|>1`#GGXm%s8=3e2n)n$FCo#`(1 zPdG*8DBJS_4 zooF17gUB*`phw)dfKm-Q#D4~OVHpOY*l}6rx@qTAuMcOTw|*+!JXyJ+XaIkKZoD)_ zADyOV^JPM(#}Xr-v>C?TIq39${s@BZ8-oZqFvlXd7aBMBu3kq1aJ|a5g2N0Byf(aY zr88Ks6QYz0uSXCMz(V!whJ3dexli{u;9C20j-eaRaKq`Yxys%Y3<#?8_6f}T&-Ndj z3*HMi3)43rt@Wb&H2o)Wa!;)Dk?hl)NizRpom?ok1JNv8b*`l zbLp4aM1}+;o1wP35Mi z4t`0&CL7oBVLd1suZi{N-AiToa>y#dMYgLGf>r3r>z?G!mvpun z4{cQR7{1yhKZ}fh#C*TN_43)GleM<|yJm*&_tgqnY$)4lQJ56_h3i;H=lQg=`^z`J z?zaRd$h%pAYY?|F$>Lq%E@!p0)W;%|JuUN}`u)V$J$e$YXj`3# z?o}*$3cc6wq(y&5rg!I+Q>XR`a4e>P9{A5BaYFP4T}QL z7rm9K@G#oDa9^Wn;Q@+FVMb0jTTIol%nBZb&>bUJIGQtHE>=HS@`-@t!|+j|gA}`7 z*^O>UTa6#CpFyAEr(CU_4Q7U7UK6XAXQ@{d5>!M=CH3~czC-isC{SB8lN0R|6c}37 zz&zpn1t?68yB-;n;c9c{h*x1Pvx+ZGS!aAp?ZuPf-Pe+9f~$v53WyiG);lIgL@QvG z;?)KnhCUtWsEvrc={kbw;v!NHkZ z8_1AqO1m#~Q@JzBo$Uhl5w>|(TCWX)QO>8`%ygFM6XB63)+U?p^r-afnE9f>F=_NX zxaHM{9Xm#wvT;vzlb$Cfn>AgC(27^SrI-hh5SQ6kGnd!;jL0+Fxi=1~$}83vdfI6|QQa z7$o27;%L=SYcY9Xh8BSp?S0W9?0>{~`0aX3j1z7mrH@baJd$JoTdRS0u^(S`)CNjT zcH6+smyY@Ev-dG!we)@V3_GjZ1_=zF47qRgpysINm|AF2k$lIUekUy1K~JO2wtoAF zE_<1JjJZa-w(f%xUZE{7wgnm_R&w+DEz}Y>^chC=As;Qrc_fSU`ZcSVb)4e3x45}3 zC3zz76kE{!wtT0)=|{)OINzOyNhU{ci5ND7u$J)bKdicr#i!YIGs@96uFT@DMqW_$ z1DhNnwjONsgE|!M@39 z`|cl5j;N{ZONAKhXG_{BF1-gWbh{)e;4~e8v#*4`Ew!F6Q38Tt z#@vG`%`@T^6m>|f6EW;b`4Uc{ z+m1o^`V+BaARcwuqAYH}@!74I%f~M0RW5x)sB3|0?Kd5LO(}<7l_#`_Pf*Q@CGK#( zY(KQ|0x!S8l%*a`r)rW`YYVh;>k-TpV@Crl#j%72FnOWFxc)Zu5)Vn6$qpp*n`C2N zXW>5Yy28A4*~v8F%2RuC$>q`WwNU5|9z<>s7BSI!7wc8&e}g6yl9JbpUEqN75y4hr zbnQY{X#Buv_l@6`K=MY#T@KbZV26ncXv#P3O~E27(J~sY6~OyU0yqu71Vkkq?S{$d zGAL8XG1}-d_@~)L9>q86doczd1;&?WSQy0jtHeQHtG&)b#QNt;y9I%Ltl$#jcaN}F zK+)lE9oAsuW53{YZL#0%y)>PEbCpYHit6bmZ95M7?S`Ov0RwRd(}z+kHenLi{GA?b z9$vC`#4Mtse+yzOokKZ)=W++6f9Z87elrVJMey1qX&eqFWOU8qnjNzP<3mv18UC!l z8&9AKuD$!Ws{#2kgR}7^W;)BbKO`-%2mU?N;RY6?|M9Og(bLlNgGvm68=>L23D$Ba z;nBZzBm6gRCIc;d;^fgx*&W;7R^fGVNePXPxG?aZoJru3W40#p92|x|Ur| zthN^!JhP!HK4giPS!CMYx}EZQ(y9 zd<`sCy%6V!t61YM0=w58*Jcw-gR?@7xk@~ofd`E&vAhAf&ulDraC2Q-x92s%^hvtf zno-2(odbRX4wu*6@RW#jdOUd2kNbv?Sd{?3cT}uW@dE{n{g$;61Z1CT0IDphG)yc0 zic4!wZsv*E{$}fr@>0i?JMT~p3;r9-bcZrtqjKHsz)0p zIZ(0BHzxSOo>d<|TEch(l_rFTM7?G95DusHv2hf2?Vp+N+Bx$C)nS~3y>MK6gVdr# z;}@L|0eWc}*}0Jig6c5WE`5g0@mX~eQ#A8id<6o#E#@4xd)d zcDmou&8o0`?mcWmiV4JPHp<1BCCWY-j^|4AZ84`0On_5#YP@{}_8TVH2&x?Jz3ilwTG%C_>OK&E2;IdI#}sS(3f!Is3If_1~HkkY;`0 zcn_7N{YBkdD2F3`^LvqEOLg|PNS$cNlzth$k9_k-7h zB<&bczx+doR`GPdQ1EfMZ2OruX%AX}oz;8C>siN7SanjG+x6uU_o%ujm=nl&el$-b zs*0u%S8_RX-vx5boQ%?p_I+C{?}$WOuJE~3J^kY8p%nE8eru1g!MDvDC@2%0XAN}o zFG*=t22Nb(^yN!eCUd*-u!Z;IbQdsgr=Lee-3klu@QGB)jkw46&Th4#>xk4UC_-Y< zt?hWNo_tebkwn;yY?p4~!_0!@GmhGW%~ieB;K_xdMk?sf5L;T-zo1RMX0O64?me6P z1oFA#k3b*?aRmOYOiC}^9k~^UMQo0 zAqw=e0#!!<-87)ebFGD23XhF#>%jO=CtK{DyV1VtZ4sYdMhk&LLch-{*@tzS z6RX!zNUfL20@fkx%`9>pwo%cLKpHg+8hBL?y3#?dFsc{Zm`iD-Vr)@-@Cvk5JR*ei zxzEe@s_ZX0Y%ZzI;Jd9UXJ($Mr~5i98rplS(e~pMdiEE-q3RI5e6~OxRgf*Gtv_kK z+5)`tO0E~h;VP63>PkHE_WdQ>B?8s>+B*h}j#fotrdnSDYLmU$yNO`*F-;y|G>fNp zEWN}@(z0Y^4#ub;#?PrehR-s?7s^-aWic+g_LJ8$+>gGfIK-~+e^I*zo(?^^o}bYm z&|tfsbG6k0pDON?pM3@nTRLi|=n^YRhT7oA=0sA?f|cnEd;-5(jQ_TZ?iMYp%)*4> z6z#nOA9@Mab{a#elhqTBo02aW-npZ6UXVuO3ei{al{sXOD4Av}D^0CSTXSa5qk+w@ z=5QAFa6i}6_TKe6xy{aEi?%5nHk3%7FTX8yYgqC{++~H%Q4;>JS;Y03I@PFF>F3Lx z*vk2nP97;QMvByZgS1sf(#;UL!<~*0#SPc96%x}=1|t9mN&wT|1Fk@bs5t}frE80S zBznKa#2YQK<;adXUiR!fN23^(IATo+%7s~c2nO!?AyB67Lv{F?T;@zrxEVyQe;&E> zih19Gdi$!lS&N;z$epxkU`fqG@?NMW&0Qpfc$Vk9Ra&fn)}6vq6BMgJ79YdVd&}B2 z5M##dKj9u=UW4I+k~NZB))$fDGT*{qIWw3V0wZ3} zCQH%WD3G&DBRbz9*4eM50e6b{msantpBh*Iwz6kS9T{Lkz<6U=;f2Mng5)K-Ire-( zl_WG6j_HC=wAmJ}MSIfrx~R8~AZOY9$GYo)+`>E(m>x?@Z<)3{iw3ULSaJqHJsGmT z%TY7C`*2T0FVR;WXBB~+H3WJbLMCydpEylxg0P{qNn3^+YkoK4pRoXXb)i_Gfa%by zMgqOGm0-WbRZaX{!M`n}c&RaF>gpo`$bg35fJ*)u{^W1+L;BW#>skEotv5&pFJgZN z7wjAan|*)rs8iDOf)A*2fhB9351V>PK&s$w2Sv8a(SSSQC^6 z<>?bK9*z<1P&swS*gt34+GmS@Z(zv(H{6Z?jbwN12UErxNB;o+$<_GJf|V=A*c(4Z zb@e{TT-i4Hlf|1EoQeUe^G^>A0EQTgy6+%{!H#%?3Ny?~9%GMH8(x=`Baq5`tZMYQCwvcuUw(Rk9 z+sNwj({_+602FI0g&-n3P@p$V74Igk0lWr)Z!~zo?0~ZI&o^>@`?sNOGDUla-Ai0& zUgd&@sMAf(PB|1>x;I_>QZ;^m*QwR4R~~GPkzaSJh?8j~UCT8%(adow6!mT-#0vSz zuspTMiuql@1t2^m@OJlozIEbj3$P=R+Zw0x&-Ps)+SDE^ld77!!>hoF3O5OTzg7w6 z7zi5z(oEQs4X%fX?4`OHg&syO;+d}Jd`%)mYHPv;j>%n<0atq5eKw;KUaiXirw#}=>8;;oE7%`H zC?sjU{aspiCv+Se=E|R2(Lc?mIBcFAcT-8)EmU|E-GQf_GX$iNWC>)g@Zj`XsIvDp zgr{O@zPc2D-_A54wJO)!w7{8IiAutTI^?plOf>a5VnDKQdI@094FlNN<1c(>>sc4< z_1WK+3l9P72byf*VZbGVQ#UCuyO(qmXwVnLNA2dBH-)U;b0M^6_4f&EpT$dY2cXdd zBgat(CtleP!?xY_;U##_Abkqg`IkTEIJbl}c3N-Pt5-88ah<)U)hcXRKEI;WXc)&r z#sUfp9v60f_hqltuw2uM8JdZfpXm7U&gcwUvNO|+Hck1C?^hTazO1&-g^PSfzP@Ky zO_OByRqBxvQrN8M+6rDI-|eW!=MJ#DMKQnJKOqe8l~QA(olQ@FP1xFH3K>drMTcCc zM(RC|D+mwy5igD(WDQ7)*8@}@VVrstENVuKX#=431l`+f~Nzwi8XqlGqegd(aj1af!E2h;HR)Y zxW7FdS_yb>Cf*Ckw@c1CK&Sf<3(;g5ev;OjdO+o{z*3mq_4%tU@L?7Iel;&=LG<jVbBNj*cnIhEWVJ*lEXn7J&%LUNYD+Qzkz{P0tYGmtX+`Qk;8L^4# zWF8*F7{R1I+}is0y&L`6{j1+Uvs#~FR=}4rbYTIZBDEdcPK^4nn7i;miLc;`b?Tv4 zEKCL=Ph-5)$FIsg+Z;XX$R1-#0cb{`#2W~|!Jo>v&pCiVE(?<{i2r1ETC*a2S>sEE zbspcn3pv_%uEMO|&mNN=AD=tad+9?{shU<_dO>n07X6X(S)e!MNo@+ZkV^d9Dnop# zN(UK@y5pkVD@s~t4qs9AdtOkU&+!4{+E2r=%<3-b+_C62Q8K=Ab7XxvYUKUaB|X#9 z;myvYqGHdR!PrS7$W4F@^qteIS@JmAOZVQRp3qP=*9~_K%D&e)GD;Bl6ieW6SJ5h7 zFE2ztK$6EpX*Xi%`WV$gNcmT&4d)ERbP7XIE%ZK+991W_fF!Y~t#Gsu1N?|C^Nbbs zfYZ>DdGdY0(jw62v?Xam^O8oiDJmo@|BJ9)pU);W)X~Htz+GzLQB|SQQLy&c!8+&l zv;;gYCu{CaUYpFDV&S*T{NW$DE*$;98+5hf-OLv=jr7C3Y_Fc)rJPHj`cV~yTYz8W z!7brQNkVg3c>~qO>Z;8XjkwpFZ3bYo8%*btyod6|X>ZU=UbD&1v74q>oBH}DRMHph zfMxGGFoupVX+`4z^h9Bk{4|0JK&!+9c!D;t6q=^^Y^Gk$tsZ`zU!Xc+TB6nyA0`b9xW*HTsAh>Z@sOqPia+`<9b{7Cy&nGTXg>5ZvQY9{{P1FdpFa4tsN{6 z*hDy)=B6L+BV-l3u(Eq4_43Jru(BolQj)T_HIrJcSCJW3wQY?IpZ&ff53 zkYjCpO{@~H@6!_pQVMe|J834GAI8I3r$)ZDPq9nRuuz~p8uoE z>UY<37xCZSJ=|o%elYhnMdXf%Y;?2JpuCm#;;Vn@dgxp;_Kc*aI=c5`{F|_>!b=-B5r0Iq>E@S2Kw4O!C1N@>Tsva zKtS=(Jw_X+LLaFmBG^0pWwK-d&3uL`2{^7$kwWnc`%k2r-%SB+0;*|-&P%lrMzmgiUTXK zBULI@@%YKH_T7UMZg-vAn8#=|B!OOJ(^%dJdIL8@tCB?cse`JOQoEw1H6>4ks z#reaJp#@SwR++EaXh%xFX(5QIS6Y5BypEbUv0y1uU(^qDeW$wsJyT$IVlQ6RNpiKa zb9wzr(aUoNoC-gE)7>1UFdJVuMTKc#Omwhh#y zef0Y+h?g{QQ~qZ})A`!@{z-5lD$EnbdT0LCg505)P44SQB<)%yTKo8aOX*+r`%6R@ z(mGM};yv5Z+OgDR%xW?U4^^t`pD{dW=IRAYFkM8JPTr7K9m*3-HqX>99n&yt+r45K zJ^Q6N=}Np$G_vpW*;43nt83fPhsr~Ohc>4n2I;`(chbK#0%168x@dxOSZ@Ogu%bd; zhJ>exvNmJX_FM@4QhtCi-SlBA{gN%c6tkJretm7&Kvp{D(*0YBHvUI?3xG6*8@P8} zXb$xmO3pR!EYi@_2C=`xR%zv4Nu;#DXipG@w0Gq2TU&j6z@_Dg0KcfbnT1+q>)K2T za{kTthPY57SJv~UdIPJXzXlvmdPC`>Yxdp7#x5UKlf?6`=?H&O2>D<=YMs_FwH+rc zpf;P?E`DxXNd9d&%2sW!tRkQ>y9el;SC!NBlLRX9H!Q_8X^%$k?R~$9l|0o$) z|AP@U;a{{=z@QAZcB%_OmMYaRHl--&Ws6{(R!Wc*{$Ls-b(76!x2YCPR zLo#qDoMOVQa09(d!Aj5i@8*9wnQwaf5VkFL#zZRSLQS-ZPeEg21L83Ja#xOKf9AsR zcFv5Cxiz|5;r{b@Igt};hi(wtPd$_qKJb9G9P4^_$>SDtXW)gm4N#7aK{&h9)X^Ud zYzx=~vjRt@ot*fve9MuWu@$bmPHyTkhG*Nu2F-iV8Z6xw_Kbm_A__w%78RGXhxD_o zT76xM{kC%LPLAGi7pB~@PIENh9DWbUF}fspjHo|#i|o%@FX8Y;{dKmp0n_%7SZRoP~0FSon64_j-U~ zYZy>$kY>1O9KHsX;z70=3KZ(B=V*5EZ< zg;39TsAEp@E}HUtMx=wv%&3pv>VAbY4HPb)#EHILG^U7qVID?IzIqUU^$_d6hHDLm zj2Mm`lk9g6z4+dpFsg4sIm)o&8NI*Qz)7eja=z&bBPERXV&JCyQ1b@kx)4TxlEoxm zut;VGyNaQJUidU8e71ERd_kAT()VL2DFGH`KstXlU*LruWxy0#7A;a#bm+0s9mD(% zL~g(%n+WRS$L&;_ho&_xbJmG2@Vs~kIsFP)1=9@!B-IzF1F`aSttuB<<~10KmA&^t zBmL>YfHT0PN6-vliL`@Xq2E=j&6=Zkz1K8HRYe(eUyK# ziSP4)C%))=94BTI64yEGRU}-82A|&mcFE<=0Ylo6GJoXe2}Btp?s1auNM48D_p$Q` z)rA!Mhb6CTg_`!iTRy$Dm0gSRq_&Z|;R_VuxWB^R?W&)9?cb{H{|omv0`#5d-|9IN z;_iq1L2gbn6`pCkbgOAc@5JX{|2&_n4@4FfC4GX?Scy;36 zSrz<);o7&Duzt#83@ZP}3tD z;an5r*4vYc0X}iFgvv-s(uY#G0ey(42nk%A&U8ZJVs9#Ubrq?!AF_8+o>%+}t;Q?G z=y!v$e|fQ)Jap%mIn8q}jjk2k5Vs-n@vD6KcHbFa>-`*$+%|Gm0lL12>2+u6`W3)$ z=r6T;N%SvrG`BBfv7%Z0u=u(ypm3Zw9l-MBUAxyW{F>#pDL7i8JT~&LaQG9$pY=BphX0oUZ}52h2#~7S)Y1lkW>LGR z{O@Q)v;SaVwwKKUe8Qh0=oDgFBMULYM%$m)FAUHemQCn`0Q1u83HV3f|F5aR`{UF7 z%AAuuKdZ)Dnx_~u8?)=%ZlKSTkEZ9Lh~YKcu}%THmIWbZeGeJmhSRhUwzuW@CHJT_ z=+!()eIhw96Qz>s^T4`kGMlmMc{Rs+$&(6Ti#?KEQw%x=cl)%6aUj;|XWSXPvyf00 z!5b7;_M#|JwD8$2QG2b>djWX`lCO9*^1chO)%pl`X`X>!AqIrK+?Xya7?7SadhxXA znz@zGzVB=r6|vDYiq&==ww}&{tly^dB&|$Jr##Kf%aHzhcC*aqeUdUq)FeZg!wF9j zrOp$Rj7{snQGkxmm+3}TohMj`_BvV&5C{f>vQ-=A3tJzgox;u|dd@zG(RwawaaT~! zA_Ui2$9BR&laq+EH{&f%28W4WTR}hC2(Prj>lR;-?`aws_;M4D)2bERI9H&##q!zc z$ysuVFW~5FMx3Y(5atRLwub5raO(y;skRzcMB!kW0DfL7zhR3sm&mL2{^DUpptJP? z4#`=>a8!xdkae}f>r`pQsjP2$$F2D!HNXa6OTt@=9R?&rF6uwV%2C|3MlU8_XJijd z{jMobHk>55eIP@%N$Mt!av`NmIn5V%v`>E7 z5O*_Qv+^S*MBR<`t<$kSN!3o6e@0G^LS@~FAVU?~S51i?QP*0JH*w$}Zj})M7^xlt zj}FrV+~5>fZ*+MW2odxPV|K>%RaR#)-%65v4_36qirfmn*3Wy@cxmPKqkX8+bBu3* zi_lv$vB4g`39QhR`Y~MO{&7ORkbmEmv-$Gj@7B;PqJ}qdmC#+yT6zEg=BeVMN_N5- z?d-9dl4JY%VJ1)G1P)x7^mxyjK|jgH0zCFwU@Y|<@V2!8q5U}W=i3hN9z#-hfl?37 zMoe8&1h21xXhd+$6IsO61!5=dfG`bFgCyTjmjl|sK$i_l%timNzz@c$)!_uQUQg>Q zwaXc!QaEMjD^8sUUf08$r%LIqIB5yO8Ti}JAhGu&rxiAJm!U2!76~C&*|Q!V{pj^X z^%VMb;dzBh^&bqqb8;!E6f3(&yv#EKwbOQ;T+lw^>e(1wOiG~VcppY)MP`pM3V-;(j{4CI`C2eZsLb zTt3;vh!j=o_P{}^+aSPG=z`90Q&0W@evi|=bgnE+-LYdtfU|je_X?!0N4(UY|K02=baenq(|KV@3uNde@ zC=wMbVqwAOJ#aaNR3z-l?BYyc_HC!uafnL=X7V(8Ai%DEoX<)1wk5gQsR}dB+=DqQ zx!GCol%gD%gjuf3+`y-+ZIW{#xdgtz0E`w;_sDO&Uy^7iQddcEBfR_d0 zvLSC4u12nGC-rrc17ypy^BZft>%P%9k&orv=7}dkkyV@dpcz@X=2TbnrdHYJjcoN4NSmxp_NEffTr_ruj@-~h&^{?RsokbG>>jM?%ugFDQGcjRm>Jm(E-PIh zEBPwot2xEcI28fun<(O~XW2R6+VHvH@ZjV=;DZe_(Xo;qrFP`4x2l9G19axJ$#HiW8(-jS9bi$5 zxxglz>(6guOlIM&s>oi`Z6gO<&cVyYSh!O3sCqi@g=DnCHITYT6lIT2IBbcjx-v(q_paeRDda_zZGkcrdueYDPW^`qN&R0_6x z*91I){f5!2k81fSRC*|jBw(ckc@Uj{&bVn}S{v(C+iMyO%>W3O=uBDp&cuhczV@JeNqV?JR(5qj_2 z&udKs@+XE3&Hkq0p`jlPwl*Lk^%cV@U|TP!gG0}{Jhho?542TR!*#OivQ)3)dF~Px zHUL4+zeAx!M(|Qjgxte%xplTFVA&P|_g=0?O5dX%AxLSyb%D}Ad3JVPZ;QM|d#dA0 zm<^BwTDg0oXiHQ>EB$0<$Pm79kc-Sr-PnOZxA;IndzelJOuzHB`gVV@CBl`BNl)o4D)QpG2TZ+z_N&)DghDa*sQw)u!|{5<#xB1rB3 zG$r8LbGapBh}dwra!J_`jQkS=jK0CZ6l5K1O&r|VBfx<0P2un z#QYf$fezT821*hMFvX}6QkVSk_GFR0deWkY2WC>*XhoDytCzsNZsi`~Us z2ZSH~%8Uvs^la}d7#!sU<|~ce!1LO>Y>gOvgqRx#?@l6u81DV@c!Ks^Ways5KD8&C zxWy#1C-b;1N5<4@UPX~E02>SNGmeIznk40xzKem(~~x0nx4b^_0^y4*c9 zg8PhR6aol98Nb}8A12HBgs1%XJK@2S2nIT7D=*{kbN_#r^X*+=32EV<>7Vy4C)FuC z_7RKVk(7(;_*|Dgx1*U;+VNVn%}ZPJbtk9f_bk-)om)&Z|C2}dAO9!#-@2nVq5qnc z{zFD|$MK)jIEDZP1AxGC<>n^<{_x>Lq_c0Dk*>Ou?2QlIAAf79J@fSR@ym;>Fz@QJ zhi^aKF%*0YAs*@$xp-ys$VOBtl!|6p1N1so4?^3 zq$gn<=4TCh3+Cb1B^jDVmq}X0k}nJBO@*!kRN1gs+U}><1@ffhzAzm5(iHxEgBxu` zOP>;wVQU&P(DsD2O<}jc9a8#nL%$uD3Sq)qgx{t5pD%rN4>bUHaJlCY(q_@Rqg^+{tJ6eoq zWY^oVrYa>R!6jW7hO-Mqr)650{5DlIH;dhkl_lcN8JQV}K_$!Ijmlj2O+TVspTrcXWhyPvmu5?W-oRQc`YyJQEv!|-dEgeV|#z%IATL$a!OqE zjj>DTnjIRap(K^S&VdCapRb2j)4QfoVZ6cqB8Y)f^b3m(O#u=;&R10pn8n$hk$u;` zCO*f%|FS|ZY1l!G9QRp1LqwMuY+en{RG1wXyjMm~s(p6;E|^Tt$=@866Yq;!L}Yfq z{ta&;oT1~d>JUHz-;?Yt`h(g82@nZA!6Tw(>KAL{pQX=_U28t+Y$kjl;9wvjDsT}T zz4n?ivL?LoNFx@Ly<^&wOSVDu9C*`l+#K1&F7?KIQ$>@a(`TQ~&bUUrYQ$4#vrTCp6b{ zCW!kwCmcGGJ@*{@JS^pbKUWjQw~SNr7YV_=$ds(p1On5_&aAWr-__rQ14_9!HH#{( z(n%c{hK)fUW)hBBq>)-UF85HRFfMc$_c9y7}DY3OY+Rq=|c=u($5EUIMwgTd6B4kXj_3YD1-O6)f~oes;&QnY4X;!;Tlg|`+`?w!T@Ot`Jw9N(;) zFHGJM%yBb%99f29yF+%l^MxFh)HYz8hw^sv=_!uy}BALbb*s-UpmJ= z@p_AbV+NiKbfpu3D$qgftOL8+Wd%}jat7_|luo`Ev!ZHT;fvji@`(2Oz`LjiVXrCo zR@*yen&}~{??5PhcKwu(C1M33c?Ej(vof_KM>R|u%5+eluYMgdREFrx*D^o9LzgKk z$tOK5df>nvcGQ354gPYsk?3#l{TU?Pr(CG71&-3InAa4Bs18Ja8?L^FN&&#Iw$#&I zHsZaJpt&3pQ*}t6}m<2nkdz4brL^vbJm$ZRXk+e6sCL6oZ&l-k~jpH3eA8 z)EwIbacO>GvZy=jiH6?K`KR!b-}+Fa)5E(p7)#NlI~yA7DqidPG48_{5yGa)o&Kdl z5YwpRPVA}TqF8tY6OQ@j^V|TXeFob1`22ult0rUgI?R>ptxH5OYbJBxmmogp3Spc+ zigj?oLCxaPgh zixI^E)YIy>!~n$tss{Jk?awGV?=fv9MB|lU+9yw*8Iodj+C_A%)4i{m3@g$q!{6WG z8LocAW6|6@b~R;%s;eO$G>b7_<(0c;rf_F&{5!YdDtU5ifMNz>DnM(hE!N5hVd%@f zdA_<;Fkl^O2A2kqiv9rT_Z$6T_Sw49S@WoNv2BZqM0JCwky1#e+#E}2z$!Vl72|m2 z&Bh6h`@kOI25E-eEU-2Z;?FKW9uSJoKse;y>j}`+>9%FDd$vzak*U34A2-9C3X@!l zY%j_%K6{#+_c2~Gy#gb=0cAV)?ZGLP1YL4*U>LpQq9pa)kqkAFZwLGY+7Qtk-{B_Y zqJT*!$5y?b+z+L8FWbRMd@K1i2PJgaTHIgxn<1B~GY!EniZoQ-_<9wz1_`*%=f;xf zIM$I3xn8Z-r2g7F%=_y#8_$xE<60|x{O^!gpb%ucsQ$SR62^r&31g*{aidydo!i;- z+%Mg=m1Ly7FyjFqCNo*EO785>w(IjZhhe=6t7Li+6`bsgc&bTjqy@4~7cK0pQeQ|{ zJqI8d=rW)a>Ot3?ev?k`s(Ju(p&T6pz)vm&vke(Csr(0nMz+R-=8d{?Rn_qf*W+@f1x>< z5|_C*6n!xq&96>s=s)kdVa$2U^%_%rS?7h$Kjv^DQ)rCa29#i#gr+_R@^Y^)RE6yo z&2SnlCWG|R&jUP`L7?mE0@dN)!#74?uY&#DtK@8udwwxn#h~`S5WEClS1{d1X2u`$ zFVbMs0|sA6KIG0U5G%1D7=Wv^tfu!MWUT`LUv(cqxl@U*Q+gCi*+2e+p#s%r=7S8M zZYP3kDd@I-B6cSfGY)DgvzUd;BRrJ>5EMdUA^Xn~_#fWRHV?9_4A_u;S7#W_Ke_tR zw<;A2BUXl&Y*o+8MwGwPw+?T2?8IG+oak6uV{*qtQxBA zUUrn)mwIo&%Ygd4bGH?GoND5Qh)g*?>i#t2>Bg=lHmL5->n|I#T4>;0Hk4T$J0J2mjI%O3hVk%xLP)H$Jk)%4V@Wn#+gc|IludP%aAx zHoHA$6kJQ4V!=br_ndF)O>Zbn^y@(MqLtU08^0LcPNI)F*GG%BYmRz!T9g~;X#|~9 z_;wrem%e-aM<91I65Gw7n27wU^}Yy;Cqk%UTcDK^dN3^WUy#Rw|AY*Bf6VU>0MU^> z-<4DMObs*lIBXT}`_cc;3E~?1C5l^JhS+IC3<-f>J~)ciROsU}D1dqBUQHZ2&J(W6 zaidA)RwMeklJpj(|=mXQ{p!uO!u|JTnprHY%S&0cU!ya9oi4QE;U|v4HA{OH_}*ep=GI>Eba9!zQ@!OmIu0*?dSj~w1!qj!_#5W+6C(Hj)t96& zkl?<${eyuBq8&1;ybm}Q9YOs|&y!(w{=Xc^|6oG$D@(_ovHQ;xA*XHuoupT&5Cma- z6>z&mM{k633Rtv2b~DumN0WTcq{~PeUGWi`SKH-&`D*Zgi#<{z*)W84i5z2y^kS- zq`?twSM0TK(d)6X<5i6>97+;;0D%U8mD5Jm%SwHB?<9fy2g8Ewc?b_))aW;JhQoRn z{b|EXgeVe<`uwK~#k|7CAYiUN2|zFLSEtq4fCxZf+ra^HSw~G>I2?Kx16c;7?^-1~ z19hLtd~!@f{9-?L>s{ljG(9JZc7Y#=PhlE z?nzQ-r!Rv^fgy`klHjdj=p^mVI&xV6*uV4MAQmVd^^@>;?OufUa)?pV1{JJ%&33z)|8Z;drhHrJ*Z zZbUn2T;$cJ`=Zj!nGS zWaozv6dyiG3tP;H%xzlq3nNEWj^IS6A>bo{fqrtoyYO{N>G>W-gvHL+ zsTVQ&lgEzD+*@*L&{NScfJ@_A6IVR>jN}rI4?2+Ei=DY`(csE_H^#6JxM8ql?W~28 zab^I=vyz&GYdLTSuX8X}%5vY6g?IdRJ(YC6x=j3H^CmP_x6!TPU{5oVczpp}D*uC_ zqcN7=r?j<3;1x{DF{rONaUIXRPn122e`C5~A`bo)dMU8yW;=;h)b>?8=cGSNDg9v; z0{s?e5oNxUyGT3Tr)jkHVHNeE!i=$x?wZ<1t0Ilb>6n!A+_suoHPtjOt5c*}Hw^VMP z!lORQW{nys{)3_O8`YQ0Whb zP=$%z1u8+xuzt->=V0LKY$Ws;A{(`c87T9cRac;>nk^sj;(h%AwrDOfBn|=15<=MD@uw0KB|UOt8Pfr)ouV;t;f* z)r}N(?au7Qlo^j2Kh&6%Yokk%4glLLQ*<^S@Nf;f^*GAfa zuj*dhHDEBY$E)4*Urmv{>DE_;x1rJ=t)aMv85_AylQ84BxRw|ENnwlkfAw|XB5v8C zYIxx>p0!aL$=QYJJl58%Y5ezbI|ol0-M;ky9LV)AKo>^?;+_wN9t@?%hEwes)&61P zRjUk#U}w_0CeVzjWkh6J^}YBBI=^z7^=>qMVDZMC#}~83RQ*Q`~sU;<~c7jpdfg? zv9Rz)eW(?YL1dQ&8x5E1KHFO*=P<08>^NQ-n7g zcCQ9P@oqN9Bw#DukGqa{k@O zt}${HGnxF(4ryAj*!QujPHgS*2vHdo1zs~&0Br-c+lDbV!C{VTs}9i#Z0cP>o*;IS zir9voF3dM2Io-#kL~h$Ntlf>jF)*<9H$Wm2O=s(Y+Qp87GK#+C`rYW$j+P`*cgZWy z?NcZ6GL_H?yK<+8YAEpHnISu4IfKlQEzh(+HYwmEe%El9i+-XNu?(VYtxc_00rTul z|B`@-w;4D6Agkt-bWtw^=zd?i#N!zVVQs}Pa+~Ti2zZnNfw4-0mVucL1a{^n4i|u0 zD2mQP%VSQ^;yYDWbQ-?zO2OSX(XTM}FIR;~MKMsDfa1#;SJo*oL57&iB}b#HnUg3t z5{fWG6o%nH73I22WBPh zIZ|O4O~D=+Oa|pQnc?)JRf{wIo@|HDmps6|88Ng^cas@PJ)D2kSu~WJky&H|Z6a3R zlQi?9*8!HbrilnGSkhK;@5>N*SLh_4yjtS^j0CM`JM42A7k=xT(ur#wLDI$7y9<3- z_?`KDmFLFGWSE)?#%)?qo~z+3I|83DYE*sjaD{HxD5ceXcC3T$h=24pP0Y2CTSG?~ zU0QMSlh4y}e9HhFvw@*(U+LiL!*)*dw2?fsGDTiHxinIdU-7Km=?Ql~Uq4oApf)kn zI~xGD0iLEcZkL@E#D;2gZpF2dXLaPzirOCxH{EPM1;Adty-qsUHw_xWTzIp-WxumR z7^uQ>i7Y$M&GQ@nV*s-I5(VCPMf|#blh+>Uxr@+wv9?gdU-Q2B!WO~^w0}};=CO%`09Ah zdIN6Ta5&0hxGbb@pu}hx^OC#pO*LnZU)f6P| z!(CaI<12iA0Wq>Phuh>kw>6b_k_7>{*iZVdS>#`eXb{E^z^B3(OzQY_@Qp*l;DN$_ z+PU|2k=cDPxYoY6)!qZ0AIyic12=6|xHU5>y%eSd5BF4D?;3u&ITq6D9N5Z3FPe{a z(3FZ2d*u_Nb_zw_Kz`ZjNjEAS5flmwzpi1dRD_{08T=sn6%9n+3 z4xW?y3nGr1Q>MM!ZMx;rCo98VQni?R9%CLv;8D2`cruI>2IrHNcv)rAFju;>_l#El zVf6jjZ9(uSlyMvY)x3BY%M1}taKad$C8wIBcEO@8@&Krnu9 zcCz<(3eDvI!44h}X!}1qg}+GczxY%c2k_Mg0=Nbv#Ni(d3D;>%AQr${nZ}+iQ=yfs>c?hDVYv=i;O0Hn&TG3SckQE z>2JR4+kUfU!~bhf|RGS#sd|n^?7g% z@5RIdbfMy?xX2{Y^Q6ri;&*;tzUzmNd~qz;!0+b=-P=F?!mIV+q3=c%5*s=eDP^tBXU;K^{CD zlvHn+o2{op+%Wg-(;)#v(+A4J>*(%mYq7fU2fWF!ha5~3Uv(P=RI#-ncBd7tM&89v znPd^U6RRc8t7URezY}awytioW;#<9|P$C<1&64d-hY(69b{803KXQ%RQ6NQ`A}6Z~ z=U&-x;Axq+*yf<`indVom-8Blz7NX#D+n;u*I4G5(AY9XM%TG0Opxz8SxCY$pax`{ z@AO96tKiorb|^db^pN_v&5&hDp>QzU)xdQSC-gxahQ8#yotL8>BbI6h|?$afN z5fr+2o>j$0V)~#!tN+;Sw)JvZYWV~56T6wHC=d%iWz1Ttti?vCpdPb5%qSRP--93m1niH5Z3R1%)x)h2{AU9LxDJTDEsoUOu_NmocG2m8{>Ki!tyv ziZL5qc9*q{xSyb+jMl$UCqDDWCoc+$#rgOuvff@FeKg(^!MmGFN!?KT-UzeGU-m|I z!@7$c1U4fh9v^=%`27j6E(jtTCoo==)toxEn4RA0CQ)%?;pE}_UbhduU#1}MLiHB? zprIya0XtjQ#N&6RY`<)y`_`zQ5*sQD*B9sM0UeK%$TuJbby^*r8%rK?7nHn1KI#AD zCjr6il{*4!1mrDA(|Pip?DA_Du-^OY=9aqGJ8cSqbQpai00|y>=bNnA4$1(WuCvuQ z1-P2lWr8v6fLaqY)B+V;bnvN@6iYi#&xpYh>9$n65^n5La$xvZaCZZ5^GIt<)7Ph-CiKlVOA`=KU zc1LhxCP!6ndwbiJ;y793SZZFMs~2tM0)qe5rESg!QP(7aj0=wa2nn%Vs1i{+&bB=2 z^nhnq9u6Pc%EE`t8zj|gp@H~$%)qzsWU;!BbfSQ_h~#}*yr!RjecNbP% zq&L(IIcjO|L(k41kd+UQh>h#;GkcR29r3}t{6gz|hiew@Z9(~(W@sM*UxNpnkWHH{ zLFAZmpdZ#|#)SDqPCtJ_4!3{k=B45Y@0ER7t`0V$^yX4w>za@o$b!aB!X2~ZS?v9e z2x&nJy=~aK+Uj2|oi%I0);h(3_3QN0YuLv~@;p7KejxR)Jc7P3c#NLU)`T7jVqYVdf__i9`6a|6HDi+>k9+ad!yvS zHh>$(g#)N}fG%@PG+^a!4T{n+M^E)q-fg@H<&IKOK);XiJZ=bz=35kINwn~ zkhl>fDURUq6J!C8%`Eyq z`w6TG`^!%Hab~Iy+8%)Ig#-Qiglhn6;A3ew^|o}nSRH_kLKf&QZ0647x3YyEU*LT{ ztT&y-bn@aeL+6Qs@)lyNwPaoJ>~V!Q0Yc|c>gEnMaM&f(`0`1g#49ObPWPIra#lX# zySZ|`?c#YJJ)_H#%~DYl?2b+gzNmq+&R;K-uj!{aoZ<4A0is0nuGizhb`Urib^}s# zHvw^Sy3-L0pi2Z1;}&E?TwJYK1NvA~tq|kgr@c>I+i9>x{apmtjcvtVpXZNt8X1?O zStFzR*Fbe6FmjdD{Fl*j-}REgt=5^W<+eu9xyi|OC|`qujE*!y9w7&;7|Xme&W;Bm zf)NP}wY;90Q7a$f__8l{;d{=Q1&%LSLhrN_p;RYeO)~oV4U$zEpPJ9}ARViK7|rnS zJnawa(|z8}n}C*AUteRIC}^a2!m-#6bh(T0x}&*AMtlGD2HNTb5le%kreRedrv}iXEGVL;*z8- zUVpw3`X_HP`{mbSk^m6YgzN4vHMrYx6Ue(jd;KkL#OWf=gTy6)Szq9CM_MenTU9ho z-cIj$zwxkfuYiJqFY^>rW0xHLLPlCGEtWNn>r2=)16QrOm@WUBQvTngZ4Uf8bVYZf zN4#x%DKGI-xm=xmVrsTsNog>Fvu?k z(Lhd@1fWzUneQ@`oqjM8;tm*KK@#cpnqu9X0Wf?*#PY1{0;Ea=bxTjwgvZiX(D#aJ ziG&vaD_@Qs9%u^U>&-B@UM{_7TmZ7xk%=?7*{67IWUjKuPB=yG5?|mklHD?BSGP8R zbaDz4=An>cQYgC#(VVD6wT;-z=fCLH+PBxUJobB~MvO4TUwbnes`B$2q zOE*ip@KNJcMm)Ue!SHT@?um3iwbj@W@>3rHXb&>UOH;&%c*dSo1sUqVa6(g(+cv^9(pQ}H?V7h_bpDT z4$bMHD)UtRs0y-n_~wUzgB1xSwg)B*Sn_fUU2m}?%P_U4U2JH%XL~hesD}|J;F5ih zsOIVdITntwXOmqMuGjRNUk_=Kx|n%>Eh9NqdEwIW2?v0$g$bbw%hZjaxfWC-X@yKX zh@*%H4X<9l@&)!cBi%{a`piDCvTmFb+Gq+l-9syNIzJDyo#2vq25y8vv z*__EiS@gkiexfM0p&Azex>?ej+wcJ#Y+Qf#+VUmV?#dBH;il~5-6P^E!();sxAOYo zFp}8T9V9L)6s?<1h8p+Cw`S8ga!)vjCC$zGv-oY%>G*wc=JOXS|y;z(%QmNEw%p#1E+3#Qs=TZ z8g330Xp^f$J2*8cC%Nuu?)5vumOk<(xO~Ga85JchD^Sis} z2B}bz$GFp0qzf(CecvMkx6Bdx!o4~2Y)z97#WbA#32e4AD8A|bc z3jgh*Zg%0VLx_Bt^Pgbho?>L-2dB%3OnhE99Ir8BLC32^e6TeWg@!eV;4-g|d^#l2 zs`s88KKCwKX%w%IX zHddpF`vu~Ot!o!dU)e>>n8$P1L&z;x0)g&4kmO!288{?NuN=>&TKD~6u!DlA78o*& zguQ0SLg3Y|II*C*T;H9)Yi0ao;#;Zl^X*4&#agvqODxYX38oWWAhh=$oje36TcDg} zZSLT(W$pk9r+@8Wp_0mjV+Zo{PF+*F?_?C28#BaI^#%=~0Ef!!cXBft5vSndL}cXo zDj~RfZ{fGq(|zYZaiJf0bc9LQPwu;afUD+tgbcR?g~sqNeAFN6P5b}9`2LGPFKj>u z06U(}{J}5@-Z}oBmHyw3;8$P_OqYH#my7;051$D6Nk6BuED`2_W zIOB6hSd;$oiALT}mp8@k9}IKJjWn{~-hTtKQ+ZeDC!IjRJ0T`Q0NT}`|9sCD2QVxF zL?GFlUJ1-(wyu-chP8un;+C%>a@?>CvsJmY49{#0s8N?q|9zuVT$d;>O;(*6mM2_xR|@9k z2c`l9P1qg2j$O_P>7%p>uu9|2qRmF?*(y~8nq^!_JaYaQ_|(u5U0S? zzAvFbRO}J*+-xMbA?gZITfm!iIUx3&?hM=-f9$NaLYy&v1!HcGI2OPvHfh)J8#wZW z%t5z|1z_A%joDzd=D<*&=3yob=d8dlJeMy@pQu@)kU7cUgF1oj1f&ZH2m%q2 zULw5{ktV%^^b&dxHIU-Fe9nH(*?XU-eD4_V_s$e%S2&CzgEMFTkrx5~S-Q0WYo;fl{gD$nq%*dYno=KT=PgIUpPlM&^z>CdMRT?041uu0| z(v-HSWZ{enw7Mdna&LQ0`vxo^fsykmIKd&B1q@3FxR-5ADna^u=u@?6ot*>N4knJI z%sA>Lv$>YI*>bu8E**C1hyTlGytd*BEH#rj=-bs@5&xUcC7)sp@;{r^(E8RID8RUo z?cez66|X`M_8~v^i-H6{v{M^mY;-mw@g1Ht;j#p)DSDJ6s3ChIha!ej2(i1*Hz9s6 z;#?2KC!-@@$CY={V5y$h>d**j=VB|dJDKQ|`FFLvv_!D(mMRkUQ#^*YzpRVu3q0MK zVH*@X6!ysYR*?>S>+=epVtUHw-e0SKJy0X$HItZaCRDy- zL%AXAW?d{`{^4GiMw8G2k`g!op_Za4CkWvAySqGkKD-J`gj80?96RHU+%Bryhd#yI2DyhGunzV*b&bcs)c3n0 z4%hP2@Q}O7XP4+tF`i$1@QW&!eQNI;#^_Ql$9@vVjd;T|4>Y6>lKT@!k9!E`{6DIn*n$)t}?Qd-3lU( zB37~OWsx`DaBm#tunarqv-=iT;uAPV7*YofCl8u?8GI}+>F?9rwolkChn-!3GtC}S zat|iYeSMe_*oMHI;%Csb)fO{0jx+fMpRLO}?9Ht9{Fsy<92Xn=eElTV=9lo2 zUFo-PFATndL9o#vX&v*XmSzP|Mc+I6ORCC;kIzb|6^a^XMa@Vza|Pyko8v^-ti;Ey zIJ9v=h}>(?fSL)0hE_K<-RN43xtGky4WKQhAa>g z%P3|q-ZX2L5zW=;X6KizPY;AU_XMHhLOsPWh?V8}#!qO=G1rqH(54?Sb~ND`f>FU9 zyIU4*&)xN`ou0@lr`cT3a~|jR2fCl%Ucty1U>Z~IzmfVuG$`f8zS zA17QsS)>nV(-Y#a`1#6->>BemMicr?H!Dw8*8aVfF2Xj*l`7l9T-!8yep)_U9JVc0 zO}Wv_asNeH-4({JC@^G)i6T+sp(Y_SS9{mAbvnRF)2#rLe-;C+g3t zrU~`XzbS*_&j}9a=|hf%L}SfIH7U4tlYc6ioPR5slmT9_fhJuB#8Yv+(=f{QB^^Hr zZ@b%8H?t;S<*HWb6$?9fI#DtZ4F-voln7s5}QT*#KhIm&Bs>j}Z)J|R`Pq1tkfGALnE%O%_ES-8`J zxMm@TAjPJy5$r)ll&{G#%c#wdfMVM2#8F5JCNvvA+F0_m(Cza%$Xj3;1_K4czyeUc z##*7r&=j&1a?wWwcw5o8DBehZSK?l*DCt>_o!02ywfX0}{i`hAv|rxpi`4sN-trdE zh`wgjTMkS!4&fvoc+U&;Q+r06A6p#q$)0dw+F;2(3zzF5N$!`)mLAFu+-!isTXUsgjSgu4j2dQ1 zQaE8A5$8f}+oMlO4}d$XH#rNys4d)@No7`IZ#%gVqo$$-M7yRPg4 zsJRdiJuaFikJ|>!mh0l2JN^+*j4G42f|@T#r_bBCMN@7kFp#|=W*F7899iHooMp3YJBVvhX1GjkZv(ivOrTu4Y#X7SLq z zKkUZ-?w(wm@NXirTM$h!DtM&r5$cD>0)>Xu=q%O#C^gJwy7gTKN(PsipIos?gzA)W zlIIp9MDV)j)a;$MK0SXUk-{5e#%>(povYeC)WUs&%BOpeir!}q2$MzzcK3r(LkQH6 z#Bbq7f*a(UDK};EE55S@S|T5};66-ANg-m_eM+|3^H1ny_s$Zfn9!dUiGf9iZZ8y2 zNR&_Lr@ZkryU0A;M}{@%i4~Vy=3k0F+7VmdV3v=2@os`KpAKpY_6fL2{yvd4#g2<* z##F1fPwdd2_Hp#O&Q9=4v7q?C4=xp{Ta4@}NboVOX)_|G`4oxw#!8AMOc>p_XaF8C7&T3w$4WruP_o10a6 zT_S!G3FW{X1}~BzW!93v#5Yo~z26{hAn%x8$(dt%Uuo|eGAz$qf63V!u08O8ZP5K$ z;?>#Le2EP8&gRVKw2vN({@7=4215DpD_#SaFAtOv&^3>(wjTORSE~8YtaJsT`tN{W zM(9y7}Edxz7?HoU)}OZ2i{Z2u;NX!otgHFn^~jFwUbhPDn|-;es@$ zc-%88@=N5ls&+r7#Bjh8DH>5iMO{O<+sVo(H|%#r1rDiN7RrI%s{M2*6O{D#KhU62y@&a@Qn z>hWJx0zyMG>>2OXFzXb{gIc;TL+8JX8#(%DD?Y5boUW~fay?5*H6))Q0N{UR0=mkY z`HM<6XFCCV+lRHCkNoj@C_P)Q)GhC=ZS)F)w~i4EfaNWm^06fIHd>=rAK3{OkiPhM z!WVy0)wg&NCZ3AT8|Rqy38rIZgqBgm;RX=>frju$N09vsgye#$G1sC7B8L~sKfe>h zxP`WReb^;cleAMx>JW+CyOJW&m>++3D~*bziwb+*0#s70z$p}fuiH758Ki)ZK#qL7 zdcPrAadSVwO9z=^9dT(K(OUc9?zIdr{}V7<jH`+FW%`BjWW1^0A3`pSd|&4%5fIXvOg_)63Q$ZiUSSGn2l^}kGyj=v$y8a< zS!Anl1GT6U{U;f>w4Ur>3U`3m;owh9;Jl;xFDi~53p0bmKZf{xQ)%xIo zB8x*e0k{00xe7hi;Cut1VdW!Au=9_sg+D)+068kcN;2?Bwwsb5<{c6_uE-vHA+J@cPQxpW&Ee1 zW(;5jFB(nmM;`?DZ}yP;apV+I4lRV!8`{1M_Khb{!KvD>b62rHrn`tIut|xN8W;AA};ffX`VmF zD*)LO&!3~pU;W1<+3as?iGR(z{Vhx3w-WjHaFf8Nlm993|BLvj*BF%(<==# zz$f~L;tHh#O7XR}IAi`?PcZ+ywrZ@K$~d%e?a4b^I%;n3+tH+D^jpb=fgLUlwNQW0 zQL9hKgQ(cgX0}izIxZPbFdZAYCHH~j_q6&|>Q?=A1ME8tOAX)5%wGRt^_5Ex6B5P^ z&>UVo1aF)I?~J~GrnSf2ZX+cLRu^h(-%#0w{-NPA6n9ONX9|M(cF(`0kEsO@pyAT)a_m;A!+j4UK&}OcwY&Y6DV3WJ_ub(*!wb zS=xd@?sbb+TTaDHdY+5Fa@VXl(E~p4_DI>gLuN|JHt5X44p>RbfD*{E>jb&Ff`fiv!KkHl4cT=M!;TvYJc!T)H(K=X& zzqcMf+y(Z=!%6ksnk9{b;_YFw&_dPEYOMcFqqXMv?6FdBl=Fp!okpi2*8Q7--#YeP z8D8R%N9BJs2ES&x?dwV9#(Xe;;&Yl-316)MKLJ$-weGM~uX!Qm4)x7JgvE`UH6WJ| zBZ4Ziurr{7bf5Zw&e~d;&)#L=TrL~I`g?IXXR7zL+T<6myNU_DY>9#?Fu~(}^J6d6 zC@zbv$B`JvJ4C}DtzLF6HXfQ}rKraBwMQ|Xi@nVHPU0Ks>U(&UH*9GlEjd&g9r=gAYA5SoVgr2A)#G#HhyTb9(0q+@mJ zYFXS5VL>IJ4cP8Dyc9Re=FGq?s#<=jI@sV@GCcYDIA5qi(6XBBm70t@rLlc8(jmh0 zFyW|CecaESNWwWFF0F32wog6j;y`9_6|#0qi1BC@Mfu%CQ>$OZV*(d=6-`qtf9^K4 zBG(y0pyx9{4%IbDw_KJ*Ya{Xb+y3Hu9yYQo2?W^p-UIF{y;pb@P3Ef1hR$Ebq+uzuE?3(!dQNp`rFlZ-0Eomyn>w$tFzT`1@L#-^K05>k(IFg){v5c_rZJK|C(+if^3ix)qLd!mUcTcBEZF(WH(*W0^ zICt865Rs_Up!MZy#`*>deE^BFFqS|Z*9tT2kkZjZ9@u5XZ%uw%nyw%;=NpffM_tvP z>v=ybUef>UtoKO_hU5UQF^=L1KDt6PcEcGIb$UP>&Q2aGULaaDBk=fC0M?#UZWm2e zY(F-?zLwyYrg=v?<9M%_ip@XAWkO~|B80sdQ@R|F4Q^S#Qh8Dm)3l)fGOhF|TP8;Z zS0gy0XsRDN(wKL!uNA$?yC(h{|!}=a(GRXd<#<~nK5jYd9MHVJ<<_OaP8@&Nr;)*IEI0u zbR79 z^2ZS0+-0rF8$U*BY|7%yh@Ch9F1mIwp|w}r=DjR?8QHc*8hx-e<7DKj>@%xp?uZT9 zYd#qXgvl8cDYiy}QuzqA__UTa}#q2)jK=#_Q*9@D&QgmV`B~sZ{lgX=C35|N_7C2IsDUBfd|t5uCzg%upsq4|`AyAQeLHPz#T_>Wg*0T&s&bzF+s*ti9U1=A5K>J9lsBtExSn zIA@mF4|X-0$mSEJnV;y%chq`IY{b1;b&CSFVHKhE!o@wg?^+n?nJj-YoAdgfY!c+K zp}t+z1Fv^Ua4x=Z`_7(o(Y=Q@Z)8U2-oEDE?O8%upxBa6^c-_XEL+(`joW%+Ho>h$PK(9f{*VkNw5A>C< zo#m9_6I~o$kXqb%Km&F+TH-h8W(XjZ&w(_A!Fcf3ydyb&PsXP)iDK*1Rn!JN2DZtQ zHC?92>;N{Kl<3{I_k4Pg#aHKA`6}m?8Z~vd{w~u4nxm7niY(O-b2}Q#CP>Lgav-d1bnm;Ut=Jr{ zcTJZelv` z8QtNe`|uxDg8zdZAsDv&fBTw08yXDdhPhIdbbe8NL+$X~;v)ZVKkIL{SwQjnH)f5> ztveOR`Nk1<14*G?|t7$|Hq(ojRk^22fDyXbZFoPJ#8^VtE7 z{DqF#NqWiv-n4CoW-oT~Y9E=SrPHNx4m5=(3FMJojkIxB5669j2#O-jZmI@A3U44A zI_pr}!h^#yN+|Qx^w;-GDq6H-9c3$Kt75}p*1H$8TS<`*jVz-_VM99ds2saYN=X~I z0K7I`WRDDrzsUngr5OcqHN3_(l?#n=7R1uGC_G6oehe=%qDr}y044u6WBH_yZEE2~ z_+4PS|4nY}mY&+|YX%wa0oYztk@;q6h!K%!m~%aZNDR5ooKo6E`13Ri zmgS-MGn#kuL|91Bzue#=oE00&SQ_rfg!~u9GffjNE*` zOU8O4B&^$_0^3#YBz@tgw#akrho8O~_2&J|>R?8WX{g|B-TC~4BGdaj=-|WG1Jy?4 z>{}$w13!XAI{iZu- zR#c?4RyrZ7x{z?Vz2^l^)-$j}GJs7=>W8cT#24SwD7|qrKK6|(%fnLKh_BOjk+aYf zyGTnX5w^v`QvQms?w^gh^<`$ys4a?Bpl)shFd5~_K=uc#u;-Ksyxy%{kq0lS?DD?1 zgt<9gRLpq#^~dR`v4|?}uRq+@_ol*8Q&J1+4zU-`c{9b$#IJo%d=j1`wU-7?oY_2| z09OCBvu@~nA41=52#-S4n1;H4Q5E+uIZJ(iC&}* z-s>7U;cw!`Z3r3qR^@8C=0aOP@qd)BYXg?8p6&1;zbmqkhgJS=(&r3ws z9QRG|&PN*{d*=WE@999KVkDscyl-E>#mSI#qz;N^<}L81#p@0oo(^gXEs;4pri7SH zllrT5EbB^;-5*1}FGI2JL8!r1Ld3bHiSVP)!5+o?wonLeQt@ItCY_tP_>sb&CkOp+gGuW*|!5>zW$Xvkemd9+3|A zCkfN%s)iaeIdqE~`-`3=c!1w@q5>S^Zh**3@bFR11D*TrlB;L~n*!9hTP2 zdS|SvKik@WF&;{9TW_`{)j^;G^w(yho4eFTp?8Q=1SD?W9sP25yu^bnv|>Gzl}%}) zY!6Prp)&R6yUFLCzV~kEt^GK%2Rtoaujc?#4Ztr`U^;c#$K6X#mM{)ZSKEzWRLUhj z{(yhDs05tC7$|;-B%dRdzC$?ErBTZ$;d_KEqO@P!7%_WYarX0PtNY+qUwTIUW< zwIJkwtaau=jmpgmS(o0*XD|Q@7ffD}!A9KOTO<2hXMgOs&feqZUv>8Czug%=gHh`U z1K`O!XVhjE3pp>_1&ZIcE)w-&51>7HR_q8g+mgVY+7vYf8T{x8cHU!9-TdPufy3km z>;(VPA#d(q(soR||I%38VL01|r#_M~wf~Z4JOASOFXD;keq5@8o{QX*8LD^(`Eh{| z2k>=l*GXU#7f6Yz@#^%BIIPqu+11!k-1gf-<(FxEj8oGOu zGTEScW{<`+Y(8sBo-S#q=A`Zy$us#^{j;#vLOc$7&-tj%8pCK;*pGul$(S_mV2$l- z){biS(7-G2*Ig=P733i5BfulVqw%+)u+(eMeUcO?arh4|O#bWH8dV}DvQf$CocZ{B z;F0rgpo|y<4jJ7EZs?M<(2Cyyqh1{44hwd@89W;tr{omAJpJWyDwkTb$3i-Jce(R$ zW#6vH#{%(a0^IR;QbWYlgW!*v{X-A8Z#3z&HR~iUKtXU_Am!E&Vm)GBsL(;%&vI)? z$3Bm(%`=sH8cPbQM1gg=b0dGIdU86ZKVo{+NGB{Z_1sEZ!cMU0R`pr4#{&k&puqV# z_kUC1U<9iPMd_!t05iWiACgV@V8tj^vvc_3^_w#Ep|oF*ox8<#oVkO5X{==R zE%wWE^K}WP=6iWJvT0jWep~cjN`OGUe3Trgpa38Fh*XOen%Svk_ahfw!_6BR(;-O) z?JhJr0g}OM-{ZD=AMqtfz0WSXrY)|1yOh!I6~n{N)i_f^v_;xVA;IWQ&n|><AiG}+E#nEgveH*D; z`HvY-h{@6#@P%vU72Zq2JA1qRb@xu}WXnw@rQ$v6APIb!t+pg;+*b`qAWEx=~B zK{$p!Jnc|R#a7Y_4hQ)ziZ8M)hh$fNNG}J#n@{LQN&nrx510JgM3jO_q^PD*P4cJD zg+*l=mfGu=BAZgqp71yFH$IX$kG(B5d2pg-7qNh#nQVGaPRpCo@g36lLm%#`A0&(U zI(!1sJ4ipVi$1%A-mo^g)o#6DckL(C8a+IEFw_AO-mJZ()9(9hjmJP#!h$GzJ2gKW z$Gf?busebWqkF-c6SRddD}*vtJ8x1mVTYK#)HtS=4}H?T804P2xh@<+_!mpBx@N2t ziC)gT7OU5c8ToVvZIh4G(d%AYt7c6hykcdXXmh4Js?zc9tKh9B$JmS8KkZHSznKU& zKf+(nGRyj2#!|Snj^S{6vT(N^8a!lxC*JFGQedqgv5mAXE$wM@(3f1Kd%?dgbf?GJf={NmT z|HrU*_derUl;9TP?3mO<`Vg#5gNfb~gn%G1c;Qjl=474Ay5pWj|A2}+ZEw@KoP(x& zebv1>bS9M#tteE<^a>J@V1;slyg9d<2=$*$AEkI;!&CqXv1m8E0cKP4#|bXa+&Oe~ z#Y#6~#cT^DUV6K_WKeoX@){RHbDg%DSN8_3F7Z$YUE!1M(s>sw>c)o4>f+gyrl>0d zFk!p^R-UFIqL`@(@%Fx)s+Dc(79ZuyP+@_%yF>`ec2XD|?DO0w5@wtq3Y*B0D3BIz z9>hy8>Hd6P0-ZvxaDh)-X3onqO5)tO@RCikcl_pvGRjvvxW>$v3>!Lwip+IWw}EHb z&7x9zW~8L9UFOKkRt173n~V{Le6+AVfwuVsRDK*#DnqZrW zS6<29V}DUG?yLwa&fi5If7AJP^;J>KV}@Hk#tuN)>Vw~DAb_bE z*K>1f{Tpl^5t9YXJ6(EDy?Qj>E zysV&JCBC;@==|}u3WYDQc_D{BowDN1Rafn!X78jgI??c7sn+<0Vs?UX5R!HjuAJ&0 z^%IS~mBMvi`Tpyqdb!u^)ZrzE+!9<`5*h4m6R~X-gnWbdLoqjZuhx~R^<26MbLQN>Z=xa+9)Ph%DGC!R~JA&@~ZB;+^{cGP)&21 zFK>0adUe1BuF7y`uB5B}3B^EgD}*=Ol-G@xH+;$44G=@n(h zF~bkh!^mKijwP`%{-RR+4d5IfXG=SEt`3IkrN`!OO`g8LH{$(e5vaFAcb1U@cE+{w zLdkS@!3AuQWK6m;$sPDR+INNZ*=e2@(8fZvYk#6hnu-hU-=E9tBB`x66Rk)q@Os4Z z!XaiiH{_qVEe}DTe4VxL3B{QSxO1T5^M!que|KLs?R_nu;>SK4Fw;G{E`F*CP*T z*OC5z5b~~i`3zIC+prl@xJL8e4a4W;oxzAi-hMBX%vSObSMj0_E@fV>Y@_by6_X6( z#>PPnTdc!n>}!er>YueAuG`QGf7e-ca(bw25cJ>28k%@hF~KFWIs}n7B2x=<+})|( zoiW}6)!!T71#6=pAsF&I6!A|Qs>8qZZ4KFyyqW3u&l(B1k6q-aLL0_Bx*4GJPDLrr z$8XN1yjWl5W0n*tn)jj-_f5_P**uR22~m~T)1+zIgxix`)OdPrp9KYfedv30OW6_h zwUcHyU6MOBIHCOHy}U`#Nc?eOh1^4GeC+inYBuZ_m8TtVnzFYk1(h=PP?@>g{x_q; z{`=)V+mxC-?2hb;>EBFEshIvaJOATf{pTflR*IaxB2sG}>vE^%U_lGz$zwf99wD(3~>Gqx}tS=D`L z3c?xQ?%ljmKEcu}nRln5a_nVV`=G}Qd&Xnp=L@R)j5Ur|d$yJ~THX1oFZku?E#Qo9 ztm;SdPIYUwxdSWnUC|sU_WKPcK5yCPHGKP!7+NUrU}F>2(8a<0n62mSs)#O0pR$2*l!F$HB~Kye0j~r_P#~y zjfMHOeftOK-BXlp6psls05z*q1YQWi8Mx?+jXyHJzpCWtNbh>!*fnSBaw)vj#MZO? znr7eoN&{wK`Eycy{Nyhx5g$ZW z>Sx1xmu+R4v2vPx*$%HYd~DBbpS&iLHZd=+g8EqEGaX?gM8x>Ph#~$^Y3#Q-@{fx+s)%8^`nd7z;fO z<+X9nibMV8`ec2N%~i*&F=LH)t_;`NSmZfzG>#VG-)mFIsh_S+aozA*=;(iM*7rgj z5Y%3Fd_P9aYs!f&lAS)G1+RZ!E>eWT=gom}KFW7Q+&t>uX6#k7&Bz9I=+>5YnT-*N zLjdkLW1zH|@rI{a2ajnc(2+E+3% zQ_0ArngNU2g1*%&dZ*Hx@G7d>7+!V9jfyVC4?7r0tgid?LxWRICE2`UoJ(b}J6G<5b(GzDu8u9SA|WlTBBqRj;e zV$UAgX1ASxOk;X9)F)=wi}PcX7kbgheFmv+eYP}mIXDI z1?CKs7D2&~HZD@k-|NwZsqZ5;lSuQ}4EeO2jG0yA+n zUKpIKfxHz4h~cbR??TLk;Gb4dTX#qREI$Fo`d*$*9L`!hnSZ#-a&r7K#o~rzOizQ( zH=^aFZmrfCz)ne_EUrAmH{!DDOcZZ3lqXy`1T|oP!i@JNx&V=@ zBnyIzuG4h3V-Hb+vdkVf^XW)-Do6rt(W0;#nXWwVqCC_OBY@-b0n77g=;uq*%ia&a zi|CG1EarU#rVx9|9fKezkgQS_(&GHYaoVBAEn!$5 zlg90;MGttGPj@dXzN0ddlHYV0vcGkQY~8gAzWZgZQA$ldbmzk6*AErios;|{G3Zpw z!`0^N83=!PF;UdPG)vCMq94VwZeF|@7-arBo^Q+egto3kP)jiaJ2y30_L>I5I z#zJRBWC-UoSAMBJFTm18N;jZ)!C||@V~PZZ8hfmY=i^j@?~f0w_o{#3rMoJ3Pk~`v z;U;|GsgT3`S2$5*8>&$Eq=Mj)I#-h|bm;K&@WfTF^UnGg<;8tzQ16Bs0Zy}R#55mq z6u5R?W4akpJ_R4n+6QYnA5bxiCqOukjrZ}4CZ?$SF$wd_Qr{?7sb3THWyxBG<#1FP zMHPOcXCdW-%4jbtYY!DF?>g-bK@|1wYIGHiFJC?9E5R97)ozoseJKVa&6%*g26#e7 zSIQ2mvsQ}G=gd{sh&nSi-crbRvDoFgykK7Y8}z56wF!aErfjI=D~gv2vBC45jql`g z&Q`T>=qA2=^zNo>K8eL{O^kW^VOpD~hv8bXqE)?CIKwGz-}i}YK11_rmooqpTmFmc zg4^i5&Fe{LGICWV3%W%8^^fb@G#V7%5f7!RtxeRIu1CE^Y<|tfh$zkK5$*z|GC^!?xWqo^Qy6_ z^DyTC;p{l9hcG_!y4#ueNsE_%BI%p4$IMs96_SX%;PdU4hTm@$4$Ei6UJs%kYaeB^_(KQ ztX(2oKrV>dcF1l$xg7K84}{}|BFck98rd+p2YOb(O8MI#UvR6I3r+G6&Czk0obSrU zy(o-jV1zTMPOJHd-7lL#>}~=jJu2zRiiO>4sS^tEFp6SK$qbBe7((t^MxyZ?A1K_* zS~W_D!Mm6g?GgitnK`=YPt{$w`bth>W`r-q48L;R$=V2Ffgf`|huxC$2WqtaxMgf> znH^|fH{mE1;K@EMTix|e+WvkahrE3Ee#i>$7nK|8=Hdow_Hp2HOh<#g;JIa?FEwCc z`$DqeNgC4!sbdXZDE>n(bGKF0)`dan%3RLRoaReBB}tF7eh@XB@>eZDn;wDJ*u}#e zZL|hQah!0S^95F02JBiAud1H$9cekx^v*$0ds&0+<0GIxN3oRXW<)YROqEm8kYauy zv4^J5%>Hl~8Mg)P6d+&hS)amgMj?!DEZXvj^YF66*=cBg4hmyQ9wy!Z8Q;z%(OvvF zZ5t-a#_)vNY&qDfJ4)CV4454Z4QFlNKEDC!CROTI2oCrz$$7VT;bA0qf+M9ig5o>S zH_5pe>$Xb?sEA+!+btS6WP{fCqOi&3X`0)`&HE9%d1*sEoL`C18B|tMI(P{gOPomv zbYc}=)}@r_Y4xF>LdA3%Q1+sGhy~*UZw0M`W}`!&=f9`|YRSz#s61W#LvBmNy@+io zyFct$vVa&s5xN$!YrTT{rizB_mn|JK)v}Kl5o~+AGmlD|p^YmrDi{&M6F_HJqE_O; zvDM+>@BGAMD&Kwhjdu1bKyyx!I3EHAZ4aCiFQ7ZLRl0v}%~SpA z{%bq_bMV17PGpT?8uJD&y2N04SX*%Olurkz@8xjudneRR+N2eHqt%j?Vf#VT?F0w) z9Z6c1tOh**o6@>9bh;9{toDnl6?8rvAJ8lwXFm8v1;T^@4$_8Ebg{&Q2!d_IVH9dB z?HAQp|B=HZ%02)p)jOg3qkmDwh*0KJDIBSkFKAL0YF~=V=hZ_5Kx>=hkX^oCRCAyk zR*?xVFt}4${J*FK5#)|rkWWni|Njy>H3WC!V*;{~!cz)9tZyYPx-aYw3h)zem4~!D z`k&{De18m|crLk|^A*F_|4JPJNXP%yy6yk<*Zwc<=&Kz>f=sXAnHyKJ4D!O{vM)(&m%N2)&9bLp*pEu~rHk~w9Hh=O2S+vF) z=Vnv7t&*?)V+E)4TovqP)@H zm%h$h7QSXm=Ni%ZF4wXrNT=Og-~5LZ%9P-j7s1>$g*swMhgpt2t8DKr)$z07Ymt%f z#C?bR1zo(F7q!33Oxt7yIHW|o(PZ}ksuVz)NW6^c78-217w5$-RzCUG$~< zmm8#6`Z5vX{FYx3@ec>y1>B;I-BG9dH^pqt`M-*4|76yGK(`lBV58H-eWp)`BY$w7 zCo(cl06SRAiK+FGLA7WSK%2k>EOr%D1T}n4k+;rMU3?3bLS~rqR#}{R(j$Ydfo>_z zk)xoqr)D6B1V1o)XMMO{d-va-X?p{mHPKi&DB6ac$IiM2?II4GwiWcnjhz-j#-$r}2sZ?Ln5eLz<9 zE^qXYkgXfV=UrnussHFv>z5`6ACC8y(|>8HoAq99=>>p-hlV)t)5r>G<0-uOt{~c# zZK-^|dR3A`eZ0$+OHaA@Fg55s)fMVjd=R3PLo7IzOac9yfLg3GdG?9&i)uuPjEpV; zCs6i2r5gNf71}w^q!qK)UsP7gK#Rd}=Y>;0r^N~fynQ5s6m7#w`7Vt;ynf-ZXfLEp zoFwu!F7Zy5kdOV3YKrsL<%Y)U(`h3DV{2KaU#X*@j@DGo6;cOr%vK>;3&#&q_t6Z9 z8%8SQf*M#eXs3G)ydLAAq za(7SD$Xzk-%0EkcAo`<8>MePIFCwc66$3)`xbGXM4vVxvl)lAq-LGVpYlUsR81vKg zCy99H$|$~Ijs42g#bF-qVjp4rdGNSP?Btf#E;YF1Mo{uON8`P$FQ%4TM$H~f964_J zI#yll^ADELUfjxVyIF=}R!&137QNF2tXbOP9+R-Bh=cf;eUWdE$^ro?;A}L!i{~o_L=MLFNJnozbkTutx?5SMIBzL{VnA#Vl{YH5w-Zuf|x=WodVj^Nr z;%9>nOe&SXHwi`-jGt$9)tNNW^mH1XFL#L?r7};PO{6~(1x7{S;TZk`f=Slc$d9Kn z!U2O^D!SAlDx$P@fzX_1WQC$j=rn&Gts`2;J@#1i=xw7>TGWL-!t=7r^tz)&=)4!2 zYuYYZ6VaV7rE{CJ^BQp9?2KT-k%5-U-dUyBB9_4%n}0n(+Q;A%`Q^R zM-d_&YRxcNID2vcI2bDv>;po5s%$~OvYmW1A3md4kwtKw4#(y!_IW374d^(!iJiC=k{^M9M~gVHJw+qmGLMHDS8+;+ZEliq~bOvecia)A)hSb05t;$j6Wj2@p|7W>w|7(GQRcL*;UYDX76fj$|6%b&U^a~Oz8-EgKm zeS=3a)sZ>(#poUN<={uY_OTZ&sPLvfN^jq9>Mm%*J(kG^tM~v~rXuzJ%%;}m7LoP* zw=LV;2-XV6q*8&@Maa>bf$#zT-ddA^YMbc@20|^%z((vzOs;PI6*`q~wMt{3UY(p} zY_9Q^$;G_jI%)pw2k$#2=A8wrzNVn)a^55>%MH*NP52T$`ff4Ia3zA zS}l^EiXa!@BG%4{{tRD*ALDkA8>;#vqyQ^LItL12{(HYad#FQkthyw_!AT4;4mPve9D~WD}qx7&S@1OpCbt}-VOOwAn01hv}ZF2AP^v&^ExPiVG zBf#@N{?pJ7VLiD4PdU(q6G1VZQtd0`DCDHvLOTBR(PS6O!le&4s}oA^e4@k-Itzi4 z9qh;_D#jTp@^+|kN>xC_de6n0f|Z)-b>Y^D*fptL(a%&dRH%R z@cAXbJR^%y{XI5ps*2=~!7LX>CdEC&Etv*fvQb&Yztnx7(K_Ua)Gkjpsz{A&6TDZa zB!32p!vy!;Q)wNrX)EQxILH^kg#fBB z-(ixXaJ@F@cSghD08ue;fk_%o!0N~~(zjSmfB)jQm)b`j%+Jw%I->AVtVIsBb74+i z2Z$JaH~}G;3doZZ=L@g-?-q4sQdF{Gs*7)Z2StcoyF5X?nx?~%epI}i9{%%1W-ms^ z82i_tdC)(G&K@*UKHGf;_3y%bB{u74iiI7leX&{s5R}kOhLwke&&$uk@^mTd_gZ+l zz~;ZALlRf9r6HmRv)yvYj{zg0dN`!HcIS>?jJy$cHzj1Vo^q=j>{Z}9q$KRpZUM{= z**XzOsX&wIhce+Rj8m%I=A`Gp%bCO@{~ybl2LraVUwb!1KEuS;rlXV*8%)}_NcQlh zlk^~+SpKXfb*tjBM~zQ55q0O9D9EnVVdD-F_x$UCDu}N!G%!huaA+7{IC`|ty4=x2n!@Ics8x7!E~FaGXN1doYx{qmaWmZ_^F%IopM z?-(8@8r`{RGJXr+Mvtm|evta|^}Y}^wo36rH_2|g>x^4l(>}{tQ^IJ2&Gi(Pgs9i_ zcE_{p%^veV&AxRG6VKn8Zov1%|I#b2nfDhWnqgE+t|*0)*yx$lhI|Ug=c<|1sp z*;yb6T#(}1f?kaHfc!)g%NysBETM{96IPFEO3n8fZ+{?VBiJkzulKWlN1ZL!V8c!e zY%ZGFXB21WI3hjI=dC4ZCIwFq%9=fwW1rT0n%y`=*U;RAVzCh>F^jy}s1{2ZH@~-+vpei9u>b9+2;<=MkFch$P2M$cxb~&MA%OA=1}mLzFeD` z-bcAVBuB0C*{_>bg@w8*J}dhA^5}<`!L4*O(YNd54K|w_eE9CkOS#TyQa7+}WGmUH z*I-A4&68yY-E4Wv6Ih1|rxyT&pg%pzuLY85=(PN12I&_SYW6CW5PUE?g9&od*a(I)VDfBT>#ZE_ z=M^;d%&T}XAKO_&82_6Z0@!Z@zI6LLdYcT&cCak7_5a#nMSq1`3{wjD|lLxdg z%B6_EJgPhtpl$&!KL~owjB{M;_Y9zOiAh7`k$Dm$4syEQF5SKCqf3JXu>x9-8pgnY zS+?&ec^J^5^Q1GO!aM18&F`uKJ#%y08Z$xaS)ZOH%I%nwkG}h z(j5ey+seJ0yvykwEY2Tpe%si;LvKbsSo^9EIxFY63vSfyd>=ck2oVrmdLb#a?-Q@& zC#|a(g=XD3o89a3{zVTSTPJ6+nGr0-?xhs+V9*mz*fP8et*iQ~fJE&14(#L^Y4mi2 zP9Pg4vESSL%^?%i<(QN;6)(8w+H>@wC_IYN8@mLw=+A?kNzZrt7H^1=k#mmM3fa60 z(*#`$;X?zElv(#RV_R=kMuW>&*3u{G)`#1?p5gV;nFKnuzP!oKWGkM?PWIdHLl-g^ zE>~SRh_X@Ko0yQRn6{k$$iW^!{mdf5@DPI2sUir~-EW8*JWyO{4N3Tj;|fi%sK?M) zZBrGFs&d-0M%Fxj@Dpl0)VAPeXQN+Y0QD6kz3|?k#nt8n);|bF)Ff_f)lNFd>%gmm z5vSaIdJy79HPhzJx7Ob{Zul98S;JoTTJF7+U=i(5JhI6A#)2B(k`@TFl9a1?z&vVa zZ+8mP7$5bVZ4Wq3`FqtoUl5iPsT0{!1Yj;k(C*^_nxe^1v4CxMHRvlc$&-a!8CyHGcwow^P`8$RUfpDRsV6!aXK%;q-`E065X)iZ z&JxP2F0uF!bkQzA^UTKNm3)XT!x+-bLg8kpb1uXGA?`iHnp*pQK@=4Wf`CX9f&!u- zpdh_ORJwqG^b!>jG16OrK#<-;M-hlh6Dg726MFA0bO^nZPy>YQSw8PG=Y95`x#r9{ zA7(yjxFA_8>%Q-Q`xRO<6`MKDJ8PC9`07uiQ844{$$ODpcNa)J=<1(BH!`o>+mniP zK4KK?bn(FS=8pC#Zwte|a#g&7E)p3Q6WH90b>Mrm zAFE&y*C!!a*hWC|!*Sp&C$S+1C=1X_9H)(*G;Ur>zk%bblzv5VSuajn1B@%A| zVSSf0GllrR4{kt(J3!MoOHfU^jkXRg=IgY&wj5@Gtut%#>t(naFTC33z(sVUC`NT> zVIF8`7H?bVoIL1Uf++#ivn^qk^@eglv3E|m+!5c+Z-f)FG`1>!vW8C0SfF$Cw`eg( zBF6f1b!Ieu=9h}|ehA9u27k-neqU=IgZhi=5;7RPybV-;?Es7hXO|{J_}vef{&n5r zIem$m0u;w<8u8`fB5)gg08RuSb;-n-XX1t%o*o2?v)H#&>jOHhX+q>k_eJq}VBRkA zVHw#3lg2Nt43U0p^!9?zLXWM=TdqWE&6y_-9<3B8t{Fx>Omkd?yn9Ndc~bGaDrk;$ z?dMq}5K4aF8YrBukcfl4R;o2#=QXvn*vI#mR5wK!K3{Ns@$GOq`6!CHU_1|t<=KWS zHZ;rK2i9Sd1DWmvJv;2Zjk~J7^tXF4s9}{SNHTk3A9ZB`r54 zVVFoB1<-N0W2fVL+p_!L;uYR@vbE`Ef8YNFIIzGG?(U@AHn@LUMC8m}q?ZXcY0e zAXML0Eb2NBn^IVoIyHAfHwVA(QG(~SXpgcUKh zMH4k)vBFEOeFl4wiZ88U$n9kF9ySFYlJ3^HQz6BMh5TnuG$ti z%R?uiFIj9Z5enR*Bz6m!W9j~%z~;*DlJ^v`VpaKk7xg3HQ3i)C)2z03(@M2R%9AKz zZ9Pi23zm=u)1x4;JJ(TUB}u)6S0nSu+ErgK9UH#7$jvZa_U&}lGo_*eqi)}}1Gx&D zZ&>xpPpqwjw(uM}xzznd^{y&Z;TBBc8f{fjA7V8gOIA|*Z)t@bvbv9k>;8R2f{h8h z&`FuqMBQ1(n&&5JzF}wmMiJkZJ{$wD1lMRELOpj|r&%?OO$5nOddMj?9EFr6-sR37 z82|0G|C-$>_wGp5lwv~3iC+doVI@etAs*e+KM|U3ttE0iG*^NOQMj2nR66p#vXO!E zODn~b&yo}lUjUP_|D>CrNhCB#M+!pA}ROxDkQ6|KXx%uQ+`v<#XG3 zUhD(7;ym(7r!9NtmMWl*#Qz<*$B<9&;1QQI8n)ap04OmfRlZ3FRmTk>ij5%ja9eyh;^L^1&=&5+5tmsosrLsi6Yz`t zH?}&I^s386^?KYVwz|#3?q``FI{OAvM1p*tulzmL{rf9GNVB--f10Mo_KSCmnBrV% z2Kjdb3`V|V49u*STc@c9s>c7$I{hyrLO?p{4~F>6Jcza5evtqQw(}?&CTqVf7aU9M zVB2R+sEs%%be#D#nm)6_-X!S!I!95cv`Mq0{d$Ma&6dN3SorX|lr+C|cVXwacGYWOJ`Z^%nrK`h0~HIdgND{A_RlIpZ>NLb*q`@x(8K1)0xLGtcNmfOzM#$>ph@|0Ff!xuFv5E5~;{pD64qgPmm%5p9c_H7L z*xQJ{wobr&@|r`_LnYBa$9i82Rht<PT1fD;T_-vA?k98`^3pm@uY&53l1%OI24ZDS7@&FpUbmdX@ODrDB_t(jf*WB95oci9>gCZT2M=3shUtIO zjsBGqOSL9z*fYF8v&jQY*fl{E3@`BJf(z<|iXWGy+Lc^1c2q_P=MCk`n7qn?2pT75 zCoj`qBdw4(@Vio@;v3L04QL^@b+tx)`}L{U*(hNKxGLVNZAzmAK%3VVcR=0@E%!{A z6cv3K?G?83@cO`ua{c4tRR{gddoOjp)H}DH3IZvuuV>GfOqfIyM0*ic(y@Gb^HcTB zS&%xZYJi8YcAc71aIKX>`9>){mOPQ9CzTxApSo27!1R2GX`RcyIZzTR-YnFR$Tsb5>sQNyVU&VHo$NonWJ=Qp9WFW+}(G<%~&0$f!{ru ze%$&kMtj`?P0Te7e@R?7H&(;Q8r^#Kkzr7zzjhQCKl5tb)*~oV1p@j*4#k1{PRLmlnR=+{HJ_ZfK-s{<C=ak?rNN_hBd^APZ z!?8*(rwpYMOTR8YcqUdSUyn}CUF{9fD@FS&44^zYl{Z`Hq)m;0@|=z_p)hI3x)p|G zxnJ6hc%;QMAC+#MOPaDCT-enS;?Z?5Ob~9}rt)O3Jds;-Gtq9!_3D12Ex=$eCFD1m zFdujh_to@+Om>L4huywuc6G%G-^km9aEZ~)1*!Jl*Ug^v8>hdTGvT@e5#N#HhZ6;T z^xfQbJzg(o3w3h7=rw&z`HSkM9PW$^8}mo5tw*Q@;JiB4nClI5;Q%TLlHc^eIy|tGSyfM&Sh0?ujK!u3u5AF949QFzGE_wqf zHEz{fQflXcJ+Z!CYzQ#%_vArqK8}qYWbv?l$cfFS82l9>LgBVOErY|L=!4w=%l7j> zgnCR7S_;b$JLrP;Wyn)o;++OEW>onbYZ?8dNH3?S;ddyzpAzvHI>vV~MvM9Dw=+Tw zF~^qEHGZvTw*1WxNnukr-j!9IJjvk|e0_Lk-QVJZeR5SSvIKEn#a_vB0|pJAf5P^IqKv_=IhEz4FQA};SuF*5Rwo>x+eeeJIk`TBYpA^1e{Uo zCq5jlJ$V8HiTy9iyFu~99&K%ULwXC2Ay@kyZ-rfa?Z`kwR;q{AI4evjYmsrB^kv!(TkQYd0&)aVrupC}avb~)+!C2WT zEPd{*w@<=4;?lmL5ab}Jc4%+x*#Yj7Q-8&|3mi03QsQ$CTh8n*D}6y*Za$IaF5VCh zho*Id<~P60mGqW0-PHueCKLoogGyf;Wxfw=kJx}-iKMyK$RONN+y2#y5r+*yYl(0@H*C!zz+fqZ59Xm}yN6(;K; znikq8#)7NuJhKEB+;R0iQ-}@I)FXOMn8o-wez8W5%7G`2|4%o-Bo?kjNSID_T7~Id zoWIJ1#2*0q;g*2n&&2*75}aIWdHStB>qB1C5!A||$(u>KHQ$Sv-+8X)3vs6*XDDvn z^%x`$ze!5{1=an4B79l1aF4YmBPV~&yMMu+>e4H%libnYUBR5Nckm8BFpY_g*{wh&?Fv_CViU6WYin=Pr@nl`JPUIJ>(~tq-FdTR)q)6eq{$C@D*r z@L3UBe+VJDNH2VSrf`%mz!A`77;1xgL})HuDtB~-7b+c;+jd1!-E>m35L~eu-qdv| zNxD0sc;DfgmAh*w~v-JoH z)?=A#Nx!_~ zI4yV1o(*QF0inNb@iw!a|Io=nqWna>8~&Q;hIA#w?hIM_qyiBY_l@+>i;d>adp+XrFgC$+tg-AyISI4?-==LLkwyQa$}S}L zfMyRv_j18!j1_-QG04=6byVMEwa@=+NPA-;iug8^k)ChaT(g+W7g6QXh0px_#6@ zhAD50o8CAM&knp?m@wbwFQ4>g%Oou~mW>?RB~?Q{Y$})qr=Sj?k~+kmKzG(y=+yRN z4C-F1^GfNs&BAFS2{HJPyncJJ{ISa63*-`XHIu(1X}^*X?ROBH0Kg!8ov~MItns4K$wc1tUX>f5>~ch5klaCM2( z#t*i%OSWGK7Y$Ck21L3po(a7?_&R>1!I*2G2o=e1s#@g6z%Y|&diMptK=)o?f zqs#g-kvlr`OX%{WP3`#%dZ;zOUyX>3C7UXsE2z=(@Yw_UB+!Aty>YCo`(1fa85{_K zw~u%L*B_DdpbrKPOQ{zsw%^sf3(UtfvW-8#(8@h)C8BZ z`Z@7J;hZgGE;#JbL4W43t!>_;QyJZSM_bDeI~}|gce{RgOuy-7k7c+a&hK-)zlN9N z{|VQssRrzNzP&$I`jz#@Th_2*=&g;Yf`J-9$jFUItsh;EtX;4EMkhvL%zsMpWtlky zFo2N`l#9gk3oI6QFQrduV2u)bhC%n<-{RAjS9xlI`#bae(V1LezXNskDHR&L4NNcD z!FYZwfR|uiKHb62n|ohB(}rl!MgW*f!QPWaL<^Fn2E>?^^B#aLCz^zVC{e*Zh;fFM z)uhl_LJLlq0LYbDt`CZxS0R>h_H@#~Z-GDk?*IazX*y zf!EW4%oUPIiYiUtJ*8g(QO^LAIXn;CtX3YL3HtNszA|1pWFJ*a?h5wj!^g? zWME3o2$cNCXDkpif)sJ>7ew)+-<6(JH?v^<-6>LFfcrgooQzLy^qJ>l6 zxve+f4Mm8X0m9bhUCI%J_{08)jpn2s_@cws*OO@TyONyyTH7(oCp;bE^14u{b*wz& z{_`kNBVk$G8M<0GIXz$yH702+p%d}SdOOM!*H@MLb;i_#;kRDJP1FaxuuvEK;s*cu z*oJGW2hj%NzmLnx3Woz*yb84q#umUVQRq2H5PkO83QlqE%e%=Zn+*?z(!6MKO*dG? zdJfpFjpi#wg=u6nM}a$d%!(eCo-?j_xG3S!rP%u7(rv7u3p04)*h#AXG`deXTHd)nw_y)%3AVo@uO;0wzV)8N8X+wsMiJGZW#9ordwYdwI$nY-=pM&4fc)mEG! zqN5bfD{m8TbjQEt)r;GZU;qME==hi^!cu&xcjkKBQu{tnw)&wIS6t!^3-pdk2dvQi z>4RXBF8Qs`Yq;u{0y)>bE@H|3-B=yW4nazM@cp4Ke|@cI-{TD=;tHF}q?rdCWPiY~2|_Kio>fE`(v zk8rXDHf6^m5!)_bMO&^2;sPaamJP*G^GCr7;^sM zc}?*UEneSvRhQmXilQrDs`Ik7KbJFA8jE>3wxqwsB}lxGQvko+_f>z_I>8>O=~Usi zug-;C2fcboGK{(+A>%l7=Ey*YfooBNegW#ZpRcuDKmV)(<(d#UWJ~My($>a6cbWU` z!f{;GaN1>Tail7?6f_L(|N8E;xQWXLWn94EZfSDSQHird`$G54I%~@dtp#+@&8swX zk0>JL{iQmY?MtoBkDTsLQ31|0ieMAcFH9N<>|iU$RsJ+jCSig|>CRKB^{U|v zi+@PRa%h5EfTKb`Uwm1woNKZ7Huh-y*t=Dbc*n*S5n3)Fl<|qucst&4m)2C4>C*(w zqu#Nl9aB)K0gk~`&81=5vFKLfqUWxW4_z%z&Yj4}CeGH}^WEFtogw-c_uPQ~&Aa-u zmj_AIzaly2vf}w73adp;=D+LRp6B&q(D6sJ&F8#v)BOox`Db=TMC*GJ^+M{(Qqq-d zSH%zCd-D^mO{*#isF2%n8Wp~Bnu!h%YNBdz{y+g9aKK(raH}kFMeUr2QJqxdO(s;n z7k2Y}fkY%;b|sNP=Hi*buMNr80^e?h`xTB!pUMW9P{}C)TNF;E_j)F?@Q*( z-LQKukR`Z+;uDP^M^UKqFvcu<5JHHtumy5X%PGV!$3BDu(u5@PZ1~n7hT9a%g5`Di zz~%B~)|a=v+b6uR$qA(jJn247pVw<$pqDkSD}$;kb1mr+YZb~7Ny;+|P`nS@%bYO( zv1Sp>avzA6a*LFv0cJA)AZaKQkT~-0eYoW3-`D9hURShgcjVE7NdC<;=gAhybK+g0 z^0A@)%si3smA0!TbZRsp6GAObvkWR&s8rTmPFc^O%*(huhriQ7QPA&!$01si5ZHmP zrw5s0ISL+iV-~t&!F*_%Gs!(J!z}z60+8M*ThQwihkXHk3qC#d{r5X#LM)H;_GCa~ z>pF>b>+8%J<;vPcDavOi%DdN5`F3!h5&jy&bgGR|%82jOHrFdEiP2HFElC5&#^{>( zRX0(ecVTB~ep~nMGS`zWiCUGvNI(Yo1XHnl-(@{$hiBB^y5r1|@1*mVrrW90qaCB> zZNbdLG7rdDTz)fC4oij*&}8+^%^KpxEHb zeOLwx43(?nHLa?@|(u9}>cLf!4(pxU_d@f@62{5h}TO_||a9aFqe2{{ho;aoN0nt3~hd&-zHX2tp z-Fp>&i0tk2!cADPk zXTXWzFRJhyTi)$X(@EOGHT7d$HvwZWBaROaBV`qk%M_N86;5gVm-N_l$k(Tm(hh-} z_O_o#-L$ZgaT$*v7at89VM$2%{mX5)eHmnC$d&=uNNNUHQ?6H)C!SWSHcTsC z_nL{IpBp+7ftszEBM*7(0JJQ2Kz_0JNKAPfa;-n56k4x=7NC1yEbkPn_G7C63^!Rc z%~;*|98t(?&Cw8NO#(fef+dqjpVI}y#5kNw5!L&7S9zL}a;$(nhY?zIdFM%#l*0a) z(e$Lns=KX2eOUULPe{D@C#YQX1f&R54HdL3d$QnYQEQUZA-b<%Paph&G3&5Ojs;)d zoT}a8Y{9nYObP0n-&iO=ZkX|MKdJE%U$0sYL0gdH(#_EwQUc%>jP2>?s5`_n;OQcY zvmp*Zk&y#TPlwSd_{rz%uXuFRs}W1R74v3S?@Kt!JmPtukymTUa+G+RsB{!bf55!V zjkpXXe;&^*xNBrE1v8_EPbGy~`eTl9eE_$6jpr{a7lZ&&NI)e?Z)9BgzS4XYbhR+{ zkWt~v{cO)D^Dux;7Hxs(<67SV%|q^zl`+Zla|fxed?hiWYE~hr?Gz~~kj3!9OT!Vp zo`XI!72f`3CqchXfB$GYWqM{}dGiU$WjoyMIKsEe_%zl~zh&bVt;$dsNlP&KxWt?? z00_O*$Yt_J=5RpUMfy%mCYs7~zX47AH&fC2Ep8D`KBa*)A@=9<+bFG73nB#Mwl~!6 z#5_iU=>ylD!JzgFsp1Q_>ZBBmW8CRd9Fup!{#+s>a@Mn_+Ancz?69qxgwg5?%|@e+ z>xtrf`-KF3oBs8?Wk|n9D`pF>GC?h)Zl%NOciW*gim(F=mtOjm@Ip_F=6^ci$Rqq*9_oAoKi@S&2E*;~NT$gdn`34|RdPSI z`gcwT$uq<8)V*2IF#YdbM(r6>n@gUB4a3O^s!A0tIJbT%uCO0K2f6>;omZ;<7`T!? zGo1_j+T>oOdZxdiF9fO2McE&RLANKpS&oNAt^+fM*3&{tzWFxk;tmhx;1vASmt;T3 z(dGhsvT~^Fv$69FjPYI{K>RI#UT&lY7CDUOR1|#7hSn`e;jfFL?Id*paLxZH1ysnq z!2aAj3&0;ShGZ!EN0rvU27hUPQBiyiQU5ADWSMu^-$U$ubln^NhvEg*@rQqVN6`ISt8^1+qU7J7 z?FZ%|o4qGl-v~~d1WYm}`ez;w%_fJ;n!Z+qB~I^0jC(r2W*xGp!pqCiCSS#l!$IXH zg{@$#;>eeahxK1h^|#ZRL4ikKU)?z+9f`S3Jdc{*(WqNY)|_K_;`%(D7j3@enmgJ81Z#Lc& z!=B}Mj(UCVBl}~NeS=Y3-XJ|6dDVMuCWjatO7_Tup*05fB9&Qrf>pJJc^H4cZ(h*yK>$;dtRG<<#^lwPMkp*%T-J3`+c!SAETc+G@kQxNB0IO$BaB;d&hldP17f7 z{f4$=u!zg(&c;FPGl?j*hY74)Ss4DnOaWq(3}$~|X4!Jz{u&Q0FX2OQwT;?p)|TYS zmd|^ZH-4f%`n;dE98rrF{j(zUPv=B}r|*;{oQahUSL1fO&uqmp2o^B|{OHqaWemdr zH3%BSf)Ww1?=t<1>Y)YA)xJnGv)AG-gXAy!Wascy+A3EYSv4u(>3-sAmb1dk&YWMj z{XWxubLR)oR=BJK8Ci3tAbS=~j-TS7OoR8*)K?Mjp%3lnpB`GMaL%8?_m{8o*xuc2}xLCXL4Pl=`JeiW?ic8+~(3 zEgRNpQWAGAbFy`DHRw&_-fc{7gtdTAONl`@Dh+0>#hk`}RWc)%%mD;uY(ID*m=sQF zc!1ax*x{fjQGi`k5DHKTE*~f1?E}(@sCs$BCz2pbuT*m^oT;ldiMiU>P%vi5m9@t^ zcj;tP?1ZIRk{$eI+1qb_c2{C^W>MA9P1G^Pnhk7nB-eRKYOrw_I_0Rb?AF`wbiW~f zeg47f$(rp%N5A1_m27X+pnPkyt(!5OngoD5ta)V^gZo9cX`_bf=-{+MfzFptlqe3qMbAbLJEL4Oi7b2^uc?)Tc%1 zN693fr`%8z54kgrER-O@b!=qSaCdnlr@|HOL?K-x-RDv8LeOWLase?28;|G%7bz=l z|1H6(Ir0$y2`n6rD=k9eOg=T&)6qlg63@u8OQp(^PXE0q`b_-&5e&B_9OLdyFDtazh$C zK1x1U!F-SuEN72L+~{sDH(5KL+WFa`ciGo~V^_tbZ>W3oDBCCDbCS%u)IoqyP@rC z5;}xoe81RZF=qX#wx-ey0ODXAAzv9MCym#N%uUn@n%BG9Z$C6JxLSYim7lzlqvN}` z(;S!d(o0b7BD?b1bVX8mZdNKb$^?tZcAt<_gMWi9UD%I!w52En^cRIoWb-oOjhrl) z!P@w159X{f;kVIJF>{ZvF#5bP+wFB8!VM@0)s_spsBLd&UYn@jtu9y6UxY@-z;vZoHgBv7Z;reO9E_=uVFc?~WCN)=_Yp9vi36yct9{7{e0QsJx#)!BzlX-Z z%Nra^`D2UsPiW?k{);EfYs;y%U}`h^2K&%75V7jRfKAM#sZj7(QxIOqdFkk@+lD&@ z-{Ux5ksccGK{p`vYc0k!t%M-{OQ%Ucz%;_c?gXNASX_()h}^elw({LC9kQ$l28=AI z+1Ir7ELxc{m^aZpP}V+F*(s=zJzVgp7(V%BE%@K~%xuzU^2mIjpP{RzKp4UdCKC5H<*m*lCm0S7;G;yMfvfQKPK?X_d;7&~q=RIfy;T*}dP! zgD+1M-ib2bTkh(2sMd0f*hhbRT@l7f{Wkc*i?te8Y4L^Ghw*D#8E^KZB95QW3+1aP zu=U7j>o0||YvPMEz8q}VS%^iE)_e0M&1<7fQS|1WAqiVZy}?C!(yzq7sEDhsdjJ!! z7_Rp>DxLd_Uvn5p&cYm_KZQI|KI?|LW#x6(oK8RGnu??U%NLs?w#X9(WcR)C(5W9| zCY|f^x%`02siKK~c+69|#51?jj2$N>xD zgUO%11Ch}hz~YLh&qLScWo}dt3G7F-%kA)9)W>T8nD4 zMFy>miY`b4{twK$ztV0gAIH$jrTaOTA<-+I&DV7-V(|Z~s&fA^UHbn)udSVsL=HDW zoH{XMZ*NCNJ613pjnMoLje>$O-{*Lt{NzF7u(7DQ2x+c)@ku!wf%urGIcpc&3XFjz zZ3O}bgWn8|%_`n!cGdJg1xZZu9u~3P)M=`I^Vrvo`fSbMi{ejmhB;<;WxG>U(0M;< zWoov86%rE%{2}@-V()!ERLMz$_cVc$~#QRX%<}m7X)UlKD6-LlF{Oug$RGIU2r>^|K4t$PM^<6S6Kv6{O_*;Saf0(z zRt~)h+Z5i0FQ{FUdJ<7+z0@lw$y3;x?+hxDyJkWPS1l`xd>zDb{@JrL`ucBQF+ZLK zY&i_GNf*Sm{mP!pwCy}RKFtZD%Mwr z^HVQ3sHNs)-lT}EOT75ne>rEluO-8<#-~Ro2G6Pdk$~&M9(L&f?v$v-5+1!rxmr-= zCDfcxd+OZs(_q4W4%Sm3&LZayJT=#5hvW{a8sv3+l4qe9I=j!$E2 z#;c0u4gh)MYGmEG#yHg?FqN4=w{xFb|k~v&{h;j{Td=k%daZe4NSPztc*_H&axo2<5zs9AhCnw7Q zc@X8?EVGrv9UI(Rwl|m6%w~gMG_q0c@sxV<2NYz4>P+?MufBF9=ets)=;TRM|y z!JLsvE`}WOB6$=_pFg_kYPN@>y!4l@eK-UInTMP%RS&iupRefM%@dN>&leX}R-ee> zYV8d{nDiBz{AEpbe@b~lTI zST#xq-0dj&C9B*Ikj;n4tx-Fk70ECooyf&*-OiD1bgTleyA5ZCbE1=)!cy(w0TA?J zvuy!a17H(e1(GKV{E0lG{dw^pRXl|H5!}z9RR=K*`Eoe<`pTjaa7s8qVh$iJs3Rs< z^a)KjITF4}2%2*xJQ;Md>RlI8d36g`|RZ$D+?ddmx{ zKKf*Fix5sxQvg2UHGrDw<3zVRQ{75)?w2{Qcl|c0%SiH?|@;9(g`iUVr6nQY@B{x zqAKW&q$ZtNp#Ca0stN=}emF=LAihhT#WHl~uf=rUzw(io5Fn^vK6q2?ZtJ}945Mo` zf`Qa=VpgIY#nDveYS=O>7t$)lYseUT-S|4~QF!@2&9TmXv=~kFLZCkR`gzbdbWdi~ z*Ty{Xm=@p&qCIgde$c2j@nW--%Iig`NuieL~Vl-HKO4Cq< ziIGIvIJ&e*j(0{^4-x7Z+T@j-z?_H$qy+*u1U_aThEOBg1qk$~?q30oBMlumJjuSn z`pM`N;uq<}vM4P~_o(>ii=h5|E1!<%{r8L`X})uHIxZd&I5atpC)d*FVptJvx9N`;5PM=4nHB(s z&oEqj&@YxcNLOAN=rBMa{yvMyu3gpf{FA`zI_B2N%zyKuD_#FjD_kll;7af>$AN#d zhW^DIY6i@5KcEhPGW&6EDFEqj^LhygsDh5qn}4JKYv3PPtyVyVDvqL29*W%rYI@aK z^ypJeg;BpCavP7{KcSjQHt*xVN*4R8gLb@17YiOA_2C~c{G`sY?1mjSBhSHb`2C)w zZ-7&dOiNhbm%Z>1Yn!|bzvnmJpW(d&o224Ia{E1&4}M0HCV-l_m~k~$*|0`KTgk>W z?eVcZ9W(FmVOhgHrZK&l9vNwuNqPma)`m0|0!xVNmHm9iLYbTB@9pRz96CR?2{5@$CbW)0}Dk z!`rntc6vVX!N}<)RZ(Z;o2i_ykvI+gA?eFwouvyGx+>8#_;+*0sqK zO$J374D;|j)B~knj?WRahvNQ3`nlkU8hd7JgA|!hGH@IIhHG55)^!9knHi2!7f+J%RLnevnM7zTHkcvzsM{ZlluKL?ff_Kbs3BD zlgAo4<4Ug%Yb$!B|I6E`i)+mlsR8Pxe2>-(#-X^x-$H& z!N@|)vbNy4XRON@E$KROQ~Wa0niNRbp}({e0!|RyD$kt=&<++rzG`0;)EF2kxq`JW{HInFCjNWJs) z)bFbzOAiv}XrYNhwxSyBU4Kyt4S76y{Kn^1ZG!W(?=4{B*cqxuVcpObp$%&qa+z+=sRPmo^@Eak}6FP}cK@*7g#TG*#$7@c)& z_Ja7sHEiZZF54*FE22{RWvZgk=UkZT^Ok?uRKU9##mI7O%`Nll>X^J!Rt>fu!J;5RRvZ992-twV%8v#f+*gIx zTPA);f$wqckN9+BxM~K;RMUbZ%B~1jv(8+z!kaPDU6w0`%b;y`0&vUQ`lLqd*?X@E z%7IxQ*zG5_LK4?j=i<6}?KmFuP)D#+9Z%GaBb(%(2sX~2-tJY8VN>W45~ju1(|COR z6alF7;XKUB{kLi$9_gVq_LjEUYZjK&%o$lm{Cb(uGz$V$~wluOLvy6y?7alyExl>NI2R^FN zN0MOgS5hstEi9)g*Py$h7&JYlpIXPMb*Tk5i|b%Anw)_v3?tqfS-!(?-9sAte%dWd zg^G+N3VWS@$3JRVy0WV{4;i;OF)Vi!M+THg-3m=g*xd#cTUbG`)L{H$t^^`tPLM1{ z_=LN^c_IDh!)@q_x!)8f-w(DiHClBNS5)tN<-faU1)Ag z8!pC(%|QLcHZ5r8>6u_&kkjuawk+nuLLVQYsZDD4P#;onsfl0lJ*O%e6UWZ%7!@oI zu1j@R>928}w-1uhB(qexxbqtfa!u(7tiL3@RAmyRRtyG;5o^RCdANIqJkoDLd7sNI21{)A+e&EsBfIbf8D*Rau2{n5_ z6xvU5(-oRLpqPC$c4T0cmRne=eS@()jm0h`r~!$;Z0}O(-s1l@xcYwy82wKa_y1CF zt&~{GM&T#RJ>&S2?Lf5!Z(GM}nuzeftO>Yz}e$YMYXE3ifRBI zT`Pec6M&+qU(h1e!h_Iop^Olf`fm=#Fz`S* ztUg`k%iD!CsXFEz=>t=N&nb`6IewNS>P`g%#5!CYq)1wTbz^x!c7b}-*UlWXa>spi zgq#&&!^VU@91h7*Df8IpH+=-Izf`$MDWIb%mtWtuXZ(xmXYlDfId>Tx1=-OOr$1h$ zp+uT2&wV_{r|gtilT*5GiT2$M``t7cD9@~>1XdElSjUsOB7=R!kC0Q)ikIGoHsGFO z#QW_dc9-7c9nkT^b!$8gpp|M*!KPDmhP46BvmDZ{;-G#RM+Hxaq!}9^WQtEdfux0W!#GO1(#NaYFW&8(2PQh#c3+4`t7k_yc`S=6&( z%@pg!;O1M%5R(<~?hYavoq@I1K4P>-Z19Kx+;5Brsgb>zCnkRgXzi|lSs=cFN4oF+Y+gef@bb<9YXcqNj z^UilCz<*LO)?3Ze9NffBgzZ`^(~4xjQmKv=j0pmZcPz6S*q-`V@U<=pZ8U;qP1SV>ma zob#E_oSV`j)D$CqY7cFr%#@XM#`Lr*ss;gByLInimeh-Av{mf5Rb@B+z4#9p**=q1Lqs@U{&MTu|3)=DM8rYUpnG`~jvcXZ&^EfdrtNVCOzrmtOV@mh}%>s|B zVF99|IB5YZ$(6L!lH1e-4?^{l$aR9b@6)|sCagrm4uKbC@S=6aq9X@c2>McXZxsaTyxbDX6D{>$vm_6Tafxg_Z<>Of@XGv?< zSlOKcA5H{0S;M~?3%R22G1Z?4B|NwX`2X!Go*lY)^{x*yY$--DFT-hXR3`KY)tzlV zSYkEQ#v+&%?%on0BZQPW4}h{xc3S=5`FK=bGH(y z4ZrQznO2|~5J=Ugt|Ps0T!re2m|}7Ar$oTdB!NJ}Ukw`egTovh1=4LTO5S=h0?V|@ z3S6YDi{(g^P!Nv`aN1{Q5z5}D94pzl80%8VnAPCGrs~*H`2l%>Y%lAXzsqkp(h1Yl z7FEYCo8p;Bg8a0F?p32YQIdQ0H9!P_5SfJ>S4ouu7$x| z7ONwRv}5{*TlZ(B39jE#4tn-g+D1^z>_Na<^I!dcaTE4|F2cv&o1b(57szG5@`S!r zZ#}Ee3sG-w#D3j&BJwXXeU#2LXtx%tpLE-mrggbsG@P^|tWq_g0=AbmM``qJqGcqyJ~5|9O)At*xnl31t>_6|%mc&PEilzOx< z%X>vO%Q{z0X<)T_bi1w*M;DXa!&}MlrND20F^_QijZ=%;wu+W|4!EdELPe8bUCT*+ z>b@lsv~R!8xvOWup8e9;Dwokkv9R+r|B4Q{M!n#*Vr^V{-Q1kDS}wd-*8;ex^&lSA z2yEfke!^V&mS z4Vzlw09&bxtqTXg&?$ZV{X6$ct!5K`>T3%X2~F0JRf`A9oa~?0Ki_YbS?{Egk7d=L zg^NCWz~!Isq=58hJd22p`>X4dQcKn}(7JJT#66eA7@!V%bLl4BeaJ)Eq1T`D1_jzHG|C&bV8^a>Emfu^ifqY=YDL1KsFXcO! zo3tit1cF*IL?JFX#XB1i|7!f8+b2q8*nw`L z)Y4a&9?Jo4xCE3tx+~E9w{RH@0uRvKA)mfjM0Y1+E9gdo*0Qf838PHYB|EQ?}5$WdOXdwWA z>iri;rCBfelup%9rS8MmZ#9yZ?YNtD=uUOu0&Ig3(<%jTX#Gi&<)t8IhTnG&-Qj8$ z6A_3K<&PKXd9ejpa|nSm`3m(|NIoY-$>0X~U=c(OH%E$9hCUtT>)!>-%)5{D7vyLf ze&-1YjT{_ljPfkIuWIyTQf|jf^>7Sn09Drn}`` znG1{+8l$W+f|qsA)6w|eMrm3-6w`dwb5QC6>ImRVEtbj2cf`%i&B^Ww#L)Q8&ctdh zFCSRd*1T}D=aQmX)sXKS2DWTMZ~5tX`L!joqveL{PHAP&nMC zXw#jNIw`^l@(`&$$8yeSCdAZ+4+*GI|1dK);JdU5ZiVq=aaRKhSY;^Pz=J%w3^)kFKiJArug zC`vN;F1$|_4c}S-kO8#tDGf|X6$spOAzg!}27||TrQIR}l9L?cRQ8E4;yhY0H;pZ` zI&LFV;_gzH5+diQ(FB8wBJi3eLD_ExyNxP|luXoe8UV08P3Bk>Z5=+K~b= z^@sf#fkxjdL~-TVCp<}${AxX_$8Y3^$0Joh4%d+n@ClCWo=VJfd2O>_W(x*hD#SXY z$wmu`h;!r{uW=CgR_=LWC;PCcFt~Jze<;d?1k;|@n)iOhN@0+((E=t+? zcfW_U^UM2(-*A%bO+tlh+6l=_=*Tg%?h?$% z$G;T0(dMu!ctO)$b&G&*4jvoI?QqvH_&J^SN%_4Yu!hEo&2vCGFkMoJ3Db>{TK1?9 z`d&meEk_&!Q1p{%0p&dtH7bB4{46MzmR(($AIM?7g%sK{>$ZwasLGHg%D~K^Us0ER zLQ7!^W96o+x`B?OGg~yaoTNfJJYu-`2tx~d)`ZKC@LEi%joB_YS+35OP2)3-EaZ?o zZGD5tUW^{)FyA%ii!nVFQ|cAA!8x%Sudu_DGKcR^fPRjb&*B>$nO3{2$>1ejAn_?i z{dGvv_|EOB&-aJ>?kvY@T;vb6#LF+Et+0bE=#DJ$@k5InCi9b46=908qd8%KFI39% z2*K_Q9s;f;!vz_Rf%Q6(ClI0rW;GVoA(>22#adY6c_+URfC$oA_;w9sVHM?*L)h%* z(D_(%`+8-FWRh1mlTH2QZXBi6;5?yxJj1`lmN0CUXj4GG4Cqa^>WgX%7`A({>Id{G zSEBhXhE(=yXZf7R^GqCnm4JREQRFeiEJl7Ny&kskY;sGYw*ToJ)JT=?QcsjL1==~C z{~_p&?T?CP4+gm>SM#6nyuE}~HV9ZH8+qK$VSIa)lYz_9@m=3pc|jqSR#`+ z9}tJL?D2z3WxHZ|escj2s#q@O`@Q6pR(_=@Z#B#9ocA>2(>a!_xyth8uhsp5HrKjf zd$%67y7TWpFbLi)EZo;Pu0al=RMzI8Rabg4aX_#rcx12qT5NMMSmjR|*%fG3%7c4o zmi}hCS1+>ag``j7wsSq2TBir|HFx)>@R+qQiNAK6+Sv9yuIRr;o!<~|?^A9 zF!xD#r4pzNoVZFUOGEFu%&e~Xddv3|z|D^do^`j4j9U=my+nPv_Qqz-aBZt;8Nf@E z+6!X@W~IUP1>XbG?~q*Tz>~-y6H|$gS^xuc-KP-6t6|9d1kpebrgp1WQy2F!PhV;% zZrAM*Ad_@vTv-uC-7~&|)xW2DiO)hV^)2&>Y!coyQit>SnhVxoE}<#M6aF*$HY=F7 zS)%ysIr)UJt!Kz&l9q|Inj-r&r)7ri>E$)lD}P?2&Wg|v@Cp`XkF}TeZXw9qp<4Bm z-q&M(lDe8WzFWGexiF0KPOW|N>j~IGy{diTi}(t@tLv#*|Gn_e>j z#^fR(ga@B#;porlNFD1w1mJrZTALK-*RJ!$z0F-083XjBaZTzv)fi_GCmpyREW2*< zY4Fw`f)z2Jg`cM#yu5pex$9n;dY7@Y@J@eD!pso`F(~4 z+^#8*wkM^{D}BYs_#>0wLfq3Xm=zx0Jslo#MDz}<=Z_R++}ynw9axF=dyLfPY-jwk z>D<3mIycf$$L^@ij-hG-TQ(i>omj)qM)@@fM6+Na13aKx?0{I6qzkR?$CzJHF%ejN z!l}z=@bH22)=nU_t_H_5O-&Vm6xm-H-+bD^FC}c@x-#q$3eBP_PuNT%s+$N=YzjW5 zaEZRTt9LMj{WOwuyP-d{%5q!7%dJ?Pa%c%RCz2B-gE7}Qj`Sgib;NADE`Ju&$LH3L zZFAfJ7Lyc?@DMRQgHGIekE-{03bO&=g$^kllhqBSJZfVxv0P=K!Xdc?z{H0SD-7JI)yG%FTQTF56iNgEobP`Uad% z2RqcAYJv!L)89yf6+oWLXTT>~WkCLDk`4F+4?T(ib3cZ7RT@fK16X|CT{BD-y#!7n zM4SFzYv7fI$x%429@R$5_!UQf=?PFt)Txuomv50GJ8}Jg00V7?kMLzM;6=5b#E^K7 z5dVhL|BqWumk&_?mY0uE#LFNk!Ti{?=0Uz1Fge}myc($EM0;l%ncZ947i z*}rD%|IABlcy{zRFi!JdX#C$KV^MK(bOZq~O^>7h_N6jY&E^}dO!&9U3+3Mu;A~FJ zk7N~1d*1y?<08^h{s(Z5M4?Fgz}!WxY`MJD^AF21ZExxA8_5_QiN{TrbiJh2X0=BI zyM-C1>uNyAxJ)c@)n0V97venyQPgGbwK*JSJLKu`7RzmnCqZbxw;o{Bf7lP)9R9^@ zZ6bR~-!ngtb*F{|QcHSeT-09me)d^%?VCol28g+uLrHZp6$TcFIR|lic`sK`lUMQ( zA&oofPzLGXEUBz5b48AWx7Q9-{`hHS$@D3-%@SQjNTECe?}1iKD*)RBmf9%zBC5?W z)7S86Qw`E8_Gg`Rr(wHBjC>8Bop|O6@cxRT}FP|L=erG*!-+`{23HNO=cSXE)4i-dY&j)-Ga3Bj`-IAmEyu$dz zU5BLFedXvTxw5cw{8it?MqhLg)%N3q%0TfWwz>7}eTH!Ad0>BsYtiPf)muP?ue!Wj zZ9u#LtjV1Pm{d%loOskS4Czb;1O2=1)L)ZuJ3>*h90-s5(KiNaW&e}L4q-#d1TSm1 z{8d(i0$x8mD#{y^i3CMxeV4EBWFGH|% z($M` z6KJWLtdmlibH(50WDBthtGlo^(y62`jA+c0Zkm}JgC<3(i2gd919u0+#@ld)P0rt1 ze;4(1A_5Gedh?jP{hpeEhMzw4?q9 zSs~v(x|6bDH%d0YCL!jUtm%%}mmpXrI9_@kd=>10-)i$j^fl@!dH_nL?sK@#41spx ziVUy9a8HScM03{s4g*v?nZ`FRR&vB^@<*^Jb#Dn3nxD85tz2$;NObjt3!idvn2q!= ztk$I^9KHcR=z7J9m#3JH1sZ(@6|Xidf!VadEY14b!{o=V@WU77XPfpSv-b^Q03UzB zolW6zeus$q37;2+M^NoI<;71kLAaiMRdp?!7c2u)5l^?0X#S)ze}4X>F{jYm^PZ-M z5; z(l+jq_da;AAU+$e`8s)=COsPIMn7F7G!F`OA~O8VVJ^TUrdJEyvTS){%^tHdLN>i> zTb>*p976q-M*Jz{(*y5m0zrJ5+}1T%T|^C821kD_IYydexVG;T(-_-I zGQPBu&Gj8Gip>z%z$k%CmWji77k9X1MZNO-0&K0mpEsM9(GTERL?)GA{}!WA)PknY zUH4jL<=I+XPU*Of$$gJr9qED82M(aAZO^Dn; zYiJ%-z!U!xY4#C1o3i{SyGAK?Uam4aqj>tPN?tFx*Ld%J>9xh%wQJun7374=D{zOB zrocE3vksmTfgk2~KGP;d5B;PKV7Vx#v=I6$awC}S2Q-v)AWLK206yttJF0g|87dw? z?}O!fo~|sylt3qNnE<-?!Lkn4chEc4^tvKZj;^7_+P4Rr`6cMiVQ*YjKqnKjyM40+ z3F#hr3_UivGn=*e9?AyGs9l2Eu2$Ola%!0$6cQLAv_2%g{pKt>CRCNkKE0yHdf~9 zpQ-Q9CIr#Q-_E6raXVL}>o6okoJ)J<|I8+>zxjGqo+7!yBs)N5toixaF7M?6YKX|e zQ}2n%W>={&WGX}=)*YX=^AR3cb&V)mw|kbhKB;4bPv^G zOgZqrr(=K8Y)i6JKVK&CiF*{TOBcIP!@K=lxIXho>6t%`#Nc(Q$X3tj$}d!&IfP=2#>xZ}_ovfo_!pCF*WB z#hzaGDs|k5SW|Qq7Xn{?=L|j?Iw3-arbX`jYX0LAO-(aRV0N{J^E&8K?+B3V(IIp$ zDv8YT7On=B-dLu@8Sw%viKHTY;$=3?0Vkc1B$w+0GB>f{t&`<3m6T>YU&%M}Mcl+> z5;#%1Nj|34M7Y}**i+5npKd^0BV!W^sjg?F)!}@iSTlv+vdiq<$#EH(1cQ(Ou&n&xn~Vkz+YacK$wW@+Qlc7_;PB zG{Pv3;9r=Zgtg*-@@er_hAZj*4dC*yLA^M@17&{Ar*Zk5FK^JdmvB{U{pjAdUyGkZ znCcOl1mI*Aq}2!1L|-^bPZYKau{sU7zav0BM)jV?-e@oQ&^f-RhiK5I78Fx|ro^k2 z`h`tuPTN=hQ4$(rqDrnZ1M|6#jxDJqzR|ulJvukzHqULJRDaS1p=2z-z;Y9sq(du? z`Ug0uHJZnvV7xVOf7TxGbpva133%IqsC)F9u>K;2G;BzDK@teE!=;$jpsTNlJgI_c z4Ja%>e9U*qP`y?KMtD?ZCpWIK`q4wneV z&~w!_1(J@yXR)xl&D8hng0|^(d52^GupzvOS4;md@Nu2i`d6;sOI=AKe&SyS{~ z@wGwIZkS;%`V5DGE~&`=27FGK9XOMk{rw=dZT9PgWYZ}ail0e{4kRNn2!wsoXiL&1 zP5GaX0A*_ams|D>lfswODKjr;8ILFI1G zw`J00Gy#nxZ7tD-n)9+XXJ2(>DLYQNf0NVo#qu4Sx3aPFH)ulwAPC?HM9EKo1>f=O z=m0Yv4M4B>lc?cE7MtwHAg8*I7wT3*`YrR(FZLHIXriM-MDl=RjKHI*5$se5Fj6H% zo&e2OtDp!j_Rb{0Zpj&3F;M7?PWd68%9Fmcq7|ZI^NLo`fbN`4ymb(McmWgD zY=_lTG{X5!y(n+vOGiBD<8rBqT#`?IEPp=sg31$pR)*(kr0GbSkjcsU?n!PDtpsO5p=imcvtVDe335y9P%#V@jM59 zIJDUVm;JD=&Y3=?ZhU74mi|I8seQ-s;>*Zx6?e%5Eo`hx;-2VgOTFh+gr0bc-&-5-El#&IQmpxNP6B_SbfQcyH?;!Q&c z@4#3FpXB^o1ArPF zo(F`5vk0q$|wl}jR%v!9N6-MxA7_;>v*%dlH&7g z8ZjyT9AL+%6pbJV*^O9fD`XB4J?#T&}b|4PnuW=#M(?OcL_nzx|+$SO!heb zK;0`B$@}N;uXRumMs=Lk%FJlUO*eqsc3~*~08Gdk1sb&pDC%24M*5Fu(9_+?8XyGk zN7rMxlrc8?DnLRiX{nKsf3hS7N1%QBneiP!3;tjeI?j#)2EFtk>Nx}fBKmEU?=Jhn zmqWikn&_Y$O+1s+Y8k$BuN^fveyAwU@}F+DKA6v9k&@K{1)};6=w;@re%%rGK%Q+J z^sw`JR_j_;(r>?TL%dlSNK!K<7&un(%Yi{6e4=VypeZk`Rvh2UgEipZvg2~N%Ac)N zL34I&2@Stu6vnA{`lL5gvDVjcuSyo;!W)xQOTVNoId^^G`10d7ERQCEPTt5mW@$KV zn(3p@j;u8*!Kd+umbJL(xOo5k^K)$S%h(O#q-K;jS;L$AEy?`OPZ)6%9urhjff`qO zV$D1K*ycn6AOh`&GlvM7WD5-`whIV0mtkM#te19Cet9DLke0X%GL0mDrZcO34;bg? z1r2_O_sTHkP+v9kjrQ-39%GGbaDy4LRd)}BG-B*reqNeg^ySHFklqL|$y*vJXuqZr z9nBX!n%i((Rv&m4xB1ld#D%iGjD0*U*Jo73F_WYL_5>^bWuobO3IxUYLMhj44$;7s z@%1N|7)z1BtP_OM>-R^jKKMN13f`yotBSJ7{PYpO_sGr}$ z9G!?C6jRURmXY`Hg2-@&Z?x`}cD+NaH)lOIFK8V}p|0}zE4h@l{Cs&=Z#?}oZ!6E? zaKLu(mET+U{i+Pm_qGQ+8PGDDT_HvBFJ}8KDyCcGpsncc_K=bfHh`iykO)M`% zs&Q0xlCrYDYY-h?tz*(0+|E76t@S>yC8YJ~Guuy9H($N>QA)NslE%#izxA8kgQmAl zFZ(K(3utt&dHJ{QoRyL9mIsRG0AUjvbLS2^A6qAiTpdowHL(eQG4FTX%Q`vICCBMc z9G391_oU^g=JzbmsM-5e_+IbcqO%!b2}qo8?rt&a+_$RhR5n|AFVH>e-q&=4>BkQt zus&R8zmd4zepg@HWe~{tdYtB&bVo;jryW#uoV|@%6#wDK9xReNs#g4P=_>o$$3I59 zi9)A6bdjsnn}r5(e3w7Pd>zP72so>b^(DP*8V-S^9fFyv3ESSw_3zbu!zIrNaM})^ zdn3JqxmGT%t#a2-U$;a(%qP$)K7tllLe#=$O({VzOLw}nAZkfIQN-GG@^Fv4U(w_e z1GG9wSaUg17pHj%Uu8Av27{Mx$of_e+BpZT-Z6`Kbi-zam(Ivq)e6%|+!YJHNr-f= ze=o^B75Y^%B0*@bhbIY9B&bWIa%H%JRfun)cC;F(o6nLh~?|O%Kl7eZ8|^ zRNu|KvkqL`xaz0Y<#Fvm%H%a&2WsW^6nm}-T6>c)=DKLk_3t1 zY|MLa8x%bA8XRBXly!CDFosqiFSrnG?v%R!rWP8v$KXMnKwaj9SCXFU;ZdH022(og5|UAJtlv^xd*~TZ=Xf zmT2RT8;77VRh>xtcjZ72AR-Tn*u}bSL2vt+Ssi(POdpuBD&p+9;&wynlp&48zhI3h z+?Zz>H<_Q7&#|b3HJ5PCt=xV%Y)ylJzn_l1C61C^j*G>!$uAVJjm1-TwTLDy9mhaE zUP7upa83A5OzAJ*R)*J|7`7g5WJK#=}bH_F+|t8j;bmOs|I zOguJCzD+)g{NrrByP8&DoQFAxe%Hqo$-ar7GAqTUFX^V9{q&}zn`sg1#-mo0+4n{d z4)rJMJ?6@7|KUhBG&rO%RqXIw;Rcd$L|+qlhnPzayjr6I{xZc})AJJUu zs{Xx6mSKVXyF2g@Veb};6mao18GTVxb9ze^7N7m4Uc#a;w8`CsJ4J4$cL6dmvEm`) zqhGo{6oukXAU{1+y3~Gz8^=rGh37}q%WFr0HSMwNDkwP^cy}7Y5L=%8wJI-&1x9y> z7)9V+SXA{?#h^&lzbsQ>!ZjlHN^AXNPr}w@11f4dxx@~CQ`(55s-;HP-=>Y_yaUzF zRfCCNj-NGS+%xS#kRYj|JF%$K?~CZ1B7C_*1)u!N;&Z0K&u@~14{ZD3E;w(@edjEO zQ}(p!690U-q(j_Eri_asa_uDC{1|k>ub^}O2Ch7{IGwf8@wu@-Y{GC4)cX^{!_x05 zvceR#Kd*Ye%7pM1yY_%n;3?CNof4wmE|+7!yGESqrs*$FO`b4}4nPgmS^p${;Tyv_KI;a#VY_3e4Ga@X<2X$=mrxIlM{<{!Z|Ybrv> z@nwSO;(Ebbdfzu^ZEn2s6z~dZ35Xjlbn-ocfM}!w4XWO`xRhxjT^AE?#%j*gmEWIp zo*tBpOp;#=R3Fu!cR~Cjz77p5d~$5e)r|e+RKP5|ODR!~>|^4Y;oPNg>k@-q?6 z_Bz!T4SJtma<-qRzZHFePBcL|fH?@YtJE0{WLHw=O0m?Zn-H_hI-L@7y(^l*6Nhq- zy&Y{ubrBCvXP`xR|70UW6nl|so~yPqWjvi6XG=*(aKhMZSfZH}{AQ4zE2y#{RVZNk ztqQ9^XqL}y9G^kl(wdl@Tb<%!nXN49?OR|uA5=prSOVbc*&ZNaPK3&KlyNiS_zQ>1 zbkH;?(=X-SF=vZ0dHw`bgnnun_jl;`^q+rsxYnBK^Uo~A(tm*byb&=U=NR873T0_i zNwqbFVgT9-@&i%vK!JBzSRGRrDE!x)6Ure?H{n)F@1y;= zRFI--N3ir!*S|RoX^v6bZfpRf0T6#eJdjoTZ@0I#4%F41a(BTGCE&~7;aJUM*+>9I zzXULy)}J)pe|2}iAD4o&MbRg~)Z$mb#`vWV0OW^a4nRbo=r<(vH&q)K7s+9k#pNM=1Fy6pw2Q0Qw+NRpe-nYj!*cLT ztX}RoY24oW!4yb^aMF=Bhwi{!HSjN?|8ld|*CD1+Z10Z$*Il}3wn0Bk`;M$}`l3ul zpkrZB8q?2UOWYB(b}Lc|^Tn_EhOMT+hzOJz(BQ zPbFC!idZ*8D~S|t>bW-N=U!RR`h>`|qu=R|O|OE=!6bjV?NQ~n%wZ6WU6v5?&6h`I z_5PWkaFQ9Pkg0(08j`OqLVo3|ZeHXlmu&+xN%4YuWV7y9JH{Ms>aJdJx1nrZc+EgN z?lGuulic__?Lj2L&Uf88(LTQ8xrfv$26Z8K3*NdD7e;=PENTz>=qZy#Gp&iiNaN@} z-g6U0%H>5_x;Aj-eU*F-f>qnUna{YE;f)w8CX?{x5%B!rUnY=~yHP><)JxAjMnaM_ zKuF>n&t-~iqY)+Il}D>R^%Y84N%L#Z5-1unO|SFZVrKP^ z1y_-J+YDR#1B7_YhK_fUhGK2XeB4)DE;E;3G2N9YvP z6Of$hY9hgwu`P3}c*ps==h@f`sy*1OnpP+CSGl=!HDM(#f_eQk=JKp1G#&gmNRcG3 zZtxYDnipboUhZ9vGt+k8mvZbXIr+kavsk+LtQz7jG_*)gH(A!>v)gjZR(92zd?uIi z^EIhlIXzKtx7oTan-%j@?qe9vwao3C-jc59Qx|YeeCyMTpWnaj&Dia}B)>7kh93-$ zqYlOUqv%VY;G6XuNW;V#}41U_yCclQzSw+ecdLZLRr+XJSnbvM6`SlF8mmkwa?-+Az z`t_Bc9w+>OGZL~p%*XZ3pE0VPiS}oVZT22AMTe1%;t?vh{bD|aix)mFiG5EeTu+uT zU;7>TIqO!)=za`^$AZEbe6`K6^WJq~9dh-WeR+(Yc;Atn;zzxf&f2xo#{<1Rb8;I` z&>rP}!6#4f@Q5<0-i{a<=Z)n)`j5}!>zK1ue|uRf!-4|DSRBg$*_`&US?sp-jkJI% zaLY9d4Z-@z+v~f<@N1DL;J8}$+Ta^w-;s^%E}|737bhk68#dVy)-}s%R^cT5B(J;u zR#yIj98YynYQ0IBEK)>OMu?b)WFMN+oJmr11Cd+oy_qWHcN zS6ll=T3YHK*R^oroG`6sMzYURi&~EfJ7B?Xt3$J^IK7U48TZ`kU8~S4PsR?fbN5oD zcAnNt!@ZlAzxoK^&Id{2-E`CS36MpRz`|+y9oWZZlg-%Yf}c-ABnXBB-&VrQcfJAF zk6<6cI|f-Tv+00;xX@K zs~1Dr3mtdb738xduTJDXHFP`~)g=-6l5oO`MaWxb9d%j{k*$<>-ILh7K%pHP z^sx#kVc^MRZew!bk4G2euy2g->Aq#$rj5qOI9cM+>_U}D%SjsiD5J9zN+jQ06ed$U1e3xvv z0_em`YDf_$*LkJn-C?6s%i|Dfc4nr#38rzXb-6crH3LY2i{%x?ekp&9VqNkHzp4+seDipV!SQP zbh_RTSF_GH$m-`n8&2zOHd*cZ@#wWnZLU1c9dtB|s@ijCQRY0XZSq-iyRsg9cXFpD zdcaYUnHmSWr&G9<;d6+HLhQNvC~n1Q9FYg2iula9XRd?J%RizSvO?us$$CN3-!>C0$n?6h6G~rp z;nlZH8H82c%Wh-tXOR0Dl?jOSZI|TDR}twBrg@2RrdWGgeAgTEsq}}+e`)3Sbkk!6(Mq`tC4RsXbHU(_iah3<|XBWsCQa&=Dj?`ybtK?2P|P!>z*sM zN1!&Ld?)1b+uj*>%;UA`mN=|>PXr9SIHO27DtO@E_N0fUv)^Byp*c2bA5e;$H4FVr zs93Yn+Nke!-WneD;$Mt)Z5n9sS(0dFk#QxnAemg8W{0e-Vf9%0X()e#U}~E7?*s$u znw=HSZ?PB3`M_-7_fEveiTVPICq}Y4(U}cAx!k^Q5>KmjSG~DTB9t70FF5esa0y=4 z7R$Q_=I2;xJ%}3={`6CB-EU)Eq?izhKj+-|CguCQ-{-cHpGxtxGpU);HfPcdC6Z~u z5`98N(krs)67g+f{&a$@`qHImHoq*A$dtsqV8Kvj?|>D zN$jme{f+iZm=#3PgT==xLgzn5=h|~6an+}oNi3U2QW=f$HsulwGbC>M2Q&M1Tx z>Ji$GgheL6cE4+MGo7t!$HbAXdd>QSEBBg08facL25k~uk@=L*x9?(=U*7mVS0L$f zG;_}e90oUES$u6oTII{O7w>|%Dl=VuutG_XF3s-Z*r0M+13PF70YW>>d<_MDD&D9G z>J;yQDzib=cYor~(4;=BtMd^2RVDD!p+Yl?R_^6`-8bR))d%)X%3P?U zRN>g&0R!CO*QkPcu?X5rD~0hiLjlW&U7p#M2Zi{R(Q60Xe>^;9xed<@Vie~d%r5_T6LbOcdv#38;mjCQ8f_io}81Hw5Ez|3+l2g4jqo_IzYWyqNX$fPc+F&aE36@pncU`EP*>Y5y(z`}DUzpHNoQ#GC$T5bBiY zGXjPW-{EsAM{^T@UsQX@=f9{2>JsXxJ+Gu%%@Ia7#V<#3{`&VD!Ti0o%ipq56VM13 zZ>Uo-#QJ~HK`Zfp62C?p``51LM^snimMFr5b~=o*C=+i;X@0J2WY5jNJ^$tUSNOX? zxHKzqvH(c-_gHqN<{13(s}|P=dN`2}z*<}VUhxdn1q*XMPmw5TR=owUA zp-`*f{Vk52qf)>+rz55D$+`Z@jo#CHFf2xT-tS+2SYR}HAP%3=X-D7%*WbLr?7rCa zLq}B)UoC%kH^|wYf&sl!1%h_yTD(D1!2o>wilOQPGJ#MueAzmkIt8%tHmktt;^A{Q zQRsWN*?LD^16~uWx?ZVqs(97PWSn#E^P3k>AK1%ZXxtIez-x~!-d(!#HMD1nUwOFY z_5AM`ux1)t&}oK@R*xawVVNF3RrN;@(}DE*i;exEdCjPpvSEW%7x*t?0ro=t691y8 zpI&i#YhhmQ_4N;}K0=>7mHo<%7HZSWW#28ykls}qs6gkwu*E$cvH-5dV3TuBlA)Bn zRpFcv;}7;Cr?mpKSxe2Ar-Ja=X2`x0*q+z6`_!q(ERwHG(3QKMA8hmrw(fWNi)_6 zPgGeW2o`K5l2yi*D<_Zx%wc=X#EXQ^O8ZqIr7Y(@s+8bvCh%6os8{5~_vR-!?d3D( z;{Ho=*7aM~d$?;9iA7#>-!_t(^=lM$rsiUnZG^W}vbZVZ}(bp6=^Jq6}LvS3D>jap7V5GHZD@C`2u84FZ#NvcunwqUWJcFQn-apuGd# zu)P0aHKkxZU_v-g$8ZPUzP7Ho!FOBZ4r#=MSGDG$<`$H}X=K)7b6%#(vF^+F#Gfu- z{nN#5tTQY1{Ya4Y)l!*F$++VxVkCJLe-8u6ubUBPAY!}X}&YRRTWw(Eucua@oiRKgO-9iy; z5QB>tF2@u1&MNM3i3dE|d3XT5dG2f*9lq>DCPHc5T|6yB zw~spZp!vo3Y4$g_9=pqN@R-P?U=3K-r=@$s4`+r&44S+ID*QcaS8u8+rs^`XdMlS$ zEG4xz5I&UmGJ3chQ7Lb;^F`hkcP|Ms7;BFu8l>)QYCu8@jKp*Ft65X-c{`m7&dw+V z)`g*s=~iNQa`<`Y?{{PqE|uyWQe-ZEC+LH{RN=CM2AHMc{2v z4|L7+W#N9I%I2RmX%xoh3;dbmQKt`_aRx`{fd1QXhVNH{x_mLL`Dxgm#_BqDkdx;~ z==Q@`sW^pG{N;nG0snuRCd5XHAR0>m>b*Wxm>}M zdekpc_JUP;k;lT+`Pf2HRx8~vwPEpmyc<~C7IHB znUt0gFi}9|4b6e0Td5k~aIdZ1I$e0ZqMh99C%&>jQ zV?rJ=%qCh(I6U~h6|A7@?}jzSA{$OEvu!`_{pQDd;okv}G~9d2=S;C4r&XkkRr=h(Y_cB+G-ogvd&pjbz8*C2z1n;kj){_hPhS%;#4B*VW+{=06WlpRVK$_6Z*zGayomRt4o9K2@HT z_^~nDZBF2Yu{!B)ww*e0-W0!Jl-%+sQn_m}N^Jo;Q<>a{p;Qfm$%%(|R{Bcp^GKyr z*eo_Z7-$VN!wl8Sq10#b8C7R2k@;@HbG8AJ5dU^``S^h{mrvWJ2bS#1%E{3Kc8c&1 z<|KYZXmZ>Z$r8S2 z)@^(-4li<>$&yhen2s#rijpzWtLDOaVt8jxFmjt*j1b{pedmP7EU~GT+((1w1+Pi^&I!oX=BF2>c z9Z-&ztIBAO8jW4e$@QwQs3;y?;We;3T-!bFcdq!uOVvkY{eLoxS&F~~RH$E=0GLT@ zDigjelMkP!z+WmJBKbY`uQ|F{*OLf*va~8Xf+~&Kjat#!)tSkDc1rnP+ct*3JxUbD zIYWAl-KlFp2?S1lq%#EG?}alX*8w=!Jq;nqDB^_XJ8jQt;Rh~*Kt=RdK=%J{{*TdV zD`FD7#PS1hzlXV7vP5P~ca^GC^|r$E1JfWBxz!^B>vY zR^xx?8h*d|Z+5x=?5gVjRWh274Ap`*G5sHuLT;Upqe}E`knzyaM=x7TsNi&P!xHrP77zvok5w)pdtdT-J*}rYzg^d51 zx1YPwMlC5-umEe&dMeB2WwWOqZ1jY#<@r{;QK-sS_2(`kW2yj~G6a6UeoPCHR_|W_ z4Vue)2h2X@k$5MMg)s5eT6lAo;<#a zPb(jF;jTg?F(Dy5QpgA)J)@@R>E7J0n64kgpBPvR-clxBsH=`X!t;+9yx<<;5vR5} zSPV0JtV))W4?=ue?5n9a|0sW?rqa31WSycQ@s9CwTxlo=oadudo4)Ny&X-~*tk&hK z62r$g!i6@B&UU_c(r@4vY7??K zs(iSucAmLLRvv*4a<`vN*0ReH_V{nDXTke#GJO?o)DIB7!tbXFQKu{C;2}c^sl^R3 zAW~q1@ElKzPkW;$OFh}6*vD3*C7Y1=5fub;u>FahZb>YF`3mLOZC&KWBA=s>Fi;gG zjA@jUJWYasKt^=tIZrFZlIO*_agbNUvG5Bp9;XGl;P4Jelq*b(Hpp!!C@$4Qv}8FE zbY&^PV0#0aw+O$Dw)Ez7i(g{P|3mEP%4>~eR!$jb?SQaClqulx(wSimqt1s@Xjq+Xrm_Zmc7^SYILCi!z`cm`6;nw`xV*$=8fjmdKt#q;tquJ~oefeIOu4w?tizqdxS4cOQ|xsY zO@40`pyWaWL~b~f)w* z`Xvc_tb+In;)gSjq4!`nRahE^&sWWHvluQph);wp6D?7kZ-EKD##zrRMFh*=%+>&R z)J_r072X3Up?WRdf&=Z*;eN=WtUK2FgAL5g_M=YpmA&53DDClAp`K`(9#tXHNNXD~ zZ}h;57jfZZaHl1%U=}Bnen0x+ac}0cP7}{@O@)V?6Ji=B1&*=s0qW&O?;)dzvc4%73&xTTfSNFR0qI_WDOScaOkU0b zKj!K26&!l*11oSKl3lq8i7Zklo}b&8;~OzANm+2`Hwg=7q6*(+P%mM+GsvHtXUg^3 z>(0BgS(3nJo%GslByl75thm1rC*Sbt7cI{frHtw1ubUj!ZK+BYY(C<|TRH?L?7ZR# zjsd=K@k!vs4?n6(sde>X78_Hg9Lcgw15Yp(A*}@}K?I*UICa5N9!FkxQn-DVe^= z&t%}fJLb-j@0Z}0)r+}^ifp5(C^R`jQz10P*0J)LHY9yU&9Y7WkQ6s(F7Yhy{I&6s zHU;afS{Hc2URqfrf;m30f>~6H$d<1y)x`P-EN;VBV&FGOBh(ckg-u(c#Q>q^d(ZEI z?}q&(o&J1HH}~U$^OhmI!ytV$G8sKl?p=d zUpJ@UHPd{2uJe;H1V4}kHhyD`zl3f&i$Izh3VFQd2%ytDYkeE?W*_AEBod~-E?MyA zhbharYv&9RfDJSn_6u0@Xn?}5f0jJ}<)1qfOcT0=RG;gZA8EMRy2LxB(|T5avy$f? zlsD8u=$4M>Z;&e!+P28uT-ze#3tCjJ-c z2cc%iKnQ@jH!FXGO2XlIw&2>8pwMYM*9!JVOP*9`X>t{f;M^jQUDe^5M}+r>vo~H6 zy2g+2I*5tldr8^wU)45geo{}zZH~z0omGh|ox^;J7bDnF%cgWcw;oo14NYTf5iFLt?ykG#Fxt0g_Jcfw* zL9jhV&p_eaJ6U~})o)pZ!M`CHtH$EG81dksL7*UyoUqimU=lKWOKoxtrzAX^V3~S4ln$JgUi;a1qN3TKPe5;Kb*4 z8sv0?+?c=yONdslGx%ocwc_+K+>72$@wnZS++Kk_4f;OQsAs?dLMnI^vy^_}uO1l7 z)no5c%Krkm(kXB7ISq`_-i8(b-9)O&z3K-c;j-BgV(22|K~9;d>P5MDht!+ZMyJw! zNC1DmQSR0$t-rNe$GLp4^}(e}M`y&=GU^(};9KU=h&zUrr}}XFV;Do)NXr0H29NEw zT(g39o$TG5dLy#CTl1p&oM|-$K)ZHc2IpF{R~UF7Hu>r%vv5OhJucM?3X%xK0E!$| zo!Z5VB(~*PgQen&iL>~Qd6I3I#gd!!Q;xk!ElR5QUkq@wUlyWk>|Qo+< zlWPu8hMkCMGc@Oc!hmNQG-!Zl%uie^ z7OlQ$ghU7*Y%+778;N=uKAC^(^F`>$XgMN=WsT&kq zfkkg7U6nah;$Ga~`1(mTFm0JL^E8aUGIcNWsxL0&~a44&dFWiYOPodhwW~4;t zSZD%N^Cqj61FXT=TYEEUJd9YC^mFEt6MwEkL3c?D~GlNWNM zIC;z!wy~J*6LHyLN-~8>vpA7gorag(q;>l2S(tzIH-vhRMl3`$F2LTtw_6*}IRSc@ zbnw1YYn{c^fb{mILF&mtJ5*v(_;+?al}GPwh$4eQ`3l#XdRchT*Li82u8GL{?V5kN z@Kc)bHbCm*<$cFRt80P?=6)P(Gr6qrFxVlcuIaop=#h~uoI*b{?TvftYfYL!^XJw~ zZ}+W+m-tyUGdG6~pwFtTi_~xXx{YYGanmPQ-fvjlj?!>}Mv^q~F#^q}kB*z74`(Ey z$}8n11z#VgIsFu0cG=Sco>x5Laii+BN@LWuZl@qIlW10p(a-HJ5aOd=JZH10_IW7Z zkxL6ZVJ7oaDtS?Fnb$-yxA1`DJ^B$CI(=2s{TfudDq$;kBg+NX}NGW zJQ1R>X@hkWL%7Q`(@ETur3}nKv}-yRsa;b4rKx3AM@t6n!uy3~A&+MK{eB&e1#>xH z{Kb>8v}{4VRaBUOw^L>#4JQ^L4e+Yn<$hgqE52g`5e+XIl(16G$ZU7JXGW63y$9O4 z^344*aEAx1a*I3P2z+#rA|iSwrhdCoQ6j`;28?Cs_qwqVB2WRlih1ThMp6Fa!62Ml z92(JXwc0bt;OAtfW{J~4u3N{wyK!4`EDO9iFjx180PIvebJFUzR*Ul?Q>??R;r))L zE*-%I3D?M6K81d@YTjBG2t5$>!*I$sd~^35AL)QmERk+cFr1rKHu1bnQ5*3`x_2il z15WV47ccX2;9m~ns~R!3p$e&-7IB2_R>q$8%7CwEi9uq3N75+bc4tp6MCF&MV2b3_ z)U=x6?NtU!qY#>ohmScWwhY0ord80}^|G_H)+)DmiIsRZ2d6k&&;8c?Mn;uC?KXD0 zM0bapo9eR9bmkC?Y?))P7!U^*{#t=!IS1DC>VMW$QjjKUSkABgz^DZ=%JgkZDq1;o zZmc{cRKy+7>5hfETU8x%o}*kBq!>NBaH0P6ZQ(_Lu+}g}2f=)jj{%$@Ro@XcwdMr-fe9CNOd|L6S#S@oQ~d&L7xu_a?^32XNC5~az{;I)twf2u)rnCP&?tQpHNVYL4N z!MFZ&QD$&bmIM3F5bcbk{0*A#ygt2U{22|ty?$q}y@5`D)Mk9CWzxF0*Pk@?RauFY z59Gm`$dZRb7CxkS^25GB(#ef)Yg$11*eicxz}93dc)ia`-pu%}%2wJzWe%;fe;#kq zo2;v8Wn0;$tJ-7u;Pm>c#rSqF;w9(MQp@#m*OO}DsjjC&pNC`p#ke+-FHyV{AI#I) ztv1w>Jo=@tSr2`6M)$CHYtc%@{~o`vK{f!voS5D{gB*Z?fgZxj7Jt(N^$&Iy|86Y& zKQSe~;y&*KL9ex^oiGBM?(joY`af!=K^`3c@j9^3Kf0YE{;``0eerL$#6*8Mo*DiT zq2VL+H_+*%HSk{n8o8kK;0J(r@!u@XiW~o6W%x**FXMz{tv;{Qoul-WXmoZ%wH1q3FTDmjqekgvaW@O`Z9#$}?FH(CeD{%Q}VKvRnX1p?Ty5{ot4n?&J zo8Th}Ix)c|Po3WGp~D?B!jo5c#RKXM`0`%?+eR!&g~a$QNC~*>74S~~-X=VNwhL_K z9Ytcfa7ebruOR&QpI0GErtqUX7cX=GM*VOa3OwG(2R^Wao&qa%(rPEknXGKB+#N3u zizgTGq$EgyGwl0{$QXZvay0k=EX!JFdpduvZEj@<@47#V-+-jNMU9Qk??NCy%8ua) z7$|_qmw#eE6YBZlN8`WHC!*c=6Y>myUT!J5BGYV1azu078RkuI48_LlVtS0_nZT1M zx%4U5WIbuUL3Kif8%1lx4R{OmW>z}>yT1G0$3~mMhlqLGy4V3gAh9C8U0l#!Ya`h| zW;OVNV*O)|Y%6cA$m^q=JKP2vmcYMi`M{!RRj)Ml!;iJ;&0wH}VeDke1j1{4y#BEkZu zkTP1BJP!jN%|xbtwHL*b4nd~}B~kFR`pOGwW`}G7n}HpC-O3QvnsVQKztS&j#=k0o_s7lROreLNR9R}*Soh_M z3wjTx4n&skp%0vC_Le8EE$(ANX%Gd+l{uXlS$pYsq7J#cqoqeLCNKy}H!HvFI0LoYR1zJxp@pF!Fu$*LTMBsY>L?fn`@Y^{gTrr3vFdzHm#n$V_Wv& zM2vy0CzZ2n>)@xfCu#y*E>cM4CN#B!5^8q&bW@$^SC#!}GAXt}FYhp;UgBR{``0J6xIeyHp(Yw?s}rmcSy|#S)S<5Fmg+fB`c}|)m_%`OlsVoPo;(gb{PSVc7xxNuvLyG4Lj#!-tH(>WY zAM}@g>X2Ae^GSdsQTMvJH6~u%HYVBYCPjs`@eO=9y&tnd?EIjd`6L3z^{ZLML*ZfR zsHf5NidoLE2^EY1LmfcS`#_t65eIhgA^-sQ2%K+l9;lq~kut$Z0j7#KA?xEep!;Jo z2oTeOX7>sm-YNQw!*P|eoD2R5zVUpgJIA9s;v5U`NdjX5$=!ka4=IX(^VqmZv<&B~ zv8-DL1_O5#S-)nGB$Fi$VNJb9fs}54le+$~H?w=Sa*2oMdjuMIx5|De{tF-SNYFSO zs#RN_cciJl2Mqi9?E`?lv~;%l-`uQ5k>$&lbbdwXP;0HR3=qXRnhyx!ohQgO3Ie8% zj{v9u9h!MwJ;5PrW9bz?*z~@us(qnbfkJ)QTj$Z@ zyH3)+J&Q_fbB(<+E&Z!O(+Y{|%Y%Jxs$gY1Tt>gUd^XC!k*ygRU*q=i-+6{`3-j0e z8)a8386kTH0R7M$Azh47#~Cmv;uQe$--t``%x&PSXc-h1);WPM1vQUv&oPgE-#5aq<9Zdo}0>@{P5_`Io`(%03c(Iy6w2M582%_jCgl_Z8*r99KWQ z62xUfH*^%QCIQ^77$D#<(w|c(DxAd*kU&JKE-%fmyT`>A8NJ7BQW|hS94ks3f3c9v zCRxMd`*2bY#fG4XRCfa;zV~9;uQt;!`pe}lBKJ)bpl(wtT>;)=DfuJ@i|h+Boa?1J(p1!aBIzM!_eyhU`4 zXEv(F&%=^Frj9?06KxWAx*M~w?@n`0vkyTV&fl7{rNH{K2k2IY_8{n|nJ~UYQtFC^ zp`V(WLf-`AiG74sUeEj-b%{w$TlN+~b5bl>wWLZ$)w3qXJG(a2xyptBO}gXHXJb5n zk72bM{{u{cP=Dc)fA~in3^`!=1=$H8eA4+43GCP>K*RAPw3o<#wkC4EC zxJB^CCVPK|f&P@7-QlE@|I)6ipI6Hw+@qv*g+}4!Ti)y6DZJ9_rxyVW^Du8_3pvbd zyY?mEBwoJQ>8rY&v(&D)3|orzV(z^)?gvj2-M(ASt1KM+%4^7kLc0GDiJL&Tc4p3aq`%-=N(p!WW)sO6nErLB%U#^snDhTA3{C zg#log9ScXj1*#GMTD~WcM;jOm#>G&JEe+VSQoVW(X+cwwaHVe`@rwl~+=r74x@gaw z7ByH!nB&&x;EtmEagB;Vt2OgzZYRZ1ke!lOBa}OoG%A>9SP8?}Sa`VeMB@Ih`{5ij zeVrh!I;`=$a(S>^#b-g*i?Q4s0+kSCn)5e^x!>?VFTQ?2-Niu$=&GGpW|;+c^D2kOZe8ron>wX$6E%x2eGI6y;G>v$^GQ|z~UNA3)qO5j*q6-}^ zrL)$a75&1*%Y*U8#}Z$`u{s|}Bw}Re$?!3`z{{VZ=Ec_hro`eaOXf@jk%NFQjg*_( z`-AD&9=h^~SE)3w$V$p+z6di>*56bjxz|V+=#ZdO^x{=`TwH^IH9uY*v?~~JZL~)ZrlQM zwA6gz#;_&3q$rCVemp}g4t(}p0tbn%qApva?j30Mhbms~9UWW#6nu$aW5oix_}YyXolf7k|oXM(8C; z1??gjo(h$`y>3G?%kJFo#d#sSA*nU*$JSU|=5O`QcK$8mPFeQQ7i&|Yhwk)ydf;;Z zG-vzlb=Uot1H})g_Y|}^UMk&dHkEu^eAVhFKJ78fmC8=f8sMR(2W3Anx$FV$hE=w6Q!%sTh;GeQfLo%{x?K9(~w+5!}7 zn-%%lqK0y9fSfu|lKWEWSiE0%f6)ICm`u-j`n{;71ygR}qOw$=GXqDENG_2CTg6`r z`d7{XalDh@UvV6>F8AonS)~iNTP#!%n!dbqKS^h}?%06E(Q+|`ypcl{!KiU^Txxl_ z=Hu>!>40!V&A}1gY5_^D9|ES@9Dn@AV0MrNG7}kKo8=v(8;Tux16gO8?=M(PW75yT z@xpJlt=S#&FXLd@_Gfpn#X>FZn%PT5?vMrmWD*}k$OMmB?v;N~u9MT>JyEXXSSfD1 z<5jumFWEjmG%;3m)xPHwQ_b4EU9n?JraB6rqNPrixZ<#i6E)tb)_Bt{=P9MF ztNdY|#rn7B1FLvd$VzeW|>6&lT`f#&p&9(m{&pOP|wQU+t*SD=BZJ0lc- z8I=UVxyu7>2X=z8aZY;kD(nu1x~4NWS3eCQ7cu^m7sm)}L9KOrcQ{50JO>blgH}SJjCRIcXrD?p`Ri&C-I11oEC;^|0te9~Z%y-q-510&v<=QAL^Mnps8t zTi0%0rZ^Bfbf<98U&wu)8Yc0Oem?K<#8ICQs!Am(M?aOtGmqc`_`$8%)5Lxh`DhLC ze|>c32pKz#m{M%L=r zgl>D^WPgjxFm@pMn4~p$K(p$+PA~<3F=h2OM_p%|n8M(~T;D)D3AFBd`2@f(WYfK! zTEmdD(2M9o{3nzNdK+*boD5jweNyHP>0JeCvNzz(9TRA=s@QQJ+@g2dK%MyoP1~nC zn?1TqB-mt0B}M!}sQXd=9rKmJKySe(0-1FA@zDx?`%SG34jD2{%rYfVJ+%++&b*Ry zud0qu(^9lEH!PenVV0yi$Md*^Zk$j|JBG1kH-mlC1t2SWJukpp9j7DjN`twl+*+Td zzB?Faym)7ZA1ck`+()2l&*yKNUnG_sLL<*Q!-jybzO}a=3Ka5}AEXE4nQ+qoy8!z~ z{%p5N)*SR%s@ZfadoXhDLFOmyl8^N2;nKi25h=V_Ya)Nal1b+C@zc?bpKQ&x+#q9! zkmbGGl5a_rq&u~HO8qb2M70pOhHw_?a!cTwGx+SDg<4;&45uhjE_Th1MPmsz=LhQx zLJzr{Zpi~R&Lv*cpICu~Kr~!rZFsx*U3@)?PJ{QbZ?1^9c%ZY!^h3a>TbA=Q1ojQu zdi@#YtPg&Qvj#~8X14q?Vtj=k6$@xLJA7m)+P(RP>S;2Wt=evvzhq$ZWe*oiv|LN@ zl_wKaU|5cvSk#c?T)$qd`W#!3tzI~>ZhUvR0zziEIbilvGFdQF$A%gi7uC!@T$rVm zU>X^vN%J$A=`A+vY6fPsNz!P3(Aqo``&m1PLeAah+$-6^OPg^LMIce!xfPst8D~$inRp7fufjz!FE=e6bEtLEZ@I_ct=t(>Zlu%4oj+s^yes4&LIvsDAd>yham_ zp(a8N8wTJwQ>lN0()MtE07#=${PXZJv)4cP+MFP<>e5`lL6G?~Qgy<;Mi*?sk`dL# zQbg<4mO;w?lwmZ_NBeU3$W%W$8L5{fC~6&ms&ReI53+UvP&NUwL7T=u9KR%i$~%Df zQ{7$>;K&2&G@Se&2Oy{a6}!Uh1B=xDkDoP{E~0^3xt9x43!f}RlDe(63}aoISxzrM z6TYZ{U=u+{=R98ldGcXP+9Cemk^Rzxe>^5eq}KrHGk^PN5q}){{9Vaddw!rKZNF%b zR?YLCK1h1RaIn0E#7}KO7Nz%daKAzDWiAL7@G)91eoDAKND3c3B|M4#4GPToM2!~$^8NEZt>G^6(-C6(iJ@Z+&wxWXsmlQfpI}f3R2H`)hC?DNneB*Xf zq>J<@&iHlB5n;z%J58iMVfaa@Z|r#ibAvnBA82O3`-R-*5ICR#8a97lPc$pu%$NxO z(Kh|UB6mOQC)y;v(K+Dlaa}~&rqdhQ`3%|5g(-Fx&tuN50gntnvp*Ii4S8o64D5JY zPGnc3Byy!kk_q4-Rpl|ZXM~5Iy~{kOSsYpR6sX!`yj0WaEyR_mOS0oVegIeDm$q2E z_Nd#sXwU8kBXMs$Qf9mcFWf%IplVe!cX}n=Nn@pZKoQbn99Q6a|ZaG7L!8S!=mvdl;VJYt27AHYl&0ZJ3NzwV0n$k|mc3P#8*;yIMrO_`e8 z*pK*U@MEAMF@0TO?$j|}sO=D9VV$s$R$(gLNIQZKQ>}wO0}-=FGtCg`0)HArO9Zu> z=WBJwElcK#ij~X~j=DHH(&v$-W&AUsC#l!F&}@mP>bJ0+EyB&z>FbYj(U;rrKW=sE zxfqX2K&O#*nvgQi+(I4>r~}#2%653kgpl_B&7!wnhT=UHOp)g0rcu_jr|RiDo2sCD zr!Ps$9kuyTINxQU&ke9ML^K8NxS8!FhAL@?zvXGFyclBk>`p4JU&fGA5M$WZ==y4{k-6>%VXBAM>2ZEgI}!&Erb9TFvFnM{kl{=lZB@_IOkz$ zs2KxoRfv;C(<1+sx?=UR&uO``6+5+|Y6AQ-dDF>jOb+sICfs^k_N_SSkoV$1XAU0l zI)hq7SiPXV3Gc%>8<*)%E8UZy2s;jMO4ncIyoxSNI-3+2I&w7J&GHyCf|gxjJi9-2 zfiTmijJT<7<#>awxDEhQV>){d@&{WZ@uS%4b9OjnyA zydNX`D5l+oO5o?GHcg^5;ZbXkm<<`mZ8sXsjk(}G$CdeDaPJ#B&s^Vn4M&okzTxO<^DAF}~&NfmKV!mYvYirgyk&Je3OFP+=1 zWhWyIIrK&4+Uhwti9y1!F-vWXWg}~8**w{k*R_d{7&Kkn-#wA&ynM-xSt@khyLXQM z*DLKYk*yG}^v^w?Ok>B=mBRb_Uj8I1`ROzOxpPtj(_0LpOpjWZlw6EyTfU{p>f6AR zv5yFXai&i(Dp?p>21o=ja763&m;Ac&<7?y`d|r$R*T%1VicXsR;~47Z%Sbxj57xI zNI&d5F|*xp1U~5q@ttrFoAemDpIu>2KE75R7OBw5z&o+2T@8#-&eZ}kVbbF;L9|%- z`@KD?S5!d^dNeOXV>LpBR;wME(zWcvUq-ZFzkhB#fVvnoU|&)i>ZRr$)YZyiudQ&J z(mBGZLdnYYoNa+=q8&BhKRU1AcR*v*tB@6VdEwU0$2gu9FrC{gjs16}Q}^@OH3xEw z^{M*4c`Nef11*$DXlZ_u;-F;V^J^X9=b5=xOAJckmT?0FIr$A~b(qh8xWEJ1Vx zIwXIqExX;$*dc;C?qg?q-Nb4;P;dr0=wvBFL-JA}G`-ue(Qow_8wLs8kV=jB*W^v3q6PUAPe^f+}6BMCdA z0W(?O+MhLyshS%P^fM~YW##~Q{~6pKnbT(9rjij$Kg%OkWEz(u;*rLP;|3Irv5OHG z*8obTDN?DCO|M%|5%~bvjK#**%m8o-HoE*;2h|Pn|XL!mFn%NCti|KZnXJU>E z^M15To73$lYuF=$WG^?itI3W_A3`W@$q*?*4W zR3BjFa-`lvd`^${ zK}=l@Y5rAZP9G3rxX9ua*3-SHGDli;)aaRFyzo?9Xb9Amurp6kyw&wH z{y4#)DI0WidF(Uz`6G4;pop}^kTi?Vk?p5`xmSJhIx)0#MmXl$>qi#yhA&uO3IE3X_onlLk9jDLHW zDo-RDb)6?=qN@{>Sg~N>nMRINN>Yn5S=HswA`8vjHoJWL?cpznnzgCMuh>_J_a2Hq zk3txeQlj#ewnYV3l^M=Ih5I+Y!xqUPaJ+olhl=;aEe>YrnY`dz>l=flJFEPaY2fu5%7CemdsR~ieWC0P>^<@Me)l6!o`kGDk@hiF=|iysRi3hv zUEM^`uC%~X3PpzpJImeuox!P(#W!*}^u?6fOrzv^ZD`O@Vw&OM#B7%ei7vmE6~K7= zEm+@#-opCpgincuvn!dpuO(V0D^fpfr?-ewX7}M=t(5-Yd+Dav zQ8fI_@s~|nk}<&`^tLzeB;3k!3K9&16u4;|C%{6kJ3sjT(D1pOc}Ran>Y{06lN6O1 zqr};3+NDbA7ns}6_fN`5{I&}PMDX5A#7?y?$&ux~AL(;ac>38NKJ)veC8Sc~Gz)hb?%VwWRL>1KQsQMF9b zjJ=_aK?cd*!)dZySLNYaK#h+dYK<+#ZHkYN^vsF#ey<8PaEbn~TkT>Q`Eo!s=tgxY z@jMKSAyLGHjRW*PYv(L0NyYe+Sy94Wnz0+Ad)(@=sS@|1tv<$ax^-*6oR0O`-Rh9$ z$e)i0M0t=zJen4-*P;E!Q9ic)Ea+-Rc9754Ys0JW$URgPU1|_GhzD!B^T>Tlq_P3odGIYAk%oTl5SDoz(sxsJf(I z{h`}D`UQ)?=(N>xoB~FysPQVP#qyf_+7syahio#Wz8)@!bod`7)4wjk8vhh+uKiiU z<>b2UwfMIz-`IVw{u{)bN!S1s5I&-Ru&scM2F##e#{Q14%7i&F|iL|ab(@>%in3IdoT<_6%=xB?FpLB`}a z$Q~Z+UKI$lU3543qUf0{_!|`WZ%|u%PHI1Y`dRkH*O0$H=@C$V)hyzanSa4QQ6xNF zipK?O0(#T;K=S->EkOUyhNHxhDRxLIoYi7d8fh<+SW;4dAiDP}oUV3!%Ppr>2V#=^ zfi#^S-6g1+DnVE=)BcMe;cSwfp~7!$k^8SY4@RBgR$KO+_!Oxf#ztfN!8w|~|5Se- zC02z?^YEt$@tr6W6ZE<#KI5lWRxoM);jx#z$|Bjo$AA3^)|TSy@sL+jz4O(`_H(eL zz$cuXBPI}=-Z9jj-Rhxl`8+bL@G3=*uHeqlo%2!UD!NzQHU{tAx(1Dx>eG}kt^x9^ zKRfpP8@@)&BA-*7Sk(+kTP_Cf8D-Wt9SlHRIRW=$tFGSCO{Wj4(L18^edmUIl10Hk z)Taz)f`hBfD{usK;W(dsY)VD7^`hQchvcD_h}5-@8C7l}w;!_avQ2zQQG6caU7yUz06a;`nm^qgB0 zE0|XHOy7*N^L?Azvnc|48#PqTq|=8afr%8kdGC2wnDK?o7Y{Ty&hq(IiS@U+?~Hb?xOsl0-;+ibsinRK|#k8 z&sUxHV}C{(@2F|d8!-=l2l}>WMDQkd!n;OYd^&PD8K9U==!5Z;BA&AW;nMuI0&%En z>nQt`?wPa`t7ZZpBbX5@G_N#_Fx6WcV@p0O;4W?tmY4nla5}8psLfxIJdiW0POfJN z%25>dQz;&P^$WjdyZ_{ubv;Rjf)Dx0NKI#<-}E`k%COU?g75a@{kceU&s@b-13g9R zPi8TTZEqfe@86{GK0@w)b8{qk1D!FZG~5eg+Rr%wIoEmg1iC4|nWCYL#+l5c(Qr?T zM3~k7eI#1xk?YWi7%ljN(4O75!^^rI9{5q?i2crIqen6~o|lXcWva6}+@R%1sF8SN z1O!$2_dMpmfZ)oaNAXh`3wv~KF3eJ@hy0kZOv0|3Ed11pl)?FdDMz5{Tr86O5nioc z+toOml>R=x%ec2;+T*mcQpUB-Cn*rL`5LGYaZ?y+O-Os)1hx6J_7q7?63yvu9om*m z)!9b#l$GLajT3QEu`k=;NQ^VBZU{Sla$eg*Pkj2kxy)C1bgG_TA`G8~jbHY%`0|Ty z2MK8Tgpbsp0KKgkxB@{7=!O-I$p!#Y_#($zvCw+4o9WrAeb!_YyUTyw)<573SgXE9 z$j6H&H1M<N@MuGl09lIGbew`M!{n?Y zbNY>VRi&kg+x>nSO{|fhs#ATdws8iFn5DC9F(~H*y+vj1jMvW?wO&6<2J!4B0F9a-Y<_hMwr6aj5zTHyFc=uy4yR?hY3ZvoXwr>F*1?uMA=@LN zY#8bdi(JHyL;5Wz#Y<#@_8*~e&@!i;sQ@6r@OMsW``$zk3R#*M?Ub_#)YI(i+@vHNz z_G<}AKiWzQ$k7k*16+_M4ERQV+GKCXGar>r(M(U+KI!2!_zJ)&^{*m`(Ir93`1|ft z2uPit5xd}L%-}C_a`cZQ{PnXVNSX!iYoMKJO z%%DasxRKMaE3unq?zEH~uV3#(t5A~NT*fuWQ2Zc)bm?@mdw*Mwqybt+q{!tgH#0&D z?ff}?v6N;Bp#6%PY@RA?rVq8zGOY_JVMQ72x4bWEyo+^Nb(y3V``NYQYnR)UdR1M{ zXyTeka1L$7Dl}o`-LV<(70dRGZh`K!ti*HfdE9sMp6i{C{qFa}_P3typXB9Rc^Evm zIer%VmP(u}!N9$V^xE1{@u)Ed;cwnR7N5XK0^EyZe>HG``yTsW_r3RGXlskJ0EbIz zs^WpSURZ%j0569*J7R6q`{S^22u8E@MF}Q$`u*3=t0|IgL%dcck16x3dOt-Ziyb1Q zNOcz#)GNFxLt0hB^fco(CJ42?B|~kCi-GJ-epEcG)*4^rX1AjdnPv2~1{4w(Ir8@Z15#29VTfP&ngXkG#|u2O?<7 zS9bX_tE8U%1aBs&Dbu6341vsda*4*PSPPY^Xt7(f@lLJw)iYEM3X)Z3P_4kGe36pt zR&LO}_qYp^BxyCq%Tr}qui^Xj-QJwxR$$C281Fd^U;k2E*5aZgMxe#K6^*PP3FOmd z|Dcyd$CXMVO_tVD)8NHwP+&$(H*5L9K-#1}_OdrUBD3+v^;!fjx6{Ct(Z8fvf_n~*fF6Op^_j(^x!hS$iy?Y7F zV+fqb9BUZEN2e}!D&>fnZtm^ZgO0ui>$0(PmXZ`L%r;Y(G3pGq-_~5? z%Wp^Mp9?5pC|U<8bQQDKt!ctC^`@C#DiuVJC3CW$az;a;U-u0+mdx}sQM^XKs=OK? zw-1d4Fe`;{p+`)}t3*%+RG>fnDdTb(eeeLiT>>zG)Pr6Tn1rc|WW2a9fUM^LW_- z#ue9bPFQR9=J_SfWtFea@vA>2O1IbgBq)62U69R=g`jlUq4*QZZ)Gua?YWl_pC*i9 zhS=y_14nBcr!bkWjgwry#3i|7N~5P~y}2UWzt$>|)LM$$nqBuh{NM zLJkXiz2F(nubVEE^FthElT?8{nwHp4-X}&z1meztz}%~qlS~DYe&w1L`=^Zvvet#M z<3l$ZAUFOuZc{0PN9A|ByBPtrvnw@Wa_5oSF55=WQ zyaVpa;)lq$+NdFXf;Q4v8vup7Ny%;lI#gBl?ry19yXh-*idHSUA4hwOjtg~5e86i> z5Rp$jUJ2?w$sn$fexRUgg$o2<$IRo2Iz+{`ZO8)P=~)vlI$ zCp)MsXv?j~S1z0*j~>LLcjvPL^}36i;2XWy2pRd6;gDU<$!z>{e_g!4P*UH>%h;(G z^-C0=5v#K)wQ-9>qWIqZD{#Rj5Q9%?A*uyBYyTz-gh~Ca@VR65eQaI388&R}F^OGA z#mi@vbl+j@dGTG`7-6#)%?HbWQ+e<lk`&BJa6knb8r9p%wLu@_2UjoT9i|yKxMLL@sz_p zsqs}tqZB6Z*T(Oz151D0rTvdB!dR$5?s*TnM7n&@bGQppjV{(QIwiz@mj9?Y;T(B- zP0sr_G2u0B?oQ`w5L~!QpE1(O9F(NClJo63?#{i29F{^K%>EUOe>`ig`AscEjRh}w zCR0XL-yrveh)gAYElvmX3pH7agf42FEy|&`u{M{mIqEK!@Z1sso5x>{9y6+%2-DZTW>hELIFmprvyf+zmybUf(grg z`Sj@^=X8zKrG9v+CoHq9_fBP5O|&7%pFfQV=;u6blt1Tyg}SA zN9CsplMEKO8E!X@D6{$;OvbXq5JxrnV;nd#?LwU+ta-iev3$)!IG;L=z7>0r*zg>? zczux}P~Ojn=_P}){>q2ghYvAISGHI(-U4-};HLdZLQZ}jT_4iDy?!$3p5gG8yyI7Y z>lro$#7JNd+yldxi+ZY5A|BakgOtEM{ecRj(t;1zpgsL09mT=!% z(oi+uRkzskr{}k@M15KC!tr=-jr5+hdgEmoHW;vu1Y4bD-h*f50SV!=)+5qtVimHZ znnN(-{&S-p%wz!k2|k2{>7BwCKBe-#C}a6s$BXJ`jFIQ#I=GewvXpB@*7NUeORqQj z#O^OzJbtg~{?7d4k+z~*iffTc_l+&gR6U$I`D-o3hR*<4dSx$|!mFqE#<+#SotGHZ z*1JXxV{0$kOP0#MLF~}FJR%!`CAhJO?=r*3Xf5_tRF(D&i4}vS)S`%u-i(0#bkj?t z=gKQ(Ml1cSdsm|^?Xb{bQ=PDt%>JbVgA0wLK5Za&+bs}RZl>`3`l>!f*u7EFFHAJM zJX~0R=NhQ#D5Dfn6Utg!h;bfdeKfxqyY&FaHy86E`u{Eg2$G$Qsz z#ZL@$lOMRRW0FWc48xo3XE&e=`Ihnw?W>VWPv8;17=5&AQgUfcuRW_KSTX%M8Ovmup3HICj>Q;{b4jEL- zRUOhMsd~NCM6Bxbamkxe15<{bkC#id5J4PVCtd_LUU-!(Y|S%e%sSn~9t`v0Y248x zrVwt~WSO;iy7Htsl}bOWDrGr!FZ-x8Bxws3OrjGZ2ys9$cT>7iLsRDJ3)auVsJpKD zQru)xtgHT9M$R9T8Q&30Fpplf^Q}HtR~7r^_SxnBR$=e5)}}fSqYQuJwb9rKE5?>| z`vRuCc&20l5%-D^`BH7)E`)IL;C^lnV&|b!Ppts$lQ{mE48i`81s}G%B6MfIM=CkH zDFwLKRtUNwxc!UIPwI~Qnhci-M%64P;T^VPamVW4e145w-rkp*bza6YH%aHOVpj=~ zvPQ(9*MK7)gJE3DAzd+gqO^1dP&HD+w+pTeB8k;#asOKLRY0W~kqNxIU)7i-6zW;J z&rozY!RMtae;aLc_P|Z1!xRVaY*^6|E|gYPAsQIAYfX%L>^8_(rpYL#hey$u<<%S? z8gB5qo)H;l;;E$ra>{(TkT8!BIl3VoveVY&?Z#u5W5ru>gal^Q%X z>fxtZah~lo>}FoM{or^ayAkw8_!ut!c@h)^L1%@tOJjU*@^Rf#UhyR?cFp4Du-X(2 z33Dlv`@o0xmC-MdYm-Fv$!W)@zH(c=maP=g&AF*0APRX=Vv zQn~q=Ir#g(#wyNKiu00A>>9}*!*|ZoVyM}NUv{UIN=i+w$ME8;9+wCm`})RQm9IGQ ztbIPEaFi;S&Rs+(I1~`PjOl%~9t5iT-i5n9HYbW)JdTjTZOELPzruhuS4$iD{ELPPKuN zl+|X1?p(`LQL=0lyht4o16sBUDy~SMId71+oFP~6s`p}y)YW^(*!+xg{khkQd{A0A z^@P***oqQ0&G1)G#htlyk#CvVT`j%LC}Z+iM83L9EV^u5-ga*-KJz4n2Y&pyKDxd= zdq17D(Qx#rK-cy62cdPfmBi%QQ?N>oEZ#_hj^pSJzPSzQ1nO-SAZyXvPQ+i=D9bj$geLCE&ehGMaduYELZ z@0<2BU6Q1ZK;4RVVMP;pG>s&-wb_Vj7uQ>dRwJv`dlpVQygPTIam_SWQu2Toe>xe< zZj67nO^iB|p1_d(YZ`XklSh-wD*n8S;>a@wQJqrHjfqaTiP|88rK2iN^i65V*R-(%VhEB^bxs!fjQFq|{$BHAhII8q1 zXKRo8^FwKF_2+{gsAB7?sxKuWVHZmL+U86iKpT(cDsc|Ic^hM89`ZAoE(Vm#-#-mv6yxJfQV%dA)rrE-Se9A(PFMi$XlY5# za=ktJg*95m*yxd0nS0Su1C1miLKXDXxkm)e(7L*CL6%3;?Rq6%#H-h9Ry?Q$SL<+% z$}uyr#CZ>0dxVb;nGulr=+ZYDXPC#4AU~)ybq3IKFn!3q^G0tmFAjS2c_=R!5&5sB zLDxSUW%W&v4g&Igp|HFL$aDf`rv=5jL`K$eW3z9Na&&7}$|&peqp}p&-5)LLojq^= z_7UZ4hyo<@-Tum@)7u2)jQqZ-5w)A&7o+^EI1 zrK$_$46{E+{iuY|1k!eRE3rb5Wh>x`n z`D?^p(madwNYKJjc2R-kH5bBD7C6yeY42oyQk4$>EVms1IS0p#@=FT@H(MsZaMgNQ zL{>p)Q{a@tx#!fm+cZPQp0S%G%yNe~2XNf94mKz_V`_Dmk? zQ>=qq&fA-mF*KZS^BvEfDO(4>SgbH+A!$v)SSi0E#tN(_M*iOHWDp1Upc zJk7V_V?Ns@g;^oafe&AZL&2o(1kby-!jFbEC%RE&3JY|fwTM2u2NB>OZR}rf>>J4> zz7=U<*C@yf$IxiSuU522e%i?3$exYzS1pLgsMUJy#%t$ACm(N?mgqBsHy_55Dup)R z1#+H={g@PkGFON5(e-(y$sTIi1SE+;SW?CXlPWK}05E}9+sdQ6$Nsfs!V%8XD}W9V z2}-1lA2injB#v`@scxn#3m$9?Pr9M zIO(DwX#vCIlu&>2;>nly#W&X#IM_NI#|u2|GVdl$}$<^77T?*X|!b zKD*oNuK+t3-69{({%|-*RODlRmSZxT$NaKIxxc~$>PBm$wXY*LM@Ul&6z`tfg8ND^ zM9~Zg3UO~htY;1zQxbb1r^{*rv?i7CIGw#YrRaRtx)Gl$z??)4!Z#^b{C_%;EkPf8 zQ?H?l_Vt|hsyO0dFAWwGv{vS$5j{6SMM8~mlAw#l$cyF zGSZV=Y${ZBJO14tBD4AAmc7NhTVJWyVfN|eTEWLS&nA!s+K`2IQbf=BqOQh|m@i;l zokoVuK>%q_EfC(H#}i`F<+Q?REzyHpQs?LIsj4zH$IygMhwFNIDJtWE92~n?BVVri zix9(5ML*vHNE8Ku;+B%>7EtM8G59=oTznw!mfs`#hRVorlUm~!v=37nYqIpaHWeI? z#G1Yj+Rd|wxb4%g{IZdi>>=AS^B(20rxufwNAsb|&@aZPWEPB|J+)qJGTvd1rut=V z=*fT*ODWf=!xtFi7>FGGgp<>?`sx=#DdZ>AcGwW^2Zgn24~ zsf8)}4%Bm$>;o1BYlBVbEBm)QuaqSUqV8}#d%BJ*9=}~drPnyRH#yVj=UE=E$;6aM zo&y=7$~uh+w2p4fXVyiw_4n(L$grSBc!LHJ&mW%FP=B5tEncyt5c~i!LWhcE3iBnL)6Y7UGWriS&*YN@* zH9wl~SsYb~RU2caTTimo7VjXQO`h*uG~OF{k>jpe^bXj|Zh7JFJ(K^?g^YQ9;19QO z@U$?0J&-^^A*;^_Mb^^JUH4aTIkLdP(<`pK?R0kMYJ4-r`)h)!uP&^-_L+X7^i|Bi zyIgrc4%ue)`t55PhPD0>ES%cMz0p#t{C&_)<0z+t=aJ&R=SQr41PxhP$>iwuAd&tp ziUy)BoX7d-6LGfKN$1oF7ixzjnvgAlxGxzhlZU#UmtNPFl$;Q2p{g$PHf&`vDg9&< zr9IEFeGydQnzm<5 zUQCS_%v@fQ*IDar4CMY&>_XEysYK&bIxQi?Pzd`ymNOL5E&L8SF7KhM# zq2C~1p{?K95dKj^=no;Er>fy}Zg|tV%c&3SLpxC~N(T6|&s}0NV1jrZs1O2q%x07~ zY13*?1cu660@C&$)#z@Ve3`S6*|`brsNnk{Pg8}t6is=VT)>ke<3@sZVUk2}8W~;8 z0bI+Tk%S}lwZ($5;0E=hynwa&5Sw)RD?C>yZ}ZV6{gP}zkg?iy5PM;uQ@ON4PK$W( z(VWFAHHXvAK{Na;tRIKZ#L5=&BVvx;ZMH278Yh##wY}Nge)X_4ovuZ{yk8AMF@bc^ z$jSsAYebVrgHM_x^@22&#ZK(Z0%Lo!?`^NWEogsvo+ZS2Gbe&LIOw`%inCz=RHlUq zU7-(KHQaRby6Pi2hdU>=Z5HF-n+~eS&%b9DRo?oT7w}`qk?N-h#27Ks*rvpszSf{W zf*}IgFBq8H*ZbtW;Kp~=^bZyw+159dgkj+L+@$(YP3zw_H&*#dJ zPW$BIM5=E`*B)@V=JWysU#EemBlnJETT+2ix8>7KDGmdzzw4ZC4~%Vd zauE|~W2q9lsP1ZoOn9~nP%6PS0+;<^1pYpJtf!hJp+_&h3)Kvw8HEp%IB zBhRRTJVj?DXW^y}sa9tPv4KiL+Y3=UoU}hWDd%@dTOHP{=g>6faZ>gHd^L`cGvBUa z;F4v)MAk{)>VvHBx8&=-P^>k6(3?^@AT`2iFhQZghq>|CaMvpVSa07x2TOb4yQdkU zPU%TCRi5(bZDT!o?#Ym*!}`q2pEL-t9(g+!?TMFj8LZJ?3?8l=d!TmLUuexD>!CmQ z+qrG`lGdsgKJV*!@ya4inF0b0W9t`20{{;})ik&0D8IC zKYFI-Jp2P1<;Ifo0<-pv{NFW-MBL8J^;tv7^u=jvR#*jAUC3&L(-EiwD4sg67Ar6f zJDe~PD}JVY>JZL|*4jB}CS+7g}rjTUM*IB5;>0L+Th|`M?SFK)u7fOld{RG{n z2Jv;bpryjTw0c_o{oMLhiGdrN#(jucs}x)2gWQq&X#(zL~CoM~q$Y zG0lIqdZ%PdV|hgCm4{i}p5*M7EXH=(5YrVnp;P2G3F3#8R%H7_ohK?Nyyd z7?3FK(9z@R+>ltU^bC$ln^;yHyIvNxAas_tP@lTui157+5noWDmgK$FCga+2^(l4( zf^Xuut+?sPO3WVAKz4IBBh+1{yl{D=Es~ zg1OUGT6`{$RaHdVa=N7pEIkC0wbWuAvPN?T$uNfKgPL{elxcH}^fQ*6c$F+4gj!dz z$=1^1pc&z^mS=j7 zTwvb?9Ga9d{31LUI(*-buni8h89~4>9MbxKFZF*n@c$l*|3lxavq3)jw@p%=ng>6A zWS3sm>;8*WL;sVUl7MBQ#Mtg3K))aZ`Hx_T_vU}Xz5$A!e=>6DX9Iys2Lj0eA8^Kk zF0XpauwMm*#L4EK|MB~O?|*KmbNZ5R)NKyO)>Z5Lnst0^DEa={jtI{nR|hu6JvR2E z^rDG==CDpd@5RCh8_*x8MrW8M{Qq`9^F5@cR&9c>3w-?ssbM5Vm@Y0p1Dt{I9r!Ac zr$}R{+=G0NjlO^3zCmczWAIOnqg3LYu9zcQYl5l~_3(8u$Fc2G3A>bAtGgp%V~LF<1Fy7rt|Xpd(qa%z7v@r2uLhNjsb(@y+&732a%+`$GhfcQjLJKG z#O}}qCTk3zWf()ELafzgqZ2P|3}jlGTzbl#)cJT`YDICMZ=8Nlu^%_R&-(^ev1$1g zcCO|E;>Mrdxca}PfOJ7j2*8W|d~^#{gh*?ZiAuAQX(tX!@W~>2K$|g6Iurf~9za{S z-%VmM=G1urT?wDQdC+nJvK{5PD87nc14UQBdv{xV;8LPwKTGm_Wst7D);ao)=qa;U z(wMTkzy*y#;IdPw^(qZR;nmq-^C)Y4L8Mvq4dMjfq|d~=aR=3343j<5yA>DddQl;O zMlVCBXu*>MM2?6mc-9sQDF_t0)3$>Z#rTnfO7^}|F032dZfI}l7T^d7D|aEM5^CPs4E^{2uS3q(aqFK?yKv0`?h5DqHZ8!i>L6qyo%I* z545990=|0~#Ezd01rrDnv;mf-C4c;fQJN0kZ5K$1$kl00c$4_!`~lZ^tB@LVgVJ1?@&+m&y~n_eO;;j`dy=BpKa#b-CHygS5{6(>~P$ismz|hZu#^^h4)Jf<6GqKE^3b& z>3mu7&04}`I?8%K@O-hwYU0F>`V!O{@faa^Sox+j%2sIsvjES4{h2?_PEf)vI*_pN zIVbZM@~2WS(%q;@0CNdb>~TJo<%#MYdc#xbW>71fEF`Erd;bAue|d+Ab2oewJy>|& zp%Q4|tk;43I)4F3xq@UMvm*0i?b*H+e=|FU%E5twn7Pj+cP8=*%G4t=bBc~V#8-J4 zkLvbf+fGK*OfcUHQjSyz5VN9s{IT!Jy=m?{yGi?q{*Ch;ah5Fl^DUCkY1h52raYC{ zJzm}dt)_ZaRV7v>D%)OB>bf;5uNWU;m~5GLP;rVU>Ui<5CenGCYQoT~kE-^J<`36= zgFwT89uQ^^({TUp@v-8XYKIn5m(k|9S~)3h>{H@x<5ua%q&{nyZ<$|;61iQ6iL(F& zcQLQrSc`z^m;y=T>*!Ch!iI%Zz3kycckgV6{u9F%outCcT>^^REh8GdyvV%!Mj-!z zZBSZ#O;WGO@$)t9pJ;uSN_H@a;<&&{fyeV~i#$Y&+t7OOF+LM1R7z%m4-%nxvTYJg zPeZHbU-0n2nTKgDx3e{ z2^Y39omg|1-Nhj!F1^=cvuCz7g%WR~7<9G-7yPQG-s0;yH91?dizn_!9MB{l@*mkK zDUJ0`>k#3$P{;46MO$q(IRfAdpW;jdM&l!J2e5GjbO3NJe1^@7$tlZXMGVH>F7zUR?-a#pMs$;BUP#r+suhet@GY z^_Gzbl$?gupSi1;yJ{aojbSCArKDrbBb+TWijw=_%;>4-4pv(|z26{epsBV#bW{l9 z<7<`9zV4svVR3GfCwf#@L(A!Jo) zaR`8U)s+8ynB)&%w1I}en>wE=ElIh4gS@$gx`dKjqagU>i%9j}&^;cO%`3mTEP?-+ zC^Xm*YNXi}KvKaM=YWwLm?NLhKKFYCk#9liKbycpp9j{}thh+3P$C7vc2f?yrhQ~0l#0( zbb}-hEPK8aA-mnM8n_)9FiYnpoF7v6OKq$Ic^E9MAPXFSG4Kpn=6^_lEA4V*0{?gxOk%^7M}O^c|^g&#IU@0nf0;FOt;ER@31i z{h6avj#jyGj<@3ayK_ser3~$M4(wIJw-cxQVqh=&C&e72KfRzjZNNMq5&J0G#4y9>ZCVyp*!0_3AyGk|R+@v1AZ$R)v$m0<{Xk+N+z7W{6rqJglF^8CmSYbG0UFTXr^AugpHi=EaDb7|%Uao^57>e{d8eGdX|xp8l!=s3pEB)+bJVW`}z= z4p$}O=HF}hU|QQ}B;AfO&90#oaiL95{iOq(Z|o**v1}k220Mkq{N>yp+%=1%KwgdA z(>Kgo84`F`%20;(t$kE-+5hQoY=Kk#+z3Vsl){@ult!(l97}7g`oc`(5&gn%R;6D_xC+ES->2XF3(8yve+^cHu_f77#Ti_uSBK0Ua(&i_ z!-r4{UyZ7tb5k?$8IJQ!9G`{T6B02a@VBwVUP8{xH6ud!HWVmZK7pl|1MkbbpS5cw zHICY)UsC$&0!-(v8FuYI4khymaFM_bQcWIuQ)bcfZ$p>+e7R8|Zq^vDO9; ztTwm%1d12{C9<*yNc^s|f1?jR1qG!~bL2OOiiF#*U7L_UM=Q~H8_7)w?J|J0&)$^4 zo8Rj@6~AitObG&ds}D%@$G2 zPV%r}ZO9owN2UkW)lq;lK_PVOZ?C@S#XssuG#qOLEvZmXpj%#=Aql^-&9W{;1#fNx zy+Q+5T;?sS%wHJb?x)1xCebJZgPFT=?aqAfaLZeZW}`a=Z4n%S&T76Zc{`h(d%G^g z12y1fCCNK0ExrU5Y)cc@yJg;$+EwhS`DW3i>3jMa4(IYeZ`f>Qin3P3X(oFZMgBL!2{IK~kZ$WKe&X>)@wgxSw zmBawWVO_xH_U=+nFTgGbEd{Z_k zS?_XVZ^7N@M_NA9$CNkZYL_*<`Xt=sMIQ#aeq3oARM}Nb1A_szAc>dgu0v${(r`^y zYVLlC?Mm9h{b2w2F&c?&;v2?Bz5^GcwnR6+82yMel$47h zrzCYEms%tiBJkdlIghtVt)m2P#=y?9M2QeW$TU!vvsyIULlYj^?a|B#9H=L$u787| zg|_mDRwR`oBE<%MuNHvH$xKMB>eM}1DMAFJ)i$*Yk=cUrNyw_5iyHZcM5jG+UBnzd z02rckt#brX>cvM(=>AN(4;5&yGStHD1uGl~BuKm2mu_O`;IjSwv0ndRqS?8c7yK0tCiwh$LN$<#K(^ZE-P)`%x|iIsU%*tC|VQ*FJIa#%BTbi1*zI z6+AC3VJ^a`=fy6&G6PDBOJIhl2r38GO5k4GYh0bICdtH16Bf;~4zN)@-sL9iv=mF4 zY!~(vkmXMra&_tE$X7}1u4|%OlylchHp(eMys5qhG5%@bLY3j9G#BKabmR@h{!DEb?m*#J`(7^sMAFd`Q$cZUWQNO-UX^5ouF6Z84s{I zmN@v_fvP+$l5#104N5#*9?2mYa=9Rkw61ar`S!X%^%o=Cr;ZgJl<+y8+NvXHFCtC* z2IKw~iN~#kZEWEVXd3{5M9fQRsIB$VSZk4_B#9-}io!UydnX!50iGlZROL%eZ0gWF z+8h|U6@*^oWLe4?NzbE0M1N}C;=Bi|n6UkRb4HUH?b-9sLIp@V_C-Zqg?loOm3)%L~*Y1ty#r zt!(=(>rMChWFsbBHt)bC4Jb3O>9jqgXq9GqO=Iz&oLpAziJi+-62Q-S#9W({h)tBvE0Xm~@*vKyyRgJ+o`Y!R~&U5CdW zqo+(kz_R!(KfJoB{=QMzinst@p66=OMO;t;FY}hyb4C7oFTd;v*{{tj z7)&f4eHF)z3R$tM#VBGZUu-ZArXjp6ly@ICf1Ihb`elBac*+M^#jUr$y+f4Q)= z`pd`_&)Yfg)`#B|40fB_{c>`CeUf+GUTK4=k3J)Cj1)<|x?gFR(td0UIJ><=E}F0L zrTX?D0bq3S@mFP_4Vr1^oiz*vB6Pu+*E}YYBW3R;0z7;q+lMh=Wf&azFe?( za8=_0g3zb^T^iM0*ZYM+mh-uKHyzR<3b`C&D9BHVNQkVBh26b7d(tOlYUp#@_T^n{ zUu<~VM5Mqc3=38AcGAvZsz+Yz}el<*;~JtqW^!-5>ta=rNC* zC#tz=PPm+Ey5}7G$vCbkcKh=F7S@ZuZBFrgG^OUGaO+owDbL7~= zo#(h9chf76w*5ovQUv)sax}C-=zp3Ou{VeA$u@9QWl%`YJd@jSKY z^0kD}{A793nUPjW#O5V7nlvNYiB&RYf0eh3US-!+-VKfv!yP4cn0 z#=P`dWcN$n4bnwgkhH%n&y1of$Bl2Oz1`MTc9rnh1HOk>+N1J!$MrNV!A}bDG#;9|-m0|vAY@C`%5gD}u5TH&`6BB} z^=yjU+tG`t_+j@dvRaqaTmY~L6g>npKaSMNW8WYp=4wn~R;~7H58}y>@<-*uK1@<` zWM{z)W-Ud=rNxqEMspuo>M@@Subf3!~x1jBEoEBZ(ait>H zX+==mV=(+LRm8f5CQ8@yfTY>_^XQ353LcGW%lfX0#N2`~kJPdm~Kl zzd|v}9IyV(d(i|@GH1gW(N_bQG>ny{4t3(B>SWtf1B{4U>Dxw^B8?5cyN3pok8`+N z2>s&3AIQlHMErzgx_`zQ0NGX0{KOk{{Pn!JTMG5cYr-b>6IV+*OX@P^8@BDSUlXn? z-gFc8eNAKh?;=wG*5n09Xfu!=e2i|^{~Okusg}$1iSaGD1wSX7b2yAGcLkd@jK{vJ>BR|ib_{nA49@BQl zSDJzYRolyD%qzId6rPitTcJp1Im~ zHow9JiqF*QG!otUVUxr(4zA^I5JeE)Be(5`Ki{Q!d60ls`5(V4r{K*u03P^0(C2mE-_L8wSFume7FguHS!Pq# zN9ZGK+gWy<%2sbeU(VV8H;!w*lhfzPsYdy3f(23F`G57wruXT$AmGL=O6EWBxF3Lw7gYztm*o-fW_QD`R+fb z2!8wdS78R2#xH@7IqILi$^!+je?0&H%0A%`f9HN1Or)G!GZ;;sL{g0|3}p?i=1o5v zYM(knc%lrIqYtXR%qDT4;(ugcf6W-r{2bg!81Ad%Zkw3Pr<@(lVU0oG3B3bx6H5_O zW(_+`|I<-BAhQTv2bS;aE!3s7l(_+f3QgW@Nrh0Y?1EYcpv4!YX0m@v(%X+ihYW& zg;u2Oo9~OfI4jV0ePkQv9_UQ{y{D&KwTdsty9(%h_oGvQ!gTd9Jl4Oo55Bfh#w z&JvIay?OjHGnXCGv|%EpbX_(r$`B#>y{NwE*-#IF%Te%=fOGOIFnLFF6d&oV*e|cT zKN5>9)acuHZh)D-jLYkqsW-Z=ht{_Tcd^d#5=^#D%v=7EOmE2w_Sa&sx8MzR6|m!7 zir!glT+-k0@oVrtSuBc=b(7KsG3dI4aC)D)UxE=Wvwj3di&^%_*GZGUo%ftz?yr%? z1acQR=Lc}p$vh8*`(q8F4qY%a9^eDVY*n9&THE#j7l=`Fnt;=qJYg6;JHU$C8 z@-~to+3Hp&Ekm}@%$?CQz{TK5<0RXHJNxx>N^$n)K3boQ+6>KXiU8J4@2y$?ivU|Bk-!adHWemI|;TH!pW8h2<(SI;3!eXl7)ma&~Cq`pY-_;;t)gZ=u36O!62Dey1 zq|;}G{Z&3$c@s}c&w>~@-(kY6TQS7*e4T*Fg)+i@Iq~_Wswe<5U*HtAeN22Y$!X{t~b`a zg=2u=7dKtnt{dnCwmxMfkKu3_e0OrLpy_(kpNM$X<%1--PEIarCXWkK$b#=}d^DP~Ul12;ATsGQ(+jQLB0 zMD+Ts^gcg0B(*sR$(H^Lr!si&EH7ptbV8)#roT8~GEwKo0!k`ZsuCqbyNBvof2)ao z$j6v~tcgO|7o|;F^Buw&*;yh5cv%dT@Y!Wj!31qr%Hj4&!R}09#TKSHVOa?;y}BzR z_8YFu5xmkZylG-9oW_XYw{ZD`1x5u8J;gD<9kZl!Ed*e_M#d`dCKHn!J zLhLFsJFlq=|HoyB?!T8Ic7eJ{Q=aO0N$2C@cFwxa$vlllMwGRaa{|oXe`5Oa^>(%g zs}7J^dyZnVO(%Ms^Sgb0|4VBMsO_-^mq#p3As$SFmv;Kgo;eB}DEQ!eNCYcioY7E+iQ=Y@^c0J~eZ`_cf`fjbZnSET`CN@c1EUU{oEPFUb4w2-R89 zh|LzJ9J=d=)gNZHXV~{AdDaVP9yoM6J&t$yG_Nn4%W6y(QVO}OFr zM4d6g?b=ICX`@E6W;WcglIh$Z#2?DptpI*dI??Mvow;+=T0rgzFPA8G`h!wF_@v)O z1Hf;dsff%hi{~D%Ed0KPTr~B3@4T#F<}0_>o)1;Vsoa_^@^{Gh48Q!eF`04qYR|=* zolj9UJ#t`bbp-0=8ez==CRY5pGdF#ZTIK!40(qie;<%BvJyXHE<{}+r3KSVG-O5MO zC@C@oe=vhcVc=<$$tL#}ijIyHw-m=*e0G>xd3HW&?%bbHof>Y^lP_^>`N9l?ni=2c ze5T~+2>)^lMHdS^1aDQ%0HV&4%AEux;n@o zQFb`_+%EHS!}Tul-Qql!-fQJH4yToK^<&cazwdd?PkY{wVXiUhuocrXrFHlH!&tOi zshhe8q7i)P9|VtY5IP4VmeDZ9myXRbt=EJ8+H#LSw)!_B7AY8Pcg%vzz2>_070q4; zUP}B28S)f4yT6_NJ%a{PMuz3Z=2MM^-p8+9wrOg605>;M`9r!vUrMC7fBjQ6qoHm} zA}75B9of@~ud)HWYQ??zCZ@)WFgh?Lgg4>={L0WDPe=f%%|GF<%TsD1|D^cDp_Z|7 z#U(6LqXKA(`RPd?r=gwqX6X<1%rq@bFIpNhe})!#BJ+Xw^=95Q({%+t`=8FlrJj88 z;aArj2AN%)P45h8p1Dj;aoT;e{m)bAzSj}w`qs+9Tey&a@S4rNPg4=r(;wuCu0OwS zY}uB)DqkK#U*y86=40pfxbt6nfTMrp_T#weFzjg#Qe4Gfta-a8hqI1ljGqt}>~BR- z??$HcDvn7vVNHKng#YZj`ySvNWWXH)Qlo1J{Uqw{a!qjk_NsPT<^A+%D3r1tj2MS% zW8t)?qdrdCfj##ek{pL(=)3-);nQ2LGn5yn96T)^+A7w|bvsvAnVIDh)B$rm zkLw71c#J;sF4%K0S^9UPAoA&za>S?JMef+<_`zG`yn)kq!5%;(f4S~tBtO=yo441= z^UzN z?itOznf~v*L2zU#6q5|<@Ncy57Wz`zpFS3gz85eE`W^j9f9=b|$$$%3)2{sDHbedR zfR|b^jMMj;qz`U##G%z7XCuyWJTKP$okHcl#{>8uBMa`=n%n%`dU<*@DV3RI$dQ>~ z5r0}l8KT{l@;;&Bsuw2oce6yxZ-X~&eVilk`#OGqh2Ku=;bf_Ej}RI()gZUjVkAFIzXEl*fS1wd-F^OO_8FBD3bRBG?T&)aC z6XLvf!by!gO04%)+Ie9`K4bhh$Tpl9SIg*Q+GdK>?q8OuQS8|^lSw>k_RwAYnTwv( z*%-#I<-;|Y?NYo=j!)AE!%>+v!#RJuHIi%2yj!Cf=aGB87Ztp-pBy7x4X-(hqQ~z~ zRw0ItZELT6gER;o98_4Uyp~~;Jvyj-tc7;y=-CQ{ecMS^U@eGNojnf}pAmDaRq}m| zw*7HY=!vS`;`S>YM_5>2dW*C%!(Kd(v-!w;5Rr;Xsa3#lOxi}T<%|P-5acv#_HhAN_N1Yso|Qm z@CA83A=i6WmeIxH`!l+kBkAjpZK4JPFDINoEPR=Qzf~EfN6f1|hcBudqIh`TUo2PY z@umyYg6^>6{e{fg-evPMteX1~t)GX>Id6SrORCxjOt_XpJc3j!{tZIpk18bz6Vz() z-V-hqBa5{s5e-n9k|TOFCzPVMvx;;C_W&B6XSSEWQ_8V=>dt z)F`Mhp#5C0CQX1{fCPO7#L!72J6D0F+s&- zspjmm#oD6Rt5aw=wBsE!+ZM{ko))yOcqB@~mv#DvMS%aJ!faG3{$;ig(sYd>FL!wU z!JxfuEB4C1yjMqd??|+Jz`+#%6-oAH$xGomLF>*Ct0yTJ`u!!=R)JIqJ}&ovG56MC zQN8`v_#i5s(kYUXO2-f)Fob|8ohl6i0#Y(Ciqa(@ARr*!jC6x^gLKCXAPqC*00X|; zr_VXhd7tm|{^GoUa1Gb?a$oE{_r315KCwP>aPGh=ffptXAKy_uumMjQX0dvTGJbGS zV^g^Cc{P<87qJwE)%ZH+9o{X+Gby@%Ym0&?C|t8ad@|XxPr@&a@fEQM{sVmTHBP|? zFL`7jksIaS6-ctS;Yd0BDpXWXuIY9xCJmpfKA|<^LbHQQwC1^smuSkV1P5C*{n@$g zb_;=bx;i>9NEA-9;TBpf7H(IX@iM?uJ>#w9M!V$m=>R|%9sW#Rak;g1KGojDZB~D9 zoGR35oEu&g?)0P8b23>O`>oHimytWVr;Z}FmHzGZ_?t~Y?5XPwDX`zZR?)Ja^O(P-_25j;J;;6{f93X=dk93C^shKvI*|pU_L>)H*7S1ggLsDd0IF2 zVhmAUA4}w6RUXR~tCnH7M>k|4XUB8hSJ~3U8$juEgAd0GM32T3fvh81+S^ufd{4*f z@WUo9#8n z$(z0tMysS;b-P1LqJmGP;yzcM3>=dwW<8SG=nKb(ilEISP`$3)pC@T%b0QL4oH$Aj z!Jcp*rnyzlSZ>D8A8x*tO))=o82Ae&wQf1kKc(ufCNg{ysjFZsfWPsz92 zfi|t0C%w1Wp0!KPzCdY5(srd&?s^RJ{>g@|1@Y6hYvtt^2EGr{}6PS8GTDf0TS;TAHvjSbnXra^I7kek_pi1NU%ognl zbB~cKb^@zV>UTdc(p;Z5b`e-Vp$`bsT+s7Q{N7~T3KkH0$2q$ELktL7^$3WW$zAJ` zmpa(?xg{mUA@MdT)_{LUgFnLB=>BfypZsEgor`k_>C)6d9_I|j6xb(Lr!k`zqC=rl zxR=Q)b+sqmEyAM+joc#p%{cWsf}K~RGd-~?l+ zNYsJtGM>2z91E3a#FXSzx}SJ=c8$KBgWO`jcX!NwTWl`Uz1xC95Q>}UkSNU+`qgmG z+`(<;UewkW@+`;DN&n?7N|le5v7b?^jStKeUQJZ~0ug(+AkB``L;ShEi~13c8txTC z#UsiaG$_c3XS0hw#i~x@X?TZ*N#G>?d=@Bxaj-i|k75xWDeY177?jtT5RhQ8j-ueq zY0Gn8EmyH%B9oGDQlO?5g9WijqV_%|H~lO@z-TlDzkJ-fB{qLTi*6Md7h)brN7TuX-bRCIEO151XuR8XDcWzQuP-bv4z?90@E@qL?a0JwV=_eF1V=P8 zdK=HYU+gcD%iX|VU`|W7m5O~Ky?dA)bjQycw+ULaJjmMZ$x4hRKpS-WYtR(5`j(8r*VG^EFa3?TLtDiw7KV4#w?KXV$ zc*Pu8x3NvArA?sFW>fg_HSA;b&9ci1#6HN5A^n7@9=T{?v2~m4LX%sQ!Y>CYd^6J8 zhdb!@9=@{YDv5R(U zlsph!4}HZC@q7kImq328rCB0GnsmW!UDV?AP{$k(HSpw=>SnS86F^g1z?c!Z^`#)?4LC^g%QoYj56n{v^N`zx8T=^=)ea> zt5;6=3v3%-73FVBoVx@bl_>i0X;+RpbCfwIy`*dRa+mGkwB;}@{+MmI_3p)dYd8XW zh*W7oF5!AhAidrRc{CJF)}X3>R%JM7Et(zNVefP$>L=G625JbiK7rP(^usn76aeGbwE_ z(&L`eZ`0SD>@#EWry~&6_In+cyrp?`t{=$?u_i0bxUJHL6VL~*EY+-3Z-6a`Vc6*r)D=Z_qWB)W>q{v66`Mwm<8s!==!vEB*>pZ52wf z-jN|^At%Csk1UNp_NTR4>CC-8Xh?WZMps*3MlDwpVOzzAN_6@|u#4Uu9 z!`|liEIi_SU3K=H2`T`FSbu&i{R1x;bJ~_1-L`bo40F+z#ywf;;Y9KTX;{UqXo3+_ z)_;65>~jCU(qVM0c#WZYkM2R;!%_28TO_1|#HF3wVd@Pe2+CD`a{8j}p^Lk3ftgZ` zz%5gu{X1bqK5Q-v!VQ&;-v|IAKnS46%{)tFdYDdbQttGlL}RbuJ09_t$s+Z9K;LqF zn)eX>o#@khwet+Nx-qsGnCKf@pki>5x#3x2a#Y6i((l<3O`p!KOWwYcPOX^k%Ntue zrovGnx`ZHXqxiwHMUE2E>4D$}j$WL__3xARUJAV=fAr(>YqYRoZVja>cyuU|2)e^` zJ!iRv*l~%nQRF!quk-5m#L?V>KUGUb`PklKABpfH;$bU6lzf7ZzA$*Q_E5LS#G!Ck zhpfv<6f@VwLXLsu_S`ssSmIFrar&TSD9JS^OHK;sXNy_O;&g}Y^X{)$mMN!32&$s3 zzN+@1e{*tVke2al?Yrwwzc0V`w$_iFs<_LXkK0Wa0oFfdVL_!e<%8>ogTeMUV|+dE z4{l~SaiwxDG?uG;wpdi=_jh<#F=#OR6@>T0csSySnapIgQOtu+v&fOZZFcgyz4iUP z^WFsmIn(>A&M~O3@urE^PY>|Ie@<1wsFpA(-+uvfypV_8=laC6#QG=d60xd#o6n*gHxGYrGD`OO#Q{=4dy}lb~xeYI<4_gyq`4%x80%OZSSeH`ENf{QR^L(z(7ITV609q#j1n;J$2fQ z$S@ML^pvjPnEAuW`0>xXwR#MQ+8G zJn3!iVG(bZ1OXZ38H7XcB+Hr@VIq+Y;*|EY!=X&Zs@c@>qT58W(l&I<$e*)T!>=-% za{m?s0bFt3N|7;om&@bZwH@%e!}1r^6O#R6tEUu*IRVx#@g~xrbUOO}vQos_HOCJKLh9M?sn#_IXv=vqlCCn|QPykqk??U2 z=MI;<4a6zNmcuRZw~Vca>V!g<0~`xpGO$(x{>U1hG^g%rE$(iLbm;W-Tm{)U1&kN6 zg>plD@eX=8Oe|8SmzP5USnAXWY+oGR17aUXWjSY{X=co7#QhV&U0TlB73#}lA_PfBO zn4(ppk?9~!d!|JmOlUz>;Kbslz$lSCsj2;w`H+>?PtFxFE06oiv3YNboWGKl?PG1f zZeDSZVHrl7x+?Os9##KfS06FSbqB9`hY7ShRD7hr(wm>5o-i0(!g~P&dFPp7o+Hml zmO|mHi3T;3xgbsWjL3~cj_2RghTU2@RIhRPKUE2dqL7RETjcawvnMJD2nSKg*HrD( zI?iUA&1_4jrL&~xIe1CGG=_>BJbB<;obW_xR85qRxt+d$eeeb-+76!SxcUneEKi8N zg?4E}VKp%^FVELQ8kOP%W_&}lqDG$BYJXFv5)27qQNsJrHC; z{(K{8^nx6R3~h6|WwC?1+40S`Q4{;@DbNq)N-OTOdf|0S(gU{PL9XAp{%APFt14+P^tZ4l8KK+Jtcc=#|;pUU7g`Xb_UCcRb z0u~{yhAb|1>(nv>&l2?5KJ!>DprZpD@z656i*y**Oh*C)<8(9A#z}JS*U)wISvAIi zAQ{zZpXUQfHzD#3dS|*7H?=dbZfs2wAnc(8f{8cZ!7(Zi>@6!?qsg{lL>ki#$PVr7T|n_2w^x z9la`M4d+zZL`uq<>Y^6*7VnX+R;w|1;t?v6t5UBnkwe)Fph6hGVyqHsIPAC~)vI0K zZA<7f3nsAEoO7ZpELK?^Nw@Ca$2&{o5h@cTyC!$t~ zh1foe>1MJcQsao~i^&WaTICH;rCxgj3j$l91)9{;Xfi2pZfz0)6(3y@BNtZ}e4YOZZ^faS!@{#Zgm zWn<^)LTVt6V#+0zqhCEZL&trCNW8sia>_Bg+XBazcnjoPTa zp##xmg3M>=eyT8fX1Z5;@2*p{QsFrNC7(@}EMMs_B-91NQ$xu!%BPQ6)XtiQW zV-!lR<(z%J^GRg*}gBb37`GG!2l8Rq(C)s0vsv#zK9Zd$jUGQyi&5 zQk!5uc+^)+}BH}5|vLQFhJK#qaEqV;F$*IM(#?BzVA3OzKXVfF8YjrwAIg- zine-&sZrBzv>}hT*)Wklgu-|E)5y$b^Ig=Z9a^0qSdCJ#4j;wZ zws=al4Jy68$&IgA9F$^gmmMI^F6or)TSnJcRM5p_n;A9NX+b8(y$K`tzh`FRDBjE- zDOy{Kh)kGk_Zm}G@a&84I0yRT%*vmCI-{VCc^DZO|L~Qjk4rT4 zZa!83WxW*Lco*sCu&)1OLXb0DXPmR;y{?`$*~^o|TU6STXY7t1ABqFD`98@Zkoz=t ziv&Lep%A$0wxhGoHKj=YF-?&WL=fGS07rQb02~n&RX(XsfeCbYr~%rdtoa&pRH}vM1oYaE(Zi9ZYE7|Jid0cEl(+T+tHuyy?(^K1puWXfXP1nC`A|JJ=_T^tj9TJ zupVa`c7N{sEWjps_qtDPr&w+5gP#en?M&96EpzS%A3}10AzrD7xoQHR{4E-W9YWRq z8SPx-X*}G5IoXIpKRF8>)6+jl;lE#BBtmsH0o+oYl^kk;B(czJnnGw?$Oy>o?tC2+ z*rNiIe_lDWjkjbUbc_~z6j)tYI>A(j4h~yfXA(@nQj_EDu%(}(KWP+LRg6E{XQ_Q0S6>RPwd`47M?B!I6YR??geNuAnV4ZE zFrh^l?(Q0k+4Iw|a0dohh|L40`r>;L{Z`A+7&R-y&l%AUbac`M9;ro}$$!f|{ zi@>@dWi1~~9)~~Mz*^~|?8tx2dVa#Ze>8^c@t*69!8}xf>qb8P|oyYWVmqxktVabx=O+R^4==8D3~Qx4Om2i)m0k-=^xF?R}= z8&lrWBJ#qk(n+8)XkZ<2I+7LjBAhe~6ZTebE0Y^Vi%s%$FGSu%FvO}p#<7v&hWxuc ze&r2+8b9Ves$xYuKmcGP_KDJ~&D;S5r8(d1{Ks4gA_b;R76!R42Yc}?yZrr=&iC8Q zxknrAffk3(dZtWAk{7dl`6jnhWAmSpO5Yoz^Su3j5ClET^PcNCZ>~b^t4%pKHrZ;R z^$Rk$kaMr^nN{V+!1#qfAE@-M90)2|eN%ac!}$RIrUeM_=3&ycQI6`T3`3cb!gC6d ze9F#A1BdZQo*I(EAqCqr19vNfZrXjMlY1wFd>9l*EZf~X3(1u`No(MnU8Bm6m2%C| zaAS|YsZ&1$i3DZ~9i(oR)Hl6f-qBtrquOpE23xQ-pvADUTPNq^-wCQNxSlmzMM&P7 z*)!w_1t~@M*^N-N-NqaJ$L|)piy$Gm;(@t1E(%UnDK%6~ddSMj%4?19y`HYbye!QQ zSTN$x4$Y3W)0YVnNgPfl#9y%fj9I46${v07>LKdNLt8G#Yf3J02! z!7utIZ|pr$5T_?8BYs%skcC=DSOoYBV9Z%VTPOw6P0_R=vDOVjuA;89ypJ|d@G}EP zD4bO9--_Jn8`v38E7jV*;)yhRZ3k!BZUcgw23cTXQ?;LBKs3^lGi6SV4?4_Tv_^~Ipf885y*n-ylWb@igeP`a-L0Ab$1v$RNXEXZt)URBr z$MD{e_~*u3E6V7+js`P#^}5d+$y4tuoAhUCZ18Wab?m__BuO9VbJKsvt$MQ`23E_) z$hW^q4?CU-Y);KK(eRLoy~!S`Auck%0`dYIbTreTSUdJ>ye$!3p~c)ETwY+@vt}KV z1GrOt?Kxb8>2}$luLTQR(KB7`8mv7Tpa;nfjWybS-Dcq&hw|m9PkALP7BqRsBoq0t z-gxHIy#7UZ|9*gxZqT!gI4?bOPLz=TH*F$~WwqC}q9m|3VIoZf?YCvoL(uv~r6G-8 zl4UNd_h02vl zzn{iVzLlfB6Vqkdeb>sb;v*eB1B9<<-wb-Ln4+gTG}1wvvpGYdnRYbdQ^{YdAFLGRSwp{(uqsw2$P}@XWbv1dc06ns zFW_C{h!w{fUA*%m^1F@CXEqIghGbdpm0-a4xt!9>U3>IcL$f={@`IZFtAhD6{+PsW z+#&7#$2XKUEfBz9cbOb6xKEL{MQ5Bw9*6XYK`+@r>?;iqqHcFC8}HmGdWuO~Hzr6` zkRc=MzaaBtS3H_587o3Ti!{~GPRqBo3+k^)Ky347hO9nEH9$AnLeh;C_ALZ3A`tXK zaJ&39s2*Fy&OogQFcu-rBWEh1hd}93{Lfle@`S#si8~T zwZ^(XV9oMWK6JSR_X1-9{PMYeypGXl)N;#tH8_NlDe;bJ)EO1>4{_YBjNNr zcMf=5zm?}Z6TlTqT{;)}M;dQ^HN`}m$4pMnXI>9wFv?A~rD!L2DwB<`=gNbOid9rT}uA8L+5oa{8!rSBU@G^T>Zw zC>UL_4zeVnAxN)a=V{KLS^0*J-|n&tS<59kK4T)ydwC*UaCGLU4AKbzJjO6aE8%2h z4zq(pc?fNCe~@F{ zEXVQxxNTgVEeDQkzrH-wJIM#(9MM#)$l^VmJc5&$sgF6#_TQ_nddFb&$o{I_OK&)l*oQG2%EM{E zIirhx_wPvrvd5U`PF|=bBo}-+on-;$xV^|LkZ*s*mUP{1>sy<>&vAWAr~h zgpYFu%Dt5#@bc%^_c~)Ol{xmMh_Pdn3nIr*i~WXHC$Q_jd;hn`L;hzkzW@_@vNw}S z6YVoMs>`OQ4V6Bcv=V-Z;>lvvqkw=w_dx%o5E7SB1$7mJK)VmU4zxb#if4V)tpJPhDTDs&i{>y>7hU%`93)mHxopRJk z=O^W(lnQgWGRR#MsN^7y`}%ThB@irlWs}4t6(;la#7E#>OorYpTl|l^Hx#5m5BepcVC?Gp*W8?2Jbs#`SCNwTb zj*&amZJgLi*BL1lyGvYSsWHCKf2a22@2jjGz@g$D-@7+tvqc=CHiqw_1J2Z zqod4%{?$PrJ=?R9@;Pmpjz=%4v*xo^6z=bKdI33qHt<94u!JRs?8ySFF=uC|70;Nb z&YrLKx+RAtx#_4YfVUBg(QZG$6p}Gkl`c3hj_fG6$z(X%y0S$n=@Z;GV2pcHrf6cb zu$vb3_pBK%qdDp%2I)vwx!|_6D~?()TFl(H>g){c@bcXM#Otq?snYdoy8T-A`N!j|>72i^v^5a&;C8hjl)Txm z0{rj(0=2bt$YZ~ij&(F&Lv@Ec@jH3s+t`lI`Q;2q7|=|8yl)m^Ss$d7sv$&nh+WyY zX<*>N4Mo+oFA-xD_qLyouGH3!!(-UfPR{p`=Ub{PoC)&MKh@|hi3gPI^l&%==}VT# zynR(srJ=Tng=J)dVcr{A~v3CLfod7MmWP8L>?g&)PPHsRhaC3o$sKPQLcdtAPGyB6vUJZV?rs9~h{G88)y|s5=7+jB2!)yBm!kxoB z)I;%Zb{<7~*qS}Cs!cQ6yBbY1EfMSQnWB{8qV-iZ1%YnkM4g9$dZvy@7(1moe}RIq z5AB(+k_y^RPQ?!*vddHn%hU5^>zE;8?%HMVAAPGHb9xlSs5;6a+;ga^N5!+m5izv; zu3vd3dB1edYI;{9nSoMd{ajHpRSK?pxzZLf;=s+n?R+6iSBM17HoYeb+K0NzNP#4rvJaj=Kp|k*UErNYMU&qv_;t_mtPzto-7U@ zR_xb)dC!jRSmAh0hx_d}xcDgl8`x|X=PlpPvB4;cL1+`TB;`U#Uqp>xrfEFlu+#Gl zv3^bX;4!Jf^=s0=6qGMhjob)e2$$)mMExK<)jh+kRJNBbSwB>hOf#xht_?(NNiy?P3(He~|5O38TtT zitI96HptgeV9E_&b|J81VM=T>wn-u%bymtm(!0ygYIZ??*4l?R4hm|29;LAz9?*2t zx3~emVJ-V5k|o2}DX*7la?-25Ih zcZe1G$=r%=@&tdf)8JsE^C+*=&=B+VI{o_fm1~9lz$)oS=7XIl8$23(A!W4qdx0gX z9YInQh8*uqoi(>Q37k^As^3Wez23Um$+vwL5?cIX%%dMxIoIxC`zE|8j+Ta&VqvcN zalsi^C?kElB{5+v=^zzOH)ig88UrRct2@IxmoLUZkO27&Ze%)^4-4r2e!>&ri-od1S{X1#Ago@7`PyJj~p_JHea1^Aq38? zD}oVVpgnEX1!*c%oCec1h&NKCXH>Y0ImPPv=iSIxA#%EUk8iyGdlSh`vHjqqNQmK> z$96vquqISJZ0_eeVX>H{;z;LVGjV$>_sn&U+Tsr)^J~$^c&T>jL}2!e?N^wOL<`za zD-AU4k7X`OluPsdjDdq5dBzW?%76ekjWtMCCYPWcIL~u1O!csJhB?p<7SGvu5D-g)puy z4Hq@_p5efchCXdyhsU{eZ_l==A-(_lZQt<3*e9WJWfCtRrGB-VdJZ5zlMZ zpCay)s1)%W?iBB2=OS2MyE_+1h(o`kE3ikWt}XcG^Uo1s8*$&$D8n^>9QpbQ%lNXA zXfmC@zLyXt0Y0 zlp%7wrUJWjvB!P8=NIU{WH0kg(9Vrl5=%E4@qFEVgmbejNwfIl%#5eEYEJasd+!Rc zS5-BlJBTAR)T$t#)r6e~-!xI-3`)Gm3TwTq(2k-)HJl||z529Vd zkWCF*1u3x2o_@f}BOOz@PK~M#OJDO*}szC}Tv~_tHdV zl*EThGExnMfXJf|dx<+buEcfeyId=B?xv!hi#Cbe1NKBB>p!{b>v?c{cbW;k3HyAd z7-SVs=Cos_N{aVAOfLF)tODPw^gDs~DeiyMk9Lcp2X)s_FTsP4KA?%ycuk-VamG1u znss$Gk+aI0bu;Y^Oe^5!FVFM==p6dJNd6c$|Lxc${=?V|iEJc5z(d}E-_VA=s7)AU z{4B(NA+|OcQQ-RW%IXRKWUMmpr&n~(>)-kk?f`lNNR|guEJvRcV1Q<5!e1bAAVAyJ zpiPM(YG)CsEUTK#NJ`gmI!U~@U}|En(CK+UYKxxIgTGOJ_%OODc?Uom?zSZ|s0f0n zp+s9F3c2#6St-hQZo46TyMR3+!4SKh#9;ak4Uzcwi{oVK5}7zLRpu&&-4uMO@1{p{A6Y9IQlGG8Y7i!S|v(+EKBIRF506NFUpV3w0m2$|)gs zQ`Mc%X1i>!K%lobWwv;5*)4vEx=?K-UNNFcvc1l2@`&c<=lXqXdG^QNS;F{Vf7q~JBbPG9MoB=s!7gI@;9!NUfexIIt^XM@iN7{9uO2augYi=`UZW^ejSs%;zP^mN5ODJlm_xSCr z(sq)M?y(|+#%Q76=(6oA<3etHne+OhR;VNiKo`aon+2wr-dlVHOz+&I?AQ1|4gUOxuCTw>TK(Vh`Nuom&epr-E8*|_ zSuwe%ZPJ5LIbv(UiA}r6z)ZQ9jy$vGRJ9UaE0PZd=4B3)2#wn*ZcZ^$?7`zAgdL1_rQ#z$$g)!Mt2z?1PO>(PV4C26#AG$Ga{*Hw!}2Bf`eBzm zGg%w##x^`7nL1$?IbsBo)g|@w7xiHn+Q%5N3y8hG1VdU5n$4-i z$G~ZjL)n2dPv|5C7?czE3l#4JXp_jN{*;&$B zK#MLy^lR?;Q_qDRzu(KGJ>C2t_s%-ToJWp-3_VSRtDra$)4O25>Qj0fDI zhSahEaaAL&`Lx_?)7H`N6cVQ)+oshrk)b2CP~=g+v{we~v^bzR3Ao#bmYH-^%$&I? z&>elYXfQ)5F}M11pb*VB?zbO1q5YgZ7^3M8ZLeL)$3!T6VULb-L$6a`9vD+lHF#9_OLe zvYSlG?{&VgRQXh-ZI`>0Iet{gcW83O$cps3#1(~Ipk0BwLaP6|k58}IY@D!Vzd+xk zBYuHG!V_R37ty21f8_s0V*-v;d+ z@7IyuBruV{wx-ST(9Y9+}z=fENy9X(8o-EE~(XO6_@hxJbM;$(=X7k&!X^fIi~VHTDf&k z#EXeh)76QesJWb_d3t<-+|Uwnb#sdi(FJn1hrnRTxFi1)E_gom zYuET^(9#dU;Q0Rl6rr!uNMyjkecHepVGi5BZ{r0#AeIv?(eVp}2L+r-19eu+JdQh( zn5~61Is5o76L(cddhd}+o~eFkN+ZRAH}bU&A4omp?Mg|G{iGXbB6@<2w&WaKiOjp3 zUt3#4lde(&NQp*~>ED9lf940C2F#FF$It@dsW2BPd-3$vQ*%Qf-^E29WA>k2<&il* zAC7dnWqrAH{L$}LtRn23z(D#${lu&#*VXcnfzFWkb7+3_^CC5nW6j1?~`R<`jkSkjEqHTnbJ)2<%fo+uIW-ihc6jffpUDPuK=}(8AOZ>)k8DhXQ5YRSD=s&k zi-E`Yxx@IdA}vHD0ldCxF6i6py;C0s57mfou-h+5xg;~-MmDlY7i9%IZ(8U%iCPqF z>gp&%la4Z!q0DHS;TQb^>Aw2XxMUvKFst+nR6Lz>aH-{W3|Z1OXp=`}{sP&z zoby5t1}CL|C-q0QXDpvZJAZ#2U;yp*G7`);hgppL~(i-Ss(+FRG|J{8a^>J^C+ zM*2&EZb49|V%7)KWP_%JHAt8(!zV;YyE>thvV_xO_ad-Tpi;-c0sHqD?f>BKcvOE+ z{Z6n0X3V=!uciT$K+m`Gzb5{(6j&<-+apr6X=X7DKaU`#=HH#4*YNxTf%zKvS)GT( zB^WEZ)j?Z!r3Rf>;;5Rpm6el?X8!dapJI!N*!h~RUa6YiDci_C>Zbl-Za!;&@^s4} z7Qdt<%ihYp$MA=<&B`}iQ5=V`@2HkAQ6c&b&qWT@#PFQxcGHz^1B>h&UDk6eJ!g8= zO-LA6I|C!3*nk!a74+}UGrA$%r^M%GeT-+v*%ZSlOpg6TF(WM09L08rg z$r)zj6w=E>94rv^HNvLy{NV2aD3 zw;uM75B7L)fHNCJr7zb4x0jENC^>8|0=1`q`+_o?fXH zi8M3ODyv4pLj-oJ?4K~2oD~Pm)_q*Q|7x9J$W5eMM}bvw#FREl?#@OUPKx4t!vLfW z`o@tlltnS?M#9s#R`Z*I$L!p*4Mxtp&tiFOlh|WlC^-z^xJd`%!4V2He@*NEfBnh# zkB)QtO;(OA2ZZTYxPWOUA!vW&AL1D0`YW2D6@L zEmk#AWFj`6dj}Klk=^v0aoZ|22p3}S#9pAAI$E`2!hf)#@4~Fn{md8L;PSa7Vwb+# zluUgwvoX3vZ~f%Ua7xqh#jy6)(V1*I@7k`D+zOIJ1bD3Q3nNIu#v)U4d;6<(tH~25 z>M-Xq!qn0bcKVMWU-`U%{IREUF@UGrOq4lC6sZi|i_kyMf%%%Bf(KFra)R{q?wYAP zYdbMkCwu-r;BU=HYKJD%27xt}BQ=!J2^p91S@UOde#5CuEats2W2dn}4-)mvAyVI0 z!_Gvms4RAXM^pwzI`GcIfbm4|0y!a!#8b>W)V4kxj$3yb4LtmU1SrZ6=;aG zT7R#b$T#6JVSkQ|RtF|@9pWCbBUsOxh7SSl8g`MbN%UxP`%aRM-V6?7KpmSJvjV3U z8I)BWqk`a!I~k9aGf^&SfOI7GF zzE0_|6-Zm-DCp9i3*orc0zxQk1%QjqE>dL^z`pdY_a<+DEcqnk+V{kXKNFpT!D^e>RXIb33I z-`D5#Sg!v+6+q-ss_8b4z%rm zRrclXf?xMO>|(ixlI?z1I;WBN$Z?aqWr${mwFj3EYyJHh`^J1w+iEmTjhf|U6?>Eq zc#~Se_Pg1DXp|NnF83FM1yllIsx)0#S$;|yOUcTlxV__6{Kc8^AlBv9i*@PoyL}`) zlXd=x5hiDgUeq?Og5+ z#?&8OuD3x}sPCjzC#0djVajm`SfuW%6{F2#HG>zNjjw!CL~6VeioQ5Y4nImamj9Qq z>0-Dr{h7NO`CH(eJ+nDIc4=FrLkj^@ zHV9Z^Yzr?_yy8z9yw?v-S3H;Qo6ygSMAfpyqfe5SnYEQKi?ha$Tu3#aBWm5t9eu^a zIwkA_wWgFPB6?IM!DZ|vel3Rbs81Nn3zsX3^j%;1-RQG}3niHSMVjq-$=ucY;Mm^b zXhfn+$MX#1RM9nhMIA>Ij??(YmdAtT$<%A=?hb+?5%8Qj6QQIgAQCamvTQ(r7EwfX zhe9a|?9znC<;yF5)47N14zu3(&j?HVA3W|6-YV)@mt=}DTkwH*0y=FAMlQ$fut+hI z6T}Gkf3sW8i|{|}OOw(GB87;E* z#%s(at-Bxgn+V!y|A`2W0a2g{*8hCrRXVDh72w|7zq$7XFjY6}if#o^-q(*{7r?Mc zAh>#b;WhPY=}0mWaRph0{r710e~6Wkzux@8@gvA3;S|zu0W}P~sh~!R+>JDnarvs( zzAiF7z?(`RSQ2TtI=G$xYN!AE3LW_U?>E&(2D~V;D`tB+ib?Ar$9^(--?<7=$Y7e9 z%IH>bR>T@a>U=8404Y@&*Qw{4F8wpH2|L^5@0(s~T~I70y3%*1pISM$_6;D-hra3S zSMR7#J)rcjVtX$f)R-uJgoQuC-_99sGBL4`l1ah~mOZY3*J94=A3Wp=!?wYuHvIDz z+U$&NYK*A*!b*Ga=c#)MRkv_>U->Sd8xWys5qs%qAtejat@Fs;$k2~!cZ?cc&pLf{ z*u0l`^FH3Z`AX~6{h#SKpcE^^c$oHXS{A7Wrzd8b#a02oKvE{3w82b5D5@I|qBBvn zd;$Jp@U+hMEy=25heFfF5;ZS-S)rK*udyuI0nz5VAedq9if(4d-Wpzoomg3MQV zf%gRGdE6ZPWDY#lU(B1mJMA7i_}w0u#BdOuh}`a`@#Xj zbl@|}`*!$k!)JS@FQhlMCE<dKlO!^IOiPd4AFtiQhhMHNI3zK8qAz@up=SiIzeM~;ew z+lS245gFfE&$1O;DJOpQ?83f2ktZe&6Gv0Ec4;)fcfvjfA{s8sw<`l(DmUrIORA9Z z@^Pr^S^->tkF!irH2|qv1)@ZdY@@9fn}s#M(nTq36K0`iK)=f~Vx<&t?$t4eOcowq zXxR=fo?ou#9lwF54ad8kRHwV5Z~ngsi8V!2!xawGH7a^pwsb0deotDO;J#@#nhB%r zyoH}ib0oN{Bj+?_ZUAvYf}Jni9fLr&lM|(qlj+_6_^NMY51KBwJ`_bJWYx2I(fjar ze<7Ux%d#7w%Kr%u8&9XjGX6>U7kNWe5`GK+$mmxCfbjQTlL_PdYup3K05Cy+X8_WG z3;=8om}JwUya>D7g@l`te;Ua&AzTxNEXA^pTReoY`{98mC{v_^& zVSm>4>43J+|1VdHKKxg{$@-rV6%Ie&W4sEc#LD-aMHjx1-_UQl^2B61c z%aIaFDCdmJqO3WeMP+~BsD-uatRs&15-H_Wuqo?J<1I(JbEb|BJo%j%w;n*M@_rC@7&wFHvbK zB1o5#V4;hkH0dZJ9i#^dgen~b1QdiQNH39&w9pYjkltH@ASIAc10?a?zj@Dh&di+o z&Nt^<=gj;wYq6F@lD(5>?>x`_T=!MF_~q+}-=`gu)?91^IS#+gWcw{FB7;6$I~Twa z^u?U@$e4YhaO7h|@wF)VC#$B)wmxIPWPwR38*d@>-*{1K`QCae#W5{_EoE3Oc}vJxi%S!)yiai}GvS;_Bh-89=HcMD{b0#T4U=1K*~`<0F(U1*JYsu}hn; zM*p0M{7&CjgTEYQVKRD_C;B5zB%j8qSGvv#rX(|}n@{NG;r0=2jP>4NmSC z-6UzBdXRZe_3?<#0N9AQ&LjQD6izOEPL+!)_sn#>@bmN=C;iG_o9O(!*~WoJV53SN zIbQgiJOBUS{Qjewp8uuo?bu&}FzD1ks5n^w3MJVQg`oE;Mj10bfN-ZSf(^>{p~#Ar zzHvBp$dV7xIB*lsx!1Mt2dHRI2VRRwYS}}^*$d< z9&n)Es|}j2P^V%c&p(Y*1vBO+!BiGjVK+A~(r*6BvO`RaiRY$)RDh`2z89vB6pqzp zBU@qgGY0K)r=Qt4PTz@xD!==NDsyVzJ1_e)P^9}9_6BC{#&M%2Dj{t{7o0U1N*T47 zj|2ImqYw_IWEz58d%}79DJ{pO*+W@lc9NUS(wXoMMGcxcHTkasA|FfLm`OM+>J^oh zl+%e}mOd}6m~oe;y5iAr$DsB#wRQKM*B!1nhZ6pWQ$m)$xhc#V|`EMladIn67 zu8DIyHPey&0##WqgDsGU>8=aK}v?v!XL%|LF>r0o9kIIROu!-GnYwzCq!uRl8o z9-l7|Eua*79hZ78IFG-4;CiVlb!<~tba@uE@VN`l6UzEaYAEQM1>!#GHz=3nMOGUj zi{yeC1)r{IpUkqt%a=llJ~qcCVTX~^iL+{q&_u#ACLf-Edzh^L3HHveP$juTm(D46 zA-=;)pd9_yhU8*{hF>S8TE)}@X(Ej^&y!4Cv$qjxR+-PfwzPb!$h!SiKxPZcu&d0` z9taxra$|H9D_VVVJHsjO+Wqpcog+V0cwJj?wWS~%vKDb5O6jI>%14sa?Ep`+u}@wx zK1)@#j|ryQQR=gtuq$8t!AjPLKNC82qncTW{o#w?m~^!aK1|e-y))XYQ0Qg}jx%2M zZIXwHF-+V%Hg+*k65ea28wq75S>>*)x9wYAG|8u<+sUD6JN1yurrQ}8+zAReuo7KIaZ<0$ zARYp2LI@4XV?e)dD0OnNE^|sgTl2EI@B1yDW;+_T7rG+vJ%R3MEn$cHzzrKg@dq&6 z%L<^?Mnj;@%YAQ-pTHQNYYJ=(y#4RheF^m!3QrHGY?MvT=)xiA~*_5Vgs1xri$}^MP@t6T*tHDkOz}`R7b8xxi zv9Av=aX62W$7LV_=t(H)l)y8LNN%|Avh(UhqfFNu*-kzVE;mXi@GM=NuiDDF9!;+Q z2%#H*vJq_OJ`yVdN6fOuCxs+k{f1__F`~)0@_?E-BVQw);5i#3K1J4F?ivV68ae-L3o7E5x`k?q~IhlGA2w5RYlOqng(Z zyp`aLAbEk<_uzP{oEq?%0xzyXg$X_}jHUUATdyt^ebuHZ;9O=={ycpUrd6UF$b6ts%AQuYtG$`s#jBYx zSIP8&QBqxo4zE%&f1*+E!ZUm|ET|zY9osD9kQR7}lC~HCkvlQh{hx zibwaP{fSXbhwjv60OsRY{!B^5&&dZr-9-I0^QY>?8(G)D6ECdU3Y4l_x-?T5Zuas`+niw^( z1uH3_-k>dJywR5_Z=@-Ah8LaF4n9{mD%9k9`#JHPV|EknGzZ@?jNGFW2vy-kRxeR+ zCM=dNnZ7gu$}we;$xd8@#*L1U#dA+(fH4W5%?c3JHzBTcpu*=*i4fHq@*N;f!_D32 z-^k`Y58W8}*`@To^RdLqg{$3Km}B4=J>W6qdNqogpw&ysgDuULA!mm%yUt!T`ukyZ z{xu{+--|SPM7Ge^3MVE+*gD7_8GtHI>^BJHQZw`-J#KzqWI%3!nh$fr{iO0bAMBT5 za06ujKNg$M)J$?}6;x5kZAXEE!TtAz6tKm;!CE zVgHDowZK(X4{y9{{IUfB6-mXl$$agBwilMsA7j%x2O^9Wp63%iRfp2rr&8L?G`(|s zuljSOZNBUH^y*<~QltM0XveCW@jVdql7f&Q-;&{pgJ{+?)zzT%YA>07-&CsP4UW7N zh-1-J$;&0ywwBWCf-jQ}u+3#hG8>7fID8*UOEySs6I0qfLlb{^z_~s!wqt#PZ6P;F z7i*bJIB4TOr?7O?l{MTE+ZQh~e>r5J?n=}hSLaxqfIz`M?4&ME8wI{A*TLYVt@Et% zFmaOI{*J4ngvr|#SE&J00^dtm zXy<1m^hfd6ibM+3o#3tbW~*}YiP#Y1sK_2pbdvr+qL^ev<2uxNdwj zA`88LT0Pw5xtjIl&e zFkD{QvA#kJSZ|WL0CoceEQ8{7c-`slo!8Ct$&VP6xFCk@VN7SX>!vq5X0jz_V-A8(K+E`Q!k|&T~2pt2sBnC)|Gw4XW5Ac<2 ziQ$XByPyRB6+QG>V|85>e9}Yv zl5pm0&j6gHZb7EE*nE0>Al&Cl0Zlj=}A(pfl8?>4@TT_SRG8mz)RAs_oB_5 zSHf+!7np7f`Qt#0Lr?*PDPb$tD}jM*K%gX6jlb`7LFQDUFFcBU^4x?$;^b>dpLm>r z{E=+4I?4>_5_n?i!WzYp*kl|ZYbOWn8PnWlhGUDqBB7lNn~7>0Ni){WK`)%v|0CSu zf2#5H_cxe!lDnaj!%zry%((f+*wb_;MX>lwv!bf}@5Af>C+y-H?$(p8z?367^VIwW z$cc61A;*^k5z^DgJ%8X1^0q1aK{=lFEH6N`B0(T@{-X0*JbkHgWat|^y3@&(2c<^> zE2NmK#WDFh50hd)$_UOiSdsg67k#^6=Mbh;01gcEnA)B`7VX&y981jGf(6btnMw}X!h7$Nf{#s1{^0xBJ8zM{N4+Sj zND#emAHNWbE8=~UzZ((U)QY-z#$FW|z>_x}bT01sxhw<09{N>dQo z4_{8%&qS@tvifn{-0(&w-ItZBb-<+v`l5yZc;#=9Aww60B>NkLH{5T$g$ghV%^zU% z4S3SGZxLt)xtg^!|F|T-M#?xOM*#OxV z^BYtYjsuf3ILN?;FvWGMFB-U|2K|{E}?S$vJ@eVhe*$|c9}-(mzBG$ z_#HsKk`FVMeRb!w*!r1M{CI6>gG59B>U@UlP_7@CMeaedDqTFocjFCUS2N_cytDai z4p0Uwy*0=A40h9l@RoGUEAyrw3U@ju=QTWu@A=?9hbKh=t`aZuIi-i2aQ;-g6%swE@zni$(9)}g25*veO#8>6 zpGd_e@Zj<`j>;w$&BGzsCR-3g6%7wx^b zBH3$FIgBTg7@X{v;&Q?g-t;bO{02#RPu;AtQ38@~Fn`ImO8Co%ImfItdsf91*OaY5 zC++lEOzBGw3OeqnyY{GB+jOSZ2;;fA24F$|hC=++s~b>Zn973~#7*2h=`pZlnAmgC zIy3w9W{+5VLu23rAe`DTeuAne;qL)zZLB>X+*|n6F~C*IGyjTUa??+zzC;m*g@{j%<+^@?^0wx0Fv5Dm$d>@CB4W*`4*PFAmemNSrkvU1n4AzaDL9i`baB# zU;q85(dzMmO?lmD7(fssq#a99SrnamQFtrGZ^xI_juCMgwSGzJv4^3Y1A!^Tt>8x6 zJcJ6Rq@>XzM^#$Zm-@V;&d@|XB~M!Yg_`>)qeORXfauoj-_!NFd!GOf_A>jmr_9ry z?h?LqYf`4^{Y|ZmVEyVr8m^O}pIpKcO%L(8)yi4)^%kRj@70l}lHpMc`=hev%m>c4H(pAGqrJxBXBBt%pyL1e+xvgv z*uP|7%>s0G)!(4eJH!a;dCallpBi3kAK~plM{S-`gwX`{kI2iTw!9UPkfb0y-SW3SaR{jBNMVe#)9b)8_P`iR00YtW0oY z>(Elp8PhRM$E{(El^**4-@Le5C(d;z({z`}x@36u+u5dMpZQHFGMZ3bq@Sa}wEUuw zOS-nAxqGa78N=qC0D=3r4Li+)@h;+Z;H6%^k1J%+v*4cm)x8R+8$68D)cUIb2Jxn^XPUG z>O2pwHYNF3-VPsdf3w;JKKn?=+r{EA;@@HJSt}ug5Czp}wwM5ha_;h+e6IiC$hMpmTP)q9O{UHQiPNN2qo;Gr&2IS@94?oKowMp73;i(x7n zQF3?(Zr7ay{PfQrVfUZ!SNQ4&iRw7WD<})lsYbS`yZCuhidX!so@vaxC06=uNoo*L z5;k00Lk&%=K!=|&4sd(UEGOFK2IRYC>}$gpc4X~bx5SL)XWX;iL93hsfsR=e0!7l; zjZp?_jYf3!bT=#w!&7q@J<&h--n(VDOxG=q;)|N>1z4wI6nw+c`Y*3zUboR(%Gvu( z_&iFJ^d9sCK`ypxAQDoYk_M32^IicwhCYDBUVLS{sw;ByOuxd%E>2=XZeZNFP#&DF zrCx0JWmeqcciFaW-i`FRe$HH-mcprDL9rpXTbEcwOUs(C_7~tBhv0DeE+#$A28<7o zjD4%l_iAxi{zIdrH~DhHV;R%r5MX_M3k_O?3fo(X5XD1e#8_MFBleBTJ(mtSpkA;2 z{Dsw#bN9Hq9-6>;SLi~Vmi~&b-TdS4nrye-fJ$Reukp!mP^cSq)XJ&--NC560)lAl zJH|U;JN~nL``xC_Ui3Fa(n~YV?Y#h)ej992l{$>v)x&I;79fd5fA#PH`4>l2CttxF z6d zIbRe|Xv5w$+j&uzKAJCb87Fp0$e*Z#{l)1lD;F~vsLZ7-)Ay)98Uaomh9XRLgD&fA z4(#-V_shBPzWBg6y5PKnmY~}Ki>P07Pd-sU%!z5 zV_5%8W`oh``#izhmCBtzhNGi*57kkelfx_rB7@K|Ee0KT*M`zlG&QA1OYk;py3zGO z(Wt!WP`wqgUz2uVbSuf<>FrYDr!er_&YxdVl6zC0b#}Z#oFqe_3hjWhM_7`AodBdW zAnIaGO}xgGyvE`E=4S~zauzu}YEnIy0?q!e^|A<7E&(w-!i=q29xPXVDV+vLCCDk# ze2G7@*r2+!*W2j;y7}+x%9a0o%;GO3?PD0cww+%Y_+tc-&LY8RC-TDsbc=XiQwOz_sH3IpRIr{EC+2#<_vm(-xm;t+i>xLAu`MnHJlsG@10 z)E?+>eWm=j&0m38_J8Fy|Ap7o_+PfSifU_&>f=(nQ2v~$b8B+Z2FhH#Y}W2;jMqjl z(DaeiTjzKP!ciO%N>`s>mnNCa+AwY-hWPrlHTt?mMo_mqjnY~ySu(m_kZ-5y>SNDb zl(y{JGnbBrXV-M@MOpQyQ zV?Ta2c+ryJSc^rTvk#Q~QPwDf)po0KsQq|v&<%cY(t3SI?OXe~4+K>=Kq5Mp4_GUm zUx#wIopwwW^~{kz>*_L7&TnpglBcpkw2EnR=1jlm_va+3#qlsDnhyhusB@1GxxM=m zu9s_U$@?T4@<1J2=?gel@jYd8(I7(mu?(42;oT@~kt@buc84QVZt}qH);3~5oUaBy z9Y2aEs$O6gWV#qH-(qOUY51@O#0Wa^DLY6W%|lisxOKQdm8H5&pE@$Nb!H1zX!7vq z3f?LC((;v#j(53kR&L*)1pCx(%N7xlBr?4DXqzi+II4q4>5`0+`%L%x6MjI^X%et@ z`r@Jp4A2llw+P;=Te)2w=U}7#h1|o__t@at`Pfbl5Y2YMz{@p2My=h&f0B@pusisP z)P6taQH|P6@OD|;T70;L=2lf~{CUs0vm`mf{30H=X2~{~G_^PKu`WqdB+X%sR;$mz zJt2*2%Q90Hgs3LgSVj0+UV7Dw#e|?;cMGhIs>e#%HBWpI>5KaQh2=8(|$Y~;=&f=grhEN_1AmYzQs z0eR?x7(?DIf=`f)Ta0`akYOe9C?{VjbEP{b&l37v&#ETAy(rIngRa_pMDMf@`jbOC zHO)JFU3T5q#-B%3`hjFwTq_+}U7*?{@HEMG0TaRPJ(sd5%If5Vpg#M$-6LTWhuI#? zh;>isZTAGF!EFlc{AH2&&=bnAMqNigsk9jtM*4obh@$JK5^Y$ z?W3UqsOW<=hz>;a!{Ch`og8-l4BWR{fp?CFsFOeKwf;&C-q@WWEPD?n$3C6B8c}wfj383MuZOtMUmt=Y&RdMuK z#bWWfuUQ5^C|!qKzBYV37Z2Fz;sg2Ncd^hk{sbJKl9gE-^%+u8jo5rM#p_yoe2%5{H=lRewt7e-D}U+K#-#wBXt0C1|e45US>~) zs5aIF5UPfC6gt{pKeYAHyS>EwjQ=+JIUtyKLZ5;^-|4IgRQOUkU-npK?6bLF()riT z8)YrijK2I@$pfZq_Q3O8bbQ|-c}aeaA5(CW&4SRSCD%WYlS02zXrZWPsRRj`Efwvwy8Q3>}vgGNPSFrP+V!U3- zO3s*TdFAv4G8(dY9>CxNpbNGt=-eGaq%~ zjgPm39=0;(^FCv8@AL-VRx(Ly0oew0G}e&QmRy5Id@3i3l`47_vlC-4rDVAA+P=_O zNC&)m{@tC2NRi}t%bD5?xgQ?U!h%@Cw)U7B>^rWx$0zFV~Qs7@Dch$!3Qu3iMxhb(+1>c zDG*|c)t=CA&`+)AknA960AN&C^cz$NoZ^}Rvda)Tk_L#auN~1xgANa`gonwXDA+-(z{> z7p_=KFYWp$>Mg)TR(tYK)I8Drt#g}jZ;?ate}jSmSB5*)>wm5()uTiH`zzQ8R`Jc( z5%#2RVz)Wz)bXu#pDb`SL7J%-3uCb`=`K+6mAfl3Wop)QT;uFIXT@6#GPd3UL3gv8 zxkz$EC!9!Fpz<8G4^wqI#}2EbwWUR!{!zUy`owO3+S`-2T-OzKStHjxuIa@4`$(IcswVH@|TPQqjaD8~DO>#@6 z#oD&KitUM|{vr%Q9iFd_M`Oh~$ZW*Yw03aBOCj>q^Nj0L2Hd5a#bPx|lb=}0CnPx3rYB4u99sHLUh->fIyQ5~*UeVpZndtO zBU`MZF@y#xGu#b}-nZHpCksj_&D-oyhnF@|u-zFI%9E4tH~8GGuKxmv=Tvc0{(^2d zG6KW9`_K-HrnR5St#tJHtjUNfp4PFs6}>!D2de#wXlXp(+m>X2_zgO~L9*E&Afh(9 z7kMz@S<-`zG})^)mFl_EI#2j0UAYz5J3}+yuYIjq>o+mD%uaJ#*;57b`HgWeq`bRZ z*?&KXBxH+&u}`ZyVR?8OGNm3pK0T4XJd%)UEaJNK?7CFYWH%?Lvj5Nism}mZfB2J_mjWAbg(;y8@&Gr5HICIQF`Epz=_=(k(M?&c=@KH2g>sbBhYN zQhw76ltoT_`GowHp&;!R1*tN|8)N4$%p*<^<52C$h%9DIvS#Q*)cpv9hS9Ck7KgX$ zHz)b(W*Pt#?DkSp0nO1(RlwRqlI%(f#Ws0%&!47BwK`LwBnNkT^O@COkIt+0ubxXi zVD56So>aMxxza99??IPox(XxcfunsCBH~KPhPmX+{9c&JX6`hd+L2uQ97TP6n%lCo zcWz&<8HQUqsyMfJ^`~1TNDvUb=j61wO5BE8k+37z!R|*CpB{!&^<5 zQiOt>gS$lS_Dx1%P6N@h)bpfvBCM0!Fd}46Oy-#aU+H@n@0a{DCBB{y1SrFr!otj5 zt#@93qD@qR1Q>XHy%e=<-N&rW^r4dX+yYGxGz24zu|r29Yr**2goB4;(hwd8{`1vC z-aHNWMstTRHZ(M+$X0ZS?VouPebIp9T7PHM3Mcv#@);8NwzB{p&JZT@d7{_U0M6=_ zY1Nd)?9cQwiw$~^-VAHd2~)XWMf7;Z+uuO}X9;fD{$yftD-&t!-jIwk%b4FvkeTCt z+t@E@M!j;Ce%U1XO9>yBm@Nyf(&oV!2V^YwZxESvy#vF)r)+r@;gBCt`L%)l^EL5g zkMjJ@J7@YYB!1g@*j_}Pw^Ser<&(V!8{IyV%x}an7Bhy`JIm=JPnh}TnQT5&WA62! zdt2U7RK$Dr+U*bE5mV7cdtiNwypDMqBR-X8H2l;>Yk8p8>`d=<$+f48CDEa&wvm0Y zB*D+IHiFv-tp0D%DOy5)+woRh`c|BmnA$SnZ}$M#mgL~xlKS~fY`9IMo}g3*N_J9C zcyG!iLXdX1E&8x_|1K#CI}ih@M1?{x!Lh4u-rt|{q7+~V@blIh!6i>F*SuJ0Zo%yrM*j;R-a@sUSlV!RCqSjs$BLoiZu z5L^Uf94h*H^Ch@HR@&G`XLNBGJA9+Gtdzq~JsSMF@LM6B`{8Kl`I7tZr<--7e zm2jY`8?|4Iw5~m)qR2H_?9Zdg1Cdz?Lg8w zQfYxt&93QG2!!4AS#2{nMvA4RloqjiBLzjX%xvD`K=C|fY-Zp0#3QY$@;5q(bS*XW zyh^qOaYI#N3~( zLU;t9g>pAnIxFc58a^ zr^r3b1`qP;^}r$A2IkrP=b8){nM1$LkCq=?*o}lO0JrKypvqu?P9u=LD=ph2vEHdZj%FYT`57kl0E6FCiMW z*k5lkzD-|m9hRq40tu8dw9SXr8q15#7Pq zC`e{~1vWX117-yAbugFZZ;`SZ4)pU`#>hdUzdJSJ?~?iYuBY!shn&RykeeeCZm*f>8sS^}>7)z-ZeOX6 z1Y`o9%;f2oE$?=bO?go!hHXq;GP~YK%fTdj@1w6>B0rftLUX{zc9?PZ^3YuxpMNk4_TR-||7+B2=s)dRXm=`cdwn}({n6}qy}9*fg4Vl+ z5(2Mb3-@(#mNJ%2L>yJC`O9yR>->=po2SB7JBkzT*1pd#GA4Zx(*E!`_GMn+uf&&m zi{DlX&@1n;ahkZBEYU7kwWwMv`>1nI3x;|n^xh|Vgt3@eZ0F~DM&yE$KwQE>XD`wu zIc=cBk@UMr(61G2LMKa2*W|sLv7y*QljD-!IeqfkY06eZykUw#=g88lD97o^8yjm_ z{!2M&5BF1Kl80`Z_3T>}|MG*h0Awr(b~7W-LS5*6S?x_-)ed%Eni`}c9>;4-5V+zZ;~+1@WWRGe~n@nym)4@NbSv+?7WTr zp+VSt^5a4y_!=vkQM3M`i-)L8n|76kkEQlgruXOZE^@FPe6`g2cu#jNAb{f@B8+!Zj0VgT+F%#F zT0gGo^WIfh+1O5caaz+_5SX^v&A=v@cFfzzoi+hyoB9X;D4pehd7eMVhfMbp%K#wy8j=L2hFzh>{As5T?Do+e7RW`pmRazA zACSEE8#MI_i@WogSoL}x$~|iKFyle_J1h4uuQUe-X-~TpFEepG;W`}l#gBRfN)R9+ z&`S=XYQ&X@T|i9pSk&#Y48!BH9OEvzYu`%PKEI5GyabUH7dYCH6UFMJLvAwDh`7>S z92mU1V=jYlqKp;hPq#x|rmo)b)dVGm-ZT)KR#svAIsX~c4Gx<-F-nz!i+2ZJ{Y;q4 zmy44zBdO)O^S#$s=}^1x(Pzl@rpxz=V$en0X^b8M0&u~}(BWxK0wkj9H;8fa)~S!f z_@~us))HK)LVDdZOXIEc708V}TR-44M!$r0>w;Q+-3s71gAp%Mbus5M*iEvF9j^4# znXZfWhEz1Doy1VII?s=B!9-heBe?Ot3vnXtKS-8*YD%Ac(;g{FYCLs5zlg|RvP%%F zmmG83kRfWfplBMIvEa#Woa~n;3j^$M0pq5b`rHnmhnA*&bI*L4v~S)^G+XXEt0-H0 z*D8;n_YWaI->2KaTRO8Un=YK@g1R>L?F{}yf_u+2?-PS(+b?d(^pfTmIodE$9=Bl~ z;=txzFn!`ww9cJxJnsbbJgY5LuNP;1;N{_AYyde>fpOe){ltr+lF34f(1l(v? zU&TiIxQWm!dglSdbz7cgF5WKBuO(I!X4w3oKsto-IMNOk3&`CU=GjRck1?7bM{fvJ zXlM9Yebh0~Y?d7RYxEBXT*EIMy1zcXkvSdlK-CNR*@~Pl((ZOd z1XTB0dFTQg!g$2xA3SY|4n|Sg(&rtW+0S$q)_PwXrbIQi3oM!~w8`lpGUYcY#0J@OmI$N7bJtR1YSDw~ zD@r#Pw$6}*+bqOe)A{0qZIuSx1-S#NjLMaUtd(LDuNC^V^fD(0CwF>z{XKyJ|LEt? z91aljsgJaYmDJAf0>44`CdWHrl`E}~$%s2+WG6!Ce&)zBAwP~RS8UG7zhjc2Rnb$M2BCq_57H2g% z+-_Cr7Cz9i#TNMO=DgF*StpgitLZB0$wlMv3=PaettIQ0iH z21KU?A_KbX5Nl*ly(4m8$`SRAJNNh^APn@P6k|4?B>=v@zc|9le2b8i4&QS~;`hpb zT<15=EZse}FMT4}EO$#I#h?-Z(URjG>S3Wuki~V1nZXZf(esiAmTz;P_`680MW=~7ZUlb>2TXk6-Fat`B5%l0Zpq$u#4rK>4>L_|S7?a}-7TgE37CZ>g$q!g8mHb_#^oI&iHp~V z?m89E`b|7MsOjok@@43jM^^+2Ql}u5wb8KW2~6-;y{&ZjV6&9=bYlK75 zqMzm@K@(Z80tFEFkU(S)3UTW<2<;|t*D~N#rsH7C&(Sefcy!_HMxF+ zHcW<)yM~$WsA^2iEq$-snF{!@h1sFLs47%|}=@Q}?pQL|Em=?)o&I)GMVvUw*;5OkT)5r3NQ z#5)`;f_jP+lUo*p+X7F(qg#1iwicMFdqk4DezJ(VI3t-_4?ci9Gw%%S>;Hhcz=G@5 zH-4nz*=jy|$vE{Y7RO9$=A~RP>q=wQzV>x>hke3*r%Y1`T!6F}XZgiX1B_HKf|ZyM z=Y2+6`D~S#>BzjL#%D9ZJjE3H?Jx8g=>ka&t|wX52exUgm@re9<~XN(lEB^Ci1x~i z(wjEJE^i*s+)XkESW_Nf62oGllJe1!oA3Qe@;A?_09GO^lbCQ1`FBVsmR z9>0%e*S%@xAP}M^DX*kQ_j3i0S(pb|P=5fXGEgam{;}$KTc*{07xsfk!(~DS)hX!? zcoE6z%sKN{FMO|FN;E42DZmK%$c0RwEGfdPh>uPO{MW{-|OX*sxBWt%h9y9^xM-eC_8o3LpFAf1Dudd zR6+B+6u_#%^i*k6e5KUWex7xq5UFz&s8op_My%%a&Lt79%V~IJT*A4)Aq-nH6Pbx5 zUNj9a+R-4g)keu=%t}7{QZwxA{9tIvT-ojZs(ln#NjwkkDw=ezUAHQF`n$kK zPwtQje1rT4a0tv1!epOyic*WQL-D}7j`$$AUrn2ZXs_opn{^9JH(D#jg+HtZ!dF1E z$(?~o4)v+15{I$+=0tC{y|`UX<)#kC3RUSx1<#V?9sv%ZAth%|@~4aPQUvf*>FjI)FA;^+shj!hkp*BK;dA zJT*73Mp&+r#tZlJJZ3qm?PMEP;#5s@k=Io)*6sq$UT5I_z;9XSnn*DOy8R3S``cDG z#BFQ>>%?@bOOZ&{D2sA;#g+hc5?w@QJr79yesID;viQxynVtq3FBWT zzLDUK<+HJ*d_X$JiQlPj_|fPBuOPh~a7%_3o_?K1W5N9D_WO!#77X&m%IW4sqFXCC z)KYnLxPxDoARC2tbf~3VJAP1G_cokUZ%8BjEzJ)Wm0h%ERaMpJ>Q`ND7iwy1Re?BD zy~xNM`l1{VNFMY8Sd6+9krVz{AR(SA4d?0FB5-u;@NUmePdKhTXsB&cL(cX-Nn9>` z@~xG7i;gyZA4wbdW9~{lC3JT_1_n{$k&AcP?45G_9QjbE|Lt@8)T75HvMDVV9U$;g ziwe5hdn*V@+o)fLVkH5ddcj4G8_Q}9UxvkU-_7iNnavUE_d8V?y?4pMUJTlDNw?B@ zyRkHW_?k+;8S5*`X^u`{m;MWS( z8yQ!V3@<&d5u{O*#h(m;*^xz%#g5JFgVj2L!yDY0KT{gVjb^CJegZ~GP`^XY#DUxw z>EbQ0;5lAEkzEPMrNC9^xYbDc+d|gfy1_rXy8Z8-@4t?S z`Oo>lv^#0IZgNV1Cv_IHn|^I?J`~AWBBM=17{D4$7c=B;i*1Vh^vOCYnLK@*F#3Q- z`OLO>0&LrGZ{QeEatzrQe{~|59_E@0+nVOWmQ%*n%69Gc0XNcL&u{Uw$GLFl)3Mja z{4{bv1v)@Z0f`Vf1|y9~tl6S_-YS0CL|t;FXF#W&YO2ABzYo3}i-TMz8O2=Hg)B&5D2NZ)AT_`wo`prn(WI*_$bajrNx;&A9|mV zmyGXIMdEVg2Cnv_MhP#L@*6D7cD5}z$P7@PMXizI0AHy>yi+ObDrc#*#T@={x*@QM z00A&FKtb#VJJJygj5A&lZ*e=kb_Ag`{syf|S0Bk(1uU){zX43peuJ){;z@vfT)-5$ zhexjJD~az%4hfG!d9~e7Y$~|)$z4C~S-TvWrL5G4a!%i~VsNUN6DDk;%BJMK8pC;U zmADoC&C=Haoz51f({t0>cD|gcX9sf= zLxmjYAQ-voVZ@WpzL@UargZ$<$hQ zN};%x}#>FJe=s`1hfg9>u!=|0o4DZe~)J^$p_oe(6w z#$n%%6nw_kT;FIV*#|#~cpbLDa@u{u7Y&&05E%?enTw~0YUo8}{noV8SU9Y558jyZ z;d9NP@cd@olKXmWEbHosvBOWmeD9x6m-&0BP3}f8W+kpd__{vH!KUz2v+yAPX;VYR znMG{vd#v1bgG=YLf}{4>rpAAAv~7iQyhP$xaLQuMm*Hg_B&6Elo^V~((P@;Pl?=~D zT8UUo(m2WIidr{M!wdT4?uO4%6EMo+^0ekg%sW4MIuoq_X>%q5opnAXQsCVq6ebdMO3 z#AD{p>Ew5r@t#Lc!n@m3G*i$f5^`~i3x3dwIu&@a2WO$KOZ`6RA}*sX-=<_bl9np5 zfuMrk6_`;pcr zmw-Ul*1yhrYvtSsi+8lt%y)zt<)A#XrT~D{<|CNC&s!NT)!7B@eCAG3BTTs66R- z(e70s6pq2h=a)I^PdGBKPKkA7S-x3PBHWt_HgvEpW+5@2aVJvFGADL6h4~_iEy2x`W>7^)GDl2KR$t=42P#IN^EWV%$xo zTg6T}9re);%Pz9hc;?4+j{*nb(`ny=-FtTuejJ@bfQd2K`B+!ojczO-ogPXpqPu?o z%~VQLRk%Rb-1g*id){Z8_iOK(zu(7jKyL(&&$Cn6Cdpc(WTA!c^xb#s?7eDL8;!FD z&1_35KE=3}bz{J!o%rc1*C&zJY)453U<|MZ2}O)ppi4mLdgC~o`7Kr-t=4BAf?w{}mp1eZ zq^r<)1REzOC(YoeP7k!Zks7S3jT4|(dXDkn#oT(5qK5LrYC2KRVy))*ux=WGf|Gi5 zlOx7UzR^HkZADMkCk^8$v{>=3-ubXVPJ-C6{gGj6luu3C%kbGs$ytk;z8~LAB^Iu! zfgUNhbNrOiol3w$%a7T}k4K~jxMXv%RuKx6tTbD!*@;WZdc#Z2wM^!jL7?q>7FO}< zR&UEMLMq=FFZ*`G=n#G+TEf27l%g7`p_G(iPFSC|bW#8`U`f)7(+;+eSE8qebbS+L#EyfU+)xEEZzLtb=DbqW{RvyDJ5$Y!5yJkwb}*P8Mv`U17Q8rv{*df!=-d8z z{H0%ZXIP-rzcWyaI^u)k-Nj7gm=g}VCh>b=^!;Fw((C7awVJk#)$3}Td`whxKiOx4z~A0i zf;{V(ztXM#v$>wUh6x?IryVU7MyTA6*Br#7s{9|LT{lAH;iW6`alLG3?%I8L8ldV^ zeW*%kSm>NX9s7|6wA~`JL>r)x!KwhwGLh!s{1J}tYR5wB3@-$$seIAQcrI=`_NQ^6?YN?bFE!Dl-M#HXiukLt(9$dqF~Fk0^&T+wE4kHN2 zeKav3ZNr@?j!lc)yFxZ8M#zoq0pU0LaPr5(Gvr`fcS%shvQyY>TjiCN6~XstUa}nJ z1Nqdk^g4|1IcXMq72)h9JF>O0WN)VNQ0d9itFak1?m3OnMBW12vYtZyj zIVw-gW6TeJ$_WPi4a1QAa}q8;zOG&_c8^U_VpX}I1SYb!qFAn1^QpHgq&wCCV*exC z;cSI7XtDC;j=M!k#xHufcnLbdcT9&}Fb~XQ9)wElcsx_bHu`15h z>e{Jp@qXCdLoT8mHm*BZiK7TUCO(z4Zcb{=VgC{!s-5DBkP?i!B4(dDIQ^EudW~y_ zul@C{?r|h%vlyvsiP{%<>4*2(U3-K-+Cit#Y+B?0u=nQiP`+*d_%KDvlCnh%A=za~ zgwY~NB0|}wl8`N12E!;J>x59sUR0JTYlI2?>+sX~))4pOfV^<}^Fii^`dx|@Mvb`gDZaUbe*hlUvC zI*D#~ImjCI&io9&+<%g4xAQq{{?(#Ip=)Ek^GybaIWhvzy>HQ@RxWktrv;ZM*Y(_n_V|O-?#E^lhNwKUzNWn9 zMUi|))`b;Pv%ZjRd(-WeH~h`K&F_x!rKr29nm(q0yTgmf#>nN-h?L4U8m~DvurRQj z&%mujwJ>{n`of)kssUQwY#)dC%}{craPoE3nTm({+`)^|vY)kzy>nfzE96{%d&e@s zz_qK5amI#|jMw$_>^e-)r^aoW+g1sVT(5O z7MpJkh3{9>CPIfXHHgr?D6P)Qy(Hb*wc_eX;A>3Ad|Cm^P%0t9n)TL4Vu-@*3;BrPH_vNDe4yk_GUal|57suA+-%9l)TQ&Bu z2Pg=&wrIk#jO!_co-=;dW0L`>_P>L`A9uTj7ues`_pXm#OQ4rl|lXKvVmI zxliLHwUEyYQ7rYm#uy)>RML{3j68{X7ANz4EDS=<2ZcVkYk0-#^`qd03H2DJk~$}c zkAr&&R@s?P99J4HzxPV)FRPC|u5lUaz^w&Sxt>yX1iCVtzuMxUOBf7)g$y@;ZZ^|v zTwY?TASa@|W7o~-MgyJQsmXL2aT_hfZ1bWL($i5%L?oi{<`A27`y(X!G~wk?oafHU z{+!;JaPPz3B8^VZ4{OY^cg7OD!pg6cb!Z=H3f(BgnBGNcqiWk)u8!MgCmL+21*|Ey z!e>kDh{vXH>0CaScIla&NZ|7T|K}v5Ix85lA$Vnon`A=tWPzF9laY4jylQggxy=>E z-8I5pP|KbES18KFxW(@v2f>aXkYZAcZSJgwNSEx2)6wrzN)8X=zkHgUc);FpJM^IA zqt~&YZ{E65ee}I2mW`-C56z;867I$kZL4vGev7qm{;RFF$4W-}K0KM|z>Yrcc1|lO z-YZ!0?fqjLRFI+eaX^Uk(=y2~4>TyaPwOr@?auotq}#%oaY^ac|6EzS(Hwtp3NR%0 zWDzReQ6ov*w)#itJ4KPy>;{UjSS#2~Zs1Jc)8?_-WSL$v%x zphafP>ba_~uxO-K-d6r5d&T^oD3l{BYZzsRmT zW}Q}?g|-zw)q5&*j7{WN@u*zS(wTz~C$wSmfrZdQKVQJUs~MChE(U7srksqeFN;c{ zK1=Vlx*hNO09o`}R_glX%(KZ)8W!x&tv z?hUiVeGbzvL;L-Y^dt9I`(#DlpuUzVYK`DM>X6V;SZXp=BY!Uceov?qe<_AeJ}s%b zUj_Oc;$KSYYrqb-7}fCJK)RdYwL?8QIzRV1mGmBZ%ld+vPArj)qaRC{f~BqOCc(zI z>u@~%Dc>-)s;p@mk!%>xdGR7iU;EwK>J=0DpWBW%^*?nEE>)TLmqA%jv6a#k$pTVb zHc7WSkuxgfPTP!gI1;O#Ufg`uMu%&_GQ2$J&DDnna>@pbGos!he@6SFTFGL_*inIK zB7BUkicNr&JXe1drS_rYWvAzfC4tp>Pjj#5+KK`66ZBXyG11b#LKc@0JAYYP*PsUJ z$#=wQl8d7nUq)F<0{PShi`$(tJ~{`ER!4V1HDMR+(EwV=MdHm#o-b$ZT}R=pS}w=Z zZn4+&HQfC$a-;4P|6F5l13%F+*WZxBLQ?j6N$W>&c&vfU;n|%d4*G1CYIyFp^4`>V z2<>9($AL3%vj(&@AIKh1cW6UFuM=Twk+Fss9YNtomBoj@X_)NoPtVVJ9aBp8@$I-ks_9f_WDe>ises@+##oIFO_uA* zRWl|8FU%e&isT&4PcY3VA8d8_SxMKj;4o_3K%wZI<9Nz>;EwYl$$+@@78*B^Bl%Ew zkqwb0LS8y@WkxzVWxQGLn+8r~F@#6XJV?Q#x{zEp5tZQBh7EB7i5}7io4%VJ z8~m`le!=4T_@l2U-c}0xcMR<-cO?s>lt{P{l?63f+tJ*!y`6Q5l}{a=7ZiPYCgKOz zCU+4gxLY31qqsCLvpop}rpO8iS}ADBdWaYmO)elLj6DQnS3lcvDZAGpF>J_z%hXrr zzPbBBKyUiWa1M~u+y`b< zLX3s(0o_=k>YMi5JK`84jxf&YagM|3Q%Nr!2sPTpq`K^xNo8=Wq1jA8{RK zDvc7V$Pm8U-K#GgXjSpQ56}G-BK*y0vVSVHo&n_C*LzbjSIdqZ03I@BVna2Kk+gov z(%hJ=RWxi6DO&J`+j!=^>igrXRZ6w!nxn`)H1sN*Oa>Uh0jx7fNE+2Zj`2TD0sQ{I zT{<*iYe*6Wo)U(76|I@o#AN%Bm%NIu`!1PFen(;`3fuv>sUHx)@yERdF=~nMktgLK zY7M@bLZt#UI|5$wGXzp}X3 zEh;DAnK1EV)oYa?2&C`1pT+`M)uU{J(X8{+%(< z8RcWV_Cgife-G1^*l%KD@X2jgqed5`a7$e%&Mo{`NvmzCJRn{xf4z~Du$DVK6s`fK z=1-EcM-aP|!zf~7xxyDal=i$@bec!##0(BIJ{^4=60h{Y_LQ@AlN-mzy9VRY*N&jS)k3sm0*RmXO_d1?X`wR#DDB2!O4oK zOIv<3g6PG4V!98zEnl27L7v|6SmfDpL4lN)anEw5 zHxUm#!snsp#kvD04DahYJP|pdbN%UgRk;-#5nc!-g@hk-=`FcHjaAblEH(fW^}AxD zV&wkssv%*pt0nQDW_tzSNyYcPgQ_HsL*C>T8TU!wTf!9i9{|pbeJd~sg%w!NA?0EM z*=d8*$1XXPELmy09DXeHHl#@sit0Q7b9mfTsdp(lVD@UQWxK3!6c(P#JW<|%p0g|9kc_?oi;ry$I@n6N4eo%5fbI8m zQf?&R!yFZlc=X{^xt7NZirK%Yy>ypwbe5k^U2DYV9!X^ja8_)E0Xs5-e9^y+yO%kE z*Q6+G+R#owSNqOpPfW`pe6bmifXnZ_y5kxTXGpcxw3;anNmyz-rJDumbQMvH}hQP?Ea?(Wn}+*HcBQ zHpoRzY#fiem)!`weXi8tZi(eshlVsth!itPPEHFI)F5wA5ofwhbiMl~*C(hNt`7~saS5eS3 z)QO@B%<1;=^82)<*=0k_r{r0$lEcT_t#`A_o8MuKa=$kMYP!|NJLm)4(QojLzT-c1L(!T@B!B&mrIvi6*mLYA5p z;3dZPkzyOl>T&jO4kz8}iSOC83A1PrEQs=2;&Kt zafKmOh)6hxQY#}xu80(80{}wRr1*R#2TN54Mb~@wo>Re(YIc#+803yOIm*s(g2kK# z5Pmy=T~ioIGlXYjK`-+{UjY_ut!ifWPHIkji>>gW$i)QqHy_SD${_?Z9nm>8y(CANl;=qulEXilH@AYTNFHpYd)ACE^A>>T=D znoi;ER1}w53F*x2xmaMo=CBf~-QwaR0Gn6#_4dAc$s+zy#E4(XpOKLM0iYbXC`FlM zGv)_fDtLYyOmzI`za{ES~TGUK?A^7O5LpC+d zp9^IjTR*f5)RI!J;q}5e`wP^031p1aCZ5qlc_G}eIV zQbj3@==a*z4;A6=S01OFBCueC%8%M)%&J~33AtQSQTo=@6gfB(7yH@!g`h-L9{ah$ zlRNP5{F{cD%UuXxVBt(-e#&G;QbI%7U{Mxr^FTROE`i*1@>$roaU*dm2 z3UKvDz;i=UIOd@m5>Ymy-001$nO;`!Ji5b4#CdhtMo7*L%ANPVq=H^*g8C%3nVxXldunw%0 ziKg>P1lTB7U?w%YhQ{o|&TlsLB}k)XjeLBxgxXiuba>qmem6HN_g;&iy%xkDWnX!a zv_Vz{Rcu1N2qjg~s0#qnlQu5u4r4LF>y>pXnLlQj5F_ie9L!>Vd9wF|dn;oT`j%{j znd!$!>#^2aS&Zo&JM3#@&k3eJz59!oO7t{ON?}}J-$r6g*x$s-7>EnMYGnM5f%!KJ zFcC|_i`YP^Sz)5=z13X%856VDMxuU4g5*lz#*NWS5c zz+r-R4H;^ar7hele4?Ntuy_A3Np0D%ymq5GQGu9PL8uv0Vl=r~`^ zfxbQmm1=;-+TLr;H7BtXn(JxIq&XtK-b%F7zafLvY}m1&_>?kzlYGIXXGvSqG4k73 z?`H-PW3jh~Dht7CQYwdyxb)htp-In(>L_7CW$P$Eu9n4T>gqeQ0 z2Dj{T8-%QRzRW)zSXxCCz^icKyo^;az$Vz^FZH$hg@eCL_v(7^WUt+1aWcQjw?o`V zcX-`AeR%1$py_Ug35H5gEIj~ztXzlyAI+-6e4Qn)*^)9wtcTSRH9dFl`dFo}zqLB; zi-j&7y;kX8oV^A0G`X8h*q13aH%A>6y=Gw8e zQi5dqk!2LeM4CI-Sr~b0_C%krX~)fy_jHrUmqHVf*ZBy6Qf{f^5dRS{Hp(9G-LwuE zoxdXRNU6aHAMvBc3AL66vvpN-Z6axRU=`P z{}5TTXFg{AVpnQR?<5n0;gws4Q4#6~bQl8QV!#{Lx0c%pzuCGUb%F>0b6*yD7U45w z*0j-(9;b=?vf_tQx}rrnnqO{}LL{=CjPEvILuJf^+z`0h3dSlpv>ZON$kgy=+N+G^ zvDrK&A-b>WXDX#H`W$@sm_SZh~+#H%W303?;svAU4fN50-=!HcuAbv08S&H+|Hd($f3HqZXC-1$heY3XkoAT0I3s@`T70 zwqE=gW9@*s-jbTq=rfkx{Kpo*1eI?{%|CF*h+BMm#urg`L_J?}?)(wxFyEh7glGI) z;LLfI`zV}*p$A2G<>xSa5Vw~~WM#5Y^`kRSaSuZezJd>2vYH2ZYT?|0{!$b*YCcMf z$gN2TL?%>&gP^10Gu@U`U?sc1XuhVFU>j^G5{G(++a>QF;2v#(dq7blpY(4wn~G3V zAS{o}f?9m=hpP@r4(ENH4DlDsIx{N-o9*u;Tw&pHeeCgtQM#!Ym&Reh*f0e3haz!W z1m#*w{n4P$>`46RM9aYBtN0W(UZ2)ICwJz!fF$fYb~Q{c=s+GRD?dR&NYAqUn}H?Z zfrZjKY5dQsiYs|y)Fo;RgPxg&o5Y-b|D^EF46#*8NLZ`sF{JLj;QfphKxEk9)2Ei< zYz7QzeHY#lZMqVa-Ci6OZ&&hUUXnBBOZ-e1;}Zi}v_RKjzRlH^EYaF&yrj@;y|+S% zx>r6p?0FD>z)JVT6JA+$6XhEbFYa?`W(mLa>jWysZCV*uoN}*{qmM%WmgHoceY@A@ zVLRZ0a5dzZYiI5`H?}17A^sbn0iTK1E)g%0snJN?6enpD3x?sm05f zK7N3WvgZfnlpTl3}Y*hNZZV(?V zG;c)vik{}Y8vy_Ds}jw5P!Of{==l%hv>e)?5;+nA5^(ieU!;P#i_dNr`+-EgF@jLmr-v&6?v4Z+3fH z6)k;=Kk~a*vPN>;0&jfi%h_JFO_tSso$IdbLFv86zSl#bsSS=EV;pp8{I&ZvOAH#Z zdtP~Z3cfa!=V7nj;hspGYsCoA6_>sbZpm9cp39Si=enD(x5LIKR**-y|EB3f`@zHA z9z)*}5XJl7rRbI=oR7Ld88m9)5r!|E`=rhQ!MCF?Qr)ov(F#mmUd*%bQ$j ze`(Vs8bxQTBGwb23TH!N&Ja!FSW8C8g%$ZJCvIL?BI04^K5Qv zv&d#}7SDiaQG~Rsh&8pMTzpMTqdR37q#3ts8uc@+O3c4~w7=0M&uJu4Md<#i9Dgm% z-Hg@4vFpeg3;cI7XHr6qGBNY@q1>K#N(1w&T@UC{rIxxU1j5qx_%O%zw=up#!W)2( zxJv;SvJl>4*JpeGvR&rG=VLx{J8b#A1vawHTp|=LO>4`r555n4Nz0JZyR&<5>V>of z2e08}i(|2}AF1x6ivH4sHQbKP?>)J<3Fa1Onq7G*bx9DZ ~@4<1>`^R2=yS!sa6 zu6SuUh+@x$_|b!eXC+z(;6X^wl6j_K$;Jn3cV>-h-JlIN!r~29nujXN#;9jf7E&e< z0w@;JYk=kLrl=LP7wu6$WuFm$SK7sHSv$n->;dFv8)K)&USe_dn7KoE?qSI%G2cGQ zmEKDeiB5)ka^ia!gHhTxqq&FGZS}*D*Q1}F>OXTm_Q)N#+}4hqx!#tn5D(q?yOhzv zj|arihU_|3>*|aZkn2}Svfq>qjcbRHMIP z((^hp-nm=N5tq&+3@_lMF7OMSaC>p^J~MGRP*b4P+M6}a#s%%D6FjLHCogyRJ+r(< z@LQ!J9=l5Z?m{PSLof5ByC%V3CAfBWI(rEWA&pa*KvIzG3^YCtM+-M20#4B~Elt?Tb}bINQQ@@!!u}?CErB$NYT>hs{LiEAJsMt0UBR9qq(qh<9Si5i5$cu( zwf8$k-)RzY^i`E(E^Rn)a1Hysau@0hK|e@MCC7g7P>15Bz`>i4`e_ePC@BIeKAX=K1T<@|3k(fO>8UnA-?b z_$I&T4+!OXJs|AdSVm88*g-oL5f80+@a3Q!gvQ4aT#@4QjdspIml@@Li- z^1X?BRVPem&uWfG_{i1)uZVrJBFSd~zI%cR58sIlG%9Hi?4Kzqc}iCE>REJ| zLcjlf?D^V(rtgvoq}+_3LGJ=Qr05s+7_7M+y~q3f@TaFfzu#9|zHt5du=^wRMw^CE z8W+m75cAQ?wBYNk+~)$(^6)i>uP?*J#$|F|zIz@dG8&*x3p-iWU}pA*OQuL}u}PGw zOwsA+MtK;CoA$sCYO*!O)HZ+HA4;8{f3r9 zV7eWaxTo%tMv8M22d|yhQK8+tKVM`x-ZWXk2-xCK+P&hT^$GaLQdvMFXwo8jQSWrZ zSD3Qfk3faD{v89Nd)PJ^`p`pF6C za(#|KCtN}|t)4R0>xZlndQTrT{NSP$d8*`#YJeP?cT;!kyDL#7kO4q6dt`{};X-9L z?YvAE4f4Fu01rU3&xO2{GNMfp-ejagop6hv&_NFEng?KjI0su{DF)01aT*8uBRmz| zcC3#^$s(ClaL`Q+Dzzz!5-OHuk{owk9zT^bgt$GEPM`%&h*5N%DyM%y;>BlH-UulF zfJB@9fKa)_-Eeu+x;;W&YD(bCkE$aT`aB&W@ZyT z+D0}_#|vcv*)@@k#`D0VQ(3mp<+gd0v!?Z*Fo*%{p_Vt=krG2^f%upc1N^K6Qujr$710|1@ID zScYlU{q}g#!Z~YN0{n^*dFcEcZtMP#gjJ)LSjs?pi_#yE6^Xg&|8#eK$z9>!L)&j& zMEg?Q@OA9?k`W9&W9hYZy-?aTiPP5cK0Q^7{VWzf!+ZT_djI)>8uInUHHk;lHWyNGF=M|+Ts zBe`ofncBbWbUKB~X@Ts>M&yiMEGdDeHLpvn7zg`F?@Bv*oaH+BZX$Y`J^(0)t&-_? zR`ai>wm^1O*MYq#YU9KaT4kJk%=zzS9us1A=}d~UthK>9igU)s)`n7CO#s6AKm$&e8r<={(Wp9f+({ueP_Ect-gT8rLI zQGZKDJe0ZrdG^Aw_K~cFgQ)BoT6>j1;pHTWz3$KQ&5QMTjftMd5|`X!zX?x)geSJO zdr8*kJIha`nVqzk$(~f})ZUkK{|JFzk}mpvGDteX&x8%`4tr~JT=Shm3bkNy^vFc+I^X(Ro<+OP>2s}R#O_ub!=u4I5>m}^PmOy68n;9&2=eF zh;J;$hoT?qU{Zf0$;UM|utODIGl@37Bg3~Rewd4^aD?~Gg2PkGV*}XfoDZy-}=^PDKUl*_k8d=s7oH?Zp4y*THhoOv08; zLm@!&K1rG|=@KHakIaxdyBW8VVpsar?bZ2!cY#B82U84P>m@scz1MfEjDPrKv4 zA2=wf&2;fNCJNLNu>6izB-&U=U-(oGcqnciecpYa)=Q$XS{$ZtP}7&R{=MheNHk0( z!|v8e5$E0cZyN;(CmCNUZXSX1(b!wr&}}k!8byp`3MJsqwCAaDJ1DMkEuC8)jK3AS zXaCx3pBTyRxi2(MbR)(V^kEUb^2Q7@hu$hK{)IZ-rj1|Cnlaoj^K9pJdV5H0s!4w# z<@U_&C?L}W`*53MVzm2q&?(rrl=B^n{`|U16VDZz);5A?giVe!!43%GD~6EPr%UtS zm*}Z}%TzgvfF=cS`rz#sx z+1T;3lM0}d#;c4|lMP#i-hV-?ojBQ+_-Um$$E3OZqT2@H4;VXE zT)@OgF(1e>KZy9maQz|!C@(g&)js0Hj38Dwho}!kX{Y|QR31kurP}Gea<~U{TX;U> zf)VDi@6Tx%N!i#nhg&86L<9M@jru>CW4{`bUw@x~=Li*gO0}n)HOWykS5&1g zAYT4@H36w5oR^rpTpPh@CkdvUCz4Cj+_c!6h8Xtz9 z{sh5qmnVNtWrsngJm&WB@%xaj^l3W_^y0nZ!=Bc(Mhx$s&OQX&U>ii^i%MuAAkyI9 zfr|{v6yao0zBttqOC{+O@))ZU4@WBD!y<~g77MN>nR=YNlg%th7v08szrw0~u?j|9 z9&IJF^rNSr*g#HWD&E+a-D9s~{^K%LdG69rn#dw^Ozf zHoI@|Tg{!Z%WcwxHaSNP3Q^vIs>dL-|?G1@n6 zgmlf&(oy{nuQy?-#LQA*j9nY$X_qAlJ5u>{>`A$@@7m?A%p{?`ekpf8*GkY=?-o~K zj5z^2zuKZ*v$?Bt;asN^t{YBy_fxS^x=2|*?I@ET7gLEgD{&sRNx2NcO(Noz4~7!B zfSZd2df;$nR)|9A=4M)6$V(q-H!hF_Pz{*N_{Zo?{{QgA=Z%j7Iz>&vqdChQS|)P_7~SgJy=9Br zPB;z3AxhXfdJ)*j8C-u_>i^0P4udRX_t>fK*Uw$(%#D(rfXZb-q5sq(0lP_cp?w4c zC}ipPXV1xkZHhv5hv;$TR5acj)Z>AC-9A%)S}ZK8!BeR* zXi|KiA4tRV69LG9GfwmFN8YY}cClC_Z5&t_8r|Op(GAakKvr9Dao3T`G_6QDtu%$Y zt8M2G$hPtQrxE&{V>6p7=BMO4{uEu1TTH!oQSEU(s`Z~lD~y6MLVs09uQ++G{5cSW zvxzizvMu%_{&Shfgbg2Y7c)Bi? zwgcnstbRml)~^tZZL0mB@DEer5uo=R2JRdK<`c3C=qys69-vcpWhrd zv6-6RKkV(1$qOmP;cX1x&ZPGu;H|KO&ccDWuG3NYt;+OIqN$XmYYtoB1jkUBc1+iSFI>VNzeFtzn)gCuP2V!J2~fO zio~5Tu5^sK!X9i2=vb7o_@NM^E_P6CMvCM(Eg(&5d4Y@_8GN$;WQWvL)aLa*Pl8P#9#4BJ{JtCBA2v-3V&giL=I!Hw^0v--@seN=zM} zx3VU7O!~PhvIX^`f@IW;km+tXZ>jl4Z}ScIxx%pZSy11S461g`qhIKzH;?JvH8%}! z-m7~jmz4vRnpw3wf>rX1tlCM>9RJUold*45WmzF+>hd%C20iQS4_YZusCuK)nmPyF z@>EDA93}FuO|e-r*Y{@6+S~fMHS^fuGmwafZ|`W|o~)E1;c7Avr0Zh}UVvTWQrK_d zO|wXueaGEW0OH(goc*$A!;U~3KrJGQ2=2rE`cA85u~N?{C9-s5|H7S&!-5Pti&k~O zA+xMp#ScB*P2tZ&-@bD4F2b}T;;FTXQ;<kY!qMjxKyzgRxuQo zQ*0P>9@?Kq^}1`e$pwfpMkBb0LjAya2DdY_1C~JZa}9)g`*|0!uNh6zY~;a>id?kr$3Zvle~g$As>VL*%W{6Ht)Fzv?t+(LH9IJLACZ(-?S{&iJI?Pu1R<^0 z!MOnYE~ql;1?sp-H|4TSny>E@>G84E)&O5F=CTc#B1Ikko|J}bDx!9VOGZp zd2zrQ=z>A*Q%qs0-I3sKep|=aJFCinW&Q5l?z6%V-No;FhrN=5LiBtFOGl_;wegR> z&KdZn@A1tx@3Q2jN7N*g&=0MYi%1E+S$Z$*RJD>?GBG>Hoj51@tjhJ2zP~Nz-LYr) zi=*%N?d?zuD>ZAchbxHID)m#J(Vzj6o}FKO528X2aHt_Lg}?idfgl2m-gQYUtb zf7h^SJB;wi0k;BqMx%*BJXcJC;$&j@0K5)NxW!SRt^e`Efbo}li+{g+>w~&u=nYPA ziXPCXvN&XhPz}*#@HlVVNF}OMxdJLT^>gZ*#ip0ac0C-XLRfLLNGW81JxeoXhR z-6Dh{!n;U?90uN~Q0*<(n*4s8jp&aj^T3XvcOjD=Z@s3&11!sBgL@1dH44zvw=kr0 zN<*6vy&n+tzEb#3hu&Vhqd{$%*C%ybpQH>a(#?7N7eCE&wb33-pUUl_)q1vZJ77fz zf{9GsF|k?_jg9mXYZMO`p)iy#K4h5+1zX*Ef_cw>nh@;xcyghraDKcRe zI%#HN>smni{_&=NtJe?$f)ePAn_ECrcb9E0=`+st_Rg#ed5AxwN}xCFQU>!1{?XO{ z+V-IDqiOrLZZ2a$fziCm@|YnjV0?}!esK9Np6@uuv%H_&-LrkA{pPIw`{{-N(WJ2g z-=twSc<1Jt)by|pPM+xtZJ8EUX|VE(NHx=|e{_HS_Qqt$3Pq1w0gz%4_~W`GxoMg` zu+>FwT7YCIQ2>{B?B1(Dumh{bG;X*qCBDl*23qiHJ>=a{yX(-@XZ?jR$O`Yx7?E=~ zqwl|()w|>6uh2Cwk4?L9leURoq|$J^XI z?+?OnOM7Hw#>q1AFJdWIrU)yW-PtUx$Im4bHu6^+4A@q49Gpyz3oy-rH3df&P_<2) zi3M9f`*MBckm9P%{GR%aUX@}!H3(+uZYO4@hWUO)>Q*Nq*) z1@BGLH;B8xCpGvCv+t=VpF{d1EAM41fTRtB*;oXI8)Z=zGdX!{pWma#t9@#;jp||a z+pMUk{F3c70VKCR$=#Vxm+%>{8py2Utr-2(T>aFV)*38SX$GX(|GNZry}l&zKJ?AT z;ws1by5{D#!gkgmq1Zp<>HI$08riaoWj^rhoy=<%=Vxl!t5R)xmX8JLP2_mbMQL z96E-0+7Cz;2@gCv=Iyx@ZU^S|HDY7zV6aIX6bd{#z~9n6odX1vBHJdDIeG~AA$ftS z1N;b4fXbp{DRxk+SF^wa4D?c-JbLwKQlD631c%PC*j{xj61$p^6RopjLYe)ZA5(&F+MI1J1%0{h9f+( zgRSf<1{$b|rbtIz-k<8C1jmlcV<^xHxrQZ${?_CD6!(R3}QKzQW+{vE~_zPEvfEvX3S%+V_ zV&47S6F~Hzlyo6eQpPMW_gAS*9#2o-K2}A|bS@BEuPj;>n`%gJzY6s3mdN8|#mdh5&S zuv&nxv1O>+x4&GIceN3Dnra0!|8+jhjQO^%DK(O`t=i(%&F^M7D!l4ve;S(wss~LM znr!RRwnr3yw?Q5n2-HCUnE2}{#JD6Dt4b)YNR!$eoIm^Z7R^vW-c3$@4%iD10)Mh; z;!3DAePSw$kEAeA8c^DOlcDtuc>VT;2Epf@u7NMH1H1(=pr#KMZ(Df|(S#mcnoL?A zSC>;T$T(Yh(f<-yUE#ksE-SzJm6(ym4VF0olMTYD^ktjNShl&%y*)*vo_yDsiDTjHm@U(W7fze&-AjZ$sp^bnfvpz!Gj)0z6Y-aC}DdX z*dE*V_G8txpDk3>fEJv{yO}qdFsmtMkd^bly-t7UGwHXwXYePo5MPHJ@b&eRCj#Po zK7e$2^;Rg`9eAyOpWw%PC;&b-RhP2H2YTKvhfg<6B0*^yC}5>r|A3+H2ga@FUx73j zI}Hz8Vkb;dT>&GfEd%REVv=Yk#MXRHqQ8I(YZZjQ!qQXo;N!w0xXlT7aEKN4&HBFj z10wR1=2n8GO2k*1@Y?FrD$w%{)5c08c9q@0pIbATyOV2uOcuj?otD>ZJnSJz2ZP_k z>QG~nFu_G%n>MuiZ?b+nx=abWZ~J80Wv>+fBOmwe%& z=3dRY04tEmbgtxeFmsGa@RfWW_F9GM;M>;_ANuOnuEff{*%DWpvWhOcR7%q#T5EIE zjvla4(H>1}OnEViyNpRkiT*UjQ-CQREgxTEm z-zxY)A8qqAxCv~NsHyZn*_5ShK$0}NfVs+B0z`dq)*6LvuuQ{NfrUx`j}|4Z{w3Pq zGe1S~ZUvvyw_(B^h#tSv3Hti-{uUAO8f@a*DQh+00G?e)_oMAs5k=dhBTOpn zZa46PrrYTsaxLrXxawgN@}U$_RdaR!XcG5X z`aRlGl>dvV&9VY1ltSEkC=UIHt2EA z3CerZBGxNYterU^@>zJPgV8CMXL%Q-(RDhRuUkXb^7gx$QIOSR+U!=MaY?Ino+aWU z58LBiHXZYq?;}M%$*EM&SVe0=enxhA4~6`IjHi=6(B9jpZ0_iuj+^oaq|cb}(Z$V? zO!V}A)AYZ`@i2bIal*%oP$$5)j08(CZ1Z-d&9AVYzn4d$`$QU|X*mG(Yy=LvW;5WV ze>r6VGhxvR9vwHq0p>S+MJX>4wyHy$2Hu*#iT0gt0t`SQ`olNpl)<_NVq`nr)^p{r zJD77TF7op}eehVb3Fw;~S{PtAfAhKg3f3@1{Nk5X0)9!b%BMN#@PPaOOU@8e6)erj5tuXq8i7%VmRZ{favJ z4>1V#Wc1)Nt$YmJIET$UwG9{8@AWO zUr*Pf+iyu*vwHuXZ8djR@(aa}c=}@sYfkyR^O#@X{ToH$EQ%o|u#dR;V~kxyEwjS6 zEB%ulVY4-CA+^?$5|`5_#PYS#Y6P>7>)i5Em(K1{DBb;r(NJf2_l9=EP~n*p(-cA& z>PU%`{-x&|H)agV7Sz)&pr)u9OJw$zJ3~ZIoLzJ4COz^hY^Z4iEv!I0&$jL`ive}E z30XDSvSMdShVBkCIlrvY6IyH~V9ApAR@LyI%bg|V*6-jxyre98(Kt6#?&jy<4qI#l zWt-0E!SBMy{Lmk991BRqJ-DrJ?%8$PK3ur6>?_2@Gr?_*=96F0euw>C)-R70pBt?V zB&b(Zyv)pUu-{k1Z1!OGVSvP&$oC-W6DZ3eG{Uy(gQGrhhQIVJbv8Q91d{fce(=Sa zpyGXNk>&O2$%oCP(+eX^Z{jk4u4o_a2S0ZsXEYlnWDyRaHn3-si?*GeAr60AsQ-@I z`2TBKBRB+a2@2pGLFBJfV(rm$G&Gf7byGe-L!SCtO@TFu;ohj<;ojr>26coInzb&) z_Yz|Z*%sbTb(wZ*@aXR$rPTBS@1bJfC9X933yhCL?5{72Z}_}!)EIp)+bdDMSjlTnf>%@w~!)G(v|a{={B>)DGn2NaAhsCo#q!M z5Q^XZLxA~XH&U){TOkn>@=5m@;_2-6PB)TdfyY_rF<>L=_uiV@jrRS825}!f&Mr$u z0gEKfv>TTEpYrX7=Yh{&vIBUv>&C~P7~y|S9N&7z{+pyNCO-M?ZB2BCogwYE250Gf+TBdqkI#-(5An!)IC;h%OHP`Pbd|1o_hWQAqdoEbsY^Bp~sNPavv z{x3MEfPnUB5a>9|pTb=d`|Gfa-=m-8X`gQM5FUl^74OV_$mViyJPY>W7y4`(y0K{# zwmuJtlvn?AZbpX(U{+$SjIzx4a^ay(7d1}L|0)P!2DWqz00j$xel}h3r)fc|{()&Q zp4*>1<{Woj1FTAZwx7$1*NYORile%g(9Q-N+cpwEJqW*wL;7w0mfCEmCZOdg-~b`f zanF_apb7t9kx$iAC{bl0RG&J^+@GI@^ZbmbK%WLjexdyYu>wNnX6s8IQ8ziDz$gjOgVrGQ6a_uR29lHgMCG>p#2Zg!t!qzc z$EOUy;vgs?vY~_7tpfDOTBq3lhw1NGf^6!q9$gQ>+L72tbD?@9qqJ{U0}nGX^CuFl zpWm7^RDnbVmuIh)7+Y%B_@0G({_Xg__RxQ?-HrhLO;cK&Eh^uS3p2YM0yRqh!km&-g`@^frR*NoHH}v zoNwmLobz7qpZCXHCd~8X+0Wi1`Q5{!L9qUK|5=@-=1 z=09JTsw+PlAu-!ttx?dIwjmO-wa9s9YYpalbltTjuw6((`ewA7#rSFT7plzC%V$5`a-pRiRdh@FlRZ9DRZ~a1v1^ zoI*(4R$0)6{unsdB!pB9lE5tm&Ybbs?|?i7r0eB=`mU?-<68C~#zFEKME9*m^f$jl zpc`OsjTLUJ9Mh+wW!9BnAIa@WXsMq=ZTGI3@xhDXXUF-z!be0)HsuHoa(&yykVqn3 zJiubL;=V(6iq3o!fBK%4-lV)3{5kl?&5#Mi$R!LCFwz9$fltQ)>f=3MXG?R5(}dLJ z4t(Nr0__=WeW0@dw|7o8QTK(s5u^fL0hps*#l~;_TMlXjf!*9~An`HbnH2eqOIcyDmhL4}%H@mNk_+?UA zqV_cYZ_HIRHxfLPI%E1~L4{Xr<0cdzzlcu%F|RdFjM6JFl#xw9=f)R^wGd!GyQlS^ z&ho#}AQp$Q;6p6(6)30$%6%I?kO8Er`l_(FXO-+Y(!*2Z2sL7}Bj333`X!<&F|))8 zgm6LtvtI}Jmaz^tU~S}xfBN!&`U(y=U4=3AgwxB7>mvp82QOxCZGx`R+}jDAs_zh0 zHlkoS*vn8enKs0KZ}~!qy%VA9M+d9^@QasOf1Xu3$UuV0h%b!a+1L5d5hx3k=*W*3 zo{ii2?K_S3Gjzlw+DosjgA2BWv>>;!^1OE@KG;Ev&AL>z*6y7+mZ z2>)ksn$D&Ff$?T3SL+F5s`Ta_BHgI%Yn+2S$hM)y-S{rPo}KZ+b*VwA86w5%cZlt;-oSfQXqjtMZa}gum8==Y z0=%2tNHo3EFs1sK!FPx~D5lee|I|8#;UwT#=A~3Ov zM(JdJ-p+!m=j<{p$@|MYW%%3NK`+?`AQE?K=Gyi*{~8I|_3xjePrvSs`u#DEmw9R- zM)X2YGtAh+qU3YnYl^UYO{Qv%72{*c5Y0ba{^nfv6&A&dFUGD!JQ9P@j7O!V%tk%R zJW$xaI8yq8^_aIzh}BFum&(3bUO*=1%AWUTVKGnP?!$nQkSUpmQqSo^6wF5XX$Noe zEeK0!PTaeG?G4ilGHy5dVX7#|w|D7UdObr`=bst{O)_#(1bU`%A~{0jZCoB2Ohos+ zSRiy$#u+^CuYVU161H%eG4;dO!JcL-z8JN-mC9g2;d{EGmZChDm=>Z;yWie_@b%Qf z7xe`R(}3%?PJK*cUF5WI zSy52`mp+y2v(Z29`IvCXT7z#C#e!3==$22#72Ni@>7jnsTYWRuhwEfOjzKv|?PE=x z&o$In4(zGCdAXZKfscYlq75I3;xi-aNLn2sJ?~Oyw%$jlYj3CDDDhsW8v6oz=_2SQ z)zl|ZkHRQ=ZL|xYj{n;6F*i9=syLGS zvT7^uT|-iu8tk5jSS3qw+x`jqhVh`=N3T_WE!wAag}p{?CphmhI@{ZRQMf;i%D|(g zQch~pjr*OSf2(IJ^*TlYZREMCch`ZMIPtA*HzRXwX7gtwYt@^< z9Q8UMLLZYo_(6xzT-I$JD9|l=W9yKdMflOrsn^#?UZnIqDSni6d z`(x5=2I|`8vKU$`yR%oQ49+c9-jPe=cE2M8ILNzC8n60)<2O+AXuJL?RqA=#NjIU# z?{m(cu>;jE@-WCg850{5C(;4UFJV=iy;xBFDkjFsv2vO2>z%mu7Kzk(!#LrhSB%&F zA@dZsj$Pt#Qd>N4IzQqRU381mfXT@A;0MXZVHNUDHO|mWiJA}YO%Drxit(W3@DTS@ z=^fP4zk2?hk-`gK4%kyRc~*8=IeL;ue$8Q;Uq|TPAQP@Ux$|h?Z1A~)0ot(Tus+Er zk%UvOj9?F58&4z$yR2H3{o4%T+a>Q@!tm1gt3&f8`3~hz9Gp-kC1S1lPmurSYQ65o z_N{%z*>BFDa2^@EH*0QgE-SY`S-| z1w*6GwD>ssXftD(5YPOVRL-mPbT1R)m0ppAfbG>LPx*NeRb@6Jx(hp=vVZUM&V+Vs ziLR!&GZm{nQ`s;3=4FTgdC)gj4RBJSBm$7cl}*%A@4xVR!E>lpXf$y<8I&UY`te^t zzu&J_-Fu^-GiY#V7iNiA3}V2%GaoiNN7WyQT2 z0|U^{r8#v``}510${(TM-(#=`YsQ=|mmU=DoCNl$gaHadj`9u2{0o z`yo;H^${x8P2<-6Zj#S-AfLarbJxdTAfLaz+}`*6C1Ym-GImICqG;RG@4R8?RW?(b zKuR{-2RD~ICsLrN{)JQgtt9|k`TJ)Ok_Vv7-}_^C-+;irRUAKvoSUcso0Qx1OWAEY z8j_>UWZL>P`Z*>Y7mmcia;HKrv-_qX8l?J&MWC?y?1XtS!e_1e{!>}hh3#HMlMrz_ z9r2$V<_`qwZ)UvPqS+6DX=aMvcKH$I)7O(9-lottl^V|~5d))^(49SRn%jB8sleIA>0sN(uN?U`IUvbUx(dBf z>r=H=R!ljJek8*la!PUcxT`XYuPCdFzs|gLCY`@FDh?#HL!52_TLh(NbP6fD+M4!` z?TgE|Jt#91jY4G)ejufhhmq*9U^Qyj*(tv+W(dRSnxzu0=+4=O$a+hc-eV+=;^Qlr zF|icYLyYC(a;7xN4ySy)P4jsiNRAx#^mrk%nYn1so|tUE8RT+MZJpkrRbx{_S)*-B zC(Ltp^o)xc!>Jn{c+0s7hSK=AGkOx+Zx`2ihj7=3M|;cPx?B?Q&t@`<7Tf+}=sdJk zz=n|%=66FFIXpumMtE^Z^WM0(msY6LEawV0-mI*|ugf(p$}!B8)8Ir?aE$wdtGE2C zh1?Es3!DhY&ep$_^kD$`g5Rwryve5_Me9kq!FgkAKb%HW`WiurV0KKks@?u(1bYj= zoG3xwnfQ5G-@E<}`C+Z9i1n`-s-dnZ|*%sI?bp0VS*e74H; ziB+P|_T5gj2!3E%yzg{+f$kfNR2|wiS*;}vka}~=nqO-=NMOq`j@lgwCjT(m7gn!RLTAH zpsG3Rw89x*j_;79KF$mci0Q~94b?_67@7qy{vo9TMRpRxU>&Uj8HOBIp~$)T`+miN z12=~$Oei6CzcluFpZYYbE>uo&s8<%64T26PEqzPc2*%uKU8fQK0BebX8+^<|Dqoit z#>-*}g{Gm}cCrz2V!bbDCNCX_Ec4c-fLO^P+>s^P>N4}InB$aTN|AV;Gm7N~@;S|T zT7>94z7!e#naB*;dUd$suqCad)@D=n#@1rU$P&aEigQZW($|$4nbTYpZoEBEkZS)tgCuZaNh|+#ci@)-qzA9Sh-3)<}SlFWY$zsXTr5m6O%&{nX>krVW-$1 z&yT)&rw$1h+(%9-cVm@#Ia)yxJ`9AX^AxY%6{dD{QNCT)`An;>MWLi9x$Mn^j6fBe zUo~UH;W%=OgfaKaY-k^t0@#=wGd?^wc)asO=jvwD*=8~y-G*9S8mT!>(u@?U3`&l} zGS4*J2nWi8Yl6p=q1xAka4`7{U)iWynXc&GUhUq(vOuc|-RlnqvE$*=?D9|1WliVI z>pE_ZinTF5GH`>8hbmv@Az7P397K)F<$%>@1GROB&jx`B!8s8K5SBUC}qwYXKm+|=XHJcUlW%j|y*UX%l=@t3AU+ZH=XWnt~z=r+I+LiFghX=4u0 zKFvMW0!pbJA=uy^C!%~VY7LzCiSBGJO!@-EFI!oW4>7ZM&G=7vC>MOOZ9JJ)d403; zWH~^1?S`%;P?jntgTaJ$T1O`jP}#lB?DtJ}>qz)K<-B zD&cb7ytFA-265#3b6l^w|~5;oRMz0)96D8-o4| zydkyYxDfv)6HJ=84JaLjvq7M4H-I_F2l-gK#+Ab$wGW~qSy+t;OX5eI^HN44p2+4~)|evf}6e7^YkyAbk+ei1_ww41mO{^?dd0 zhG}t|Sfg$4MfmDXA(H7=|GLjv4zKND1En4#+V_JzsJz|;9wrX2qCRkY{B1GCx9#_8 zgg|@5f(gLU@$b-hASc7;Q8_+@P}7UV;+;Dtk>P#f-ywwvyu}tffo}$6s-D6$hObkM z)(4SH-!BE+9^uGP)6*ur(?(Adubj+Y+oyz(;+>0ks|E}oPiYCuzg%(x=YyFcNV<$3 zD!jtjy{xTs*)jO!M69OOY!+n8Y7iy<#XOwbBl*^$kyqa#5=5VcqVoo_aV9U`s?Les zDi|=^E@E%S^DY!&x&6(2`@dv6SHEYMKd?TBPhm+bNJHx|AY@Z9fzgi=;h%yp$Wt0*HsU#}{|AbVfO3q=S%hFF_` z5`6%?34~B~ac%qt_p)RMu(J@orR{0d>H`4S^a%nmG7S)HaOB@U{wrb=L^mi-?50Z^ zhw%dZ<#))(=con^{a^B@zW~Ap-9dJK{6}^^OS%TW%+dy*D~;dU2VxWd8g_Ni>A{x( zbTk@qwhCiL@MlJBTfw)EWcIbLBqWOja=fpf5mrR{QOYc%TzDCMW`$jDLjL3qGO5e z`Ih$h+HN8fM%aJIWRU%1+}!ObCMrWsanE&G-i;`o35_B56%mMJ{Y_k_Fx8RON z+QT@YmhgQs_Beq>zk3cI&7LH)}3&QHV(F1y2RbVP~J4II9XcGJ|U7J`>`>1OmI+r z^&Bnrz+6B*FJ!B5#|)f`^+A5P774+xb=Vy&elq*W%kkAxoE`VIBl03{7s|Q6U}?V%VP#|* zLy@oeiMQ6V0U6FG?5sGO5ryDmjqg4xUbWA9kmqs0sRnueS|~r3XWD))#wwD7f5GU< z<+`PVF?bM|l5~)(RHnjh+)<`CX4VKE#|$PLWa#ECxP2^${SqX%)TFR-7+0}S3voGQ zicOb6;fyn!r9WlAmQUCW88luwLXYXcdf7f^*iF`6Y~~xQdyOvft9Xqo)TgdnilpirPvKcoJ!&95ywoXgeQfG-~p`J?{{Y z{N2+u5I<#pF{*$8xq`ITJZD*)NfM68Cbj6)xEiP-^Q!e>s@6x%ga=ahLn(Q_vAEB< zH#>hRrw`V_O^irN!RdAEJ8rkJD@v=ChQHMAsbwU;P|vD=QlhbmMtc#|?Q9un;6yPc zdYxIqfdPSrmm`~z$}yXff`sLKCqj*z3nYP-1IO%BQ7(yXt=x1~aMBJ*H%{-HkGxBjKK};H zC?^}y+7J+=s@Z-=kQ}*>Z@yU3U4N&9tsH$u|CDkuaiKncHV2zFhJgmdj$wme<>tM~ zja7By-aa~gG^oPrMW2<6ZP!vyt8!6iuBqKnds(z4> z+$E-Ib`o3<)~<%g72lAPSI{BGS{iguH~SrcuLum&)(n%%8DP+ry&FNLGSIqU-;WL& zH&Ixw4whFP3(?^r?sFG|EM?{NXxicopjAMVU2!|n^aFCU2iaYVnwiUlNGclM^K(4iN%Vv zj}0%?NazP@ox5o*c_*}~Mn(1jM1WKxA+b=bwVVPQrix=`k`i-Pi5`8Q;8n^yV&tTF zkemG6286ggN6|KnWU~!QHBlG~DoVw5Xmt=t?Qh)Fn{zViN>MG;GM<)Z-MTpTcEO!+ zuglDnId-X9nNgoWianQ{STN*}&|gl`dv5%}0`zGyTm3AnHL05zKXgv&Y$xh**->5; zxyT1=f8HDegHr?LeMu`^s`iHnzAs6@#^nNs9cc-zzew@36s%l@xLwWl+IY`X`EN>n z^&h6Jzr3Y>Ce}QLxGE*Q(%jIpjUMxNNVX%`Rm(Qrc0D0uI%XCW#a|6)z9Hwm zKce@!Nc0zrA-#+McimSWee{e6;(DfcI4X+uy{(mdPuI1M2>1%6 z4lte}zZd2<;mi4r#!Z73!#jd-gPEC(DJs-DaxsT@>T#mHxFT&Q0w}-na0Cg~5QV3u zprL$r>xo##lzUYsE{Cr8D92s9bf+-l;j0r;Z>0@wIXoMLC*b?bkJ>wp?kfls=F{3x z&5PncJk*z(Q({l|n6l!P-l2!V-1++;c8EjsnXn_+xte$yXEs>yi2M8G`H1(j5fZ#r z(;{ac+*nlS3K$kRI)m+j0L9AJth>)F>%eh1Yae z!)S~St{Y5bi?N^+gJNZz`!L8V z+qgr(MXA2+8PwJP@c8BRj=|4HadM&8n=TfSU6SgWj?X!^o>Xodjts6NpCnR#-O!a`llA9*9$+(woaoihC1y-KEtAdgv`k7rGTgc z^Yhg8?vJ@4Jm#p!mKRwc_m3A*kzPebc#5|%qvZI;ar}^awc{aB|SU)2d zjfJ~TaP9W9rJ5*;Dn(afWo~00^vFzDw2|q;;jnWIL0wbiM(&=q*6=(YDCYTN#Q_^U zDec?s3CKO?TbN0oQB<#JMKAP4Mmw>}MR2uL(8z>dQ`0KInuECoH-V|Hh^H<+0VlWz=@*c*TEzl0T9GqNUtda;6#JtIP ze?h5)JaD;zgj1|?jfr7K%WQsaSVC6`t=`SuSt=tO=F7pHJl9{(67>8;Zwi*&NoAke zQD^aPsWUjws@@cTt<#v;SigRqnK(ng<=LZD!2)+|a~}%?lv{-a(8V3n#N>Mp>jx9} z=i*GB6hEv>qaNwXG%dfY)gsKo6)I3g==*5V#IuC%cRlSan}c&ns``4lH=-a7)w4Pv zb^c*-3}!#YiDMikw(R3%ftZ{TkjwUt+Nfg}EMyK2u$dQLsF*mRbg)l)hWQ?i_)7^A zHxR<>rLD1xf>+O2^BteMi@l9CW_&_>YOw_kv#NT=G=Duydze+to>m*HTq!Zrd;jgt zf?g-oeg)o-2kP-bu4e16(Pze@gIRF+u~Km0rX9{M)%Q`Mc=?3(WS`5YHD$swdBmMV zxwa4CHyW^}=vrlRf*b}}i+W*Ote6>KnH6ZeWf&%6n&p4EQwu)eKqk^u&em< zx$1eH`xhZWTLZck>vFefi;yRRswT=$4-!ZzU8}7|Yr6+5mbcxc*CT>Ej7j9sgJrbq z0bS3eSaF#!i$@3LMinVkzMgtGo%3NcViD1M?i)p$qssoFir|COsKYR((ReysRxSU6 zvr(pwo`a=l-dpdR*Ce1J8>1hILO9wH<>$HS%GCA>Rg890=drxok&4gvpE_i|@{teF zm%mhI8pp7(kBnSQ0-D|zljV_(>)N_>ia`otL(MirCsP-jIS(fmacIzC^sBW{zH9!i zL3(07OLzab!tmLP7jGzITe4nU^7{{`_68|?} ziuz9VK!w~4?@_%>{4@p|M|(Lx#yMeHZxEYtcEkJ!d*A^ur9s^ylHeY``mLtO@nHx} zM^Md50mbYPE3Q=*hZq4-D~35_es7~7Jv2&__5d>?x88lG6LSE`gF7|c{BRtF%Nwit z4)Kex63)z4sd0gWxf;RF7Oct_7$$l3$&4k#qZf_X9KyZM^@B zgttam;%wSnl%ft?^tss0uu^TvO!`LY{3 zqh4lcf?7TpPb32|kAskYfMJjv#@jmUOF!_&1 z3g596G1bY)jUb@*YIPTNHK*-@%${Bqu1)?zr#X1_laYbEsb$DbKfF@6A<0K`C2EMR z5Qtm-zjDq$Ddz;Sx?d$mOlF3bY4br&4@Afk{EBX`E^SiuA<$ePS|y|NOMDK%*l2z# z-UMmy-Skku8PYf(YTOP)3z1@=E=L?>$H6b}D((E&@9~0du~tFu7Yo;Q{VQ!QiZ>_U zWi_9<1&TBS^p7d99%JVEl;5R2-UydgKi!EGzC@iLm8@1`Ew;5zwE1kZE?Vrh=x>pl zecR+)g`N*NyJwzbz+q|A<9-YW6Wc#-q>Q8JkQ>&hGX6>?TV^4~*Ae=IO|y=)3Bs(!o&JZr{G&fb17 zCHfVi#b2=+J%{HLY|)igjpfTsxYaJtxlcy5Kdb^73NyoBHj(1$q~Z01 zksjqxoOT5>;u}oj@ySQf(paK#vQ!orhkGIaGK_?)Vlt%n)zfm9fgp%L9V#%lOEOIJeJ45wechWY5)S`)L3zsfL7$MNhGV`}V53Ie{iU6eGt>X83t z43%Mgh*W4mP7WtJyuh!X7x8%+eX}6wF}G~c-2sho>w8Ss7b%+^o9RhLq(l@Eqbb#J z`Wa?dA4%`G2C1bY-37xH#V^{k+VAv!5g6B_%uioEyV=5fo$w4Z^{r8nu7>)ZMPen|V>Oy|%gK zZKh%Z;!zXJ@c@-8_u4C|0y%RcLa$NzGYircUkOfooIoqpE_KxTY#Z#Vvu*w}iETut5*K}*rxmzvG>j^X@WhSQmxd9g};cf_+|VKa^xOQn5h#N`tOz()5jEOn|-_&W>ab>;ylT|dgT^{_55#ojmLK~8>2 zU_#NGa}Nwbq|q71cgEZ=h-}m1gv~A6;71WWlYwvFkcSkW{anrASxL^G&;jI=%y4or zNx+U!qL0VKkLb`=ZPUpXVstY)zG*jdwmND?ah+4{O`g4u)Idt?#1KzP#&6q|fdo(5 z-({CiOnc0Ha0Y$NIKSdWeEcbRu}gaL2l047PX|&G0r6W001+Nk_ zyKj;uDe&$vm$RX#J}6t;C99?3<=Ql~te6a3)FyCIlylR6P)$Pmmvb0Lu(NI_I!j@- z>!~vd5l?52IDa~AbqH)P43Ja$)*5y3d3x~{4_!v$8GEy+__)pu26-6n z2!67ot#yKk)BqW&!1~qyzJe1RJl+Zc`YG{mqePX92hA@?rUT!(yNEIa0gm;^Zm`cj z!N1J#_jIHHKX@kDRa0plYETfY`x0&$8u(J*PI>EpErgP^Bi^Q03dGwu__8QMW1M~dK(hZ6e&BCvY>iuQ1DTQw zT|lPf+u=(0@%Q4;9g7lL+GMpsY>k-_0{`3M=C1edZxK~SuBe{wZ;aU%;0fEjKO2-2 ze3Q+Fc5BW5AL!}45~uM3X|%Gc?+_zxz&*w4a0YNs-B`Ubz8gWHvL6NDzZAbp1_GSz z?{a~EAhB27JJvwNk3u5l?%i*5=&HJ233GrAEX}>$!tQJY(U8$R4EBJoE5x!))as=< zz28VScDP(FEivE{D&VkdPbccl`a)Ym(pQyDUJOyH^<{xW;l+=gvFIzLBqur5ctC6M*S*x80^swBt5 z=8s7Yt-~3we-jtmcl{DM`*Y^yZTZHl$VivL5;K@oBZrCA$ z*O%Egz*mIp90QZWr=`oCD5E~#MIyPgL{U-5`%q!Bdz+522rc(pt=J)&!y7!hgiqG+RNW$|{kcZPbkBZXx4`m>j_<>U1qEMg9hzkh0kP92h=Iuux`~_i+5w^^i;y#Pthui?W}!RDDm!jG3)1g z&!|n{@HxTis|QHN##X^8q46A8!~xdgo5N5A2W{;VGg(_bu&QJQy>P6;O;foM)C-aG zH&xQ(ljNGWMKoR#sJlk9amdlWG`^HKj}$^a@id7we~y1`6pTftBJ*4XRyP+Ga-)X^ z4LB;ku&@Qm&nuc6-i4b+erD@O5B4lrxIX$=)OmL8Mi@TS=moJ+1aA#LJTOdLo}_Oz z3pwx|(vuz;zX8v?Gb^{8p0r6;ZrxtLHd!8oLA}TwkotVx)H=QJWw`}g{(f@_DvhS% zmTXpdZc@}|?biHj_Z~cMh@AW=z{=kdRe#K5E1z!Pjg5&kHrJVS0d1=cd-|2lZJ-^xBx%Wf~b12%f?Gj zQ%jeJJ^}Hzx*4u{%e3`@_k~?}D6j#3UG2ltyO)#q_A!!Gj4ZhK8sXUCe#Nu%qTzS8GWSz)3(dcx3WMt znHA!-{KgR@0sz;QU<1a5EFKGLcUkneC_o%*7#IRGeg&J8wTB1|jOX(@G|?RmDv)Uq zbwy$@?WZtf8}4<{s*2?f#7AcHXe5jYx7At(@42o0yzP@>_VYu^7o&(2WG=hI4y!&r9u5Ns~b&- z=kraCuA0r`NAm6Fo;hqvSY2i>_}|QC1&8ktb})0%BE&mh!$X^lAcnRziCE>-M8{po z7YPfVpIa&%T3&RK%JUt89@FOboI3AP&;C46uYJlecYK+N)iE?SEd3My6 z2q2ZdPS6FonFFPt5Jp1=^9KdSfnUm4ZMH%E?1pszgZJ;+g&Q*7Yg$=qb%W>g!>TqK zgzzIF(9``pmNHG&)v-fammaKqymtYLJ&&4(PxJ%FD6!W{A%FQ&G{s;ah^SvL0n7`T z1h1b46lYSK-{_PXmNLYK4zMqf@!djErW=Jth;pGg4Y34D8-2@*_*WB!@Z zYHh9SE|@+`{71`|vYAX<*eEcua*)O3;zaL?#(O`m4o-!WnsuJ$>CLzWpKOp*Ow@I_w zD86^>gU3BKs!(aXs|B*l01F8pm19ezdrYKlc!SPXCGNa zJ%>l(uZImsu0RZjt#oyPt_O&dB7KUXK2mhJ$&N}b!w|(xnNu|%ueN{Yifw2NI{76P zNLB8KjJXH2Qo6RH4hbt$*_+)(JBuZ%h%nyLxn>x}(`#|?nejcERP9Rx;p0=bCNttZ zWP(7`rd?R~tl}ZnPu^2XQSdVh10*jJrb{dWsHhxmkpgW}B!ttLmdyPHNz-QwYy->5 z&UGk#Ecn2c^_j+jf{7v|i1*v$1e(Y8@M_Q6l=_n$MD9T2bvlbl5G7>CZ_M|U7!6KI zSS|SpLn&}p>R86`(dMfi&0W@4&h|7@wsa)+4ua^3YVF{J(Z9^FTeBjB+IF-&h6Sbcy&ZMZ!xx%G1?}*$tEvi?90nV zuj?ml8NXe)ZFx*;twZ%PYX@jn@MtHtNy|tlsAHZNG^@6J887mzXT$eWU%a!%nbGeM zPbhV%Ra&4)jiX>wswrhh^|}0I@t_Nn#}o+`D~4n4dT(VQ7KfchF)}I7xiJjy!`RBb zEHgs;WLu)EW3r9ZY2+(53ycQy^{bg#e(?ChN zRkImrQb0KwCT5TZN;}!nwMs(Tcj%Gby^}W?SenclI1cxHqHM9cWiMLk%7&rwi(|)> zm1y5h%ynm$d$st~aPejKyjaUVo7*ZB@X_EykEMjZ8ex`vGi4TIEmf0Vlq49i_~@yI z7e1swO5#!p0>eASF6S#|;utzOI{4lc2v(KYiQdeft!HZ#ZU;@ND;pl7#>K0&+bG($ zF4!>xr)r#YTAslmI^dGby&&2BqGUneCW{?;5~yhDaQauUaXwZfsH44y2E^m+_Xt?7}MbdB*IiscP?ekviaqsPugP-fTJdSsU9tEuDSxWL1Z1Sy{k9VFI3!XP4 z#nt7ko6@tXm7jZ3yQs90rTK8`?M)`X^N{+_`<=QF^!Bt>aTlYNim+_Kz0ZVvy=s%v z&b?Hi@?^~j5fr^=5p$HJ9xyQTkcE7O9T?H4nJX^v&Gm=!-tPV;JW$=lg*<(LgN#$8 zmIRQhuq#eEQavvx1o~U3jy=Vdf11 zI?{w>=(8VmQta)DEzy#(n7l4hRTEl5_Vm&@2i|zcR_mxWpwmm5trX04kpng?Pse*?UY<3Mw++k(=jaY0NTIto}+6R0l_f z39Hc}44&%BC6CyeAT;_l3t4>LY?%sVQweo|IoF3ptpmpqj*0hh<{Q>lD4{*@?R>1M zi{*x8UBLW%W72d6Dy6Spkh3GQVx)z-YU>D`bjFcIsBGpqPILr%aj+J z0XHl@GA3q9c>a|NxybtZdIVNW>)OZ4gCFFWd2U}*b^Aa8A*moO!kbaxEIVp7Tty0U zdouMe=XdBhgtB4Y?$fw$>d7K^;>t!=lpmy}Yc-Hxuj84J%ad-2KHlUr=LAy?A4(+< z@2$i9mUVg8ZE=?ppQt{Gh+~tB*gtf%Rj8wuY28o5aITeOcok^^EFk^}@G5b9royFH zjMo#Bbn9JA@Qs|yizo5H%YuvbiH7W>k+o8^n5n?0_d1*vjn|t@OkOM|Dhpm;Z)UU= zF?h($1f%op*HDmwSRIR_#tgn4jfK}m+PXfDUaVJ%8y6jmUmV%j3(p<8RDN z13pxA9g0$ljc~1#Z%z65j4W}s9e$eo7+rOqp!RXUYZRdloWe~#uFXUHh42l+hv2li zR(;n$)eV`JO6bx|h+RBvE+<{<6i!dlgtM!RW5y)6q&Xz?Y)82~naEmVZRE;4*Bp~A zc=7NhF;y?7;XA?f_S%=R?{_9LofeqLOJ;A5Ji0l5Z&vz*I{5_(4!>y-$v%%Yt3-I0 zUvfGC(@0|3=(xw&J&zWpR6o&nH1v$YVh?wB_A0bzYnmY3qPMTE7-Yb% z$M26Eo!i&RXPTn@`fF^*(^N|8*FEwhFF0aYmw7XG2IiQ(M0K3RS3}H(--6UY%IE2Y zV{f=L4KJw`i>MFd(V+l`eBK#sGTs^ic>yyxZq0=RWcDDqLzo!*JEinOa9n8KRNoiq zlOh064}pfG5Yebbivrt0T&r;FA8vNq269tc@t)RDh_gs1OmevOK%j{-SILa*Bg7@c z`;#Ln=N0ymf7sw4X?6}r+EY~O@pZI%meUhpUAdY1w~tNznm19J!eCF~R(fnCn3Q0N-5);Ubi|lykfnX{P0yh5;%TQo zOnWOVW4ZYRgmhU#075l!_nbH2RIq|I;@w<@v`yaCP1V>r4*D|$ue-jX(|sah(WleI-Ru04G&8qOH?(1y zzm@*}otA}cd$ZaIS4rH^L&WJ=yODGx?c?6SVy)SadCBVukzqYK3gm|Lq1|Mp-Zqw1 zaw>QBQ$>pwd~&{oOBto5pNleG==L9tmXzc~eael(^|s!h;?-_tD3(8at?CMwrRkZ% z2DVoV19Y=;qjoh2$|vQl9cK62VVuGf3oXeG#*%(tHW)*-FOR3>!xUuC`ob=H+`S@qI3x_Ubgf%Wf|2K?|)u? zto&g|ru)H|7i##q1zjOMXq)p6X@zOQ0Cc%lQtA~6H({4F$InfI1AP;bJ|ZCnr@D7Q zE>rvuZnKo41cjJwoIB;7lg-_R#Lb;L#$5sbA}Bc!tbKeWq9fCHWrs=C8LvX23l3Gg zO8kD(@z@n3rU!HhzqMb5{F8175uBI z@vZ;cJw4@bvwfgcH1+JPdd=%rQo}X&HPBVN8+~bXfb6P1^(+-o%RR~2;3X-S2Qir; zal37-(mGtajo zv6k9O`S2O5H7|JCPr0#w{P2JJvX@v~oIv92dG`V;*ED-wtPaeDp)K@Jvr{Lq+!Lgd zOl)^jNqANx_8SIEf5;^zknOH^eH&dvWG#^9Dvtd~)^5l}WXkZzIg6UT&dhDnd^kf) zzCb(Cd@B$0EZ!NM?NHh7^vm?u+~TWQ4^G9JI^0i}I5T2B^=>p!kY{ zI)jClrJA}N=O}n%lQOzziBBz@pN|@jc`oS!NqK_6L$gx@gGgg~0n#tc;t&YIq(+J# z*EF)>!J-_bAfa7$W6Dm@N@F5m1%1G#o%FnsIyfkC)x%#hS~73dL7F1 z{62uiG>e@_%p_uG;=;!!ndXgKXL0r`o%8wnqb{_`jfGcj(p;=q;18uY?1;KG7&!T{jd{BO64c9M5IaP98Hxn z-;{>8=v>4(D_)}#i+M+MslCZafrm49%IWB&85L8gqr9qNg!JrUMcC@J1c~mk-Rqr! zS3~kGlOw^Aur2-^s0_L*R@3L$)j)2~#*z87vWfMn*)9H4-yx7EbAwPB%Ba8(9|9n! zFGHXpwzx~T2k?cg(r)eVLkxi;)q{XL2Iy?DfPkngkaGudc>$O_a$0d@`x^C)Z((NK z?3M^?JkXS@VW#{#f2A|i{jdyR<_yCp4yNP6F1JNiZ}~Etemps-_v!2e(@3Tsm z_$>?w#NYEL-UU3gc@SgZJ9gFh16Q(l71_1H*pq3bhy3iV|IvMD&Oh=3>`a#3yBwY1 z=Rgxav8%wiyPi3)g2_wVV&O8I-Tz>v-Lo9*a6kRFF8=Ade!BxXZFLlx5$Gi+RZ2{ec4z=*I5-o8hN=DpR2r&Onk!JEVZ}T&tvRP(3?L zI`Ntld}|4Wo04y#LDuX5)CVrwt;W%9{@pPCcp@EXOD}05|oUb=1b8fNz*?1z!f%5oB^Uv$G0O<!&jj5tsiD=^=Xp5mzR%o;hpQBhMtX~4~{v?JFUNG-GHExhqvFZ8g&sF zXU&d!ht|F;Vj|9p(FWrV2+!>q5aH+yJ)E|7J5tDL4zJ0#^uXRwP0%vyf)kqEz-xW= ztSZs)>&$qOPnWjyAhKQ-2fTR2hzr52>+Mi0N2`|O7|0N5ACw#1KgNQQvK|EkD0o2((uZjS5 zHit`-hHHOVK|o#cpQ@N4b$p(~yJkU3^++KxJbUl71~bZL;ebI`0T7oH53FQ*BX04( z&`r}{+Zx(_erl%#dj)yOr1xlR<|0sLGn?SN$jwPM_zyTNU|_sXumP1ITA~N_Q_ATBo_f*+rn_+JRv{X|Uhu(2=M`8UBO#e z( zntCS>F!wIE7#0>p8kmn5$_zetn<+X#66mm=V}M1~ZgCM0cbWQ&-q zZ9J~ztLNrNv|RY~>DfWUS1(?B*L>r^XVQ~Y%O~3V%XgNZd2bKv>9NPs zFNiwA*VK#Ssk!L0`^&41qlFEGOV~pd+pjrLd{S|| z;iFQR%o*hi>6=f3^r>b)r!Blvp$Nk%R}JzNVZ+fi%Iwh+V2R2`Jo@(S?)8nsz1l(^ znr{|Q1u;fw4j2r)_^KlNL6kLnZHF1Pbg2T0T^Vh~Q}-jLUz>ad!a3l83Y7wYkdFiP zL|<<=$)i6QR6*#{7iFTqdBEO^Sg4B-LK<$zz{kUkUzC+}M?WYzDFQ(zlp_?Oi zTThYN+fnb0r5H5d9zUX^le)pOfzaC2G>!$}<{uRKYG|+V5_SEpE8{~s8-xfr6QT&< z9Fi6EO*CrZKmFy7xeq>myNv&f-f2WbE3w~)82W-E-><*n!SNLq3-nTP&jW0$;qmeWSVAFH^0kXz%eo}AD)99nqtfFRtlQ0j7%;{I#)82vK%n(~OE@>mc{i%> zlVEAsMt+vZTSPOJ9(Bsd=JopQ#8pg}y1%e+b$)qLTxjZP4_P|{alC@t$^`f62{ zP%;T6Zca$^EuCq7)=};wnQB&`)SjgM33G#%>RYdcjkXB~ z<$w*KZm|vqA{MGm-D?p~d{o{nOSG5uafKY^B``#izj(-;F+wmYo64ACHi15S9T%B}XC) zwtQtsDtvGBrZp#8sH4Q-YZD+c78 z`{diXQHaxT3UR5;9Z#4&D|~F-uZ~U{#c`i~fBVoQp@5s(2Oljp@VGaROiruLoHV%` zY1VFHYUR zF}coT$DI{m_HlBY+~VQY3t#7KRkUK>NR&3sZg|QD74W}A-2XVM0RRZ(p%h~YcyCwe ze$Z}h6crw^#I!l=3aUKIU4?*{onj(%f56&SUmpVG0OI=0nS!anRD79T%Gi=q@bbNG zxjpI^hMmL@x`veJxUP)AVqgd8EC<}dLV>%&NstoJ-*I97u&2J#W+tkMPs0Hw454nCh9Nf>wn zzK4%3K!gn+Ar-3GN^O*G)p{H;3(p74eJfiMqbm@@`yqo~z4N?K-FE$C{fFD9KC+rz z3R3L3&2+|M{>*x5B40EJ_+t#wB)rB<8eOmF{Wt^X{jC-g`v!uC3q)$RY!{q7x4^HH zf{h{yfwcz2;ogFdhtczyG@xhT@8JSKvtGZdHgmQ)lVP31_%q4Qi^nzU9-=-ev^2Kx zXpj|ipoKCG(%O2uT119IXnYq@Cd{h(kXgv&97E&Dm)Eu^h%2IFp;a)@dyL8S^6o~E ztrI~$vDR_k*d1dGe*E0 zEM%{ZOdBe*rga;6sj<+GloS)9hJ$;%t8IE(#-mEp5bGFx66#T6ogXe9PjQ=O+T5wU z_LYH}NbsAHs+u6BW}T`l+vD$-UJ9#HQ6+cjAddA5V~b4i!}&2i-A{S4=46|Z zhBte1r}T-5>E}ryj)s>qb-63z9XAszE@doh@IA?<3v^vTb!2fA@>)!+KROm{)_LHA zaHWF40q7o*k=wysL}U0y%fx3`)-DNmN$*EjcE{2v6kqXLj>L%x&CkKZ2RuyI&9+v` z!^w|avBf$%Kr_wWhmWwo!ftHm5lC@iJ^#!QsQ$~i2jmKF=Rd=;;VbJ5lc`lMVbJH6|Nh)D zO=Z{!yYo#LdnU(=tQPkY1zIooEnpU>(=rv!B7WX!@-;;znEYNjx`r) zNb~p$!7%S1(W%i&@*fgabqT?Hge< z;LL$qwN&>CObyh5wg6D}>t_meb|CP+-ivrd=`_tpwfIWtXP|@+ z%vWfnNkkQ#hA(PXuaE~sl_x>ITk0Wogxd#H*rrLa-N1F_K7~&5*JneIZ8X4tLlGp> z^=oVicwkTR^>Zi|aAME)t3mp<^xeYd5_^R5m+zAn(jJO3TNR#zN^W2EBI^mp>;6x6 zGB*!vP~H6Z5Aqvhkc(Jn)zM)t=3W2d(vWK1V^p`)xMPJg7x@Gpx@EfzvOX8(AgnfPZs3sTci_H0KZ zFpUXct}+IURKC6!IRHuQ=tQp&P+)UT%YSNbzIurW6_opC>7Vab@DXS1E;x;fNp&OH zAG&|{gUf4&x-(1*$=P}Gr2{C-=C(OX?=)O-(qJ~#shUBzQ)AP^FDward;F7 z**g4Cf4qa@I%7lfKB)H8nD;W|nzexDt?vD#pr&Kk*iqIfTEL&N>p}EQ*N`?wHK0JE z;jc$`Q|>IZD8-o;w|M2w2`>NOij9bZOHYmS0z9LvYBDwMqP8d;v=noW>r?v}c_Z|l zFY{aW*hzY&J9LD}hfv43(RNvcN6i-)cZV@2XLm` z^A{dCH3*E^oL!#y})Z1`;lQu2lRrvYAt9s#FxTH()kz*Vd1hlMEDS>j55 z)3N>;qnEhX+NaC+%d0HQ{QDTLS!~~7zLj~0A4&|7C;S2@O7!apC%ILEO*oPwms!DC zOfKjHJK#FlB?^V&?>57Kyo-ib;x&v~Tciy0w$(%(1q~f9)tR_;QmHK!C=AbkvW|bc z0mvT56cl0|Y27X^3hYb0;d8O`_1rM~d|LKvuorwSQ~|FCDnrC|;4|yg)TAUq89hCu zYGrlB!X)hEhTt-<>k?u`n_!iKH-%-33wJiS>&ISjMio2VPqff`zSqvmj-?8;bJ)v- zw0Q+z@Z67g>$T;TkHo+3GCbN?>9_gdn+TB?{9fb`>wt6vnjG&4LnVItIRo2?v1HU5 zBKR-?@kTA)UAgtzVj?+%=XkNaNM2V}nHolQ;W@^hLdAEf5gzI|)QeIRD6cH$ar zN${qcjfULzS3q)SS24Yi;in2Mt+Nk5XuxPHcNZdXD;oV2PvxYfvdYU{0yB?z;WKYR;53}PL{Fv)EdD0mU)Vzock#pC)` z4W6i{>cNQ@jtNFcPL�d!GrId}|pusQ-U{F#jEXFbu9gE$Q}=MV!Su%Xt~|8ckrQ zQTewj2oIU^JR^Dtruj~bnz(&kvAs3=#nN&wf<^3)m(0B$_SE6^tMgJQGk5E8?0xd9uyy>>EH2QI{W~ zZ7~1YBm9zQAx5NjGca=OwQ)PDyHL>K|Eh-g-1N%8MYrE)Jowy&2Kd0%j_n0I{fQqe zZBn75hQ|!{0mGb64=LP`xPR|I0x#~K<-&q6H6L18M=|ArW9N6KT9p9e21kGSbO4o=ug}6=SCe^v`GTd! zgPJe{cA!!)?vI)M(bMKf3i=V=e`}EhvVXp`Sb+GXLqNK&HdTtWnY>kYk=&5|F3S&f zSPc97Z{q+69&iAoGTIAJ3}Gb5DQL;JwY!&=5oNi194Eowce@Mze&^ygj=2yZLY)b` zR(?+PIsp{m|DMBnWG!{E1e@lhj!7@XqW%3pIMM7&bpp0}5AyM~8`FHhVQJL38u)yu z4u5I^)Ed~pXkd~I5VxfYE_0!`ZOGI$?S83JzGY8G&02o6d^64n`@Ky}3Uz+*)3}5x zwxJ{@?-7|(u^t|YBMdGV7f0C`?V(B_6!XrM6qOqV&V?2OTXA zemk#G1LP#Y=5`_&!Fny|CRZ9I5kS*(%j7gUlnZdqoc0k!8R` z?uYAoTBEi?12%}6V&GMv1Qr?9uYhOZjZZ4aDe!1bcB|(Gw0`QRKgN>&gdw>6EHW;u z2il&`+B?m{aO`Ze^%ggRn>4eE>iTo{K16M(pug@_ie&8(y!YN-^M*yjl0kKGeqY#T zvBj)7=QnRii&O4T67i|-NDzVgH0dc42wEAp;^-N#u-TMnyP2i0FZ3ln%i0GRPQ4|; zA4kRYj)$KS3y!+Lc#4)bS{dz6<^ltAd?ttFZ4J%4>U3c;9{{%G#)a)h!8O+9qeF>g z(?Z1V5>=M)e#^QWSuUOxkLl_x#R`{K%zF+`Ou`hI7v558!4$l_-#;I(JW!wz|ZxHaMmUN}z1 zf!P4vmvJg#fpz(#WRzKn>)iyG=k1jauQEfb8ups_E1Onc_1eMh+V4A3Cn3~)4Zrtn zS7p27Ydu$P1A*X6=@&kWR41@f!NeVLTv|4jIrLOcxO`F+50wzFD^v;IEYo_gp~GgYl?O?gQcYm0 zkB;*?b^||60tTX`u~|7lvLxGEoW8U&7%}_|t4lqP&h}%LS*Je3x2>&RXD{AdWB?5e!m5*BSVjW4jc%g$2;0v9dbS#@;cjh6L~M$oX#UeQ<{$r_iU zQYtZDBZ*u8HuF!gOn}lK@F!H2mZi`~S|{@x`L1xtOlv`E>QwegZwXV}eNvxYQD4E; zRoW!~aaQRfcrObeQ}Xmtn@n~M^tjxwt6rU@e+NbXLoCP_Vz!Dtj>7GDGKHeN@{aR5 z@G*YCk@*Z0g3W%0tsh_Mz?+upC`R1W9{!*)%{hwsh73JPGZB5J5#t@5!2c=wrqsyq$%(Z77en5OWA0aeGn`VH|csSaXT2CQm!3zE3{E z{7G_7<^7?H_nkk9ycT~ZdgG#SC3*9h(V={-=MG&w;8joRWJjhZonrP{=WAe*ok0rQ z0{a*&dB>Jd{Gh?Dx8|^=WDv*5O@l*v_N+TV(hVf*NNX(TRpv0?gYAE{V$cpDzlX$; z;zUr|ptfHBT7eq;8}=4xH~k;kUa*IMRZVi4XrT%WoaFd~Vlalq?(DH$X6pxqrwmY- z`Bv!JL~T<}`7r>CRt%;dzZtS7z2-$9D9^&Wb2K+1W3z)H{Aw##d;f zHIzfGV^8@^)?`6gsKZ<~qH(y0N28`q%jZqAV~&`9;C@D*CY(Guw9Xl>jmq=!o6uKCY6CtU&zJ*>Xa1T$u3)jTCO_9$D0x(ASk9c&CuHUCrW3 zw@64}nCXM3S0AbeCA>PUP>>8~bl(ju*}OCZOHo^=cV{$Jey}^xDSYzdq_0A-U@nLJ zHCJp~7uQUoS1*saJYR9Eg_#ZeNZinM(y(#0*n*(cTzHwO{3r-1DhJ7`E6oJc*B=7i zE{h3&hOK)CCUOm=ZN5C+>-3I?>pkyoQyg9ROUrAWFWOS;JEXs)z$IOqeqp7yQ0VG8V*bN+(Id}%g)?X0gIULTtQ3J;0yzzK z9;%>J48H?Y|Ej2q$*F>p;35!mGn$1)cE2>o%&vb0Fjr^HJdGCd@Bi%xr^BBffvnVG z&lJ8Qn43B&pJ8n$`#~Em67a^pxc9B9D~=oe0*74ZB3up$P9PA7Yf*)FL;R|Fi=i8&rdJIMX*H1 z^40c{7nhaWVkLI$VTUJ;#4v^CyUlbTYs$BZ#W$6c;kNfVkB&aH(8-Z@H7rc2W({pR=!o`NHNV`FaK4r_z=PmUb|CNg7HUnugZK0x^wa zH;gVG_aElJeEN;?o2lj3*Pmk7^+NY@x;CYUzsGEw%+gFAZK@5nWU;_HwRU(Xh1O&y zNj(=&3@XhsV(Z*-txel7t&4O5=A2W^WvF@{XZFBzm&OYN^A4wbk!8ZNR&P2sM?5Us zee!K2UGe+I@C}JQi100JT7u?_7MBCeJ2--rxo*FFDZW9DS#gX(!50R0y*t;7yXQA& zfs}D0nYD;0+8}8axUQo{-aG%H+){kRv!f-XWTWT43$Ak*Z{#drJD8Ae-6Wx~n=;)$ zrHOqn3!5wxb}J?=%HiY_>kappXEg~49eH|jAla)pI8A=@1+a%LmoJ(%X1?wAIalLi zMnCIOcq#8xOTZeXvQtv?qC~15M`xXh0>Jm^K(2g-f#TAUfWikc^gnCLH1hrae$aOc z#hh))EJO7~D6aFl73;@Wr&wW%{O33>SM}3f`^e}jT^72XeQG|#uZf+vC?^Q-T+|O|=3pwB+y}|5kk?FSaX8gg20k}Oyck_X+d!MD-uO!R7l+C@Y6dA?-jIr3HH6Eui&i+(*f`5PFd3L%25ZnN@5@6F zGaL|wfml?bSb3NHGGSK=|A#P92F|zgkrS;$?pnceZ+Yb8r+NAarGPD>aUOsSf7>6a zH?Q;3eH1uI*E{TfU9b-bA{(h6_LlbyUDZynjiBmrt6$BK{DIiC>T*D9nTPJEIk1m* zOfQSOw*C^ZkJfp%sx|_^xu4wLcYMqruI7UE&g9vc(7FVAOxb2{J=VN+YjU-Kfia%i zZ2G47s@_~m#81`#p6a@Xb`_H`7tUbLxz&=h2M{d#CT?*|;@fxa>yZcUU-+s#EhuZkT36Ysk3J9P11%U!1m zD%J=3E)2>zrqkG5lY7)ZpwO~<7AXV+Am=KEJc$O**1xOSoj{SHzWWl$P#4{$XA%&r zCs5`QmB+L>tueKF!Rj%(k(~7uV+c53R`)TIFr4_Xn?uV)F3)mrfL(maO0oaEwcD14&yM zTm@XZZzB4KOHqqlfQ-~|03ghN{tgXJWAdUzva5w+n*=-gK|F#@!8mt-{sb z2xKTG7%fgp!3+rZcCRevlT{7NCvx35@f#;@L`{0gn&o`dm_a`(Q5NY#J%tZ>-uLnf zr`OI=0Vlvf0s{dM7Ah5S5CNb+aWx((NLJ4p#OxV&B!&1ihqP>&aUwkMFj>w(e1q{6 z4d{HC3rZ?U`1Ix5!kA>vNWKdOVP{U=#|4us(Q&D&Q5WzJ_Vcw@KPDE=obc@+UTm;_ zetLtW^8p!;_FKqNq%+YuhoJMm=BkL9wIxY{oazN=fy;Hp5F=4huq;&&sV_l#I2&Jq z8Z=qI$BFsK;k9pyDwfd63c~t+K>Kl?j>GCOF1ozh8eUttnZG(-b>o8s0h=q~@qt)8 zS~dYGRxNx~Ig4G|%uRYRy8uqS7JJ?%*s@;9I<@_ZfN#hGAN-s`Tq(k2i*c)*zAc=kuQPB-?+jvxnjJDr zfS}wz?zT=&zoU%E>vR}V!PNR-y@d#`vU;YIdiSyfiB`H|XM~-hwUmZ;sq`Qr%WSEQ zME5Ko+Y9+y6Op|s>vY|#307Q6Z#ioDLoc{S$8Q`e5Gwo(W4PyUX_rJ7hu|vjNxqb# z;;m(Iq785cKMovb0!*iGvcYMqd4?ljbHJcgc$hseD@D!~1h7Yu1HykfRe#C2Y(YUK z4TQ<8AT#TOFEfn57rHb^Yrf@g!bVW}Q?T{cU_viBOOsFdBgnZZ(uUT_zeTScn*R)= z?ODLRE=ZHSKw0P{2-sH>4eS4Kzzz?<3}_XuD1Bi@b`0b@g4jKEsS!b9&!lPjVwJ6_ zNs+cq{PJvR$BN-N5~Z*;4gsLgo|henrA>l9YyD-a9PvAgY8x zUs8p?v|-ORQ{gm2L@-p}_=HwrGde`yMfg+>`L0#cAEd-paeiyTY6qAY2G{S35R_{j ziUYXhq$ary4D^j8_%f3vYOxqZX(^dBpRxbFNDG=1suPH7Qwe52AbpF<)@eVql zXG)Ep@9bXOdtMl|sm4+KQ&(pNFTXLpa0aadI!xq=n8Fr38rkeAbp#gu3(r+3n0(Y3>YSf6w&>z*l#Ei_!y~0!j z3Y?ALDEs=ELVpaR1ZDv5i6M3`O@28C|N4f)oK!dkH1@~NxEsyzmP~{FscEk`uGP0^6RKj!o@!a=uh6{*S9ndU zK|aM726j(_i{z6)m+9M!`n{L7>QQs$A;nhgh5B1qQ~p}Kf}Y(hg)&9ifz9CyV6LpJ z1Mqg>IA$LHgiMO3`uHdLaRd1+*vghM(;Z(z>c77ElQsGE7q54TP~O@g0lYODj{a(O z7xbC5>!sbZ`z+xL6`x^ZM>w|g!m$kA>72ykC5z86ktV`3@TgQ1qha(j4A*M|9v_g- z6R6Pz=teVC;r!`AU~y|)X2X>z za#K>esoZ4U9HL%FF6H1PgL>Pg^E&c6CY>pVN*m?0UZ(p?B|W5{RoRb#Bjc~?Im>Eh zEXQF@yGhe~B|gJi(!_(`0lEd=&Rqq#Ii(pb?=HVy{N2=-mDdl} zAHFSLbL&mOh0n0pnER}?&SpKs0SZEv%Aic?r<0-IQOGH5`&MG-;f}4(_u{WAc%?eC ze;_aL!q>qqG{FiBOd&C)SK$wz^er38SUm3v$>qnvTUI8nheylfV{x{cPm1A819^s$ z4_}{dyZ^xKLDhK4MNk>oqJdFa)L28V)w8|R!J|HYmY|*dbU9ij(=LZ3p(7Yf>XR7t zU7FlnydrwnVPrmPeXXqUsXQh2qRKJlG36;`{*~5JIVPrf&2jchHx< zC8NLJnpVy7?-xj=oLbDdmE*H1tf_TZ4E(M7AvsFP@Y_r}+lRsv>OmoF1ux8nV0-oVG2Ldw6;xsQ>I&>+3w3LFKhi{Ccc?>F3kyi=<; zh=^J{emY{*U4V15GK_ED|50waOx`MWRU5=S4f-$K^6&p!X%fKRVewRnbIpej*I?Wa zl}FaW6!NEgb%^&Cp@ps`kPXDAp~+hiOi=wfEcr7m0gt7f!RR-n64|WimlkmqAzgR1 zLuqLjlU#3`)3%KKf*6%J2F)^ny(>)|0JSNr$RI5|s>phM^*q{h+59C$mr7C-Q2G%g zh(eVI!^xOoaD?%p(pbYMN%Ze_B{-9;5iyOJ`MTC2QZFlC$ zL!Tx^F5Z?8O!c@QJGL0Tqruv-dW&aWq9XTbbnKNgMR2>PaM#&D|Gl8NvCUnk#d>}0 z)uy8;bH6RI3Rezay18+D$@4BqH`~!=H?Rn%PUV*AOTKDCgMIguccri!Y-F3}p7~I8 zqlh-kJ&eUeM>2|arp2I|s0v}O&%2lAdn&53iU!=t<(j6OA=MrCIQ;WRN9uWAGsa7h zg0#N6q1~Z81Y(alv&G*wfcD#cb2fxP-czSF(6oI8rz~AQIG8*Rt zmnlLCOE%EFZdzgfxH5&QSKva?iqba#T2aeP1^>Z>vgn>Y1cbpB%DR_1+g88IG$jpi ziGBv%fH)g3dUgeS<5(c~cH-X&AeCnTmOl0jC>>qEL;*^pb|l#w@LQq?ka_3Dj#8kr zH2(WX3cpt*`)wY3sJamX!_3-V6gaYB@S(M($>7u_umD*z#Y#oIOMOrtc$}ccMxZ z*7&J3<;!KL@;u6Ws9|^TMNs<52yM0p`bVd@s;&=9>|i{z9_F?_Eyp(ADE?My)3F`* zR!*ItBw5ue&=!Ocnepx&=9{f3NBko`TKsvvGMK_02fp+ehB}yYEt1^=TbTrcE)y$P zfzHvDC{)DIOSV(@_skBwUk+B?mR`Lz^xRpjeah9dOeIJTBwk6aDW=|}L{BOj=hJsO?9D8zwh+c8(lXO?Poh*RJ2t&j0AtXBe1dH$J>93Vamgy!)mc z`V3B^WSYjnLeD1v#D|ViOy>$So2GU*U5u`B+lU_0w=&x+xwP4iCHlD1#Rl>9-B~wx zSdMiaz+KJl7yb;Rb(iyPT;FindqH0FOn~y_lS&a`qo&4Hshpvf(!LH*HqO>#p z@AJ_9`8K$-jwJOZwMG$lO_8IWYH!2Jta8c4rqyyw2Wz-{&wYHy4V*U=bj-Tw-C*+Q zcjP@dFMRmf8Kppx9jXLb=jnZpn`w5nBlS;B_}?24saM;2lXurF!^;P0_Epg+BZ1!V zR(Dq10rx|1TJ}3ev0Z3PlQNV-$4j$m+g_4mHYjS!z7i%#I4{Z)8hK*Z5yNnE4VxeI zu#_gFWk4?Lm_Q=i&ZNQLu_=h?qqliPR<*xt6%frkdSGaS6U_vZ3T!b?IEaAshlbx{ z$>c1!6%I|r>wSiGw^FDA0}9xFuY5k=<^rz^91U+Uno1ELd6AkwER1>E2qa-h^= zs=FXTUNuz-5HJr$tkqP}%74elO1o4Z%j1GZdw#o9L+(HMmi;(s@=5SF=n?ae92Yb| z-WKG6CD}uywb;Xp7*3#^;*BPeEE5t9@l6wKhhJrtgVcGp1(Rnr##C2b$ef;K$fS|K zkubf-y&xR$GdfSxHJ7yn_Nrny0vX%ZzRD6Fq&L-mj?*FTcIsgG z6xXg76S3)viu1>Vz^r<(w5ee4^PJQ3;$ti%a;2ep#%V|HF9i)0y;7zgePHn*&>lyj+xwLJ`3BK{X=`WgQ-G?Wm zclCaf!)(=jK7LYpNj(xfeYyu@$~S;-Ds4^2ZEq+iJbcAX+L)%)ul}fT4}MCLJ9wyg z?eT;zc%EqY__B7+#N5Fwwg5RTQn>B5*P*f{56&}n^WJeBib(WFo+Typ`%m@}T(|9( zAFh@a$stW)96a5%SzaWYTyWK}(8}p*ZDtC!I8rhP+y)fHFf(-lshKHu zkEKR;WQCJLhSsn%60agJmMpNYAdycCMP~9bOobxOdYe_R_#zcruhBjfxNfg3^63`g z4ed?gDcMwwvasxFS&PNzJL*sHSL|q z(10%Ao(U<4qg!!!Kg#kjAIa)Q3g+9yM`u3>(NP-I^zyU2w3<%|4BVviaKSSz^$*LM z_mAXmK>aKc~58$U=WdN)( zs3YI8l9`<@!(2ff+>Wus45qkh!|&`Tn#xIE-15FTt@uGz*NdC->i@|F`9MrCW=o($ zll?(t^H3Je0ll}~OJD%8lN@%G*EgdV7ok$Ni#M|X|!)C7wGoC5jelu2=pUo$^Nr9N)Zk)CL^-C0Y`RV z&$m}?@5z~m3o2+`ywDM{SXOIt_02Y#aR;59RJZcao`t4QV|kM~VqqSd=Nuoelsf@# zKWW76rHvv%xqMnbt>Yh$$yoy^F@)4aObPvVnI)h}2~xOcVp;hw0Ryoz|8VV+Zm;Qk&0t3&Z$ipu^x zr9Q$DEe$pb5W0|>0aVeNXAZRp65^J1*7X}gM{PHE3ymEt2~}TPNm>BTP@>G{3P59^ zs&guE**Y2o{GGpu=zpEK7In{xmVbsBnu4qUkn{EUyB#%La{xdY3f7-YHKtMjTJV;^ zbrsDsef41ePc9q(y2S|yMc;!4q+%c}sHBaset*fpds34d#$`5g`lqjPs+$f83>p7l zxFofvKq`uP1a@WYMhw%6HLBl$&fneYS!GPOURzhuL__e~8l+LNRFsOZu%@M?z}XH# z4^~oEz!U2x=?|Ar?-}zvtGjMtpPg^aZ0a1|Q{hGWyspFHeKu2$Jm=aUm-(1QM@Q8{ zD&ttAnb|$(n{Bj|j_lUJy>u}__#Jn2Q*kI{FzHUH$S`@)c>z6DHat}hGD*EmVQ#$7 z5$-5;3Nj~%$JSv$JB0bCt;L0={yhbeWr_BMmQk$@(&}*}W4FEKNs|})Gd8iSxG%(; zEU)ZOP}h9eFCj#(W?ciD@fin%$eD2BzRM!u>(@P_ z)LdQda_TX$S6~75M1k6$K-=q>^jgy^qvZuK*M0x;f~(8_lop|{U6IKa*hT4%1zdV= zK#xFH+Vs8_TsVs3og_`e_pGDuztpx!OB>jXB3+^+C<=3tS?U1tz8^~_ z3gUGD@P51@O^(eWgF|q*Qez5bIcgb@K+!J2tn>?#0kl61Z%3&#PJ8mGN)0}hs9apWJIbeh zXd|3s2NnSD)a1Qh_9|W2{Y0f)87fW@v6VAr{8y( zyrVinnAQvb1Q`fvoNL@dY4?taOBAs#6gZh2 zek$gmb-`7teL*R>M3=$fEvtx>p=&}zLIx0MiE@@ztdFdk&yKRTb2&7$ZDBON{)eO&K21 zfU}nRS16S$Iufl7X1WUUwrV)S@$(A~8*K+{Ta^AMz_vwJ0ZQz8x~v}vpH!9%zf1RY zmx?i0OSHO`Id1Itnr=!n@X;TXTh#J~am)f9nrAt}_47OYZb-g=K3>OO675l$)_le& z>m$)}o`b-D>wChm4`x6boXJ>cO7{xrv!k#wzB)^RvM7~V&JWy7;o$mto(=fGl1&cd zmuWkDT=tevkwXFN)*Lx2?(9P%K&0*ep6!_#1z54d*hizSCF`<^W;FwY2tQ1v6(FXY z2b7(&F)Jci3n@0ZO?*Zh;HX z4Jgh*Ka$ibRp77V2)MpXGMV*~qWII(SU@XX)4CsA1^EMg80>@Jn&E%*T|D$rtjF|9 zzd=rn+h8so1FTqb#;DPKg9nINIRgmm3yoC40KbizMg9Ffe#b}t;Sq-&#y4~ExZ*E# zDqZF6z&eHbMikdQ=-D)qvPwMo55gR!cx$xJ3{<4bqgkv9WP9uGO^9R)r`*4KIiq=g z6zw1KQS6~s&O_nq%a``+h6&PLy3AT)nI32X+Q^Dk@s8JHo#BRngq5ZNzbAJjNAshV z{SKGfI7hay{rRu7j@%R4@NBJ6Ip>Z;94|%hAsOV5{XD|&wK<*ci#!V{4i_^rWY9Pr zawfak?Dm#m$@)?!OQrC=TNAsAP9?Tgc^S2X_7#RSKiuU0ShsF=ys5sUqxf@A&I12- zn&sQMtxxyz-Bfu+_mnr(SbIkBzKFXP)DQQ`YbgxB{_?r;xs1Xrx3{uM7ZtfR69de7 zIqq2UnbB8`ByKXlRR2_3yQJDFR{u5PI)#LVLwM+*7 z+P5F4Lc4S<*hPgB?-hmug70q;dNGqS5`BoxOe@At0O@ZahyX$p?+ny52aUuhcs8NH z^3x~3`2-}NGwa&TZ^d`wBx+Z(3WkEYsU*LcyDH0lp!DC3R zpwu%|*n;-On7kXt=al9Il3yh|jLuaJfX+~Yc23Xq(gw_7PwPIM?CCD6Oz>x($0WN3 z61}D{;~^lo*`1~(uS8Z{LM~K--DE5BoTuj%yed#L2UN^WIRJR7IH-~Q877m|E4sWr z6L5zHgcdWlg#x?jB7iQT=t~$R(pJBJoVM9F|j2L|(T3tM(~TP4;38oe7+XHWwB0TY zlDh+vx;Ib+yDE?!M;_rrdzOsYMcf1O4eXF=?y6XmnY!CjYF9V3TI@4sOf0#1>e9PX z_ez51R=BUbppbjo_R|M?aJ!6AmWxowuB8d&CPxEqxoLXU=|V-Y#kmg^Jl1wd9un7o zjv}kng{OFHjao)u0e4?Ob~`tV1G z-FwXV0*0_bmFP{u3|BRdeBsRgp|)Q@OMVNcVvG)G!UH+DG|0hS5Y$bYsV;vbJF1`a zH%g_FG);fe>MNAMcHP{)$b=+u0g5lnDiZM@1xkA(8Gx>66Gq|W2YTaH4{=zcZ%#L; z+|6IEr_)B+dESmv(|ChXRR7ZofQcR|Uj(CpO#S1*zJZ}%-3MrcaDqf{N%+bNVugY7 z@FC=*kxV2HIB7lTIH3r(3JC`;Y=QEiP}RRc%0${H)b!vq{Z&oeUj}Ryzz9qn{#8?U z!+$@P_%cG0VUq=?MP{d2aHXlK_eTC%Wqx}#E3o@(#Ee7=6OZD~Yb12OKCwQM<@Bke zw==fh3J`ZeE3Q6)(^LT_w7OyXH>1cO4(Vn6KF{zMinx$oQvSi9@@;9I4aR3u=K^Br z;bnynS6?2DQV@bFE%D$GF$6o-G1mE(UrF8nqN&Vf$=`utqr0ekb%OpH*2ByROW2M) z^L#|ivhruxmQLF|#GtRMhmemySoYeO&oD>Cyi{5wd_kE!LOxFelB9nTuYVub1>TnA z*_gSZ+M=~&Im6K#QRtCENywy$emD^D>SE?4{$+fHzaAtEE=2i!0$?=Qy)2NRL2*JjJ?YUptCB8PO@A>H^ zpPOo}dwi4boH?x^b-bErG)iVqHvsQ&yZpAJ5s3JGL)ndp_RN<3lVuaoPLhT7ym;MW z)QS9Zxyf2cqAPb$H^-c(zIacXWo?(C`Ze=|kxhe=Fi#dJ5 z6ZFGpbB(u{B`)?Au%6$uDGuhaCKtng07-8tQVu(=ESR$^dE2?A=$F245(Fw)P1R8P5T)7?>^bq+6c4Z=ZX!tKu=8XU+JCA9K*z3KyH$4lTyxKZfF zimH0v1*hS4S%6Fx(CB9h)uKK4v(M8bN=Tv7nVug>Y$pKyyyvOOBD&u;VCBgOd40pP zhRi)}A{lGX7cE4ZaPS8%o1y|-v5jyhg>}xdMXl9RA2z0pJ&5xteZO8yzT}iWv)gPJ zDFy3LBdXS5eYr{UiTUKdi(Kb`rAp!oQ}a^FLN+q!tKEFs9= zK$PYejgY_E0b2fn^X^MG+hrag`caD@F@jSI5WN3QUNVQ?IwG^>)5sHt@k>tz=$I4a z;Fdu8j5?oLyqeqgz9iV>5Q)JrI@jzQ_yc%Ke};z?W#sS6*re~?fEo0K z9Ne_PmIs1@f+Pr=k6PIX>|u$wULwZ3L`kqSYoXr(JHEez?{^-&5&&BOPZ5AEi63bm zYbN~B{Q8a9_`kc;X{x_UPsgIY{WEX$P6CB9^Q5P8_*z(+@2ZxgA0q`I%h60eh1+Hj z<%soaF&!1QQaqwa;Omz(lwa^>Pb=5)9BC&0mq{V%sy3&CVrYF~E0(^qb?o@?k_Xcu6K)9e8VKESvgy=<3w7(^(6@+YF?ie3 zjz>m28i<>^NZ7>lUyLUEdFVoPP7~TdWk3J?}g-tW9 zDZ9NroaAfc*YJnvU(o$zDe;awV$YnC*|wT3uhu37?b)$KMkfaQs6!7N26RNJWXmn@ z60MrXT;r$W>{}1o8-A+UH&G|FH-LFUn?3u;of$rVI_^!dsrQQJeSVhd7;AI(8r=a;BH1dm|7BNe zV^>S*gcV9??AZ3DK>lR*Syj`5DP7~c2P4f`%5G*zaKq^MAG*>UVM5E|!F1gvk0Lo0 z0K2;+s}!&%1A)N6wk!b0jgqB+z(DJ~6|}CSSe}$`&%YEK8h@k+2n@tT!WMp7c7S1% zWl%(~tAH&30IY)K?o||9rE}Q%?M-Pnn^W(LA->ox*V8@ltnQvwiwCIbGKiEnjE__; z2QHemIJ_10IaloBLchhM&^TL}Reh6E?MBvlQEKe`ABZ*bZUls;_Z3Cf zhbunQ>cD%^o4gDQ{rL)fm2%#&4D5U;$9>cFsp#u~coX3WHb$o~M9(~$3EyCxf2&aOqeYFi!c3cWhFFJGGKYz=~4QT(8 z*&c}vEqa2} z9s2(lL_DeUKY-JMdI`X3LA|8A>ZGaqNxFjnfVE2{74UsMh7&t4&@A;H)S+wI`2e}O z>VHE_RUWij7C_{;$=ivtpIG4h0Nj|IUO`b>r~i^qIJ4+ z>Gz4?vScM5B>7Dz$gjU&p`$C{?_Wv*fB)4bFS=R~W6#t4)EfHN4HfK*(5%cDd|{kz z4lt$0nveiTmB83ea#?d-$ejz}@%;VwK&2G`w$D|{ybpbhzTmXr?*p7x5lEbpGr%r$ z{}()spSw}Rj)BQUs9RYCwUhecR9Ar161XN^@ogW=uWg|qwPW~dfYe5g{L!8KjS#cv zuTE?juwSHh^Pwnrt@8V!*!2G40!Q1A>bMT5`kNsvR~;$r87UkZU;rs35&-aOLQsc^ zE>W%bykeO7)gszlS*sezOoK};9v@x(uIv*h!@0o$@rO=)8vFB#hi0UZP5Ic z(8YMCYqmtc-f>aV-e?_P_^eFID4Zm)PO}*rfI9+)9bnlhO4%okA!eCCeIQK(thLA$ zuepA~_sH3O;4ogO&^ci+eFtcBHG%Uf2xzqKitO6_Xm*NE_Y8B}HTEmqUS1!0A1!E| z*l<%Q0LD$D{wy>vV5${?Q1ab@H^D8P$1YUs7M5O{O*1yw5(0C&*+x$@?uG5j!_Wmg z5i~95ax4AQeESnv*23d$$@Q*uiwSqqMEj05SxEDm7gf%-P1wy3h=J$JM=#pt=C&&A zgUmlw=3Pk-jR+K|N>&P)a*ZTDvvjbVp9#E>=W8|WhH7y5B%*y*U)(pdh2va!)kpLY z8p6}r5DbXno~)q?I-M?aj7;md99~49Bqi~L_dZuipjp^Jv&6*@l*ppUSmUKaR$HwC zc<|k~fgM@x3cRCZ1#r!K@PrUC^5p`NcO5R_&k^$>?<(V5UD)sdupKq8!Me*n?*i?A zHYZ~4K0|NBz*+2_%3#tn5Qz&^&GC%QraHcoNaxKkKY^DwcS4#EeC zdWR}WYRUxgkd^!W$c0mg$7w`X0q!dQa3Xjod=)jqRW7{35uZek<$UZObFaZ2D_{A9 zDJ?iv4qeN$No}^Yxs$e2m^0D3JStaEFhP&R8E`yIDr`4n0PSAx+e|xri7~hig`|Gu zcFT<+JXr%18fTPHgQ5z{73e(jdUv%3cVusYY_VNN8v1$r$&fhO+d+Lhdh>WbZXHpV zuS31pEx zxiFCIMUg@He_7z5@xK)D)c?ocn}hkbO5wk)%oX zbz1C`Eo3)EmaHK})`S!$+1JUIrR>XCvJ(pxP1bFL>dkMv(a&U z$S2AWS_Q7a3>kl9FR8BrS{68UlB##=SutX;5-EksgTWfi*1ik2fS^lQztPrsdh< z8UM(k>hU+qKG>9fx3k(He0~2L@XWWmDsSvH05Rm7@z#cNLU~r2HQa`P(g{|R_ZL!XRzr+e>&+uyV|XC z2%=?5S;wP2cABVl#;7w+dY-_A|3I&yKgXS?d{i)QH7^q3ZjMRiDMmID&j&vE8@}t_= zdd1ELeRx@PW6X1QDtvAT_PL&LxDG-u!-pkQ&sH2O-Ob^3vk(?*agqF);1)+Y1BW2U zy!aGcvvhvjY`-!#zuRfKGTGZ=bT2Xv%6 zj0j6lUPkvZTdT|sXTkSZ^y2|RucQPmzVJnSB(w^)8^5!P_C1p)CgVMmfXkVkuqVbZ z_xV>;M|_N3ZX9v=3#ToDJ3mAGu2T*v#-)%QPfSH|j)v}+KRzd*XRuNP-bb5f?#a#n z@FtD^>H8-suyBBfpl9ZgB$^V!0*rd5uWx$wm;=vK(kbMcdE($^A?!a38yhVHzaMM1 z^^=v{S_q4wpsKUc{xM_(iu)rXr8Zx9-rODoa@*@&-J4wH{%Xzs@${!~-ptUo96kTZ(n77FvkRR-keW-#kOzsi;QU-Jr zPEuw4?mmtS$Nuz13ofFP<@upE4Q3W;h_EK96sBh`v32jYZN)M?g7XeU0Tjpd| zyuh=qyRze(J>~+&>#k>v@>~kmbZAlHX+jL=(ASJiSSS&47d&7OD=6{nRVz7wu7t+L z2Z<2^E{`(Laa@*^ak+MATtbH5b!i6kMPLg_ajmX4ar;t>e~BUaRk~0wqRsG122V@0 zP{s%+71WzDZ0Ptaatpf*j<5UGOQJZ=IL$4dnW<)ht^KGk?9|0xwEs56!C|u(@8sdvM&V8gD~ohsfwKN{+YoJn3PByOUXZXs;~ z4`~m(wVbf6rS1Q*SARq3{kg{VKlbY1<3azAy$Yh&Z`E{vLyY;y?A4#-hp3rXplrxi z#SSvD5Rai74zxKWiV2Wwa=U#G!QeN*ENQo0+JA*R)y?o{CDxs&UK!Bn3|&F5PQO^^ z2C4yM44EYQO(AUs#sV7ZNhlJ*Qmz501_0>?Fzl)uN_IGSUAm#DP`+(X&3n#JUuO%L7sAKdEC+928%P~~K1nSN%!{dxS*w#g=vy@TJNfzSt z=Y}to)u-NH8)B;UaM^T@QfNTS33lpdP3fL>@pZF`M_ zcl7kF`PUy+mkmnkz4c07iY*MzFDNx5^G{UdA7N+nxan*roqdk!Q3f;pq~!8w9@@Ou zu^Y?o>nL5spBy{O#MJ4#=a^?YdapJV;s2z^+1h4A*ZLq~B9d+iwHr5I%Y|$4$y&0w zacK62d;586t$VdOPRP7nlCz6@98iy1aM)D;ckL0~mWS3jma^#@PmP{&(K=3pOFb36 zPvm(gw_*2$XRZR-WHM*ONR3tMz}Y&1vn)qPFL&Gf0RQq=ij|HD*zPH( zij$?dk1Cv(7P2`9p9huq#i5e-oWi_ahtiuP4@aEyI&0z?=x8NX^buWp?ZfkZDR;KX z-NT&4ddI8Rot9(@Ok8(m?5W^GBl5ywk;c~0VLdxF0;XZ)jGcz9b1Kg5iuM$sAfsXk zVz**6oa5dLxK{Vr3veYxk0IaolTx{HvY}pCSOZTzZsXY5Q+iV4i zk4C|F$*~8%7(qA>Wa*v%n0kP}@rwN;A*I(Vldnjb&pE1HWpqcft8Za$TSK$lpS<}D z-AjZc+-@p->{YTkGq_8A<*DOcb+3Ii^-4yO6;838RiDU%P?`x?+YI^k_{bGV@)O&> zDrL6?`^U(dJoIS9g4uN4@2+?o-N8)4WtLFQ@Ey^i=Sqt*10hW2+I`V0QUf*7=@Jhb z6K>RTS~Y(eONi%PrZ%F%xP9@YI&wnUUg90GbnS?l&#b$#S$6?kFLY_m{z<3-&CKSQ zq2z-EYO^c*qt|F=X4IdIZ+Y*>`Mmn(X#VT!(&k3>;oq;Dog475geBlfZRCmX_Sb_T zW*CAS-+Geb$7fn^0^h?}Na5~7eIx+JMX65fc<#@`#b5{StE^uw-6wp& zHFoJ$3w|+?TwJ)2DX92qT@FKle@G^v9*& zsSwM@J;Ijb!U>((9QE|##3jAu*sW5)EPBoubD-v*52QvT)GIcJbW4v~0*&LJ2CyT3 zo5m>u+l2}Wo-2Gq=!6uA*&lFxi_+hsD=bg_sw&3H zF>!NHtw~%s*ovH)0+fCr|4x1E?~cQ>gpPL==VQ?M>*-`&@zlGKjCF*E&;e-UJsEUm zL4uGRbB^GUo=;!TPA)Po`{|FH-X-pZ9@l;i=%9dRozuDc zFlLWaZN*#SK9zjl!e$07^1lc@AKdh6r$kohODBPYww{{323=t9g3kJp?`xY=eeL)^ zTL_@Y-iE0v34$J67GTW_V`Xh(sk?H1N7;yIep6?2jaC2a%tp#RSI*WSIWPpCrmCX? zH26XlTnh0V-jk1iE}vH7wxg^R!7ShRleAgb>|}`ZNPJ2x$LWBx!BOY0?=}puP11L;yI{7PwM&exPgh9*iQ;<)5*B09XQzL-Q z(-x%y0ImNrml!t_d^bJ^YRqBpL6($#2FU^7?t|)V?VLgvD>Yps+V5TTFoRp(?dI6$ zq})pluWNt9zZ{coyw3mRc?3~s61zl*O{h~iFXw?yB(ug14kE7Ham z4s<5vSW1ogJzmpV8|%Qbk@K~+athycEp5MVdwD)e7S>?cO0j1jMfTx$hgb9rSAt(y z3a94hm%Wh*zgVfJ!Oa=*Hl4-Jubg}Fc&baM<$gT}V2?ESs#yk!gH^%zWyv&Ij(sSz zy{r{>Bh5ZU2*WxH;L02+2q5bkU>{xFlRQ!~bbYv{KA+1YUa$4Xasq z=ib#?*`bp;=l<|>1}X$D@TMF|T5}~-T*87U+E!ZXT^U+#fG*K%|H`4Ddb9=D9hU*W zaI_yqSO;~8bKlk6fy%lh?Pm7t_)ogE-x(R8(E~>Gm_ja-APXCQWLp(yxBX&owg>h{ z6jM#)AH8#RXZ(*@zW;}1)JB`k185hx?JM*}@cchT2(wZ)(1yqr$SG)D2FQr5?%NM4y?B43TJA*<7(TO%Y>eWUu=VlDiT{RPEj!@L4x+(UU<~9+yeObpOGYQ} zMHq3b6;QOWZ`iUY>iu)4@mBP>F^+D11+;;l?%t|@A;n&vOh=Euzahiw<_=_7NkiTD zWDothjM86GR3R6^bGZ->o+FqIlXk#-&5H+oNM6hnpeW0jzo97mX#JlX*l+C;{|)7m z1P*bg3ml^7Or`fEmlWhbRxsM*23=!wq$4Z`Zm1?WTA-%j*qEWhmEeNCnFebOA?SZT z7GZv5rNB@E4n|_u`%1Bep535;GKXWIeiR44KlC~YNXdSs&|Ti(0=@x-Ezkjf!U`z( zMG(Q3XESH`uM~>`E8!7olH{h3oU-pOw=Z!d1+SF5-?qM0wA%_$`yJg)Cg0hrI{cc& zeb458ef=lh(hJdbsb+Y;;qz)yrcWrO!yDiSop0XF$MHwpDZg2Izpl$z9d1XoV8*9n zztFACfe9%jg+Km@!^-hA$I+TEWmr%30Ah8G;+WKdS!63|weozr3bXR^1W_Y22Wn#E zt}xm$OGozGU?(0V*XLLl<5VkciV6^+hVO0xN1C^YSxbU7(7Fh76ek0oVj}fuC^2M) zzY$6}e{L%?u~=gnvPM0C#BhU^rs|=*2DX{t$f?Ug_io@jF>5ogj+bWYnn|auUjkEA$wienbq>r z(s^sYQV5g*MZk<^cm0AxaO(V0!O)L7neGm2mqK_$NwFu+Vyh&S*ZMw9+*{d!d{?eA zYW1kolN}%Mo;6m5c_nnkivj(D5jR7 zR)WYQvT%~?+_T!L?;$(23Y%Tsu?HoHP`yt`GrT`GI+&|v55W?kjzSlaSq_N|JE0bd zdm$UU);8a=Q#8Gbhu}*%C^1$Eb8`S6amEjJ6GK+sHZ*HvS0J!YZ<2JQ2)RB6)=BS6 zx;~{UF&6FzO9Og+^lVK{17%M8t{BEjt1|=8da}s~L5Jt3a)=K{(=gZQ}zWJ-fS~90ev_ zc-dn{29ytsSVzBtg{LTbna&}{J_-^Mi;7k-xW_q|KUQVM&HIJ5l2L}_hZN>93Onyk zyBW9(-V(=6if0&kvO|@@fjf!az&dVny2*VWiy2wNtyVUfcAwE2#awgA>N!FuYSnrJ z-78tS)Qpfjo~q8c?CLOB7NXK!-|v-k@DA7nb!V3kStiZ*79w?N4F9FXZ{GTxedW{fmuN-pu>QzMrs?K1{>tC?PyyACjb4Awdc+ z>v4n~)XRCh+o}ke7kreQmiv{$625<+AOHS*ZpR$^podg%ns#u2P+n7z*O{|b+QqPk zsYh~dy;Rcg8(LXLbBxc}o{)-JdJ`+0r(uI!w}Y0(km?3#Dm#UiIF3hEn2_Q5TU`Sy+{hadx$hH}B)C zjrTaHcKIGC>Tk>cBhxDD!QXYcYNb~FrCLFK3_=W5NA}1fm#jcMPy*2ZuPy4|cG`@a zjL;ET*(!DLg;PTtn0;HIJe%G(h)VQdCg{I3(hGj2fb;s#Naumv#!Ta`^Jn^WKRGcU5Coyz2CvL@dJJMuHwDr_NXpUvhj^kSPG$a>U` z4bwQ&lGty3?7Ul4MY^}_K|rGPFEaY~-O|P@RlpAz#}X7=AFX73rML|m{!8T{UpxQr zf;vBL#UB|!QPhKsocI`i7>U^n(ntB&a)vT>g*H(?-0H(oLrV?ktfgU4Sg*hM#7pV+ zKHjp8GXDlE@0+R9zqm2K?RANJ{NH_#C^m!`SjiXg zpot)Jh`_u-jQDQd1gq5BS3uZf;@JDxB^)=AzS>0LI1w2IdpEJ&h?Df~hW9<%RzI?uHqaX*0sPB3Hf0dfOD~CYCclYpQ)53}ILA zx1P+0TSj$rAlMe}s$p#A`z_f>mAyM7UgH2-cMa!}+on0>%Jt)u20JLc(ToOh|W^!slv~}W_v^DKhh1PLr%ERo~{{7 z6q8)2UONEVA`_?3RMRIChlL(bPw@6>&S= z=q8bLQsL4H6V>bN!D6~UaO0{-^VxOH%SY6>gkzCt#GmCFl4T*OJU5v6joSC?B-%8e zvaZ;pfHM}z`@UOaX>6F0e)n&iEYnRwAV%J=Sh<>R6RwMjTR2J8CNjT9-T&QTKKn2X zOD?+2=a2aWDTRR^^#Mj;ru==n{z5nW_4hy-{rgY+Me$0-3Nx?<3H zc4|Xe`L|A2B`uro!C%FV4I%AK9}nI1_q{FX*@B7(HAs3Nw9g_7P3v=mCf~#6Hzm*@9~af!j53 z1A#ux0WA^u8kG)c8ix}ovbk$qcNUfGe0BvGqiv6Kv1n1@IvYkG}iCCn#9vI38 zEjdZmY7Us7zdJM=egvBaanqk0y60FwLOGmVB@WqQ>{krkJtfu8^{UV?#_ojy-Lgf) z80hjX`O`@MXCwAo{puf~Q~Lbt@*IpKrJLt!?QGWF=&Pl+i<3;tN&5tNdvv2?*AL;e+v?NOiGJI|Z?93NtG)`!X<8kE$% zLlwe>cxO+YquQ1!M0d-`)GA8puu`wMnL~#mOV4VNy=ut0ax#LES?;Bx%b3hqy9`Itlq!2x>6HSBPHTz}JRPBy zi;4bUDRNQ1Wj90!@4UHJ`UF0e|pn_YiMS9s+Bxaq9O$)a+ z4iw69bm4f5JauC+<1N#^ebmy3D&1!89dn=dmmw3R-7yL?X4Oo4y4EkpJ_{NmmOKTU zlS1QFa&~tem=Axb^`Mz0MBr0b&Fci7FPcsV?fLh4>K;5ZIb}KP;cgsT0u_rpq8~W0 zb*O(<#%z$8sp{MJpku~uLrHQ?{sYh3w=1COc9Q^M4v`mF&jSAP19ZLiKh{SK_08M6 zJAbdO7?78^D^F^)TL`sW(%S+={<+M|-!b(D8bRhtdYJniS^fv=LO-TMxZ43Zx>L+CCOUX zg$)Pyf+FDHzVKTmlAmzMe<<@9*0CYES*@j>u76!&Cw2Ifsc))A9q4hpE|7c%qZdT} zQ`u9k*qtCc?T#mWr7+XKu2)+Z{$yt9r}^{+RBIhOp)=hAz=rLgDaPEnIwf_O$hAsK zinl8Ka?)-*4cZN)a(OqQkn${lvSfZo3-;&Gcw{dA#;sc@I4!IBY<0ih+(MhWgbOZ? z(vWeh3UZd~w~gk8Xd0qKF_`ngsTZ6Pt(2XZ7P(!m%b|gCEP42_QTTq#?%ksY5BF(C z#XFsNpd}_u(Oufjm*x_C{Ef=wy!;opJCx950r%T9w~QR3A}AY85Yv1`Av%f8B|i>`bokQ!fi>Y9wWh zo16GGFFvZOM<3I^_F&9{?-7enC!xm1I6`%rTXR*|DD? z{C6L2&!9Pd?Uc(ye$odD6(csRc7XfY>|XrLXl5&G|4s{$CYH9KD5+Z|kA=!IUt+}F zn4_Cc`3HQ+NqR5cgR==KQ-Jz&;ytj<&8+kL>ljqrT%1bRr7s2CUbB*`Vz-Ff*4D46 z`DAZ1?wHnf8l#;ah-x%_*;(vo53Nc4c0Pxhz1P}RFpJrlO&>95%F7}W6Lk8&tdClV zUxYd{H$I^1L-=CF?`O#{Bb-i0M_3Y?vrZOuI`YO>czmopmwk>R9`Q(H|&T*)$rY_Zg9OF#cchYfr;3D{?6<>ac+a< zmJzrzQL{atN$#Y(Twzh?cu_dQi$U?iPUxUgJyn(pc75a!I>r5bmy9x3cG!v5;*)Ls zCLiATu6a|G3&*jSOQ0T&x@U~O92p&R9vJFSSgN2L-4Q*4%Rjhg)>!P`hmL@|Ob_j< zI5IUvH`a-{6z-z8Ucbn4C!I?`nDyMtvH8UI@~A(|`75HSANbH*iWD^t^==YARaCN2 zqwVy#&@*5t8MB2;YB0GeA=g-@qnq)n8jtt9wGL=yQ5K{QLnX@GO^`b33^wh4R-xZ* zX2n`{vV{Y>KiBMwxjoqT z6RuLfFzT&!qx8BWTrXRh!Ivzyf;)HR(#h?l9RMo*krHnL7GS5!-%q-imG2Gq+)>5wJrz2Go^IW~r zK>l>G7;(8WU+(Hb3(+qg8pKcM(rd}~kVW);Jv>Q_s zI)M~#9;dr_#gJF7r9Gs&!N-O8s);DLvIcO6tSg`Pssgd9VxO#)?2jU){w+r5V~B_n zQrwcN;^au<$6m`fy3Y&kBUqljuJ0>uSf;I9_3Q-f8uRUdwLNP%w6L#bV7enl7?HNyu4RVwmqQ~DE8d0ik*^BlyiR~VAN#ZMSV0EA8 zKhIZrs!{H`2dqGD7}t#2=RKmTqFbaG&MOg=`D*nEw*>QYt-IU$MGH)cz{4^@4}163 z6H({hNWFZpFs>yl+s{?rGPW-=Q->(gJ;&8;QgoJbr&zH3^^f#nM$LLrEIOYZmeg`p z@<)3YnZG1RNvgW`HKvXZXF{8=Nt)VL3miS_l9G@ajRbhdY3Ja2C4<*oN&a}|YBpN4@xfdZ}=!w`3Uow%sR%PPX-=i)kB?E(Y zNfFA=|1Mfe52}>Q;v{cSiB|+gY-RRuvC`%g(z+?}HBU=a;$latW8#m1_)8=EAK^4V zJCtn+#Jpox)cP=5fcoZ2Y?IpqG+$&7-IF!^R^lic$C69)et~~erGEe@e~N+u)TEAx zM4*=mL-kt#BW>=HENXqk!0o+vO=`!e*7GLs{|8tv^+V`54_KKP@L*|0paA@I$U7@L zJ=r_^L(RuH2%p3nBJ=NO15$*5G1B?Oxx@(r)TSq#8f=7QUmpR6*`wnt#p#ardwh6a z@&z}#Rj^xB5beZ7I*ZE#F#xgCB}c-wi~Y9@;lcPORlbzkp(N0~WKLsY=}1H-o!fKB^=iU_J#gesJdd6<5fC}!Sdre%PdfDwK<@my}PK*GGkpsQj`PFht#*G{&d zDe22b^+PicVQ%HF|LVfFS+%|BPOK}&%ILx<-4VRm+k~TC%TNX*T8A@lIfC5V8+&Aq z?}K5I<}UC^Io$lLad)ows%DOGFpHp7`LnQ@{Kd(d&_3E~&;1r6p>#XCm*Zx8nLh~T zjP0A*ND<0bIJLuu@R-r`y1)hz_;!6c0OwYk#^o5GkYZ{L^&5#Qe%#(@9@{_c_kaxH z)c)Xot($^1J5_B#9LXcOdwHUKUAy_cNX;0Wi}Fr&W*;a_Xv4KUGM$lBkpdP^-zTJB zQbQ69q6P!iYnB8ap!ZkseWrI#A1z-h(4yYU_+riYq)T7UlE2#{omIA(8Jm61x6Wdt zp0Y=b%$)p^mGZ?q3OaaKCTEp1cJNubLGFnZ-!e?Ocv(VLA~XS~#S5{#lV?#k^1vZr z&C&BV?>FzvJURj%*yfagPI(1DPs^erR{;99Jj!4#1@*l=xwUFx$Y!4)08nqlRlpKZ zS_`1S01N{H*hB^qp!mjkCu4eOfvy0&rwrD4K=H4p&qw?~8E1UMwHs~^kWx)hlO9F` z&=raK2!@=|ap~6?Td$9-pLYR$yLRAnj^K0P$2rO;NiKC@b^(`=Z)&7q9USosK#t9W z(?(5!^U1C#raK5sD+c|*5kTnW0QkoDX?@cbsMVv<^y&_ELW=E;_2R^XXjarKUSlE~XOxCNWj_T+RH-&Zan{ z&b`Ko`&qYtrBLZQrV@zoUworu%o>J~kT^Q7RqKrTf&=nq!f(+lyYj%;?{GW#rf!dw zNHa$;rKza|#Xe=urET{hKUEYk+5kp__@SqGK>q=j{yIkS{GSr|-h|rTQ)`FmbNC*WC_36?ZT+j)8Xy{?A#c2LQMIA#&yu&@d_3 zLIp&GVU5tq{1GPs3l);r)NM>&{V%XL|EX%}Uoj9ds6?VZ76%Nqlyuv&K7|Y}uxYtu zjXz>Yen+`I=+Q1tV>1kOsvB(cnEakGoBinP*I8FIUh<}=NDC1?4EUUHo-Dn@Kec95 zR|{t&)ArY<&A7r>LLORfQ!19}38R+Jt!-SVZfP%1seh=8DbU1cxJKLpB@2?c95ky2a5p+I`Km$bj-v`p=SoEQyW)F6^KbLO^3Z6lp8Mqu1gc4t;=R4Xs9)VPV01rfw%s*dJ_*Z`nE4cpK_M zYjR)x-4)6fh@m<1;07a4)I$KCXKEeR0HhPvM^&>3+9{85J0Xla=1<7wxTPKKe9~^my|x7@ zw}tOsBrVu4g#PC=TqGT`-UiT#vq4~(0?ep*XC%DUHy?H2ML&gED5TSUhxo#+B=yPO!KS;<*L4x%t5yZPP6>)`*+NPwluQCmj>`bI%n$Sv`4JQ}+!bG*$ zK8t8Kf6y?8|F?k7-mI?n3Xq^Vki zc3^{eYa3NR|NLM%p0tZ8_YJBoTn?UnpmyB&e>J)G%=G1HaA?IK zOF+(gZd;l57&S4n#TUI{6Moxq$(Y5Xy!=D7-K2X=GYlfwkZGxIRl6efLQ5@Zln@P6 zK6@r;(qDkllbazIL&;v26If~Z~!Ufs%4@ddhG32iV9?^K^JnI z4Z5BWs~*A)BZr{v#X~sl-Dz{i^U@JA@FSw?4qs?XXhKzf zoE!PLkY*U5^V01(W-#_Eg_nB+c76Dod4gOyF111Ib?;|!u6MnP`}^ONWE>3|K2$Xv za#+TTx4+`pbjvHev_^XhMfo_xx#F=;VKgCbcMy4*Syr^uZ+yms)0Y_&AUr94$x<4T z5v54s6y_S|OjDsKZhGEEF_i{B^ao;I9Nv3mVkKw9LxXYMF{Uu`;M_UR{!>De6eEwf zDXpKY7$?OhIae_1)|0r3iX92myu?mSjJrMFDgz(#UInwHub8lGqOV9Ba_i1TzA5_B zC5mvw6w=4&NHI409@w!BO3AEbaIFM|1Br-niPMy1fwOS7YTXM4aCLxnOCvNwuj8#kIf2FW~+XIk>Czu9uJ^ZQ9;$Q*oQQKR^EA%kI@ib);E8rUdKxkZSUfYxzf&B|#-_NMP)o zjClGJv}sx{geUrZsa+z$G8}5N`l<Do+oE}3H)fJ`hv3qOekJFO+=9kXN-12cB`I?5)^Rh9%}rue#H8gZfDJewb5wR;LxBNJVh=P zpFDbz$TNZnw$@)}XcgTq>0$rmxw7rab)T}y7KnV{Yp<}k_9^E~cmBlce zJ`k4vFm14R(pG|bwN9S5bM$sK^N|Y|d?Ozpf4W#aY|E}11>139EMxIf<^tVmPNAdywtPCuCHq|^4? z-sh^aTvcg@R`FO--k0V}&^30jNME~xgspF8=3#EJ*G%1kGP?2$UO}O7p;F5O6>mp7 za>3^MiC3`z8?yeD$5JR5H$KPwO`Lgis#o6}+tSiM!jsbf)s$}F=L;QM2apr0n|y9# zV=wqCXa3-?U9^;2t=h(A0I=~x8u?$`{O^JZso(kX8k9|2@mRiL)EeU#=ztAr4@Ui8 z5XM%7Aj|bX6F=zwKMki+O&F@z9(ID~8Rgtav4CL15G6YUO?E!D;RI`<*2YNJCObVY zGvVqJzD$)j+>r%hT_;cQ_E`{OEujZRE_xnwc6WBGEhun8_d1D1C*IDf$d71T*PxXQ z*iDus(y`Vmy+2`HIKYd0>Wo2w zI^TO|q2VO_^uY^SrhKod24}cbk>`9BeXTC_a-P$SOp`x*1hnYV>vctN-A`8KP!)T6 zdqMK_ht_1y1ZB<;@UR>yKixns^RwLAH`PI#-HdL5~^A1ku8*kptAY!6;}Af z)agg=#(1R7-sgg~yah`-oxJX2FQ}+-sORYa%IW;OmroUm%SI=Q`^-l=vfjLuzRj3S zSjVUy_7{QhKu1liR+WeaFj$tooM@}a32gZyfBfxT8sp~@hv1ge&#+|g32vf_RRw@8 zqj;v5k^tiM$1Rx~bnRaaiMNgj#ed~{0QXTWPv*bRw1NAGS?6-Ewv;;_Ft1=O40gz3)q3}Ad@d2EJ=0V@QwuNB9_a3`Mz5@XyzSiGh(YP0VwtE$*53!sGl zq2nC|`|~K&ar{G^z=&V*M7fb-CXM!a=ngrnXwiR8F5T}D5ENu zk8{!wS-@SBj2)W2=yZ7Mjf5-(`Je=wa9>3d)~_m&%2}xL_EM7&PwLU#{QUGLu?nAc z{iu3vWiuQ|z&ra)7x2#R*IhhisfjXft7%e=_h(#@WrkipDZ7iNLF%>$~ zgpd?6X9d#4Md0m}14*GuPcu&MZj7S|0@7oFWT zBek0)IA@zv%X~y=#EhX6sAy7uH+H`xceK^vSmxAGi|)6~L(pOX#H3REgv4ywnYV2{ z>n&tLia*)W|NZZYHai!>Tc1|DdS$2HRt~~)!PMprA)8OyK7nszg%2jZ2bOXiGXCnF zZdZ3!B#!RguA@(tp^tnCTf3b&z`kt1Ik06+lOe6DJ~>iWg@orDF@V%ItNq^zum5&D z3&UE2AP|2C_(?&g?U!Sq<;z?Be~KYHr=$*8U`_zCGC3NmqrN_LX{%Y^zx_=3_fNX; z8=zyoffzw<1pz?E<158an2~6G|2N)hAp-DLLv!8%-l`lZy#2&c{_liR^PeB4e;m5U z-C-zmRi|fGomL?NkK!)J+5QHv%T}rZ#-xR=Tmw=)rfZZnJT26H|F2KG2FC>+dw$FN zMR{+kmc_X=kUHi;+sl4)4y7FGYX#fS9tID-xlmWss4CStINfJXi-Na-E|LrC>*;+! z;}Oz&CG_T%r?-{81lDeXbFWYe3D{ad!8u zu0l((YN_B#<`^lv`JaOleX9dksMbKb{t-)Y3e$trb@~6VgOd6cbWQ~Z6&(D*iINYv zI9j8E%J9v*Um-VlD=_|n*JIMYfZyrMp%cfWhh025-Gy<7y|Q9ON;t@z=2IEMUOpJX^f;49<0V>B|;(Io4T~O7GiE?y~bT zyx?XCWKfS2AssRt;tz-_IPuEv3c*|{!`i4JEP^Z$oPxvvk_OaohczkLbfKX^PZUszQVCZH4+tN)Ni*Kj;q)rBp-Cs#P< zCXB1RDgdbd_m)-}9S0&>@V$M`5t^NIbNZ@z>LxF$ttfMh6-q>OtcD+%9X<1;bZ^Rb z`$VUxlb?cRq{=J*i_VX^#zFNk$fw)=S~**u)V7x+>BE#UiX>W=D^H z14oOiZY<|%SYPVrizw}%Egj9NvbLs3TOPE(RnSezFlV8e#+(OB&p$Eo1hmBCRq7Ti z$0WKqq^Sq5N;%%~G$zWng}eU~>+@ipVNQ2rZ}>To=KRPEw0(Z#>G%&_U&q}KCz-20 zZ`8k1z&$Gr?3~M&P#3ZlDxnLzqRACz0K$6lr&h;*s|SS*^a`1{fQS@v*+|L8FPg)1>Hf-g&(_5=Gb8?T^+=3=26)E~EGjbbmWtsB&DQB%E znL(ybYmo*R?~$T}(<_jwvCle~E)DJIkOtc)iDT-jt23)vu%t%kSAuVWS4Ov;VIkbU zghLk0T%$!4q^_3JW!IT9A3~`2kD5p^>S zoSW8L(MweRvow>Y!4XHb;)L@Eu+;AOOgIu46>J2{tw%l5pl-~36Yq#SUn~n_B}7ey=6!i%7YfLU%_~TRyGZ5|yNv#^YP$ZK^l%X5@2HV9G#9+qGM( zEHB}llxjOuPs-`Of(qaI<4qan*$Z!uKYRP0%uF?5=L5$qXEv#x`!dcOeOVsAov>TW zl&yc^-n_;8S(_tyN{Su_@jp1-T|-^?c0oOKfW!+t)HiE4weGCjoX7(1`j2NV(7ykx z-XLn=7ztd{wNbz|-Jce5e>@#>?C0jSrO>qvzR#)^a!D7+=m~9EXeYNWnJm|nR9kLa zF&p%3YiN+DiIii(_A}RActZU8Hu}Ccc})I?=E&@h7D+%PDP{--7%Mw64ZtDLKc%$1 z>@8#^Lj`jAO|m2ApN+wZecs8)DjQpLZ+E|El^^@1TO_mga)x3iB?RpFPjVDD0QW{FbrpC*i+uY9GJb zk6SkX7J^$GZCT&@5jph)(Z<_M1gRB7?liz#2kOWVeecOK{E{!w#dw}UHW=VR#TgDN z3W{}bRqYRGCg^y!yC}Gx6rjm_3zhcbo3O`A*HEJO9C`*&!lt-@>Y@_)>?(Y@0W_H( zkY|=-*r1QWK$ZDi$*Di@xGm|v(os=bas7sFC(o+zR1~$)itmcdG##~d?YziY=YRQ&M zgQR$w-5JieMyUf7UzD1(AJ>-|^IMF&$mF4Q#xodx7O$-uy>6phe!4>M^MN?~1ORBT zEFCQHk{mC-SoQuSWzU{Nv>kFqsK?1wOSgbt`SZtd<<{sH4bQ+p92KL?!PgAAfe&tG zbwtj8?hc*k!G(sp?ihOLPR}2cUS(Ex8buRi{H8u;=AN9vVqe9sPGHqqJ;GCQ7Qi&I*Sj=Lu-mvyt6fq7G&nlUAB=k@vX2@U~A-QQeh#ene_kKvwtZ_n` z@0w(W^F-EDua_#$zd#jR-BwI|R-N|B=1}*{{@ac{PwffXpN2vl;QO7a>Iw%>v>y}f zSF)c};4!EpKcYRz(GjJ9cc{YZR^LoveZiHO0#BaAI-YWgKXP}Q{e&aSWBs|* zXKMreley#uyQfcE7(d%@Gl7~I_KdCinp%4%IlU-zB>SyjH5V5Dz=e>Z?6(CFzZ^n+ zMKSo@(^1`vYkK56oPiM7T=4aYa6=k9!Vw&&i>ra!W8W$qR7kVJ9C7QlR?SHawB zh*{E{XsYSKJEej-8-5wWQOiFSxX-!_=_Cjw{p> zeHqMkMo(WoB)y!EQ0sCRP1e2m=F<7h#&k|*n$s>++zEDd?d6MOCmzsZ6LsIH`De{~ zo&~VUXt_VN`yc za@)|k;-xSRvWtDd49EzNH?EJEOk^ihBn-~0MfM(X&&oee#oh~VF6$jRBsub;A}?_$ zpS;+g{-i8Tf~eto^HAW+QY()Si$EOf+$RKt-Z}bg(GE`Yn|Bp@lfoX6FH9%qEfxNM z?0sigli9j<5Rp!72uO(v2uc$K=|ogS1dJ7ws-PesO+Z48Ql(2*KmsDtiAb*z=^{-^ zr1vh8Py;0KTg;3Oxc8npbIx`Coa_7HmAn{uS5|wU^*ncpDGZc^GWEVCPpxTpPTW)B zV{wmj?hNtKut)SDs^y&4-1}EL>o&d?0=>fi6-p#qEyO{rhx$X$1GwAW940quy}hVI zCvVATqGxbDDPj~&yX0F0F-%rMtBfd%#1x70&Kqo^D35lnwV~G=nkW#0pa5eEOA%fQo{zO|It}@KSCC z#vYGs(0rqoVwkf01$9z91r)0txZ{{5Ulvh!#N-=NyV?^oviguQC}yJ0y32jZE>xZZ zQ+<=*+J*W`2G61ywI$Xz(kWtPHgd3*LW6H=e3u;IsQG9T($~jm_93SZpFT9YQB8P{ zIiq0ci%juR8<-YtDl}a+#49)v&4O;yjmBz?N5ss81%z$PS;r_fm&WPps>xCy7uHwE z7U6`zM%$s;V&&%Xr)VR6=7(rb1#5vP3S3g^ch#&+ih>2hH0F#R>dI6OD=RwVqKRxH zc}j4+RdI^JnSzp+3I19^YJ!9ptd7}E>@qv|y4W?-W+{vW+FHhvI2 zuoi!*iQF%5QM~-Sl)Lx^I{L3dlL(*h5bi?+(;zVg?sV00jdneeh^M<6t{(YB?Wp~w z_#1u8B^vvY+=G&)44}{YqthxMH;BumOIfrAjcOVgsAGrRxK;LhJ(fUbk2V}AhBBe* zRyMJ0StRdGSs`NzJb!SVY(2JVQuc=f^D}|~$y*Y(z7Ah!-uz5B;@z6Dp~&x zA9P$OWO*;D1Qn@%MGuu4PHrcqQpe)8K^h$Pt&rklk(RByRekGsMDe!5q0V;l#$%}; zAm&ZB^pAf)$?#7q#%t7=Vmhi9O5Ty(O8Xu6f4YwNLw1b1nFRmUDO4W+tJc;}z_!w_ zI^{S1g;w!1#$wS1AgKi*05L2$Ue1D>j z*4Ow#gy7;mlDnP03MKJQ@~y^HbH_bBy{DxW!rq>{EFXr%1kR)G3pMJplLYW1mHEt< z%=j4P58FmPJnetDG0stjx)sMFFHvgM)ajcn7h)hWxMfzFxKC^Rn|&ZiFQOUl8Y(8jK&D4Kt3%eVCHI!o)fFB=A~v;=FKmeXT&=`}+p}oR_zJm$RO4p(YRVbz zYfoItV)-mDbfMy_t!P_#2l}c@cmsyFdo%~E>FKL;D}xy0X~lwTAOL0f5;L1Y;W7-{ zK;V}Ygobfn=ZCS3a^?eq_i$3K^Nk*sqYoyJ2Q|W0VrL+jSEnJjZD>STmt>eqf)DaE zdp*Er4PF~rbe(4n$(YNz_7LT*;qAH?GDXJs|ft3s%MGLa#gUhwhF^jx~ z;2#QmJBp#ZLG)<>qCt!=NOcH-__G?maIh6VI}Czh1|ZG!gJKXzK#>A;e8~@i*|)0R zWB?#b=m4RHP>RuGsimztQ0dNF1kL{x*Wb>j0DFrVMY+&$c@<9a-ok`ZcgtHU*{Zvo&zSY9gStEureS5HQ%|K^*k&<&V-S|M9;PnB{&@Aht3o zHu)QLsEkgajmc|E{&UZSY*Q^?#@yQM&>wiqDQbg%lg-?I(m58^$R(cFE)y(C^Wyms zL>4GTFhJa#20vid_5eAI+46Zxq`JN#iO(GbAPLDBGH?MQWYZ427bfWvMY}lW=SD*X z%9*l`@1mW%y1cWMHlyut@S;~bHkmGt;7D^9jJc*3IO%RFv8T&LVG{~o+V27-B|qX~ z_(saWV+G?C0}#wuijh$*pb@Hu=e0kfa+NZYyItj?I2-G4htDl?zq^obb5+Ua0wxi) ztzHcJpnm%;wpH;TO94I>g5&`P1oL83=4&%qwWU4Ts+Gtl#d+Rc+ zfH=1Ur^iVN`_X~_x#sjg1G`I7iKZn@%k^j6oIQ1UP`B{rWYAs1h9|6J&MjXFc#L-A z(Wo9*U*g<7d%#b+`3^OjzezoQwndD7TU`^!BP5z=NbKk+zpALP%OUUsIM!-KC~_J; zA(#fEOv|m`PqbHKnXR&y83S%PuDM)qIWTaeBU}36iSvW+7cb!_=S0$w&LiK%Sdqx` zrcbfs_s34WD`JP2I>6U=*SLX$`nJOCsi0KkRm!Zs=au4jMc&yBy+XA!?pM0~ScjAj zzT_w#I=&(QEz5!U)h%qa1ety%IqpuUDR;6X2Op{HJLJjxX>>kYl|_GsYjlYV&zqZP zQ%S}qeC{xBmx+)}v-m^pBFjMshpoP5N4q{0I1dY^JA_x}p!IBWg8I92y260xsiWn7 zITN&_!UNu-Lh;(RdR1sXgQsT)PyS*xoJr{zQR!$%B}E#L{Jy+X-u0^0M<7fC@${W- zvrh7;oq}YRK}-e@esBU?pASHv>y96^?2pEGa*yD`4KvnYE^UVNu&1azZJ$fMi@Q<- z+5;>CL$Y#rk=C+%D-H9RU$5?+9^8}Z6E5jQCYM}jDB68WQq?7GwPBrMAewc!I6~;y zHS4uVEO)RRg~vEj`WmC4mO90)@dddDUbN0zHMqbDa-_3msC62LMu2Z<$m(0+I{Tgj0 zQ9+lkENi-(o>c+KJY`Uuzt{`HbpawMS8p9JecC-3Js)oXau;}t8Y~1^B$HoxJgv*X zej}us32*@zdpEfN-za%hQ+a<)RQl_?o@QTHB?Xm$;H3b&F1#sV{(81&H^HE6lhu%w z4Ok5%-~Wy%K2SL1azL$*T9bl)RrXKWyn#<9xIv{46HHdTbpt^4|8DDL4c*>b+vx&rxg9Y}&W->O<@~x%OWV43D`af5a4}>o$=%H}_sHu-&E^OOb zF9LG{@S}lYaDxJ}_bxzbSZ~RKIP?E?GyjG-Lj@)I&ckPWdI8frK7aOQJLao({_Lo+ z9rt@+Gd!T1F)F5X(!=7_&xmyU^H4B>e3 z@drML3-t2X6JE!+syN#-`!5MWDeN!{h+-3hzC$QGVfZAh>SOw|zdl#2>Q8JCg0moKgex?Fw+TJ5*eJDEH$n+-hh@#Hjze!1muJzyo0m+DZq1 zC)Bs);M%WKfydnkZjgsm<^&Nt9eh}obhrzArIyP{X4@NvvdBj`Fd>;)nP>1t^v;;vHf>M%{ z9gqf^eU|P2Sa#e0ZRWD&a1;;e?~u(NEtwId2J>+1kBVg!_#e2MIlxxDS7>6m$O)74 zWTCE%XJGw?j?7caz^2?fMTOedQ~#!Cc@`@#H(K!7nzPkBzfEmpOG@XKuH4hx^K=qh zuGJu;>lrexzs$-N5c|CChHKnQC?5gqWtoKUY!5m^9IB~|3#H_^Wo%Xxx3o5VTS9c^ zQ}f>r-`L-^6Uv7N_lm3yo{y#>M;0rlZ_ET}|IaCrS(HJhT_t3rK0AcN!IdzSF$EFqZ4eBc#S@@d-RWY=c#t6Y!KMDZ(`c|?g^Pbey6jM*Tzfq=N+w!NFNFbF+jZ~ z`on>TeT=caZ^CYKJw3;6Z~`j)X_w1w7aZ8yawom#=G{4N8I;d;h)CjgW?qd@_G7?9 zRxlhyZ~W1gs(2ap)Ykh5ArBptqzd_ze3uT>jLwtfcY!?AjhL; z_+Rw7ducc03!wWDum-;k$*-k4vlRN3ibE4fhBf)LCn#=<#_~0Nu*Xg48*6wflV1zp z37c^6|8||n?AcmT)G6Qj(GKGjM03`ukOOHOh&O>&rbtUP+4{Lg(#JdrxA|3K7CP9F ziKIRf!qq)f!G(})Dd=vVt56p@Nly=4-%}fQI)~~Bd|fYjxmtSzealIpJl|uo$?0aU z&$)#PBPTvn*18)VKH_)o?MgcSQEkZ$H|G2t*_IKde!~0QB-hC|xEtLRW^Tv1v%^Hm zB1!qzMe)NgULyw8CT{PhuT|xyvBz@f5zO%6h>n8`Gfvj`@&?{%=u5Ok3()r2HtCTw zZwq`_SE*yyeP2PUom|&`NkS0bgpYvD#{Hn2okWAN3r%KQg*5%4d(0mt@C zn3&WKUUK5d28WD1&l|b8rPzRsSToCeK9@r;8t=ToaWgbHA)%#1yNI|$Lb89az||DW zo`6H=PuS(y8L;B-)dmn2t98!ZRb~FHpPpjuci;x~R53}YqgwUL$O$PnE5iyoo?Q%v zw<;%=3|1!%SH8xgo;13O;v+F33J#+;#SFB(lZx1(zSa`s2);>$VGg+?1FIJFW9d3ktI&s`!q^exbFbtj~=-v^fp45wU zt`&+6@Gf!jW7j-p!=wy(U$`nzE+uM@W(0|8#@sFJ-nUEZuEaT_@ay1&5-vBY zIfK!SYa#O0lzoL8Yide>*mr~_o`_K+CocslR=^A_&eKSHsNXE>bL~BF{`ko$>B31} zQ+77pJ%JcDvuk5!>ni=O&F&mcN#2#?MRK7N-VRqK^hGDatlRXK*q4F;!Lo|C`XHeF z3GdT!dvkO{x${2MmqmsHk~$!u(T7&Q5Uz zYmy(Y0RLT&yXU!VD<2#zCMY%Is!@49zOQD)g}V}N)e?$6mhAW($y+4LiyqgP$sDNQ zJ?<4g9{2hc*~(@}#)95sI*${t2e@_G_+T}r$-_%J6P8|V<)sEgn|wCt92o21J+bJr>R za61OUNx0U(rUR08i=Y;o=ky_NvTR*iW|6VKaX>x;n!92#C8SI#O1CK> zVO|XzgGHzh^sdFO*9N?rlDpTOlRmTBs&lLM^M#4wVAk^p$X)blxz2h8k(B;3(WiQ{ zzUEw@7D{fociVaQ=}JsA${Y$CtaGp)(p0*}T3R#F{}|qNuR^6I=PF`Y7nN&T!x?nz zFwrqBrRk0(U({kWhl4h(v^Vczdf!7ND~?C-^&+i@do*R$dNyt1`ViB_hNPIRRf|0N z^}!SqwA~j4m0`|%9ZTZmL{~?>&`*;ar2tlB-GcI0S=l|pHeGQ{_?l7m4hXR0Dq{hUGjr!-;JPTkcQZ2cN? z4C(8EB-Spqk6Y)0lTzyTh$m!gTql&B>2{o`FI*WL8Jg3Ak5^sJ#!1GgMjy1SX|R{r z9_gf-f;iISEIG+-6+tS4A(9{6_NeXQtZ1N~B?eyuMD0D`5U8)8^II$OoM7nbHIx5MG=PV1)04*8*5PZ$5VdeHiRzj}~;anb;HxG0J}kK@fb zlN+{2An&<$LDrYE3;ZDjO{i|yzPbT{Nc&s*Rw-F*?Q28%?b2=o-? zApPwV=WHdegeEHy)fV8U`Y;m`DvxsxBw(} zQ`XKvY}7_dEe%?_wX*_*`)?WQtrbzy>@~ziLX%{3R zTPDx16Y0M^10x$|*H;R9YP~LbpE%<9esLcdNl~g#X?m+P40L;t3{LATKM{ zna2^q-lAk=TaUTCT?Xqk$3fInsmePKc0Ca%_n;cjQKkR@`)7~YZ!6Ml@3AY}x(_&? z|3z1aA>jsHI7BgHqQ>C#weJvcpJkx1`13wIj6gxf5=V9eWrIYZ5dUB9&tC_EN6i{z zfx^e=7q0B@K4lx|QDG%*&I4zIqozR@#Cgq2-FDr`D`{5hqr zY^y#cu$=vK=iIDp{IQAtW@5jkg=M;!_ATT2S1aHz$Md24 zK3smjrnZ&$Nz=dUp{D+os8ol?04NOQ1Au-W{>R1xN(Wmdk9J7C?{Ubf>84&iR`*dc1E87l zEhI08!F3XRoi2g-Z<+`)##>3I+3E&FP5f+i=4bOB=F*E*yDn6MipqTI$M6v6rHeub ziOvqjjaQFhr0i1YgR*?ld5O{-j+!Y%m_X&mDH6PMncJY}nDv9G&o^$CrQLYUR6P0V zya3vHv`lqF8j1OUd$e9xr908V|8%TQJXn^BE6}RiDy8TF)JSMnY&)H}ktqfYFGs@jD&0B?3}zp}?07KNtrS%bwDSE4$So>INQ{vEPk zdtWx3c>{HEk(C@nVosZp>-5g58Be*O!eAAS1wbyp2-6ea2w?m&AH>>kQa^)Wx{&XQlPgj;M*>I}%8hmtK8+ zfqn`q)vfE^)x#LN(gN+PmTifaP+OYWn;yb(Z^2+`&CulKcZiTXpONc{=G%k({89rF z0f^6Y@JgQP*PBp%gR+!M1&2hkX~W`NeScj$Yx9+ZunL};7)pi<%J%3p(a(E5rccoO zJ$l>c?dyIrSO_<#bj(${oHK!ND(++;ZE>)K-Pf^$%c-X9I_dZfwnf#ZQ5jS5M1g}v zPGTDlz4KqDQGfAG-`pfdSH-;c;68aHuQarXk%>^RMyny09nyTFr={9?JUG%8`KXe+|I=lKOqV0|?{n5o z9?M-k#=1qy^D)EmGU=%3eF*EF)4~Ord#`ITwdD%haIsvEynR~nsUouxGQ5^&FWw>3 zu|llttG&FFuDP0<9qYT3&buNb_RXE<=4O4FY`%V}Jf$k*zT;c-Q_|GpEtW@`_65Jt zK69P2EAKu8Baf#gxe{|JRgSRrqX(BRE@C05x|RU=S7}=yN_S_|Z2Aw_h%F_!;z4>;eLjGx;|}Yy z{;tBZ-9>p*Io5h}-2U)+Xr(p$32z$3uCW5mngM{a0x~fGEVV^;Yb#ocXIqsDP`ojjXrOEp8>M{c?wv}S6UwwX`vch z8H|*0JzB1vY> zDZ0IS^#;;0GFSM+U1%+vAp(3m9y%%D953u=8ZuUhX})=1_L#1v=5R;``$hVx&I(lz zN^3kvr-Jsdx`>c!KvS_r-XT+lU39|LOP5&?z);w1RLrYF0Pj7X6|Xt&z4*aT^0xP) z;bU&6%0}*#B1OAr@JpBnx3Ph$r_Q%OQ@wIMd+BLPp3fjv%#PQ4<>u-P8<#LInu%BA zC0?f&bmSE_Hh$U|*iT@o&^^zNd_5FfvOJ3tTU;oVw=iQ)@gJzADiE~K3sQZRG2|Xe zu&#wM&L}YxqzbF`8K&F5S)KJ@dv2=Icf!So;S$|fyHG=GRWRm%Sm$al4>C{dOnBQR z8t4i~E!IvKCelQt9WL?oIIrAtw47<<=*<#^&_@>6H%`_JQ3Rc=j&+7j1m=Z(i{Btm zeuTZdj*zPJ&*zZoU%4@wD16U&&33H|R{K&-frBurM!3E6RdU{EZ17 zak7``p$nCv`1A36n{I`}Wy5?+-^RZ~c5d`VD^`sXW3u$qIy{3*?k8YSq@4@s{lsN= zLFDX*Fx~g9k`Jq>SRAEa@HaaLC()p|a5!wMFi-LA1j6-swv*MsdjO$(6D4CV0>7t8X zsN+teP*#PbqBoyzENVJ;OCnfJkPNqi^_liZ=e}-Vx}-)?ZNajTOs&)0Ta@_l3!xd3|~pWNJ^ly;$AIZ>*B> z@3f7$qaz}7&B7pkrsZ4flFU4fPhymj`LOx1^I3P7x|*9B+*VO1>nVFL#-Qe)qvGPj zwMi{+tddaE{)PZT_{P06=E1GraY?S_=EjHLA!PGK=mBE<#_(XR_5PVhmoszKg+H=3 z^io0#PG6EeWMf;`Ran7lfm2c}m6! zoD50@Lb<)w<5EWH6N7h#b+G8;1-?!j`sbLG*y5J8%B~Qf`cZBa6YMTq6Sa7b+CAa# z{K_ro+Qgygma9q55@947o~@W;9<2k}K3;+Ekh_=&hZdOOtd3RUo2F3K-tMp0>d)n? z(eik|cp}it+8b|TQ)OfuF?hX;L~ zcje%B;S4%Nv@FPQ%}&!NInAplS(}~bK7>x4SH)i25R`-Y_C(w!HKt#@uyW#T$e6`v zkv#e(F^`tvm7RUMQ(?73-ytnpPXeGDQ!#6?EAGUkVL0`a-TfYNTVd?w4(T1&BQAWp zSlaCiu>x`6termITbM_TVMa%3JhoSt54n7NoZN-XgR>_vAsug!d z@*7h;C3BuX;DYo{#l4({($3%Bt9;@f&jHw@>Z|TfcUU zwv>d16rG!13pL;$1<-@NBt?9@?`Y8R6=X{BhD4DTq)zy4v~Y0Iy>Z)MXHY>{i69t6 zkCZchaE=NPcxmxyHPg3v;iWPBk}~E7`PI@WZRM%aC#nlS+wCfpiO57Us^LUoZZ!ihnLbNv)VJ*z zB(7L2s@4F(Cz8sDi2DrLdcwcgKBsbT1?yqg>^@N{{!8>TZqC~M@iaKzLL#$Sx-ZiM z$uvgeKqrPO>_+~gW2Knax@k3>b~e}i!k!pgCN>B+L4CiA57frBvdEF!srJ+eJFZrh z`HWY7gdVzqH`*iFtI;&1)@b>iYNg=&)fM+1S5d`3;js&(%w1b4<+TZmn8 z^ryaeOyO~^X-`_Y)^U!6#o9eY)2ispz5T*+K^;_wG{4xDugD87-;k}J-jwU^Hx>0J z7*`K)ij`Z6WE&ZiCkl&$PoS=;Z>TGia|y7j7|6pn$fOX(=$2>`TjG_oFs*pil{k19hy@T4{h^HiPTc^=afPQcLKVE_`O#?MA(Rcg!Wbd6hJ28E#C*?8axS0}5`*J|Pwlaa` zMXEkv;1{`ce>G_z%fBM5eKP9ht7m$a_K|}uyWh0x?S|Te@b!O%Cvv>JspCNCT6jTQ zZwHy?=v)SuQSUopEmle)bAeF-`D?l26H0!p$Q$TxX0f`Wsxne8 zx3XFlU0C9|`&1f(^TKH1o$F~s(MZ!<}Xk&iZ(WsAo`Ao z)nEcv3NI5M)#PEJDSPvr7z6QqaZidSiI<@^7Tsmk8>ficr}xYNJCV8GMW3G01N8{F zT?Sd#liX%Nqk!AvxB7718KsbC$X`myb_=rNsT{f*<)AyDdo?8kxU9xEuIDOzM6i-9 zh@(O}NQ1s99mbbg$8&-hRtJ@&WtVT(b!fju(ylzhJ9?FslJBFe%+-AcB`Cw3o(n4V zGK#%$Wdrhlaoy^-rvtd|LvCE&@mLN*3%MRS=r#P&+`V_7z*V<_#4MhfH%4sBSkkD=G;z)9%R=1HiK{18FjdE?_xCfAq(8~9W& zW_<&+>&zyJ0t5i{r)DV?;LOmPf%ypnWjWJJQFUoqUkm}|A@I3vZroe`94IotoMB`X zfctJXJB6aO1Yf6wtR0-4`bnPJSB+f|GdEnG7XYfswcrqS$O5{)3SLjM8ve)s+E118 zBj>E=>jm!w82FVvMU+TzOdEzSJ7dXSR&XHrwRA=0_tV@z-4!9)&(I+-25L~4vNEYX zrXiNbXc3+){G@=@QuX5o`(dB<7exvV*}b*vG?V@$&TOI~qJ#57ahI^wgVC}e_3~W3 z^KE|Dcb&N;(~VtQ40&#J?Qql4pU=4enhbGr1&PeWhrg8mZU5|5=OJC z4GCGJt716oj+P$Uh?d?-f4kicXLc)F6(L-=hbdR)GE@V<`c%@IbqWNb__-d#Y3J9+ zY%Ohh9;q{VK9rlJ-O1(u1#^%b4FZIfcl27P?OUZnq5Hwt6nFZXZ{@h1I$^V zR3#lAZ-V%5N#k#+I_ZLn)}I6vsd}QSL2Hzv?jvBeklMHBMR*F&(*Kw;K-rc(0W_e3 z@;92_whF$m2Zx#k`|8w@Ux^PuVLMSYG5Y}E z0$DX|S~*;qKg|>HH_V2=8TOFj`8DQQkaHP;Z~7!HVxY6@ySws70J(yRFuK~%((sE= z-hYA3wB5Y=SK>qBCh-CGlzryZd$vMgw1Dm>-=}W7|JM|$n~#-u-aS}djIZ;Vw_9mC z`q_0<>F%DwOE0yOw|DVo1@))G>nrW*=>T^ciecVZIXkbVlwPLVQm&Tp!j_bO{_Fb?9= zdRKhbi=Fv#|2m;A*)3!HcA*!H#@s(!doY(9xK*&90Mw4N^VcX@9lm(+lLlZ~zerhM zKA&D7Mt#x1uWmsJ#2h|xgz{(28N;l4Ogtt56Ws0#h2#2Gu^Tsu61w+$wLe?WWg`Y| zp1XElDY%6LzS`t=#Nj&x`b@ba>jHeK8(v#98^Ypxq}0UarS#d*!3_4SLay(S{6I<` z%EZLRF~$GtX4z#5*4xc|!~6>^Oh%J}INOp)PE}(6NRKL132uT+Qr7JBtwIUC!zQ_9 zD%q+JjB8as=5e$n+E9_0Rivv1J@!2_)coutv!=ffuq*E98KmKbhTN0jXd2uKvhJ9* zb1CYWG`hasw)sxhbmRQ#J0kbII3*{yrF=RJHz@FMq1ZBOV^6ul?3+0tB z)kdWd!>da222DTbtLrw{zI@gkc1qFab#Jc z@jottSh2x>IxkBSAH6v@p@ClV2`ZlR4T`{wp{kLjF1a!V!`IkPR~m@U&Z?Y(*m>pN zN#nV6pc&|fQ_zxI-1?W#EnGJ`dMA7V1`J3s7S?1IXI!@9zZWC~G?^dTIzk5h zRd-Akmn*Ba2zkM%ZsqWG27;JJ{Jzig72IYNF#@+2TMxif0h)r8(`3*+r@IwF_cmlr zcRs&;^=)8YAzFlj$q3*7mv?~oex;lsBAEq%2vQ+;+lR%plwY9J`4uL~uY zylu^#lB>d0vVxOX=5voat8!OgTQirr&~1E~6@s}+I1q`-GTS#DV}-FxM4W0B6FEQ| ze}etM7g63A;n2{ipq#j18;e>67)G~S*Yyl>F1t!G-WQpn_=f2JXbCzJ1v!}Mw#7xs4id3x5B-j2EsHucmSvkV{1cp54x%r1#=^+_w=0d zdsuJ9rzuBwK-Tigk-h44x}n0EpHjn|(y2_&Gxm|>0ns>|aE%lt*Pkpd&OML&USp~ljl27UI$Zr zTy;kWIf#?Pz2_CgP|V(WK8KFz2I+KvWA2z>;gMd5r|TY5F?%F-{hZ1A@d_7wK}sUB zWz?-DkVGx)*dtZodla2+e(JSi^4xTB(A*KV=&=`byt+Z->U5-`*B*f7EJzi97c~HT z=`{J6o2Df|qKb0Xk^4JjI#3FnzUZJ{*eM?P{1}>a!-hhO_w^OcF1u`+%VN~UALQ=H ze$HcelINGpw7CGtG`!X}A+YK@2^>@p4S<*gP4yk(2FPNps=r*H=GK0W26yB!LLuE9 zKvM7n^RVF>9=6*19YVN{+IB@X{2jv9ji|x;RPVO6|IKg)_uDVKMiaEFVnm{A#C!|1 z3Mvnm6dzIPtFu|;V+Cf96^nu zzM)rAsSzJ8q47bWa&+Q5LQ}CqO$X|N%m&+Yu+@LXHpfTla zsP+5#IDk|gUc*HYM=K53L#nhdMIG90H0c5|+A7~6<|Qrvw!i;Nclrmzcm$u3g3cmU zQO@jLU19y%Bty_7r{%1!Lvz>&=*#unsC<3XU;6GJ&Ckx?WPp z7al_*LO13B@iHC0?FuqewmCzlew$a>+LYba3}Bca0|%zT?+|~$t=x8=FW`u zb5y+|m=Enbv!vf9*hgc$whw$*ItuZs)8Eo-R>Lf(3OnneN`P6qdZ5-tctH1fN5CeF5pfZBrc=z zPDVkv?_Sfo+5cUKD5t_Yz7TSFW>S(mHf?U?*Q%NYjxpL%@az9?bx??LTAu{tmLm}BNxGCt|u24N!fv(7q{sng9 zPLs2Gbx++{kXBGXIoQ)zg;(y09&xeJ-F3I&7Jak(w_PVYw1LIzaG~$AA@hu!L*(NP zm1>sIm-AJ79y>j-)hkA@bXE=LTz~Uz+`fG8J5n=+cl3dhH0*KpHLdGC5oNJq4@rfK zO8fAuwVX|`_iTgEiFZ&-I}6(#&B15Wp64=Q&PY`AsIo4beyy}Bz%rG6KD}{cX~Ef) zY99gA#XMcv;>A3gM!N%o&K{XkX+Mzc;rf1g&F)aG%<|mwcgSNUj{66tJ(k?ymvf}VVw2p{Tlp+nLsJl+rw=u>@aNPe9DRYO9d;v8 z_&V|ttkZU9E$%#+Ws3{h?etoP2EqCGa=*-&cBndIK$OUxNOE{8pvG%E5@Bk7pCqR4 zVYud&z0Vw9#SI39ZIH4nOuH)>kRY$e_IAzZku;DZ`#o;Kh~lk1%ESeH9_W8n=GT2W zPadUof~|DhR^?uc5`Xy*LaW~&U0o`rLKdr~mr!k4o7Rdzv|!tooutwxSuA#kE`2=+kNWHKX6As^3|KK)c)`Vo=*6jFXf50c_>JwuYiU80KUXfdK&iq z^Z)6R8TBWU-Dqz=^l~oV*GsZ3o0xk&Xu_lO^@3&AT&uT5MuSI=KdaGx8m+MNYsP1N zg#v9W%f}81pJlq7C|pqRY`iSi!ARH4#L>7!)h_TlL7);# zO-!k7C)Qn&%gV;sS<=Min)OViy>W_KY2vL8j?%uIwA1YMp~&W&i%q)t!4Fs_r1p)F zTkT>F-QRe%E*B0Q?vi3ZF-m*FK);a@cINUkCw57yY1Rh4DY-G+-lhZPl7e49rShKH zzwcQB`ch;UpdTGeoQkO;0^vc?{NoB*>o$&i-rVKWQ06V5WgP&D(;)rQ>TbQU5*Z1C z^fkc>QBc-7Xi9k^<;=fHwg{QUF2Q0c3AFSG$m!!?v3 zq~4K0EFY40%2zuYrEM7&aoKpS?zD{5(k2peo=iIdcvNNvz@12|eM!upva@sf*0}>w z@pr?Dch23g4Sy1bVwiN*-Oy-E`(%%%G7IZhJKPg;BrfMW#6)r}b`f@Q;iT{)v)AX$ z<+;+0kk|D^>agdf!_yT8@0?X!2Yj*Ss2(l4UWKzl`^TTZ+G&`ELAx3fEY7!fcfoff zvJIJ@aF?V+KsT-~B%n&zyI4a$zyq=S?&@MV4Y`k02D=uQJV6f^YCuqfhj6X#j6_pxcwhU9?YsyD^n@Z5bIV;VN^B*4peJRa zy?0V+AO?MEI%%}7hNzMs5jO42701Tcd;0ol2874({hlR^HXir2czBLa)h@$aB(-0Z zo|$ezdF4!I*R*9h56B3s~82DIPwC5>=N6 z*P%_{A?}|6F8$~GO|sszt+)NP4Uc8F{%Wy2Gp(+YXyKcZpxNYCGD1p zy>*iRK@j}8(g2<4=le~9AmFC_xgzvd10+nn(w!EFqEq*x?6inD5)Jg2D1cJ*Kn{tSs!lisejhV3gMOOECoGo4f>DneTQ5` zZ{95T?4p8xixH zG^al|l`G)u@Rm=Ko$3I03&`qi>$88-**t0^mbALMS!|C=Pns-4$UU}|6X1mM=fn%y z4}`}Ce1YOh8@MF^xc;{{L8ro}OU$i9evRI z{db5~T)ZqscbhN6meJbs_OAzS>u5jnyW?VdW~(ubywsVDyE|7 z)ng@X!m2OGI}+QKiMb<&H0T3F^Nov^awGEPd(|`M6uYeDdQfi>^d)q?<7#JACXOw9 z$!_`-f-m%~;4vhrv|1~%x}N%+*CegZCN1Zaf9eT2R``LDxH65W`c4VS;AS1Q-4eq0 z>n~03Qc8aHSYhJ=Q4ODQnm|u@UoCPxma+2caH`vXqOAK+WC=o3tU(vV}z z_(#O(UHS3-DSlG=m%feX^{SlG2|U%xdElIUl8Vm&{jP`9^Z{>sPN>SRIvEAu@8Ni+ z#_5rnDM6>OTX}|A$sbY+W|m#m;1Xj#cn;H3Sl^Sx+9WMPXnL9yz01NBR+C@6I#W$f zH+b#8?7WEeHeBEZdSE!0LGtma$8v-J2-Ocn9G@g)zqCIe`+&}G?KgX)5}hOvQr4fE zrxZ1Si5>%4t6MSsoL4&Nd$q~Qlxe}O;!eqh&fp!YwHKt8#K87Q`@{JRsr|M5JA|*t za2@0(!67`jBJgiHSM{5#=rp*`@2+{9B96PfdNh*rW{rS|<8SUdah zHnPX3z!KQv;JEz{{EPnr*i*NdKjOgX&sO%Zcy(!mTNnp4)SXDd{A!|Ahc_i5L0JI@ zWHo`l+OIw;n;Ya$C&PcGdyKjA0v8vatWx0O(qRSpbgVtZeDfM2O2j1Z<>NA#S z^_YQMl>RgL!#LNZP0tD_6POpOAn?F|y)VK4@7!Yko^A)7I=H7Rf5?8;XeTH^AFFFv z0$y!fIyZ-;2zYkoC`u`Kv?KspFvm7~kNx0bo$$l^Z_`I2bn~(F1kAi%>x~6qP895$ zFr`}Ue&||??rxqz>V`{j$KOnQ0Qn{~yc#~=3ySZ#|6})!zh^$xeTI)&f^e=MEa)}> z%KE1>>fhKXf*G|%a}WY*?|*g$2z0l0gU-_Bzj0l0>U{GAn0%PJ1Cx(02XcV5oZP=! z54phCxE-MD%K>ZP8UN1)!cY=;Rzj1Jy|voVgytBLh61&h=X4W83m-Ih3gwM*@X9VV z9IeR8J`Y>t5rxmzz?bgX8ZXSqq`QACvkK7`XeQipv$pm^3l}Ww$2gz-pagUORD$qX zzL}5NKaicnU);}HQda<}N>Un8A5CSaR1ES!+XMe`7nk z>nF4fk~-$YaZxfg+gLv08B15Kgdz*G+8=;0nx?G$vGa^YVaV8YIiG7pp5xAfRO#vd zdmuNiJf?*(H~_AVf#VRAxJED~`+S2hX-G{)PP#+BLqgC8P+zP7xbIxmX&%2HETDfi zFUj~K)2f*Otg=0BPk^?p|>!R>?*dpoBW2W6WXToH;c(WGe}Vlz^*TBBJmPf zerQFr{NhaVi~!tES5fmIdP+I~=Kc@u_Avz^3Km3`Z@sYf+J<_*-YDEGYGrry8p9t5 zz279JIf2A<8Tu#m(3zI(2!gTGf8LG(hbBzI$#&M`Hw~DKd`=i7sQ3W# zl^=zRl&^h`L=}GlJqJ**kqV%KkLDukna{MS&xy@H9!4>(0)*Fgjeu#{>L{+f#>)+I z2A`2`-5no44?7{+`THKhx-NXIn!bEn`gKuPc)5KaeEWM)C4@)d)BOy%q545KlMWEW z1civN-AxB(Iszp8Rwe0}2_z(%8rJ%1`zVadu{R-QtLIICKA$7|MzlXkfLpme5mlo!F$ zW%`{*9?HxF7T#-Au&dsnf8-mV@fWSRz$1476gZ?6H$saALe|pUKhvG$GRdoB$ zE~kCAh+Xoplz5jT{WA7{8ln0-P}dV{IGhokk*r!Nj~{0WegiPV<13*kqcQvj#er<3 zK_k$fzz;);4>_9s52!?>=~?9>jkTqGyKkGmm?lxvg*3>;NfRj7 zz`|;!IsOH(2ye21yH_Cn4hUGm`}*TAx`FB$I1$p!h@SZc8z=gYHr$tz2|xs&c}YV0 zT72`z|DgkJ&)SvIfPbQLVx=MUREF`t%C|jdKR~QayHt*@(o#DZN+FD!7YSiD)m5;? zM4YtS+^7qzJ`D_}QBbSRrr@NFl#dr0Fqw*6-q`8w!A_`ppkisn^2HLn$Kmy1iw*Nt zP8;W<9bGGPeNJV7C+%DPNMlDf4Oz3zTU_V?+WCs7j?4&yO_pH)Tv2xJiSeG zpmAY{?@7xbVfPi}(|6Yg5enM>BcmYy{)E$U2d}rUq4{Y+#;sxkuz|?X_%m0tyC^X! zId>iX*P@nB((!bDcQ=3!4_&>k4m7Y&^>n|0!~t{=)9{ksSNOrzI&Qw7M}N`Rc}W)4 zK_0id5REb-nRlsH&GfLFhS?JI(l^IKS@uM zs|c~wCO}>s!Aoj7x@ZnY57lE~4Nr^l6}9~R<^7RaUadaY@VPF4LQX=+X6soJ&VEJ3 zN{luG9iS5OL+5Xiwvw)}5on?s15Or0*3^z-$ta!|#Z46^L9?e`@{_hoP^A!e{kvn^x1CBhtFB0O6Th3Eq^wmiy_dxMxrnk)3WM=T@q$%m$Vm5J%KMC$ANXThvsFi&(FX3*X=^6&+4HR z;1;(4KVIBepV0o}uuKfTIXyFNC$%_uqim9in{t@dCS`#yK|)j+=IFA``2Kv6Xk)Iq<$>wjSL z03`EYU3LKDcl}gui0}Y%n#sxHuPpg=i+?>D<%t;wPlH6kW8j zRVmSGqTjCFy?yT?5z;>IA9IltR7wD~h|iZ!;3diw-_K6vUTVe<8=zq1NaWg)2(?Fx z<7CY5f+?dV{)S2RtJSI3)JYQ-fwGPy8oSsn_YZYKJ|$xJ#hqqGDdd@w$VvznU`Y zFLQmd*NGg~W|<@EdF%?Tcm`Zlg+rn<@7%R-=D?s~a$356h=I7_kJ!)MYs9PLKUspB z2nN7qz0#j7HmhSLZfi9*CaQj(|0*iB^siFMpHP$|eI8)Ip!TGJ`6^SRYLZLc+u3c_ zk`y$|#{Zq{t^g{sKbr&Acct*AKx4Q}F67KaidXrgKpY2#vSU!v zgTA|W0OZakOqIZR&pKSIid-HM=yrYtVl_5Ah67C?z5OZr{eyAoRl?Rd((AfiY{p4} zF!j_Ih1b#FEYoWHs=EBibM?So8LjXSBocqXno$7vh5uwb+BK+Q$I>0+8!)?ERd_c* z>2lgTQ2N$(iL0ysCZxFXpVijaeZ>b~%;W2wS&B<83B>VyhE=Bkf{|Rr6xwoxv3z+o zoJZfF1-*KbIiCCI-eGtLK5mUjK5clUU^<@qhtj62FS=;1dKO=Pi$Xhu+z-t0p@QN0 zW!fsxP4X(4QaWB*<8P>A`<93~Y`)Nh(%^l;5mT?J_*}y-QhFZ(r8Y6==zEh1`a<6r z3_wv)0G>}&0pR)Dm%K=2(#AfmvMYz*Fo6~eOD6!us`%AefvhL4GR>-8o3dIH9M|_$ zfF&uVfJbH-iJL$kZPAx%h4U*(IRJ00v3@9~Y5jH(WVdpD+kD(+HJSk_E?c7c+}{y^ zTK?Q>agx^9f7I630TCIHVxBwz(B!M@;OlKE+RE${5Rkc?mz)gFxEP<+5|?lcKUYS} z7|~j~IvCLE^~$zNQ1EI9T;E#YHP01E0v&Py_-;)T7}=fk`)vT2NrjB*j@Z@q}Y7c!1657_+y#7h)I7zG&*FJ3oaN{&VOtE<@>kB-7_43Vq86}3r+LAxJuz{Doca3M8rFCBHp}6TTZ{XX#>PS8hmZ3KP?j}Dm{Mh=!Z{hUr;uDg=_zU4UW2$ z&2QSHZu-U#yhZE&(m(_m)Xnsu0#lWX7LaUbK z0CjyanR{I`iVZv67sH%a3Z~gBu-vlk->Jos0Z$D;l*4G=X%620DOzDraXcFuIB&%^ zP%1KZ0lucgzi%FZ10%u(JUgBR*w(9#hyhgXsxYkLU3CBAd0D?_m$1{W1BwQ2P#inyWYrOqz{^GH(6zF-ICPb&{FY@etB*gc#P8ao zaZQ>XMu7S87tp=F>%YuLbV(2U8Ysc1p!zuNt95Fc4k&*aKuh<;7l6m+Q_u8BUpo&F z%(cZU>+O)hn;Z?d2Rw5!sT>K%POaVrm~R+TsR6`VzvMeY30Ck1eCjfl zAOP#C^cYCBGD%s~nyn}F^V`<6K&H67H3T?jj;YHFW(1JBpl<^9mlUwS-h>ydX1}@N z1}^f10B?v2Je^v44UjL*0JpT8+9OUD^@=O!kV5v+yAeb0{24LNjZtiGx3p|0?fR@x z0`IbK?E4(H`UL2uDV5dKRW_!Ij?El$@ypHC8&#vBF-#T@-xRnuW#A(9U(?bUAVLD{ z<<+SHUhdk(dzm!I?J&Ygkz}fkbc&?btDgHq&XioVmg%^?#DG&N#IdcDky)XFG4B^p z{N3%7B8qb?#iyeR6l=${2}lWy=4cFW9`4+m^-X8l5=_Si`t6ATv4upP;W;b~;Ci~c znMFw~e|FP|dmzlxBgH4^8?~RRlBC!&9oxh0DgAn=3PwK01*Vy2&tWOtzR2@4!-|3_x{wih8#5yWGoN9K0zS0D5+PkN*eJb2VKRE)6Ua;&JKc(sZ^W-dbsqj%{~s5qb&1yo%&hhLWKa zSddYxAWiiUkWp18NQg$~#4_74cV zwZD7v2Xqw(D*xl6s2!IV2xV%=Y#D1lRVBvnIHN*k=%nn0s8F4ot}eG&Yx~sQS`p{1 zw~lc6Ym@h7l(9XcmiwUJ`~2V8G_?0~TeUop$%FA}-FGB|{I|Z!f8z2q!8R8A0z{VA zr(Rv_m1NkMQwrU9bgOLb$af9foXYFtUoZtkn)FD~!HKkg>mnEX4nsi5&2?Isy%2S0 zEl@afMaY^QC+EVORWUoF`&P8^JTd&$TM!2EC}J1iKc$|(J>_BWJFhZY(k*!&=IYAQ zKcv&pQpT5jk1w(OYyBB^t!Vh??aIMJ6zO-ZuVd6qVk0X?7j#-YY}m7V4+BkEc4N46 z{kiXhhaVD`-v^=7jZp`95>!kc8x$f>NEc^Zqv)gBdk3iS85de)UoRU+S$h^K%Q|Iwi3 z-+ldK&A}7=%Js<)TPHP1T}X}}Uz+JdWsE~^W5@7}Hjw9ozj>v`nNtP4_Ex)PZn;(| zK+qM>_=Oiq>3=N)k20@(W9x7CKKGjcSdnetB4R5{m ztozqY(h?)(*OEZMHN@{Y&fhZ@fHC?`C{y#f!6O;VO`DTh*?{o{5jCiLOBnMf?_LT_ zV0taISDf9{B|1>Puj$!;VU21yF)V8NdM!Jlt&4`7%GJ_K1pMQAMl!M>7vimn$9+m= zW9_Yj;%s`4$HW;oWY4=|9Ns%oD*vxt=5ZPpFmC(oaP_CEx6(fM?rgo;SH4MPVvm*% zum)WTN7Va7#Rgk)t>jKA63MlNm>sL0^$32$z4vErUJWM>5X5~8td5=5?Jm`ilNh1H ztdm7%@cZ@17N?n;vbRr;?0MuKGXjG+4KYWx{;%C3+IxhL^hnO6^K{Oqs%W~D8?29B zv)p~p!VuN^vsq;?KNl7{;ILVpVVXmGyOz>_6RAaz(p`P_=c>ow4GdoD+zHE6fQou! zShFt*8qCv|%c%><#+Kjmi=Fx*(j}!bfMTe@U$jg=j=NHRUfBr3jEp?`7c6!gIG)U* zUv}Nt*k@U!i`%*A)aPHPWZotdb@=&+zS%x?IZGga&9$F;_&o0O(?hnu&EijormQhz zCyHbjh6TE!0d-Q8_lczlh*clB09Ft)eSkgO6ZqNjJSpi+tMfx~-j56r*h zl!8BO4Xm=BYkaD@pI!bW?}cJ|EHG-%Ps-aqwzt|Ct11U~?S=UVOboln)$W(c& zmJ43dkpYax9x52stLPX0ZgAdM@W&z(l*)8ZCs2I0m2L+&D$Sz_b~;Qo@<&b3zIxfFz$NXrO`d?dNY zHRj)SoS%2~LDtppuZQrOOe>J?!{j$H;g>Q@r}*R6Y`ctXNKw;5)bl1fbV_VRbAsfP(#*T9(!U&aE|S zu}F4PrA1Zr;qoxrtrbPAM>VjGG||p9JGLl%wgCbj!UMUU;-IURP$M8Er83QmD(-`u zeJlURwfA`WJq5h~jaN$_;Wb{>EqG>R`&^gTh~{lFGHItoC(qxNN#SgKQbe!YpWic8 znbI$OwWQ)N*m!tnSf&odYzI2ox=5zh5d2RwE)BEu*09v-ho8)mGnr)q!SBpY$NWZw0&&%y^rz!K5|+opyhhor>?Pc^YgNe z3NoCxdMLEQ34QG$U;ej2;@~30wQUzH%#mn`E<+pqjCOTfD;0nEL?I1lTFEhYu8%cG zJC)IlnK_?nd{=9)aI7zoEr~#J{V`^nFtv~e2q2{HgAZ_GB zw7hgotd*f3&?f5-Xj6pn;_lMb5WP~Ox%tDNOKkKK0?QT2Jr-=30gwNp-t* ziS`Ld{WX%jQo<7mX1hv@-}lh<2){pDgCqWmSn8HVEiZX$5G6keyl{$1k4+94r{OkI z_qQqFa^q?|dT)I)&D?{p{*3F;Lvx9&heG%E5Q;hczlx0meQ*(;59!GD^XQZ=N|ltd z?MAtSRExS9Zu*Jl~>AdZl{;X26UizcEet6Wsi zYJX-~9sfwpIJP(AMJAOXo6iGK(vKR>pZ?$iFxS62?S7QKl$y=`{t38?0Afv_Unc(V zBcRuxRj)SELn3|9F2v#DVK!eE5Mq` zA;o$Foxz_q5@EigtAq+clY~uVEZVR$9A}F(bo50W57AxTy%U!tYuv{};66nKb3s1= zN;Ihv^gWN<&4WB8dW;Cd+4#>+HIl+)GSk&kca%c~)n-8;OweiV&ows906(Iq@Nyd7 z5~QFb>9vkVbWYcXp5EY`al}gy<9Q4SIYtgSR`(0Ms!LmZ#q5Fwkp_KsjL>bPrk;Y1 z$-)dRxn`7zrqcGMZC$0&%wqdvQP?;oX!LiIa*{uQr*t~YoQWd>-rDdU4?+b$MBLO~ znD+e$>S0{g#LKwQM)*LfG?pHXRY)dO)irh}@E^~571AddTTb2TtOBvHJ(2E25=APu zf;!#WUF1i&E$^-qfTtb=4CXpxVP~yAsWSJYq(;|Gq<@EZ_^6>w$A_}qWKxU}l9Ji( zI%^s}8}WM0b8_mG#g;wiStsV>PzVD9Lk2I>;`So_sI$JN`UsmDw}C~BH1VvB(R%#| zTQ?Hz&EESWgpol>b0zU2!T^E!CDD`^En=+wjDB!QrEWNW5g%*F!O=BG_bdM-94|>e z%edl}jQlA|ls)XpqEJ&+?WBkeHo28SZAEb+E~5tvXvLNprLe3@b0&D~fX*9?I zFYy<#;w7*mB?-&dmd?^)P|DNuDy3t0Fu#$1OpV9F8gC*Us>#$cTZpvPNC{A`9uira zON|xhR!yKNa?dQyEP$t!Fy9n0+8an?M}}(MbXHA&gR!OpRSHSPgym%M>KlAdR^?*b z7@b`7n0p%#azmT0GEx&jDZkE0!rYdpA`KDHv92W|5sXZIi*G>z*&`*n;~B}?l~yv1 zjz*5QHr7gmF+G)LF{T`a43g@)@loXnEEHD*jO0!p7BAV3q=3s@XD%Ay5DrEKC+1?> zO1TUP_j|Zs({Z6n!QR-0XbP)a&RWIg=<12ZMLj8EvPaV|L{fUcYUr8D(fU3t9&F@9 zw4r#@)FbjUr&&V#42GQ+jG!t3qR%Lq08b?0n1(zaB;q2Cb%)A=l*?(7jOP zZV|xT8?uUH2{h;a1@siw_-Qi4!GSlva+E22I*C=rr>wAznGpS=W%xd>K<&nmDXc zjL6RQa(*EE2$b>1@W#t%GVxwk=Y*!2m;kfKHYP)9~-zg`Pn7cLnp=^mOsxP#2 zk+3a6r3!6btxrfz<2lka;#IplUI1^t^ALeyTh6qkiVk?5-Zt8w-rv2f6!W9d3kgsx z05%6UGC?>yYQwf>MbB~M>3P0K?ksw+W}?$QDM0O&W{{0nX;3R+NR{!_Y%w4mXY}mXMHYVs%U8ERmeTQ~fZmh#_tE-8U0Mufu+GF)Iifq+D_(|A z>MGyir5MK_uzFOt?1mI&$8Za;Q4*kn+yjH4?9A;B28vlv%oHeU_HvlCU52bTgfush z1)s~vqLuPh8L}f}p&UWWq6~T!wKzUSyWZ?mMJGiZ^Ch!~2PEtBS&us;)RcAPZ@-T4 zHh<2#B*H)%7As2#S+vsT(+vXaure9hSx1~bm-R{2fAxwLb@62jK7J_DxXM2}(W{K_ zF--98rtpKA0)_Kt??9L!ca&FAvZCXhjAq{Yzdfx6@V3EL${#>z$ajd^Nn(}@p zxxAk5({>H(8xFpTkP_c|(v`Fu7cb$saVs8lVC+VU(VWDYMet_gpJ00Kvt_6slZXISnyMEVc*}Z)fMff^!g~r;8caGg>QPd1+BV-XHrCi_mHT#0;r3` zEe?S0qsdIzme&;7Ud|Zj>3m>J(Aeb)#UP80xl4YRUaA?=7Z7Wgc|i`lFrYvft`vYg z=kg0PYU|UGy$1G2LyB!FBTLwd2fN0+pVUy=_So#-Ch2E@~mOy z#*GN-W%VO;yShln&>dFjcNz2(b~{y=eoXMmf&lbpfU7xXIODAyl{N=nOEeHgEUz!N zu{T$^W(BwjpN})Rk@lRl?XJyITaxscRQ~Q$k*%Dho9*!$_t0^7ytQ+MB6;4bFeYE5 zzY4Y)T(1q?t;?>d zfm8JplA2806xrCk@E5l4NM|I_FPexMh4njTJ>o{@Nz^w<(ET4#tUvgXK^YZs9E@W zE)xg4G$c|06JGR^_bCUCDeTCJo0g}_b!FPvq7xyVT(@yV1=W*}I@F%j{ z7Hxj6jv=NcGY;JlOqFXfI!z9i6JM5_YG72RKEpdHNI{FX{Zq-V*x7o861fX^xu*8NJjvf5lKtvc{?qqHkf`=cAkU2d zHWgBV?li|}5;c%m>3>9(7y6*AR8U=6UIZy=p@gW;UmV}lyP~Q|g^?P6SmWC)KbeYa z!MuX4IS9TXi<8+|D1^oPDpMH^yb=|>hzs{RnbipKTmoK`Z;Gaj6Kd-hYGX2;i}p%$ z#><1m7IGOoTf5q*1uB)3(t<7?W-)U69ebszH3U7a3s229zzP(0Csh;d z4foK$Lf0k@GTEIyh*QseBGZ$>Wigu^+M5Zbd!Vd`Q&m=MW91lLYdV5uDzd~106 z(`KFk143v1z@-Ysed86(szvi7;}~C9vUbS_qqnJ4!iw=ng40e6`&_NS8i_Z0!p z5drsD*=YKxv-QopsB+^XuIya8{&ZxKShZ{Kvl6+iLjcOqMenq1N9^lT|GF4WL{>hx`i1O!!^jw({rH@Tcf^*07FITpRFsM-p#FeZ|qgE zUYlrB%)ormgZ=R{-JY6eD8doy6L?}r3`v9fL~yRPS^*1~f}krIkIFg26nO{A%fS`Wsy zdlMt;^G}7Mb}YQmf{6uvoZoU>&|6&6+9y>3Pm(6Z>Le)daENtbw-TrY-jK}7N6WZb zTSU{oNDaRkLKEVoD1)_dPuhi)tGfvQR?KJ9+PgBUl$6A@ul#X5XO^RaPg7B1JE(7! z!cL2z6{VBTVV&bn!cuN_Yzb9m`l9tQRq;)hRbXsxnSj9anP+lK>xudHuPu?O--Fo+S>cnhGaCripvXM|bWP+|d^2>+WX?%HxM@4lM zZYD+wQC!)zYj$=5WJ>G(i67Hyd>k9&nzg zfrn2J`Eb%_ub%iT1wEk7$k*?06;zcsOUxD$E0$B{)_MbC3{0i!R)QQi|-TLM=SX7|b zWo~RrI3RK?7%BK6oW#||8dV(>Bi~HyQk_QZDFVxlIM^!|(X_dfiQ~6Km6bqRg65WDiNLJuY<%dSB`6^Sy3Fi-j#LJr6 zZjoUw!9usH8BLH+Z?5T1}zU$bU0dS^n564njL^AiVd9F)Lk432#Yi{@$aotZW8xc z_>g!a{mAa5e`I$F|314bK^d;+bEArC!vpP}qA8V*S2WnQ^MGRl+WFxFj$X#5^!Y4e z{b&69p9qMTk-Sw5#^6J;iiRr!A&VliJ~>(g`C3n92W)&^1VU#`MMM8|k|+8j$s?Td zzLMlk{v*ko`^rE3NDOQi9XT#d#ym|R+SbAY*|?nD)Pk4a2N${TOxBtRM(?E~2UNc& z_&m5^XvUya9nhNi_arYbHo#c>r+)kWnit=$VI$Agx5v$XZ!J zyUiLuSC|=_pC)FOS1&}x@1QRj*F215NYob+X9rm!tlDew!ks8M2@wY0*p=kRjgM`L zjf6Afx$Vi6>>XN)4Gj?m%XTP_c=(<0YJHflM=#lnsze z+uLKbKb=@{Wo|WK9`h7h(I;>9;%@5{6G#+!n6(hFWaDgZ>iAJ_ww8g{gE@6jz`k&# zKbFV8hgn?KQwkOr63b`@ z>U&eT%zxV-luC+-SpEfc+h887!@gUbZAIS49=A|_DiA_Q$(A#_Bbam4vn)1{1%E%G zgzu>aV-9rIb`pK2YTi_(9s^;w*{qM56(G~BCK>5j%b*O5wOoADL=rHmp)8SbT)cc;jfPS5fG_bBu;bHDrnjg=MJ9)1&xuh#%eq^8vbqMn4?)G_)cO z=EHC}R&Xnzt$hZVeWY2 zPrkr0{-|4z7fB|77dH$%lIWavGV1+=f-k{~WaBBF)fBuAOX`1`kCWwEL^9dzFHg>7 zGd`m?>2&4DGIfF$L$9TcRPUdgM#8H2;r~WXWCl zMoFHPPIgu6FCc6^(aMbcj?Enb`^u@Ylg}OG(ne?I)z3$CPiRA!$c`Sr!Ch)=ybC)F zv$3`F#&ecclZh9NvVX^wV?y5kWBFW{U(l}o>4v@wjS*)a5 z=?WMWI6D&ApkND!=*yMd9^vCeXc(e2a#~RnRYd|#nsw7|4H;qP-IB3RXU9YZx)UHp zJ^m)p;!QAmfwRKS+42p5c0Pbzw$Y8xvY+!Xex@#BELNzaj5lMGXG{W>xj|DfJ9nkZ zQnz>zk!JqQ>D-Kil}GxN-2UlqT*5aQk;9Q=khWN%1NcC2(wRyQGqL^z_Qr!B6j@o| z64RE~z!nB`qH+5uZ@@ZzpVQOlHS3(=VL}cy5h(>NMG&ftsOI=0+6S>E@4nTCyru6p zL8+8Qq;&2af&&#%kM-~V^b=1XXEr+mql&Z(@*u|N#nIaDh{WD zX818#_X&uZRLY<{3t1xn3ITLc=%~Cyde!hovVvqn& zTS+bR0vj0>75G@xc9@I286eJHi#k)C-I1Hbm_EHF*M3?&@q0)1W9lC{`+P@BZA&3Z zP6F=c9ZJEq@5xOFJcWk8;5I%-CRrK-U*45QzP+e;L1g?MOAmG8;WUzYwYSdTgJL{@l z8@t&-=j4!vlcK70XzvDebZ2((NL~!8t3O1bJ2-T4L+ExDY@-g32-4LihE*1L-ah-< zg=@Z>&QF(>Po}!lcbQ%-o*x$0_Ibn`o2)mE2#3BE?uIIhzwNcyhtBrol#Oj0+(g#cN7Zg7;Z( zg$KWNA4c<7v7?4O>$Gvk3o5DKEucyRSH$0CFW$fC-#aeKRZ9|`)3rSzYw42GNu`97 z#6kyf0WNW(&Q+yx27)#^4asiE;3QJeQ)qn?RPENb*aL*?Di+_Z#mbN0MBDt16^(tY zgm1r#%&?{RS9!l`^yN?=XtYPPE@`SJ_noY?MC9E_GV(u~K_PTUbYP60qgA$bWhZ6rNHoUx zNz;G3d8^rNB7s=8wI}(!I^<)apZua_R$WoXA@bFzYHp)Df6DRzl$t159N1`5jc{90 zLcK79QC)`FFzh^dGjTcLy+p~^h9AblVP;sq=Tjp2xIx` zDNkuH+?3o|U(-}q^spEr?9{z>R4RE|ko=7rMPz82zldyMZjiN|S>*&%M0Vih<>0PT z5@hVK9>bN0otr~uBm6!YnMml7VK_B|KTX&Y?gf|V>)NqG@|+^`oOU&el&9l!4=o1= zA6Ze}QgY-GhGi4K<@F=lvzqPrzEQE_CBBLlure}Iq3c$jqV!~3@dV#ezOFaWHD+;A zPO&xN^(}z9?f90nxkNK{3e+dK5f|pQMs8N+*`{0*-ba;EPH3QCFR{;!a*i5rWE2mP|8MW!_I4-$HHm7nx>r&l|*%qP1dck<}jwOBBWKZlz~ za;S;r79)Z*tKucajO;QE26A-T@sz`;D!(@%_b?c-wVq!3BZ;_Fra;!Q(oS3B<{rS#k z{kre0Hx*tXQS2s6KJZG-XJr2<7}={^CtM)x(ZxV# zW6MDwNj4L7kKTpRB$723t`sT4zFfVZ-0r$-!*tB8D-nD=KmP#xa~sOzqh5)%*OJn1 zxm{rc;ySQ8PmZ#VsgQcj&!S=r_eba6t;;+JL%u8)=0;gGer~5G-5(x0d4YxEKGfk8 ze97FX!n7FXYPxbRDf2`FFljp+q~cC}e7km?0_dxQF2x{Gtky+-SW^UTb#j~LuFK?8 z`kLT^4$=D8wivtw0r7WeH?8sn3vvEO>{;O&_H45JYF;zVQ*ZoPvzQR9soe@aVP7DA zwldl!-KVSlKKP8hklIYhm?-`U0M>+LAyW)`481a-32XDgy5M_`m+Z{|mm7BrHGkI4W`o9Yg|HEUY z$^U;{)fc9CNSDIz1mz)IpQ)a@XO~Zw`Px5umLUr|T~3+=nuz+q655l~&=6*AtvGXBCsseST|jaB31sOVMvV>&I2rsg0ym3JcAth5vg)MYC&0Y%f`SwQoZtWI%^ zJIUWC0p*9zZ)P(BDHJTU6V|zzit6e0vq}=1tnRnbr~$8#f4Xt!x{%5QFfAMf36Huk z)bXO6N&%9nb-mU*l7Vu1uPUVn)j!!w2$sBSWV@h*OGh=^kpaZf)Vy$<3wwx_O8Mu} zjPg;;ev&~sXocMVNynQY=>5`9-u79P_Ziq0%q#%S;vwx!c;g_?Xu)wsaOxgd#kLw269Jhom(OUFU`{6)jK5!R5el^brgSw z3&eh$mXAg$VWB{sMZN1t{cs*79VJfUOYh%swSU~Sl;FeD=0up_;&O;su|(&O7F~~P zMr{o`7Zv)#NMj(mkt%kWA#I5UMZH3oFr4lbmj5Iy&sjE!XfZu<5prkP#62wMQPna`MJaf8Wj`lNZmj!2;y_r-@*;$y1VOdEPS~mhbbAqgiX_ zr%^tZA~T~AA9wj(XfLzh_+D=atrmF(O5W z0APJQVyg#?lhhB!43*#{VLaLfwe8#HkS!$LIssd1vNu5>f)_7FgysVqNJ9VFo$!Y_ zKJuUvP*#d`Uo1ZTTom3&Il8RiA5IT8u5UCC7voP*cLthgVPhWBc2 zI9Hz!kS4`XrFE&TvkE_;gu%m#qJs!8?MWpmsx|`gfuz;rQch8*%OU)aF7RYVayfyhooxhTt z_B!=`B|O2`DL1z5!I-YLRjo`Kk{#V!Hj>)7qI!w1|&vDPewn2`ZjL& z(Av+M``7Ej);&4l2qxGkZXUa0>pxF7+U}DFm*7y6Ei(T(Y4Vk*+ac7v&M*j+njd!c(zDI zf`-Sc3AJJg&y#(^lXE4C8WFP3c9NXzq03k`vebWr1dFCG=WFC3gxTek@%@2NUYDPU z`3K+%=4q3;TOK4$ALe;ln>4dnR7hBlV{+=Mp@ROrO0ktJ%i?u?qPqH;O7C}Jz_5Q= zHKZ9@UBa>yLqoXc>@3aIUJz(u9P$?IVtEPK5@jXu^tqyd-zg`^wQ#c9VDbZJ01UHg zeNfU#gFgiQabzvPB;KSWyvfKeQB1iBO{$^m_~HjiDp((cLHa7EYizFt#(0&{!aX^u z@4-^OAGQt=GA(9;Gds&HX8VJW*eBn2Sw3_W!~Fhv`f|UDYE1UCN<8=L(<{~PdY=6_D(5MQ+^PH;N=#%p;=x=}vZobXDda6p{k1cg;>xi4={vD*bQe!IJErC5rQ_istc;;} zK{`N7ZHz;kgOYkaY~9 zovKPhf<{)b7)u>omovr6ZS4MyquyadO(x63jLB_mGC zX#MVXEk0<71SoTRa?@>y)57dhmvOF+{L9a7FNUWjW4^yFosForaeL9WPDJwcrYQ1? z&0%=?;XYo_;sb@`-9y~3FQ&91$vqR+j6>0^k3kPz{2uFdy>UfiL1I!Cefcj(cBJJ0 zK>Z*EmMcA4Q6O)_am!Xii&-Y3*=4B<;L5H$g$;3Ki-l61CCK3kH4U^ zqfWFE%}H>}dzh!Ai^nkwF)8zx=9VYOk6{QQQCtNpQ#HN=LT*(t9C&3$j%6A#JFjac z92t>AL%99z(J1{>8=DhEd62(MPlPRpb-bCfF8UF@Fit>_Av7)I%zy>9ei+zD34Txm z7c0V=olD-o^R08SRsL;8b#2lHIurvWX4dgybbs!)P4iR}U=`Rc(va(^I6K&1%o)JG zj+g5*x9Z@A$`-tL^NT$HkJIM`tA`Y{x}G&*i9AVgkI0E4 zAIP-gX6UZcaiq#e?MJYl_8Fg7&bum7_sD4V63%S>4XtUkLkz6qJmw?>d*6uW`_9ug z6(3_@sgoL%WvTP}$d5Ks__q1j{>C{^&vE_dbrUsAhS^aJ!fi|z9z6$kyh?O6#EUAG>J7fA;~hv+z$1w#@DmexA`KXQ zuW~GBg)L_VUR_Y-f?Ws}I!AR9pT}Fsn_9OXqnG+Oe`c+`Pq|tC=a2TU&8YuXhe!0g zKWjW~_%Mx#l^pjJ&d>J{Rh_lcZWg?AcR176!~RF>3L0^Q0#tqI)di2yig@rhgJKc+ z%!gF(hfK4U40U&KZoJk_1{bBraF_Wqvh-84#l6AFj?#u@N>1+aE?BiMGFP16EQNl5 zs;C1UuC1w%!sf3WF5~PqQP71vs*^Y5@t!ma6wALJ-Vm<_tuXCZHWfpHq~L>fSqf0B zXlF-CjO^azNxl$b@{LxoBt0!=b~j~okI?A@KzN2hOXyz;fNLF38Q4^Ky>#IsxRJ4O za?fD#6pKaGf>TY#_H7#&B_L70Me|mwePc?zAKYwLYS^aSC%Bg5Q!4AH;D^psk+EK0 zsB3VX94p;mS_>Va(t2BGuD~10ugTAV7-cmKXd_T{$rRiC+4JFu!SRvgN7vDN*_k?1@WrF2iu~e7hv(v zd|+4<}@%5u#lx&7ht!eAt?4DM_ImjmK2^X5e^gzmC~qYqV7ScSgFGjwMhgR@ox+W1p= ztRU{{QzQmer37rF7g3D#*zZpt6isaUiP7srF2GY&m*jAbG%)H+Sw-r`M(rX#Ne~yE+Uh5oVimV}~~NJS)N>&hjOkMdqdF`n$RV zR(2@2^tvcxWWsp*kM=R*g6TXPm-rA~FlzKmtS!<9I2o{^cC0((*^L>_l!5DfATp7< z5}van8^%Y?ur(x%cB97`hNUd&xU$WSj7>ltHalO6u58UPAITE&Lc}O{Jt=-^q`C$z zhVgt>*S(9FJ5&(9slVqG#A}xvV-;9wq28cz_9D~x$!^O<&+hP%jt$)pIj6{2)vy`{ z=4|&4A|$Bzt&`*hnyyAdM2%67QbAYxoaF=c6Oz)z;+Q@)D=9g`E!#e_s_C@UTzkp0lwKI=X zXd{W73B0%8<>6MGG7%vi=Z1=r<+cTvFVT2Xj^41OrVL20AEP~=9eg)vnnYqI6!ZVe zyV9tpvUD4V5>WDh#hr4<~tC4oz>+lsr4^f$j0J{}Gu)fp1 z`YT~THJWJi>TxVwXae``7JY1Bl&HTl*3u)#@iNiHypNmcW8ae%`|e?#;_aKOOgi?a zm&;;7!Ej`uBjL@wji{soU#0?XivI}Pp=(K(wpqn)Mns9Nrd-`$80BHuZ0 zE0i046{5_;TC-$pmG$io*55rBgK>y3Ru_C9dg4L<;iK23${rp|$&OqTt6DnjIT-ao z=@;L}gCx}H4Dvjmb2)8)>pt|*i~0DnoZuXkp>A5rnl#SCBR(IL5a;hp*Uasy!rm)R zA8(rz8V9pd>a}=@)qxk{c$T&7BOV2l30jGI*gJdm`0-@9l{YqQL6zFn+(|*0?eT(s zZzO(cEPnd^K<3}>rBD0+W&i%gHvHu$Sx|?QH$n1N`g;aB6qubiny&n&XUEIUE&7}9 zJKj_~u;bj|aC(s%yzTs8j7W;OtcDjO>O8vMjvSc!V7ur0)C~T!vZ}=xy{hY>*=pVp|TpLURG^2|lfbT$Ga) zu0xLzD&-kSRy%SvG;u+UXfUL}I*#R`A?iWsc_G`LXL&lB@2^1k`Q@=J$<(Wo>X{p4 zw7+}vj%ssyNL$Ho>%sU-*55jwwZ{=mOLS6|&+I-~xP^Xn2YoD1db5p_`Q$$&?mrj$ z|Mb^bqJPs+4%jJ*z(|uK{O--TJ;rA_>_|AAR5$bp^?ddC>PRIeU)|R58rwo$>qw>S zo?Y~hk=|RAeKVcV)n;|4c!^p1mDVo1~-BvsV|0Yd_bC)%L37I?DrlF4Hv`7K}Y76 zhHoD3^8RT@UgUQN8dj}RH9i_B9pIn3(_Y7S`cR!?vb(5x@ft7L?>yDQ!tym9chlym ze9M);SYQ9=e*VEhss*9YUkpQFchFsB`*qpsn|{FAZlMi48l$yb^UY$MC^_k<6bmKI zE&k7A=K$p^CW=;_5+mO1giUdOf9yQUSYAJ4?!xFa51gcj938Wg>0f5P`Ib&p_0Cf+ zE9#DIRaBfpDVZLL&XhP&C|QTfkFaE3>I%Oen!j_MUq;_VR*H~%DB}UD!T9UnV8&mW z+UGqR{uQtPZfNg^|6CX+-3i$hbesopbA81L-gVaS2{9rK?OM9vi5_oNi{r-S2;Q^x zpUK+eTFhkrt%a%a=4^ChEd2SX}i+nXU}W;oXxWt zvkvr`^m6oFGHdbPw|BTMXnX<@RH^$V1*kgtnnce5XfR&)#)+P3sEzv zNaoiVuEVI;`H$HCPEZ}PKw1kIsu_X7<&`6g#fbKt9*7n|WNE&dWjZk(Kc0!Hak<=@ zdnfYzoLQ2oQ7*>GQ)9BIs$iH=uw_@+!IF7xFM3cBCh4b1Psqx6<%8X%c#}w0Oz(O4 zSj>I&EQoZmV2~Q}MCrVwVJ{xmX>AADm7_T;3v4$vu8Lbc<$cmO0s(F-o%s5CB!SjI zWPzTGOS`~6ThQ_PRX&F(bKR&WdX0CDvvEPd0C_&TEII*FgLD=`44YWC-rfmnn`TBE zOGbR!m=2_-cRbSvf(m#$4$h(rpNSD%UF2_%#{`#IOyVm*qo<)q^SJIFj7S3fF9Rrv z6&3OlZU(6~e-6?YBfN-wtggSAP>&)qr!0fcWpkJLR0c7M!f@fKTFobRQ^0psV#NFy zkam@-?iRtEJwARhWTMWSiCZzG=k}xBo0`rx=~tH0 z_c^hnfj$Q&{R-)$V|ZViqak}ko`(?pGvB^ON1;8I$+|eZ_170wrSVgL5h5cVh9zto~)r)`ON>CrQfde z#bpA|ZuhbJRU-VlO2-OWpA|s7Ym3sVcJu6&2^iMkgDE4k08dWe%MVu3X8bwZ*O_>g z!CM?@c{Vg1f|{yo>h0hpQuIp_`d>kV1F$`4>dAj`mtSAPxi(VzDmWF-gm~M7$3acj z5q_%4Uf}N82Ws+CP(Q;A_0vTeooK+%8Jf!1{Thl{OIcZ+VrBq^)$>R?i`Oc z6cZkT^a$Jnx*kdwY~%LI?~CC-OOw@HY0YqG7W6A94PJ&On$GUdE>@(oBE`($H`dijn|!c0d>_qt)JBWu1!m?t@AFlGTf0| zIZ)7r2;d11hg!No&(a6BihfV4AbBhT#C#hR0H@=#3#Uv6S%=2L%@W# zCc1)N(PPk?s`2tlLS=9rA&sH3F?y4vDRQ+zc7uB^DPo}gWkqr2K+d}h%O*QrE)RL@ zy3IEw7||jk`m6{OC8Y^>?*su<4FkvB=*2q;EzWb8yhSnMX$VFb484Jw!@Xox2JzPm zD52j)%Pnax&vNpAxd0V^F(P0k^qP($s0>xjk-&5aQxhW;0OF}}sG^@lMEJ2u6I=l% z^Q~a2SD^M@?%L<^&wqTa#ZznAXZilU(1cUj|2V#&RBBkQmTa}xdRqZ}fNhrCYU41O z<7ro3@qrt-JEKa8&myfAJw=b)Fnv;T)-SAt{z8C*7pa^!u>pCp}8rdh9ChEWwx zjV*Kip4~#Hj4FWp;WRNK7vPRMTJ-)mr~}8r11*#ZM=>Ja88#av{syt+x%kf!3e5oh z+^4L>h{hN(;<}Bf@-g%QbSzhYD-0EE2OUPAfISB}nrve_$4!igRb@~FL1IMPtNi}y zJ3Kdl&Zn!^wg&-p2G~3+YY5OerMNL^NFX?A&_jIvZ7P9=$GuOQHK!^&*{U|F88|v| z@ObRBC#gfkRJ)^kCv%(6*n`L`vmW^E)y|sx>;2 zqFefA%bjyllw_AasxRG`F@7g8K$f33dyC%SSBbY^)rL<=OUz7hkBCl@&`FPwKKY}f zxu`U4ElehD<&YG0{k=%<_{)Z8equN1CBMdY(X4zanT9Ini9@>$oZZ{ICM{z%91R1U zsLeb91N&C9XL=|mES&}(S&>;=1X@HS&^;?OR(Fp;MWYUR`rdPERpP;>b1KUW4x6V#;XJj_rZ;gv*|ANyN9zwLejRiuY3Rkq zXqm#Jgi@6Ny;$sJzEguq;?lX7$0@RhEak1!60FZYvY72_p@q2B4{Tx0#SBt<`A8|= z#Z|6P-zcBZVD4Hq8E&^d>*D!DgAq8MFNc#b+u;NEt>1fGS!s>k$%lv9wAZS>Adx*P zV$?0JSGQbU@O3X6n4QjNAK>Z9MS6dK`=#y0ScjcYtB-ER)*vtZyJUqtE2frupSYdE zz+cA6geYCcvFm5bs;qK^{2dD+3%&p#sxESr!-D&rK)G0lfi>~MeVAFIz*~S89Slws z;giIOm4dW^rKPi=-=K3|S8hb4>zrl0g;4|5+7Ue7lJ;!5WoNaEM|(C#dROIA^U$(= zbUAI2_Q6CFVC5&KMg>RY;{~#xm>R|qHiia~1Qye*j>>bS&jSJz2BtFxc@^B47`__- z*Ls>zb`Cbq@1H0_&Qdt^BEecHH1T6ET^c4Zb9+RO;D#lPCwl%mK|m)2mGvEAAj6?e zCi;LEwudMQ%{nOK<-LVLDVXKOfML;yBs?8dI74_KcnvIq%|xLSlXsTbhvp1(Ail3a zlh4$Fzv3EKPWMk&M10FTZe_2l^(Zp<9ea0u{GRFe*_FDq&eR~?+qEBZ?9z5w37GK> zY<*=7(qzsPARHHRTGQkN$Dszfc{R~1zSs~9h&Tcd zH!qYg7w^D99>d$&Tx;)`z-67!30O!wqs8qQ^7j;oc%7}Cx%Yi*A4}zzOqY=ws)NNTMj_3ZN=fCw;p0dM+yKvBs`Y+`Qd`a2oA>&q-ipR`LYWp zsLz4TM2^&j>mdT>2^ME(*?@YEeJT#0V=L~AP-nhexEHtDy-Hi4z=1e-9}8l1-uH|i zpAQ!0rp{3Nx3(Bh4;q;^v)6TRHB-35u^G2#*V*^lp9$(0#*_{or+j6pu*dm@N!wYH zwVCq1(h4EDWe)J(pJTdY{}ALYCVmNAw;TRD1kC$mKs9yNDzbTxO-+;J3GX|XPWp^` z?!3|2|F~$;5gulP?2W-Z+uFUyRrfTO{<$dbYFD7+l~N;xA<`$TV&KD`(n@7zB(G*Q z-?#ywO4_J<5GrC|^UIy*zEi6qh2x58KO@&4A-Tg!a*k&Am>N}mH|!-}yaNQFvHaB4 z;g1_ksxzv*L|#7!V|xhpX>HswPHbPfrf@sG=?EP}$7=y=gOk>xlu)i&94!Jcy~>LW zq}4?TbRE+Gks71mLn@>cvIe6cSq+sgnkcfG*ECt(JecJ$a9pBYszCo(bN#;S4I|iUOXo&yw!tM3Gs6Dqmen8IFQWu>$U{!97;y!?xUS+!7!0My%wVtULLiL3smc#~*?WLk zNVr6g{b8i^a+_Uo`$8bnpw^#XVnUk9zIyplMBgcY1z9Y3eK{rY8I^lgyd&P$$I05;IP{ojhwybm<>MJ1O=+dO8)k3B30uKw1C9&q zUC?Z3Ghh~4ha_H*@`5(8D&C@HODplVtlEoR(uA2A zPNwJ-Mj|N=d983PcpxHLOG}~&%_b1X;al=tzmU+n(~Zw_!m=fG6~E8%Ywank8Vka# zW4O2YC&xyOZtL7DaNGe*6yUG1hi7B-=gVuDQo=)BiyGM4`IKk-ktN#p#HABt*h%Vf zBPo#^s2A|=9U5h(1UTpV%@%SSz5*fFh&C<6s^@_BFp$fYR8Tp-dmxes^^SML$Xoba z{WkqXQ$~fAXoi=jUekVu)v3DO!Zx>``_er`qQ3*k_n_eC|H$5f|K-oF_-Ea|fLJKF zeIgbhftt<{abf_m!19Tj@G7oI1XROczpMie8kzU`w6dt>J=T_Xg*!cr(P~e;RZI$H z^96exlCxYmore9Oq{{^NgS(@VXeu8%Sv-vN17A4Hg3@*kQzSvE(5LQo2JWKp5Jfu)rbAEjB1v&g zLeGAJ+)p2>A=B&V>&ZIv5)L0qahvL{KKT&xq2V&rX~ACh*T$&+Zq>H+Ub>7EZpjHtz-H~PiItDGFr`#mpM2w9Ie^s>~jC`H;BAEEyOlc1ZZH+ zjS2#zq5leoCggBSPju0K)DM-@cy^Z5+C|RGg;%be_L$@=H5OkuaeGsDn`7Z z%nD|ai<2DEk_k>I!4wFv6||baV9?T8jfB=(kw>4Y@Q0k8`F63TX~Z}H_^S1t=}V9+ z>iVfjr`{YwA;h3oJj-~nd~$;Nj}@#rQr=K%S@ia)sye;#Y)NAW^5eNBVZz;(Yhxyb zPyyi#E+H$UtQB>}9n2sgxq@;RFUsC0>ty)XgXYEK_hzeTIgT@Imx1}Yh5ie;WIemB zt`ebqY~1hcGO#2#d5#_Fq+;}AVB3R%Mk^GzHFawNX}#0grCd{tM+tE(2AumBlL`HW z=U6%LVCpxMs|8wuRshfhOIB~r!)Cv?=>nYjDZDDwsMwN zGO+0`s>Si`mgJivTfX(bx8Q=Zup}7hk?3~5&OT%f$?L7zTPurVaj{?ln%qw29XvLi z$mzb6{r$vXK1%Ag%%j_DjiZ834idu#6nVkG@+5_;>OJ?MySbGjr!xGFntv38dFjk8 zp7p*3_L7MaDZtD@FRmq3{Y3w$C));136#c%Gt25+}U2lE8w&Mq{-CenE zy_0G@u+)hoBCiBcqN)`9-G*e!(=S)4MfM+hB1(_#5F_&9#}`-^=4^u~6|KcBriYg8 z4-5lK_H@Sg_in1^MFLCqCIU#&;m=sagwgiM{Db~^C*!=M)+IQ~_j)niYUk~?q^gDI oZtXo$Xw9sQvt*gf{aVy+f6*@$xa_~{q57|QPiFVu>(kovUjW+IkN^Mx literal 0 HcmV?d00001 diff --git a/docs/source/recipes/Streaming-ASR/librispeech/index.rst b/docs/source/recipes/Streaming-ASR/librispeech/index.rst new file mode 100644 index 000000000..546ce168b --- /dev/null +++ b/docs/source/recipes/Streaming-ASR/librispeech/index.rst @@ -0,0 +1,9 @@ +LibriSpeech +=========== + +.. toctree:: + :maxdepth: 1 + + pruned_transducer_stateless + + lstm_pruned_stateless_transducer diff --git a/docs/source/recipes/librispeech/lstm_pruned_stateless_transducer.rst b/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst similarity index 100% rename from docs/source/recipes/librispeech/lstm_pruned_stateless_transducer.rst rename to docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst diff --git a/docs/source/recipes/Streaming-ASR/librispeech/pruned_transducer_stateless.rst b/docs/source/recipes/Streaming-ASR/librispeech/pruned_transducer_stateless.rst new file mode 100644 index 000000000..de7102ba8 --- /dev/null +++ b/docs/source/recipes/Streaming-ASR/librispeech/pruned_transducer_stateless.rst @@ -0,0 +1,735 @@ +Pruned transducer statelessX +============================ + +This tutorial shows you how to run a **streaming** conformer transducer model +with the `LibriSpeech `_ dataset. + +.. Note:: + + The tutorial is suitable for `pruned_transducer_stateless `_, + `pruned_transducer_stateless2 `_, + `pruned_transducer_stateless4 `_, + `pruned_transducer_stateless5 `_, + We will take pruned_transducer_stateless4 as an example in this tutorial. + +.. HINT:: + + We assume you have read the page :ref:`install icefall` and have setup + the environment for ``icefall``. + +.. HINT:: + + We recommend you to use a GPU or several GPUs to run this recipe. + +.. hint:: + + Please scroll down to the bottom of this page to find download links + for pretrained models if you don't want to train a model from scratch. + + +We use pruned RNN-T to compute the loss. + +.. note:: + + You can find the paper about pruned RNN-T at the following address: + + ``_ + +The transducer model consists of 3 parts: + + - Encoder, a.k.a, the transcription network. We use a Conformer model (the reworked version by Daniel Povey) + - Decoder, a.k.a, the prediction network. We use a stateless model consisting of + ``nn.Embedding`` and ``nn.Conv1d`` + - Joiner, a.k.a, the joint network. + +.. caution:: + + Contrary to the conventional RNN-T models, we use a stateless decoder. + That is, it has no recurrent connections. + + +Data preparation +---------------- + +.. hint:: + + The data preparation is the same as other recipes on LibriSpeech dataset, + if you have finished this step, you can skip to ``Training`` directly. + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./prepare.sh + +The script ``./prepare.sh`` handles the data preparation for you, **automagically**. +All you need to do is to run it. + +The data preparation contains several stages, you can use the following two +options: + + - ``--stage`` + - ``--stop-stage`` + +to control which stage(s) should be run. By default, all stages are executed. + + +For example, + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./prepare.sh --stage 0 --stop-stage 0 + +means to run only stage 0. + +To run stage 2 to stage 5, use: + +.. code-block:: bash + + $ ./prepare.sh --stage 2 --stop-stage 5 + +.. HINT:: + + If you have pre-downloaded the `LibriSpeech `_ + dataset and the `musan `_ dataset, say, + they are saved in ``/tmp/LibriSpeech`` and ``/tmp/musan``, you can modify + the ``dl_dir`` variable in ``./prepare.sh`` to point to ``/tmp`` so that + ``./prepare.sh`` won't re-download them. + +.. NOTE:: + + All generated files by ``./prepare.sh``, e.g., features, lexicon, etc, + are saved in ``./data`` directory. + +We provide the following YouTube video showing how to run ``./prepare.sh``. + +.. note:: + + To get the latest news of `next-gen Kaldi `_, please subscribe + the following YouTube channel by `Nadira Povey `_: + + ``_ + +.. youtube:: ofEIoJL-mGM + + +Training +-------- + +.. NOTE:: + + We put the streaming and non-streaming model in one recipe, to train a streaming model you only + need to add **4** extra options comparing with training a non-streaming model. These options are + ``--dynamic-chunk-training``, ``--num-left-chunks``, ``--causal-convolution``, ``--short-chunk-size``. + You can see the configurable options below for their meanings or read https://arxiv.org/pdf/2012.05481.pdf for more details. + +Configurable options +~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless4/train.py --help + + +shows you the training options that can be passed from the commandline. +The following options are used quite often: + + - ``--exp-dir`` + + The directory to save checkpoints, training logs and tensorboard. + + - ``--full-libri`` + + If it's True, the training part uses all the training data, i.e., + 960 hours. Otherwise, the training part uses only the subset + ``train-clean-100``, which has 100 hours of training data. + + .. CAUTION:: + The training set is perturbed by speed with two factors: 0.9 and 1.1. + If ``--full-libri`` is True, each epoch actually processes + ``3x960 == 2880`` hours of data. + + - ``--num-epochs`` + + It is the number of epochs to train. For instance, + ``./pruned_transducer_stateless4/train.py --num-epochs 30`` trains for 30 epochs + and generates ``epoch-1.pt``, ``epoch-2.pt``, ..., ``epoch-30.pt`` + in the folder ``./pruned_transducer_stateless4/exp``. + + - ``--start-epoch`` + + It's used to resume training. + ``./pruned_transducer_stateless4/train.py --start-epoch 10`` loads the + checkpoint ``./pruned_transducer_stateless4/exp/epoch-9.pt`` and starts + training from epoch 10, based on the state from epoch 9. + + - ``--world-size`` + + It is used for multi-GPU single-machine DDP training. + + - (a) If it is 1, then no DDP training is used. + + - (b) If it is 2, then GPU 0 and GPU 1 are used for DDP training. + + The following shows some use cases with it. + + **Use case 1**: You have 4 GPUs, but you only want to use GPU 0 and + GPU 2 for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ export CUDA_VISIBLE_DEVICES="0,2" + $ ./pruned_transducer_stateless4/train.py --world-size 2 + + **Use case 2**: You have 4 GPUs and you want to use all of them + for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless4/train.py --world-size 4 + + **Use case 3**: You have 4 GPUs but you only want to use GPU 3 + for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ export CUDA_VISIBLE_DEVICES="3" + $ ./pruned_transducer_stateless4/train.py --world-size 1 + + .. caution:: + + Only multi-GPU single-machine DDP training is implemented at present. + Multi-GPU multi-machine DDP training will be added later. + + - ``--max-duration`` + + It specifies the number of seconds over all utterances in a + batch, before **padding**. + If you encounter CUDA OOM, please reduce it. + + .. HINT:: + + Due to padding, the number of seconds of all utterances in a + batch will usually be larger than ``--max-duration``. + + A larger value for ``--max-duration`` may cause OOM during training, + while a smaller value may increase the training time. You have to + tune it. + + - ``--use-fp16`` + + If it is True, the model will train with half precision, from our experiment + results, by using half precision you can train with two times larger ``--max-duration`` + so as to get almost 2X speed up. + + - ``--dynamic-chunk-training`` + + The flag that indicates whether to train a streaming model or not, it + **MUST** be True if you want to train a streaming model. + + - ``--short-chunk-size`` + + When training a streaming attention model with chunk masking, the chunk size + would be either max sequence length of current batch or uniformly sampled from + (1, short_chunk_size). The default value is 25, you don't have to change it most of the time. + + - ``--num-left-chunks`` + + It indicates how many left context (in chunks) that can be seen when calculating attention. + The default value is 4, you don't have to change it most of the time. + + + - ``--causal-convolution`` + + Whether to use causal convolution in conformer encoder layer, this requires + to be True when training a streaming model. + + +Pre-configured options +~~~~~~~~~~~~~~~~~~~~~~ + +There are some training options, e.g., number of encoder layers, +encoder dimension, decoder dimension, number of warmup steps etc, +that are not passed from the commandline. +They are pre-configured by the function ``get_params()`` in +`pruned_transducer_stateless4/train.py `_ + +You don't need to change these pre-configured parameters. If you really need to change +them, please modify ``./pruned_transducer_stateless4/train.py`` directly. + + +.. NOTE:: + + The options for `pruned_transducer_stateless5 `_ are a little different from + other recipes. It allows you to configure ``--num-encoder-layers``, ``--dim-feedforward``, ``--nhead``, ``--encoder-dim``, ``--decoder-dim``, ``--joiner-dim`` from commandline, so that you can train models with different size with pruned_transducer_stateless5. + + +Training logs +~~~~~~~~~~~~~ + +Training logs and checkpoints are saved in ``--exp-dir`` (e.g. ``pruned_transducer_stateless4/exp``. +You will find the following files in that directory: + + - ``epoch-1.pt``, ``epoch-2.pt``, ... + + These are checkpoint files saved at the end of each epoch, containing model + ``state_dict`` and optimizer ``state_dict``. + To resume training from some checkpoint, say ``epoch-10.pt``, you can use: + + .. code-block:: bash + + $ ./pruned_transducer_stateless4/train.py --start-epoch 11 + + - ``checkpoint-436000.pt``, ``checkpoint-438000.pt``, ... + + These are checkpoint files saved every ``--save-every-n`` batches, + containing model ``state_dict`` and optimizer ``state_dict``. + To resume training from some checkpoint, say ``checkpoint-436000``, you can use: + + .. code-block:: bash + + $ ./pruned_transducer_stateless4/train.py --start-batch 436000 + + - ``tensorboard/`` + + This folder contains tensorBoard logs. Training loss, validation loss, learning + rate, etc, are recorded in these logs. You can visualize them by: + + .. code-block:: bash + + $ cd pruned_transducer_stateless4/exp/tensorboard + $ tensorboard dev upload --logdir . --description "pruned transducer training for LibriSpeech with icefall" + + It will print something like below: + + .. code-block:: + + TensorFlow installation not found - running with reduced feature set. + Upload started and will continue reading any new data as it's added to the logdir. + + To stop uploading, press Ctrl-C. + + New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/97VKXf80Ru61CnP2ALWZZg/ + + [2022-11-20T15:50:50] Started scanning logdir. + Uploading 4468 scalars... + [2022-11-20T15:53:02] Total uploaded: 210171 scalars, 0 tensors, 0 binary objects + Listening for new data in logdir... + + Note there is a URL in the above output. Click it and you will see + the following screenshot: + + .. figure:: images/streaming-librispeech-pruned-transducer-tensorboard-log.jpg + :width: 600 + :alt: TensorBoard screenshot + :align: center + :target: https://tensorboard.dev/experiment/97VKXf80Ru61CnP2ALWZZg/ + + TensorBoard screenshot. + + .. hint:: + + If you don't have access to google, you can use the following command + to view the tensorboard log locally: + + .. code-block:: bash + + cd pruned_transducer_stateless4/exp/tensorboard + tensorboard --logdir . --port 6008 + + It will print the following message: + + .. code-block:: + + Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all + TensorBoard 2.8.0 at http://localhost:6008/ (Press CTRL+C to quit) + + Now start your browser and go to ``_ to view the tensorboard + logs. + + + - ``log/log-train-xxxx`` + + It is the detailed training log in text format, same as the one + you saw printed to the console during training. + +Usage example +~~~~~~~~~~~~~ + +You can use the following command to start the training using 4 GPUs: + +.. code-block:: bash + + export CUDA_VISIBLE_DEVICES="0,1,2,3" + ./pruned_transducer_stateless4/train.py \ + --world-size 4 \ + --dynamic-chunk-training 1 \ + --causal-convolution 1 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless4/exp \ + --full-libri 1 \ + --max-duration 300 + +.. NOTE:: + + Comparing with training a non-streaming model, you only need to add two extra options, + ``--dynamic-chunk-training 1`` and ``--causal-convolution 1`` . + + +Decoding +-------- + +The decoding part uses checkpoints saved by the training part, so you have +to run the training part first. + +.. hint:: + + There are two kinds of checkpoints: + + - (1) ``epoch-1.pt``, ``epoch-2.pt``, ..., which are saved at the end + of each epoch. You can pass ``--epoch`` to + ``pruned_transducer_stateless4/decode.py`` to use them. + + - (2) ``checkpoints-436000.pt``, ``epoch-438000.pt``, ..., which are saved + every ``--save-every-n`` batches. You can pass ``--iter`` to + ``pruned_transducer_stateless4/decode.py`` to use them. + + We suggest that you try both types of checkpoints and choose the one + that produces the lowest WERs. + +.. tip:: + + To decode a streaming model, you can use either ``simulate streaming decoding`` in ``decode.py`` or + ``real streaming decoding`` in ``streaming_decode.py``, the difference between ``decode.py`` and + ``streaming_decode.py`` is that, ``decode.py`` processes the whole acoustic frames at one time with masking (i.e. same as training), + but ``streaming_decode.py`` processes the acoustic frames chunk by chunk (so it can only see limited context). + +.. NOTE:: + + ``simulate streaming decoding`` in ``decode.py`` and ``real streaming decoding`` in ``streaming_decode.py`` should + produce almost the same results given the same ``--decode-chunk-size`` and ``--left-context``. + + +Simulate streaming decoding +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless4/decode.py --help + +shows the options for decoding. +The following options are important for streaming models: + + ``--simulate-streaming`` + + If you want to decode a streaming model with ``decode.py``, you **MUST** set + ``--simulate-streaming`` to ``True``. ``simulate`` here means the acoustic frames + are not processed frame by frame (or chunk by chunk), instead, the whole sequence + is processed at one time with masking (the same as training). + + ``--causal-convolution`` + + If True, the convolution module in encoder layers will be causal convolution. + This is **MUST** be True when decoding with a streaming model. + + ``--decode-chunk-size`` + + For streaming models, we will calculate the chunk-wise attention, ``--decode-chunk-size`` + indicates the chunk length (in frames after subsampling) for chunk-wise attention. + For ``simulate streaming decoding`` the ``decode-chunk-size`` is used to generate + the attention mask. + + ``--left-context`` + + ``--left-context`` indicates how many left context frames (after subsampling) can be seen + for current chunk when calculating chunk-wise attention. Normally, ``left-context`` should equal + to ``decode-chunk-size * num-left-chunks``, where ``num-left-chunks`` is the option used + to train this model. For ``simulate streaming decoding`` the ``left-context`` is used to generate + the attention mask. + + +The following shows two examples (for the two types of checkpoints): + +.. code-block:: bash + + for m in greedy_search fast_beam_search modified_beam_search; do + for epoch in 25 20; do + for avg in 7 5 3 1; do + ./pruned_transducer_stateless4/decode.py \ + --epoch $epoch \ + --avg $avg \ + --simulate-streaming 1 \ + --causal-convolution 1 \ + --decode-chunk-size 16 \ + --left-context 64 \ + --exp-dir pruned_transducer_stateless4/exp \ + --max-duration 600 \ + --decoding-method $m + done + done + done + + +.. code-block:: bash + + for m in greedy_search fast_beam_search modified_beam_search; do + for iter in 474000; do + for avg in 8 10 12 14 16 18; do + ./pruned_transducer_stateless4/decode.py \ + --iter $iter \ + --avg $avg \ + --simulate-streaming 1 \ + --causal-convolution 1 \ + --decode-chunk-size 16 \ + --left-context 64 \ + --exp-dir pruned_transducer_stateless4/exp \ + --max-duration 600 \ + --decoding-method $m + done + done + done + + +Real streaming decoding +~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless4/streaming_decode.py --help + +shows the options for decoding. +The following options are important for streaming models: + + ``--decode-chunk-size`` + + For streaming models, we will calculate the chunk-wise attention, ``--decode-chunk-size`` + indicates the chunk length (in frames after subsampling) for chunk-wise attention. + For ``real streaming decoding``, we will process ``decode-chunk-size`` acoustic frames at each time. + + ``--left-context`` + + ``--left-context`` indicates how many left context frames (after subsampling) can be seen + for current chunk when calculating chunk-wise attention. Normally, ``left-context`` should equal + to ``decode-chunk-size * num-left-chunks``, where ``num-left-chunks`` is the option used + to train this model. + + ``--num-decode-streams`` + + The number of decoding streams that can be run in parallel (very similar to the ``bath size``). + For ``real streaming decoding``, the batches will be packed dynamically, for example, if the + ``num-decode-streams`` equals to 10, then, sequence 1 to 10 will be decoded at first, after a while, + suppose sequence 1 and 2 are done, so, sequence 3 to 12 will be processed parallelly in a batch. + + +.. NOTE:: + + We also try adding ``--right-context`` in the real streaming decoding, but it seems not to benefit + the performance for all the models, the reasons might be the training and decoding mismatch. You + can try decoding with ``--right-context`` to see if it helps. The default value is 0. + + +The following shows two examples (for the two types of checkpoints): + +.. code-block:: bash + + for m in greedy_search fast_beam_search modified_beam_search; do + for epoch in 25 20; do + for avg in 7 5 3 1; do + ./pruned_transducer_stateless4/decode.py \ + --epoch $epoch \ + --avg $avg \ + --decode-chunk-size 16 \ + --left-context 64 \ + --num-decode-streams 100 \ + --exp-dir pruned_transducer_stateless4/exp \ + --max-duration 600 \ + --decoding-method $m + done + done + done + + +.. code-block:: bash + + for m in greedy_search fast_beam_search modified_beam_search; do + for iter in 474000; do + for avg in 8 10 12 14 16 18; do + ./pruned_transducer_stateless4/decode.py \ + --iter $iter \ + --avg $avg \ + --decode-chunk-size 16 \ + --left-context 64 \ + --num-decode-streams 100 \ + --exp-dir pruned_transducer_stateless4/exp \ + --max-duration 600 \ + --decoding-method $m + done + done + done + + +.. tip:: + + Supporting decoding methods are as follows: + + - ``greedy_search`` : It takes the symbol with largest posterior probability + of each frame as the decoding result. + + - ``beam_search`` : It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf and + `espnet/nets/beam_search_transducer.py `_ + is used as a reference. Basicly, it keeps topk states for each frame, and expands the kept states with their own contexts to + next frame. + + - ``modified_beam_search`` : It implements the same algorithm as ``beam_search`` above, but it + runs in batch mode with ``--max-sym-per-frame=1`` being hardcoded. + + - ``fast_beam_search`` : It implements graph composition between the output ``log_probs`` and + given ``FSAs``. It is hard to describe the details in several lines of texts, you can read + our paper in https://arxiv.org/pdf/2211.00484.pdf or our `rnnt decode code in k2 `_. ``fast_beam_search`` can decode with ``FSAs`` on GPU efficiently. + + - ``fast_beam_search_LG`` : The same as ``fast_beam_search`` above, ``fast_beam_search`` uses + an trivial graph that has only one state, while ``fast_beam_search_LG`` uses an LG graph + (with N-gram LM). + + - ``fast_beam_search_nbest`` : It produces the decoding results as follows: + + - (1) Use ``fast_beam_search`` to get a lattice + - (2) Select ``num_paths`` paths from the lattice using ``k2.random_paths()`` + - (3) Unique the selected paths + - (4) Intersect the selected paths with the lattice and compute the + shortest path from the intersection result + - (5) The path with the largest score is used as the decoding output. + + - ``fast_beam_search_nbest_LG`` : It implements same logic as ``fast_beam_search_nbest``, the + only difference is that it uses ``fast_beam_search_LG`` to generate the lattice. + +.. NOTE:: + + The supporting decoding methods in ``streaming_decode.py`` might be less than that in ``decode.py``, if needed, + you can implement them by yourself or file a issue in `icefall `_ . + + +Export Model +------------ + +`pruned_transducer_stateless4/export.py `_ supports exporting checkpoints from ``pruned_transducer_stateless4/exp`` in the following ways. + +Export ``model.state_dict()`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Checkpoints saved by ``pruned_transducer_stateless4/train.py`` also include +``optimizer.state_dict()``. It is useful for resuming training. But after training, +we are interested only in ``model.state_dict()``. You can use the following +command to extract ``model.state_dict()``. + +.. code-block:: bash + + # Assume that --epoch 25 --avg 3 produces the smallest WER + # (You can get such information after running ./pruned_transducer_stateless4/decode.py) + + epoch=25 + avg=3 + + ./pruned_transducer_stateless4/export.py \ + --exp-dir ./pruned_transducer_stateless4/exp \ + --streaming-model 1 \ + --causal-convolution 1 \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch $epoch \ + --avg $avg + +.. caution:: + + ``--streaming-model`` and ``--causal-convolution`` require to be True to export + a streaming mdoel. + +It will generate a file ``./pruned_transducer_stateless4/exp/pretrained.pt``. + +.. hint:: + + To use the generated ``pretrained.pt`` for ``pruned_transducer_stateless4/decode.py``, + you can run: + + .. code-block:: bash + + cd pruned_transducer_stateless4/exp + ln -s pretrained.pt epoch-999.pt + + And then pass ``--epoch 999 --avg 1 --use-averaged-model 0`` to + ``./pruned_transducer_stateless4/decode.py``. + +To use the exported model with ``./pruned_transducer_stateless4/pretrained.py``, you +can run: + +.. code-block:: bash + + ./pruned_transducer_stateless4/pretrained.py \ + --checkpoint ./pruned_transducer_stateless4/exp/pretrained.pt \ + --simulate-streaming 1 \ + --causal-convolution 1 \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + + +Export model using ``torch.jit.script()`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + ./pruned_transducer_stateless4/export.py \ + --exp-dir ./pruned_transducer_stateless4/exp \ + --streaming-model 1 \ + --causal-convolution 1 \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 25 \ + --avg 3 \ + --jit 1 + +.. caution:: + + ``--streaming-model`` and ``--causal-convolution`` require to be True to export + a streaming mdoel. + +It will generate a file ``cpu_jit.pt`` in the given ``exp_dir``. You can later +load it by ``torch.jit.load("cpu_jit.pt")``. + +Note ``cpu`` in the name ``cpu_jit.pt`` means the parameters when loaded into Python +are on CPU. You can use ``to("cuda")`` to move them to a CUDA device. + +.. NOTE:: + + You will need this ``cpu_jit.pt`` when deploying with Sherpa framework. + + +Download pretrained models +-------------------------- + +If you don't want to train from scratch, you can download the pretrained models +by visiting the following links: + + - `pruned_transducer_stateless `_ + + - `pruned_transducer_stateless2 `_ + + - `pruned_transducer_stateless4 `_ + + - `pruned_transducer_stateless5 `_ + + See ``_ + for the details of the above pretrained models + + +Deploy with Sherpa +------------------ + +Please see ``_ +for how to deploy the models in ``sherpa``. diff --git a/docs/source/recipes/index.rst b/docs/source/recipes/index.rst index 9d1d83d29..63793275c 100644 --- a/docs/source/recipes/index.rst +++ b/docs/source/recipes/index.rst @@ -13,7 +13,5 @@ We may add recipes for other tasks as well in the future. :maxdepth: 2 :caption: Table of Contents - aishell/index - librispeech/index - timit/index - yesno/index + Non-streaming-ASR/index + Streaming-ASR/index From 6d659f423dbb67b20309e00d5885b76c5dfd15e8 Mon Sep 17 00:00:00 2001 From: kobenaxie <572745565@qq.com> Date: Thu, 15 Dec 2022 20:42:07 +0800 Subject: [PATCH 078/120] delete duplicate line for encoder initial state (#765) --- .../ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py index 716de5734..64c16141c 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py @@ -152,7 +152,6 @@ def export_encoder_model_jit_trace( x = torch.zeros(1, T, 80, dtype=torch.float32) states = encoder_model.init_states() - states = encoder_model.init_states() traced_model = torch.jit.trace(encoder_model, (x, states)) traced_model.save(encoder_filename) From fbc1d3b194cfb2be4d01de85d9c3a3ea13c961fb Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Sat, 17 Dec 2022 22:03:13 +0800 Subject: [PATCH 079/120] fix src_key_padding_mask in DownsampledZipformerEncoder (#768) --- egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index ed1e2efa2..71f12e44a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -741,7 +741,7 @@ class DownsampledZipformerEncoder(nn.Module): src, feature_mask=feature_mask, mask=mask, - src_key_padding_mask=mask, + src_key_padding_mask=src_key_padding_mask, ) src = self.upsample(src) # remove any extra frames that are not a multiple of downsample_factor From 65d7192dca03ba21bff4270add3891c9730491a7 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Mon, 19 Dec 2022 20:10:39 +0800 Subject: [PATCH 080/120] Fix zipformer attn_output_weights (#774) * fix attn_output_weights * remove in-place op --- .../pruned_transducer_stateless7/zipformer.py | 34 +++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 71f12e44a..ad3b88df0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1291,9 +1291,11 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask is not None: if attn_mask.dtype == torch.bool: - attn_output_weights.masked_fill_(attn_mask, float("-inf")) + attn_output_weights = attn_output_weights.masked_fill( + attn_mask, float("-inf") + ) else: - attn_output_weights += attn_mask + attn_output_weights = attn_output_weights + attn_mask if key_padding_mask is not None: attn_output_weights = attn_output_weights.view( @@ -1313,6 +1315,34 @@ class RelPositionMultiheadAttention(nn.Module): # only storing the half-precision output for backprop purposes. attn_output_weights = softmax(attn_output_weights, dim=-1) + # If we are using chunk-wise attention mask and setting a limited + # num_left_chunks, the attention may only see the padding values which + # will also be masked out by `key_padding_mask`. At this circumstances, + # the whole column of `attn_output_weights` will be `-inf` + # (i.e. be `nan` after softmax). So we fill `0.0` at the masking + # positions to avoid invalid loss value below. + if ( + attn_mask is not None + and attn_mask.dtype == torch.bool + and key_padding_mask is not None + ): + if attn_mask.size(0) != 1: + attn_mask = attn_mask.view(bsz, num_heads, seq_len, seq_len) + combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) + else: + # attn_mask.shape == (1, tgt_len, src_len) + combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( + 1 + ).unsqueeze(2) + + attn_output_weights = attn_output_weights.view( + bsz, num_heads, seq_len, seq_len + ) + attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, seq_len, seq_len + ) + attn_output_weights = nn.functional.dropout( attn_output_weights, p=dropout_p, training=training ) From 070c77e724d4da91900925c44b237523d97f9f08 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Wed, 21 Dec 2022 17:41:31 +0800 Subject: [PATCH 081/120] Add Blankskip to Zipformer+CTC (#730) * init files * add ctc as auxiliary loss and ctc_decode.py * tuning the scalar of HLG score for 1best, nbest and nbest-oracle * rename to pruned_transducer_stateless7_ctc * fix doc * fix bug, recover the hlg scores * modify ctc_decode.py, move out the hlg scale * fix hlg_scale * add export.py and pretrained.py, and so on * upload files, update README.md and RESULTS.md * add CI test * update .gitignore * create symlinks * Add Blank Skip to Zipformer+CTC * Add warmup to blank skip * Add warmup to blank skip * Add __init__.py * Add parameters_names to Adam * Add warmup to blank skip * Modify frame_reducer * Modify frame_reducer * Add Blank Skip to decode. * Add ctc_decode.py * Add blank skip to Zipformer+CTC * process conflict * process conflict * modify ctc_guild_decode_bk.py * modify Lconv * produce the conflict * Add export.py * finish export * fix for running black * Add ci test * Add ci-test * chmod * chmod * fix bug for ci-test * fix bug for ci-test * fix bug for ci-test * rename the dirname * rename the dirname * change dirname * change dirname * fix notes * add pretrained.py * add pretrained.py * add pretrained.py * add pretrained.py * add pretrained.py * add pretrained.py * fix * fix * fix * finished * add the Copyright info and notes Co-authored-by: Zengwei Yao Co-authored-by: yifanyang --- ...ed-transducer-stateless7-ctc-2022-12-01.sh | 2 +- ...transducer-stateless7-ctc-bs-2022-12-15.sh | 148 ++ ...brispeech-2022-12-15-stateless7-ctc-bs.yml | 163 +++ .gitignore | 1 + egs/gigaspeech/ASR/.gitignore | 1 + egs/librispeech/ASR/.gitignore | 1 + .../ASR/pruned_transducer_stateless7/optim.py | 2 +- .../jit_pretrained_ctc.py | 8 +- .../__init__.py | 0 .../asr_datamodule.py | 1 + .../beam_search.py | 1 + .../ctc_decode.py | 809 +++++++++++ .../ctc_guild_decode_bs.py | 857 +++++++++++ .../decode.py | 841 +++++++++++ .../decoder.py | 1 + .../encoder_interface.py | 1 + .../export.py | 319 +++++ .../frame_reducer.py | 84 ++ .../jit_pretrained.py | 271 ++++ .../jit_pretrained_ctc.py | 426 ++++++ .../joiner.py | 1 + .../lconv.py | 114 ++ .../model.py | 224 +++ .../optim.py | 1 + .../pretrained.py | 352 +++++ .../pretrained_ctc.py | 440 ++++++ .../scaling.py | 1 + .../scaling_converter.py | 1 + .../test_model.py | 55 + .../train.py | 1251 +++++++++++++++++ .../zipformer.py | 1 + 31 files changed, 6372 insertions(+), 6 deletions(-) create mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh create mode 100644 .github/workflows/run-librispeech-2022-12-15-stateless7-ctc-bs.yml create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/__init__.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/asr_datamodule.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/beam_search.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decoder.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/encoder_interface.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/joiner.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/optim.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/scaling.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/scaling_converter.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/test_model.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/zipformer.py diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh index e081c9374..3cbb480f6 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh @@ -148,4 +148,4 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == done rm pruned_transducer_stateless7_ctc/exp/*.pt -fi +fi \ No newline at end of file diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh new file mode 100755 index 000000000..ed66a728e --- /dev/null +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh @@ -0,0 +1,148 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2022-12-14 + +log "Downloading pre-trained model from $repo_url" +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +soxi $repo/test_wavs/*.wav +ls -lh $repo/test_wavs/*.wav + +pushd $repo/exp +git lfs pull --include "data/lang_bpe_500/HLG.pt" +git lfs pull --include "data/lang_bpe_500/L.pt" +git lfs pull --include "data/lang_bpe_500/LG.pt" +git lfs pull --include "data/lang_bpe_500/Linv.pt" +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/cpu_jit.pt" +git lfs pull --include "exp/pretrained.pt" +ln -s pretrained.pt epoch-99.pt +ls -lh *.pt +popd + +log "Export to torchscript model" +./pruned_transducer_stateless7_ctc_bs/export.py \ + --exp-dir $repo/exp \ + --use-averaged-model false \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --jit 1 + +ls -lh $repo/exp/*.pt + +log "Decode with models exported by torch.jit.script()" + +./pruned_transducer_stateless7_ctc_bs/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --nn-model-filename $repo/exp/cpu_jit.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +for m in ctc-decoding 1best; do + ./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ + --model-filename $repo/exp/cpu_jit.pt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --method $m \ + --sample-rate 16000 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless7_ctc_bs/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./pruned_transducer_stateless7_ctc_bs/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +for m in ctc-decoding 1best; do + ./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py \ + --checkpoint $repo/exp/pretrained.pt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --method $m \ + --sample-rate 16000 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" +echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" + +if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then + mkdir -p pruned_transducer_stateless7_ctc_bs/exp + ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless7_ctc_bs/exp/epoch-999.pt + ln -s $PWD/$repo/data/lang_bpe_500 data/ + + ls -lh data + ls -lh pruned_transducer_stateless7_ctc_bs/exp + + log "Decoding test-clean and test-other" + + # use a small value for decoding with CPU + max_duration=100 + + for method in greedy_search fast_beam_search modified_beam_search; do + log "Decoding with $method" + + ./pruned_transducer_stateless7_ctc_bs/decode.py \ + --decoding-method $method \ + --epoch 999 \ + --avg 1 \ + --use-averaged-model 0 \ + --max-duration $max_duration \ + --exp-dir pruned_transducer_stateless7_ctc_bs/exp + done + + for m in ctc-decoding 1best; do + ./pruned_transducer_stateless7_ctc_bs/ctc_decode.py \ + --epoch 999 \ + --avg 1 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration $max_duration \ + --use-averaged-model 0 \ + --decoding-method $m \ + --hlg-scale 0.6 + done + + rm pruned_transducer_stateless7_ctc_bs/exp/*.pt +fi diff --git a/.github/workflows/run-librispeech-2022-12-15-stateless7-ctc-bs.yml b/.github/workflows/run-librispeech-2022-12-15-stateless7-ctc-bs.yml new file mode 100644 index 000000000..6e2b40cf3 --- /dev/null +++ b/.github/workflows/run-librispeech-2022-12-15-stateless7-ctc-bs.yml @@ -0,0 +1,163 @@ +# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: run-librispeech-2022-12-15-stateless7-ctc-bs +# zipformer + +on: + push: + branches: + - master + pull_request: + types: [labeled] + + schedule: + # minute (0-59) + # hour (0-23) + # day of the month (1-31) + # month (1-12) + # day of the week (0-6) + # nightly build at 15:50 UTC time every day + - cron: "50 15 * * *" + +jobs: + run_librispeech_2022_12_15_zipformer_ctc_bs: + if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event.label.name == 'blank-skip' || github.event_name == 'push' || github.event_name == 'schedule' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: [3.8] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }}-2022-09-25 + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: Cache LibriSpeech test-clean and test-other datasets + id: libri-test-clean-and-test-other-data + uses: actions/cache@v2 + with: + path: | + ~/tmp/download + key: cache-libri-test-clean-and-test-other + + - name: Download LibriSpeech test-clean and test-other + if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh + + - name: Prepare manifests for LibriSpeech test-clean and test-other + shell: bash + run: | + .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh + + - name: Cache LibriSpeech test-clean and test-other fbank features + id: libri-test-clean-and-test-other-fbank + uses: actions/cache@v2 + with: + path: | + ~/tmp/fbank-libri + key: cache-libri-fbank-test-clean-and-test-other-v2 + + - name: Compute fbank for LibriSpeech test-clean and test-other + if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh + + - name: Inference with pre-trained model + shell: bash + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} + run: | + mkdir -p egs/librispeech/ASR/data + ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank + ls -lh egs/librispeech/ASR/data/* + + sudo apt-get -qq install git-lfs tree sox + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + .github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh + + - name: Display decoding results for librispeech pruned_transducer_stateless7_ctc_bs + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + shell: bash + run: | + cd egs/librispeech/ASR/ + tree ./pruned_transducer_stateless7_ctc_bs/exp + + cd pruned_transducer_stateless7_ctc_bs + echo "results for pruned_transducer_stateless7_ctc_bs" + echo "===greedy search===" + find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===fast_beam_search===" + find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===modified beam search===" + find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===ctc decoding===" + find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===1best===" + find exp/1best -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/1best -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + - name: Upload decoding results for librispeech pruned_transducer_stateless7_ctc_bs + uses: actions/upload-artifact@v2 + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + with: + name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless7-ctc-bs-2022-12-15 + path: egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/exp/ diff --git a/.gitignore b/.gitignore index 583410f45..8af05d884 100644 --- a/.gitignore +++ b/.gitignore @@ -33,3 +33,4 @@ node_modules *.param *.bin +.DS_Store diff --git a/egs/gigaspeech/ASR/.gitignore b/egs/gigaspeech/ASR/.gitignore index 5592679cc..8dec2d86d 100644 --- a/egs/gigaspeech/ASR/.gitignore +++ b/egs/gigaspeech/ASR/.gitignore @@ -1 +1,2 @@ log-* +.DS_Store \ No newline at end of file diff --git a/egs/librispeech/ASR/.gitignore b/egs/librispeech/ASR/.gitignore index 5592679cc..8dec2d86d 100644 --- a/egs/librispeech/ASR/.gitignore +++ b/egs/librispeech/ASR/.gitignore @@ -1 +1,2 @@ log-* +.DS_Store \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index ff8fbb32c..374b78cb3 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -1,4 +1,4 @@ -# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) # # See ../LICENSE for clarification regarding multiple authors # diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py index ad9cf08dc..d50d231d5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py @@ -31,7 +31,7 @@ Usage of this script: (1) ctc-decoding ./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ - --nn-model-filename ./pruned_transducer_stateless7_ctc/exp/cpu_jit.pt \ + --model-filename ./pruned_transducer_stateless7_ctc/exp/cpu_jit.pt \ --bpe-model data/lang_bpe_500/bpe.model \ --method ctc-decoding \ --sample-rate 16000 \ @@ -40,7 +40,7 @@ Usage of this script: (2) 1best ./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ - --nn-model-filename ./pruned_transducer_stateless7_ctc/exp/cpu_jit.pt \ + --model-filename ./pruned_transducer_stateless7_ctc/exp/cpu_jit.pt \ --HLG data/lang_bpe_500/HLG.pt \ --words-file data/lang_bpe_500/words.txt \ --method 1best \ @@ -51,7 +51,7 @@ Usage of this script: (3) nbest-rescoring ./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ - --nn-model-filename ./pruned_transducer_stateless7_ctc/exp/cpu_jit.pt \ + --model-filename ./pruned_transducer_stateless7_ctc/exp/cpu_jit.pt \ --HLG data/lang_bpe_500/HLG.pt \ --words-file data/lang_bpe_500/words.txt \ --G data/lm/G_4_gram.pt \ @@ -63,7 +63,7 @@ Usage of this script: (4) whole-lattice-rescoring ./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ - --nn-model-filename ./pruned_transducer_stateless7_ctc/exp/cpu_jit.pt \ + --model-filename ./pruned_transducer_stateless7_ctc/exp/cpu_jit.pt \ --HLG data/lang_bpe_500/HLG.pt \ --words-file data/lang_bpe_500/words.txt \ --G data/lm/G_4_gram.pt \ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/__init__.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/__init__.py new file mode 100755 index 000000000..e69de29bb diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/asr_datamodule.py new file mode 120000 index 000000000..a074d6085 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/asr_datamodule.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/beam_search.py new file mode 120000 index 000000000..8554e44cc --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py new file mode 100755 index 000000000..0ef733226 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py @@ -0,0 +1,809 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Liyong Guo, +# Quandong Wang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) ctc-decoding +./pruned_transducer_stateless7_ctc_bs/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method ctc-decoding +(2) 1best +./pruned_transducer_stateless7_ctc_bs/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --hlg-scale 0.8 \ + --decoding-method 1best +(3) nbest +./pruned_transducer_stateless7_ctc_bs/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --hlg-scale 0.8 \ + --decoding-method 1best +(4) nbest-rescoring +./pruned_transducer_stateless7_ctc_bs/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --hlg-scale 0.8 \ + --lm-dir data/lm \ + --decoding-method nbest-rescoring +(5) whole-lattice-rescoring +./pruned_transducer_stateless7_ctc_bs/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --hlg-scale 0.8 \ + --lm-dir data/lm \ + --decoding-method whole-lattice-rescoring +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.decode import ( + get_lattice, + nbest_decoding, + nbest_oracle, + one_best_decoding, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + get_texts, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_ctc_bs/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram, 2 means tri-gram", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="ctc-decoding", + help="""Decoding method. + Supported values are: + - (1) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (2) 1best. Extract the best path from the decoding lattice as the + decoding result. + - (3) nbest. Extract n paths from the decoding lattice; the path + with the highest score is the decoding result. + - (4) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an n-gram LM (e.g., a 4-gram LM), the path with + the highest score is the decoding result. + - (5) whole-lattice-rescoring. Rescore the decoding lattice with an + n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice + is the decoding result. + you have trained an RNN LM using ./rnn_lm/train.py + - (6) nbest-oracle. Its WER is the lower bound of any n-best + rescoring method can achieve. Useful for debugging n-best + rescoring method. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for n-best based decoding method. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""The scale to be applied to `lattice.scores`. + It's needed if you use any kinds of n-best based rescoring. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + A smaller value results in more unique paths. + """, + ) + + parser.add_argument( + "--hlg-scale", + type=float, + default=0.8, + help="""The scale to be applied to `hlg.scores`. + """, + ) + + parser.add_argument( + "--lm-dir", + type=str, + default="data/lm", + help="""The n-gram LM dir. + It should contain either G_4_gram.pt or G_4_gram.fst.txt + """, + ) + + add_model_arguments(parser) + + return parser + + +def get_decoding_params() -> AttributeDict: + """Parameters for decoding.""" + params = AttributeDict( + { + "frame_shift_ms": 10, + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + batch: dict, + word_table: k2.SymbolTable, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + - params.decoding_method is "1best", it uses 1best decoding without LM rescoring. + - params.decoding_method is "nbest", it uses nbest decoding without LM rescoring. + - params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring. + - params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM + rescoring. + model: + The neural model. + HLG: + The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. Note: If it decodes to nothing, then return None. + """ + if HLG is not None: + device = HLG.device + else: + device = H.device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + encoder_out, encoder_out_lens = model.encoder(feature, feature_lens) + nnet_output = model.ctc_output(encoder_out) + # nnet_output is (N, T, C) + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + supervisions["start_frame"] // params.subsampling_factor, + supervisions["num_frames"] // params.subsampling_factor, + ), + 1, + ).to(torch.int32) + + if H is None: + assert HLG is not None + decoding_graph = HLG + else: + assert HLG is None + assert bpe_model is not None + decoding_graph = H + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.decoding_method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-decoding" + return {key: hyps} + + if params.decoding_method == "nbest-oracle": + # Note: You can also pass rescored lattices to it. + # We choose the HLG decoded lattice for speed reasons + # as HLG decoding is faster and the oracle WER + # is only slightly worse than that of rescored lattices. + best_path = nbest_oracle( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + word_table=word_table, + nbest_scale=params.nbest_scale, + oov="", + ) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa + return {key: hyps} + + if params.decoding_method in ["1best", "nbest"]: + if params.decoding_method == "1best": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + key = "no_rescore" + else: + best_path = nbest_decoding( + lattice=lattice, + num_paths=params.num_paths, + use_double_scores=params.use_double_scores, + nbest_scale=params.nbest_scale, + ) + key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa + + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + return {key: hyps} + + assert params.decoding_method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + ] + + lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] + + if params.decoding_method == "nbest-rescoring": + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=lm_scale_list, + nbest_scale=params.nbest_scale, + ) + elif params.decoding_method == "whole-lattice-rescoring": + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=lm_scale_list, + ) + else: + assert False, f"Unsupported decoding method: {params.decoding_method}" + + ans = dict() + if best_path_dict is not None: + for lm_scale_str, best_path in best_path_dict.items(): + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + ans[lm_scale_str] = hyps + else: + ans = None + return ans + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + word_table: k2.SymbolTable, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + HLG: + The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + word_table: + It is the word symbol table. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + batch=batch, + word_table=word_table, + G=G, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}-{key}", results) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + args.lm_dir = Path(args.lm_dir) + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + + assert params.decoding_method in ( + "ctc-decoding", + "1best", + "nbest", + "nbest-rescoring", + "whole-lattice-rescoring", + "nbest-oracle", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + params.vocab_size = num_classes + # and are defined in local/train_bpe_model.py + params.blank_id = 0 + + if params.decoding_method == "ctc-decoding": + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + ) + assert HLG.requires_grad is False + + HLG.scores *= params.hlg_scale + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() + + if params.decoding_method in ( + "nbest-rescoring", + "whole-lattice-rescoring", + ): + if not (params.lm_dir / "G_4_gram.pt").is_file(): + logging.info("Loading G_4_gram.fst.txt") + logging.warning("It may take 8 minutes.") + with open(params.lm_dir / "G_4_gram.fst.txt") as f: + first_word_disambig_id = lexicon.word_table["#0"] + + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set G.properties to None + G.__dict__["_properties"] = None + G = k2.Fsa.from_fsas([G]).to(device) + G = k2.arc_sort(G) + # Save a dummy value so that it can be loaded in C++. + # See https://github.com/pytorch/pytorch/issues/67902 + # for why we need to do this. + G.dummy = 1 + + torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") + else: + logging.info("Loading pre-compiled G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + G = k2.Fsa.from_dict(d) + + if params.decoding_method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + else: + G = None + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + word_table=lexicon.word_table, + G=G, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py new file mode 100755 index 000000000..9c2166aaf --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py @@ -0,0 +1,857 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Yifan Yang,) +# +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless7_ctc/ctc_guild_decode_bs.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import add_model_arguments, get_params, get_transducer_model +from torch.nn.utils.rnn import pad_sequence + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_ctc_bs/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram, 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--simulate-streaming", + type=str2bool, + default=False, + help="""Whether to simulate streaming in decoding, this is a good way to + test a streaming model. + """, + ) + + parser.add_argument( + "--decode-chunk-size", + type=int, + default=16, + help="The chunk size for decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--left-context", + type=int, + default=64, + help="left context can be seen during decoding (in frames after subsampling)", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.simulate_streaming: + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) + encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( + x=feature, + x_lens=feature_lens, + chunk_size=params.decode_chunk_size, + left_context=params.left_context, + simulate_streaming=True, + ) + else: + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + # filter out blank frames using ctc outputs + ctc_output = model.ctc_output(encoder_out) + encoder_out = model.lconv( + x=encoder_out, + src_key_padding_mask=make_pad_mask(encoder_out_lens), + ) + encoder_out, encoder_out_lens = model.frame_reducer( + x=encoder_out, + x_lens=encoder_out_lens, + ctc_output=ctc_output, + blank_id=0, + ) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.simulate_streaming: + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" + params.suffix += f"-left-context-{params.left_context}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if params.simulate_streaming: + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py new file mode 100755 index 000000000..ce45a4beb --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py @@ -0,0 +1,841 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7_ctc_bs/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless7_ctc_bs/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless7_ctc_bs/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./pruned_transducer_stateless7_ctc_bs/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless7_ctc_bs/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless7_ctc_bs/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless7_ctc_bs/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_ctc_bs/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram, 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--simulate-streaming", + type=str2bool, + default=False, + help="""Whether to simulate streaming in decoding, this is a good way to + test a streaming model. + """, + ) + + parser.add_argument( + "--decode-chunk-size", + type=int, + default=16, + help="The chunk size for decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--left-context", + type=int, + default=64, + help="left context can be seen during decoding (in frames after subsampling)", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.simulate_streaming: + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) + encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( + x=feature, + x_lens=feature_lens, + chunk_size=params.decode_chunk_size, + left_context=params.left_context, + simulate_streaming=True, + ) + else: + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.simulate_streaming: + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" + params.suffix += f"-left-context-{params.left_context}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if params.simulate_streaming: + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decoder.py new file mode 120000 index 000000000..33944d0d2 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decoder.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/encoder_interface.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/encoder_interface.py new file mode 120000 index 000000000..b9aa0ae08 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/encoder_interface.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py new file mode 100755 index 000000000..96d316604 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py @@ -0,0 +1,319 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +(1) Export to torchscript model using torch.jit.script() + +./pruned_transducer_stateless7_ctc_bs/export.py \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 13 \ + --jit 1 + +It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("cpu_jit.pt")`. + +Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python +are on CPU. You can use `to("cuda")` to move them to a CUDA device. + +Check +https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +(2) Export `model.state_dict()` + +./pruned_transducer_stateless7_ctc_bs/export.py \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 13 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +To use the generated file with `pruned_transducer_stateless7_ctc_bs/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./pruned_transducer_stateless7_ctc_bs/decode.py \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 + # You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named cpu_jit.pt + + Check ./jit_pretrained.py for how to use it. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram, 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + if params.jit is True: + convert_scaled_to_non_scaled(model, inplace=True) + logging.info("Using torch.jit.script()") + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py new file mode 100755 index 000000000..3de21a293 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 +# +# Copyright 2022 Xiaomi Corp. (authors: Yifan Yang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch.nn.utils.rnn import pad_sequence +from icefall.utils import make_pad_mask + + +class FrameReducer(nn.Module): + """The encoder output is first used to calculate + the CTC posterior probability; then for each output frame, + if its blank posterior is bigger than some thresholds, + it will be simply discarded from the encoder output. + """ + + def __init__( + self, + ): + super().__init__() + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ctc_output: torch.Tensor, + blank_id: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + The shared encoder output with shape [N, T, C]. + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + ctc_output: + The CTC output with shape [N, T, vocab_size]. + blank_id: + The ID of the blank symbol. + Returns: + x_fr: + The frame reduced encoder output with shape [N, T', C]. + x_lens_fr: + A tensor of shape (batch_size,) containing the number of frames in + `x_fr` before padding. + """ + + padding_mask = make_pad_mask(x_lens) + non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask) + T_range = torch.arange(x.shape[1], device=x.device) + + frames_list: List[torch.Tensor] = [] + lens_list: List[int] = [] + for i in range(x.shape[0]): + indexes = torch.masked_select( + T_range, + non_blank_mask[i], + ) + frames = x[i][indexes] + frames_list.append(frames) + lens_list.append(frames.shape[0]) + x_fr = pad_sequence(frames_list).transpose(0, 1) + x_lens_fr = torch.tensor(lens_list).to(device=x.device) + + return x_fr, x_lens_fr diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained.py new file mode 100755 index 000000000..da2c6a39a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained.py @@ -0,0 +1,271 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless7_ctc_bs/export.py \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --jit 1 + +Usage of this script: + +./pruned_transducer_stateless7_ctc_bs/jit_pretrained.py \ + --nn-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/cpu_jit.pt \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model-filename", + type=str, + required=True, + help="Path to the torchscript model cpu_jit.pt", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + model: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + Returns: + Return the decoded results for each utterance. + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + device = encoder_out.device + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + context_size = model.decoder.context_size + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ).squeeze(1) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + current_encoder_out = current_encoder_out + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = model.joiner( + current_encoder_out, + decoder_out, + ) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ) + decoder_out = decoder_out.squeeze(1) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = torch.jit.load(args.nn_model_filename) + + model.eval() + + model.to(device) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + x=features, + x_lens=feature_lengths, + ) + + hyps = greedy_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = sp.decode(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py new file mode 100755 index 000000000..653c25e06 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py @@ -0,0 +1,426 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless7_ctc_bs/export.py \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 13 \ + --jit 1 + +Usage of this script: + +(1) ctc-decoding +./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ + --model-filename ./pruned_transducer_stateless7_ctc_bs/exp/cpu_jit.pt \ + --bpe-model data/lang_bpe_500/bpe.model \ + --method ctc-decoding \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) 1best +./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ + --model-filename ./pruned_transducer_stateless7_ctc_bs/exp/cpu_jit.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --method 1best \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + + +(3) nbest-rescoring +./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ + --model-filename ./pruned_transducer_stateless7_ctc_bs/exp/cpu_jit.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method nbest-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + + +(4) whole-lattice-rescoring +./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ + --model-filename ./pruned_transducer_stateless7_ctc_bs/exp/cpu_jit.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method whole-lattice-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from ctc_decode import get_decoding_params +from torch.nn.utils.rnn import pad_sequence +from train import get_params + +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.utils import get_texts + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--model-filename", + type=str, + required=True, + help="Path to the torchscript model.", + ) + + parser.add_argument( + "--words-file", + type=str, + help="""Path to words.txt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--HLG", + type=str, + help="""Path to HLG.pt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model. + Used only when method is ctc-decoding. + """, + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (0) ctc-decoding - Use CTC decoding. It uses a sentence + piece model, i.e., lang_dir/bpe.model, to convert + word pieces to words. It needs neither a lexicon + nor an n-gram LM. + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + (2) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an LM, the path with + the highest score is the decoding result. + We call it HLG decoding + n-gram LM rescoring. + (3) whole-lattice-rescoring - Use an LM to rescore the + decoding lattice and then use 1best to decode the + rescored lattice. + We call it HLG decoding + n-gram LM rescoring. + """, + ) + + parser.add_argument( + "--G", + type=str, + help="""An LM for rescoring. + Used only when method is + whole-lattice-rescoring or nbest-rescoring. + It's usually a 4-gram LM. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help=""" + Used only when method is attention-decoder. + It specifies the size of n-best list.""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=1.3, + help=""" + Used only when method is whole-lattice-rescoring and nbest-rescoring. + It specifies the scale for n-gram LM scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help=""" + Used only when method is nbest-rescoring. + It specifies the scale for lattice.scores when + extracting n-best lists. A smaller value results in + more unique number of paths with the risk of missing + the best path. + """, + ) + + parser.add_argument( + "--num-classes", + type=int, + default=500, + help=""" + Vocab size in the BPE model. + """, + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = torch.jit.load(args.model_filename) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + x=features, + x_lens=feature_lengths, + ) + nnet_output = model.ctc_output(encoder_out) + + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [ + [i, 0, feature_lengths[i] // params.subsampling_factor] + for i in range(batch_size) + ], + dtype=torch.int32, + ) + + if params.method == "ctc-decoding": + logging.info("Use CTC decoding") + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(params.bpe_model) + max_token_id = params.num_classes - 1 + + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=H, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + token_ids = get_texts(best_path) + hyps = bpe_model.decode(token_ids) + hyps = [s.split() for s in hyps] + elif params.method in [ + "1best", + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + if params.method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = G.to(device) + if params.method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + if params.method == "nbest-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=[params.ngram_lm_scale], + nbest_scale=params.nbest_scale, + ) + best_path = next(iter(best_path_dict.values())) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + else: + raise ValueError(f"Unsupported decoding method: {params.method}") + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/joiner.py new file mode 120000 index 000000000..ecfb6dd8a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/joiner.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/joiner.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py new file mode 100755 index 000000000..bfd49d533 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py @@ -0,0 +1,114 @@ +# Copyright 2022 Xiaomi Corp. (authors: Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from scaling import ( + ActivationBalancer, + ScaledConv1d, +) + + +class LConv(nn.Module): + """A convolution module to prevent information loss.""" + + def __init__( + self, + channels: int, + kernel_size: int = 7, + bias: bool = True, + ): + """ + Args: + channels: + Dimension of the input embedding, and of the lconv output. + """ + super().__init__() + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + + self.deriv_balancer1 = ActivationBalancer( + 2 * channels, + channel_dim=1, + max_abs=10.0, + min_positive=0.05, + max_positive=1.0, + ) + + self.depthwise_conv = nn.Conv1d( + 2 * channels, + 2 * channels, + kernel_size=kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + groups=channels, + bias=bias, + ) + + self.deriv_balancer2 = ActivationBalancer( + 2 * channels, + channel_dim=1, + min_positive=0.05, + max_positive=1.0, + max_abs=20.0, + ) + + self.pointwise_conv2 = ScaledConv1d( + 2 * channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + initial_scale=0.05, + ) + + def forward( + self, + x: torch.Tensor, + src_key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + x: A 3-D tensor of shape (N, T, C). + Returns: + Return a tensor of shape (N, T, C). + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(0, 2, 1) # (#batch, channels, time). + + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + + x = self.deriv_balancer1(x) + + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + x = self.depthwise_conv(x) + + x = self.deriv_balancer2(x) + + x = self.pointwise_conv2(x) # (batch, channels, time) + + return x.permute(0, 2, 1) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py new file mode 100755 index 000000000..86acc5a10 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py @@ -0,0 +1,224 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Tuple + +import k2 +import torch +import torch.nn as nn +from encoder_interface import EncoderInterface + +from icefall.utils import add_sos, make_pad_mask + + +class Transducer(nn.Module): + """It implements https://arxiv.org/pdf/1211.3711.pdf + "Sequence Transduction with Recurrent Neural Networks" + """ + + def __init__( + self, + encoder: EncoderInterface, + decoder: nn.Module, + joiner: nn.Module, + lconv: nn.Module, + frame_reducer: nn.Module, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, + vocab_size: int, + ): + """ + Args: + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dm) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. + joiner: + It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output contains + unnormalized probs, i.e., not processed by log-softmax. + """ + super().__init__() + assert isinstance(encoder, EncoderInterface), type(encoder) + assert hasattr(decoder, "blank_id") + + self.encoder = encoder + self.decoder = decoder + self.joiner = joiner + self.lconv = lconv + self.frame_reducer = frame_reducer + + self.simple_am_proj = nn.Linear( + encoder_dim, + vocab_size, + ) + self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size) + + self.ctc_output = nn.Sequential( + nn.Dropout(p=0.1), + nn.Linear(encoder_dim, vocab_size), + nn.LogSoftmax(dim=-1), + ) + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + warmup: + A floating point value which decides whether to do blank skip. + Returns: + Return a tuple containing simple loss, pruned loss, and ctc-output. + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0 + + encoder_out, x_lens = self.encoder(x, x_lens) + assert torch.all(x_lens > 0) + + # compute ctc log-probs + ctc_output = self.ctc_output(encoder_out) + + # blank skip + blank_id = self.decoder.blank_id + + if warmup >= 2.0: + # lconv + encoder_out = self.lconv( + x=encoder_out, + src_key_padding_mask=make_pad_mask(x_lens), + ) + + # frame reduce + encoder_out_fr, x_lens_fr = self.frame_reducer( + encoder_out, + x_lens, + ctc_output, + blank_id, + ) + else: + encoder_out_fr = encoder_out + x_lens_fr = x_lens + + # Now for the decoder, i.e., the prediction network + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + sos_y = add_sos(y, sos_id=blank_id) + + # sos_y_padded: [B, S + 1], start with SOS. + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + + # decoder_out: [B, S + 1, decoder_dim] + decoder_out = self.decoder(sos_y_padded) + + # Note: y does not start with SOS + # y_padded : [B, S] + y_padded = y.pad(mode="constant", padding_value=0) + + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary[:, 2] = y_lens + boundary[:, 3] = x_lens_fr + + am = self.simple_am_proj(encoder_out_fr) + lm = self.simple_lm_proj(decoder_out) + + with torch.cuda.amp.autocast(enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction="sum", + return_grad=True, + ) + + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, + ) + + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=self.joiner.encoder_proj(encoder_out_fr), + lm=self.joiner.decoder_proj(decoder_out), + ranges=ranges, + ) + + # logits : [B, T, prune_range, vocab_size] + + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + logits = self.joiner(am_pruned, lm_pruned, project_input=False) + + with torch.cuda.amp.autocast(enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="sum", + ) + + return (simple_loss, pruned_loss, ctc_output) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/optim.py new file mode 120000 index 000000000..81ac4a89a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/optim.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/optim.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py new file mode 100755 index 000000000..ea0fe9164 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py @@ -0,0 +1,352 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +./pruned_transducer_stateless7_ctc_bs/export.py \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 13 + +Usage of this script: + +(1) greedy search +./pruned_transducer_stateless7_ctc_bs/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) beam search +./pruned_transducer_stateless7_ctc_bs/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) modified beam search +./pruned_transducer_stateless7_ctc_bs/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(4) fast beam search +./pruned_transducer_stateless7_ctc_bs/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method fast_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +You can also use `./pruned_transducer_stateless7_ctc_bs/exp/epoch-xx.pt`. + +Note: ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt is generated by +./pruned_transducer_stateless7_ctc_bs/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram, 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) + + num_waves = encoder_out.size(0) + hyps = [] + msg = f"Using {params.method}" + if params.method == "beam_search": + msg += f" with beam size {params.beam_size}" + logging.info(msg) + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(sp.decode(hyp).split()) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py new file mode 100755 index 000000000..412631ba1 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py @@ -0,0 +1,440 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless7_ctc_bs/export.py \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +Usage of this script: + +(1) ctc-decoding +./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ + --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ + --bpe-model data/lang_bpe_500/bpe.model \ + --method ctc-decoding \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) 1best +./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ + --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --method 1best \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) nbest-rescoring +./bruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ + --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method nbest-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + + +(4) whole-lattice-rescoring +./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ + --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method whole-lattice-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from ctc_decode import get_decoding_params +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.utils import get_texts + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram, 2 means tri-gram", + ) + + parser.add_argument( + "--words-file", + type=str, + help="""Path to words.txt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--HLG", + type=str, + help="""Path to HLG.pt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model. + Used only when method is ctc-decoding. + """, + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (0) ctc-decoding - Use CTC decoding. It uses a sentence + piece model, i.e., lang_dir/bpe.model, to convert + word pieces to words. It needs neither a lexicon + nor an n-gram LM. + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + (2) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an LM, the path with + the highest score is the decoding result. + We call it HLG decoding + n-gram LM rescoring. + (3) whole-lattice-rescoring - Use an LM to rescore the + decoding lattice and then use 1best to decode the + rescored lattice. + We call it HLG decoding + n-gram LM rescoring. + """, + ) + + parser.add_argument( + "--G", + type=str, + help="""An LM for rescoring. + Used only when method is + whole-lattice-rescoring or nbest-rescoring. + It's usually a 4-gram LM. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help=""" + Used only when method is attention-decoder. + It specifies the size of n-best list.""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=1.3, + help=""" + Used only when method is whole-lattice-rescoring and nbest-rescoring. + It specifies the scale for n-gram LM scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help=""" + Used only when method is nbest-rescoring. + It specifies the scale for lattice.scores when + extracting n-best lists. A smaller value results in + more unique number of paths with the risk of missing + the best path. + """, + ) + + parser.add_argument( + "--num-classes", + type=int, + default=500, + help=""" + Vocab size in the BPE model. + """, + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + params.vocab_size = params.num_classes + params.blank_id = 0 + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + x=features, + x_lens=feature_lengths, + ) + nnet_output = model.ctc_output(encoder_out) + + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], + dtype=torch.int32, + ) + + if params.method == "ctc-decoding": + logging.info("Use CTC decoding") + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(params.bpe_model) + max_token_id = params.num_classes - 1 + + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=H, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + token_ids = get_texts(best_path) + hyps = bpe_model.decode(token_ids) + hyps = [s.split() for s in hyps] + elif params.method in [ + "1best", + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + if params.method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = G.to(device) + if params.method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + if params.method == "nbest-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=[params.ngram_lm_scale], + nbest_scale=params.nbest_scale, + ) + best_path = next(iter(best_path_dict.values())) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + else: + raise ValueError(f"Unsupported decoding method: {params.method}") + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/scaling.py new file mode 120000 index 000000000..2428b74b9 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/scaling.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/scaling_converter.py new file mode 120000 index 000000000..b8b8ba432 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/scaling_converter.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/test_model.py new file mode 100755 index 000000000..7f0893985 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/test_model.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./pruned_transducer_stateless7_ctc_bs/test_model.py +""" + +from train import get_params, get_transducer_model + + +def test_model_1(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.num_encoder_layers = "2,4,3,2,4" + params.feedforward_dims = "1024,1024,2048,2048,1024" + params.nhead = "8,8,8,8,8" + params.encoder_dims = "384,384,384,384,384" + params.attention_dims = "192,192,192,192,192" + params.encoder_unmasked_dims = "256,256,256,256,256" + params.zipformer_downsampling_factors = "1,2,4,8,2" + params.cnn_module_kernels = "31,31,31,31,31" + params.decoder_dim = 512 + params.joiner_dim = 512 + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + +def main(): + test_model_1() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py new file mode 100755 index 000000000..63e9d6e90 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py @@ -0,0 +1,1251 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +export CUDA_VISIBLE_DEVICES="0,1,2,3" +./pruned_transducer_stateless7_ctc_bs/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7_ctc_bs/exp \ + --full-libri 1 \ + --max-duration 300 +# For mix precision training: +./pruned_transducer_stateless7_ctc_bs/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_ctc_bs/exp \ + --full-libri 1 \ + --max-duration 550 +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lconv import LConv +from frame_reducer import FrameReducer +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + encode_supervisions, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_ctc_bs/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram, 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.5, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + Explanation of options saved in `params`: + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + - best_train_epoch: It is the epoch that has the best training loss. + - best_valid_epoch: It is the epoch that has the best validation loss. + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + - log_interval: Print training loss if batch_idx % log_interval` is 0 + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + - valid_interval: Run validation if batch_idx % valid_interval is 0 + - feature_dim: The model input dim. It has to match the one used + in computing features. + - subsampling_factor: The subsampling factor for the model. + - encoder_dim: Hidden dim for multi-head attention model. + - num_decoder_layers: Number of decoder layer of transformer decoder. + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + # parameters for ctc loss + "beam_size": 10, + "use_double_scores": True, + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_lconv(params: AttributeDict) -> nn.Module: + lconv = LConv( + channels=int(params.encoder_dims.split(",")[-1]), + ) + return lconv + + +def get_frame_reducer(params: AttributeDict) -> nn.Module: + frame_reducer = FrameReducer() + return frame_reducer + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + lconv = get_lconv(params) + frame_reducer = get_frame_reducer(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + lconv=lconv, + frame_reducer=frame_reducer, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + warmup = batch_idx_train / warm_step + + texts = batch["supervisions"]["text"] + token_ids = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(token_ids).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_output = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + warmup=warmup, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + # Compute ctc loss + + # NOTE: We need `encode_supervisions` to sort sequences with + # different duration in decreasing order, required by + # `k2.intersect_dense` called in `k2.ctc_loss` + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + supervision_segments, token_ids = encode_supervisions( + supervisions, + subsampling_factor=params.subsampling_factor, + token_ids=token_ids, + ) + + # Works with a BPE model + decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device) + dense_fsa_vec = k2.DenseFsaVec( + ctc_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction="sum", + use_double_scores=params.use_double_scores, + ) + assert ctc_loss.requires_grad == is_training + loss += params.ctc_loss_scale * ctc_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + info["ctc_loss"] = ctc_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + if params.full_libri is False: + params.valid_interval = 1600 + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeechAsrDataModule(args) + + train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 1.0 <= c.duration <= 20.0 + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/zipformer.py new file mode 120000 index 000000000..79b076556 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/zipformer.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/zipformer.py \ No newline at end of file From 7eb2d0edb637e8000e9c1a5eafe157d967d932ec Mon Sep 17 00:00:00 2001 From: BuaaAlban Date: Fri, 23 Dec 2022 11:38:22 +0800 Subject: [PATCH 082/120] Update train.py (#773) Fix transducer lstm egs bug as mentioned in issue 579 --- egs/librispeech/ASR/transducer_lstm/train.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/egs/librispeech/ASR/transducer_lstm/train.py b/egs/librispeech/ASR/transducer_lstm/train.py index 792708bc0..a6f2bd08c 100755 --- a/egs/librispeech/ASR/transducer_lstm/train.py +++ b/egs/librispeech/ASR/transducer_lstm/train.py @@ -629,18 +629,8 @@ def run(rank, world_size, args): # Keep only utterances with duration between 1 second and 20 seconds return 1.0 <= c.duration <= 20.0 - num_in_total = len(train_cuts) - train_cuts = train_cuts.filter(remove_short_and_long_utt) - num_left = len(train_cuts) - num_removed = num_in_total - num_left - removed_percent = num_removed / num_in_total * 100 - - logging.info(f"Before removing short and long utterances: {num_in_total}") - logging.info(f"After removing short and long utterances: {num_left}") - logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") - train_dl = librispeech.train_dataloaders(train_cuts) valid_cuts = librispeech.dev_clean_cuts() From 59eb465b3cd47a212117b535644f24ed190093e1 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Fri, 23 Dec 2022 17:55:36 +0800 Subject: [PATCH 083/120] optimize frame_reducer.py (#783) Co-authored-by: yifanyang --- .../pruned_transducer_stateless7_ctc_bs/frame_reducer.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py index 3de21a293..9fe88929d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py @@ -66,19 +66,14 @@ class FrameReducer(nn.Module): padding_mask = make_pad_mask(x_lens) non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask) - T_range = torch.arange(x.shape[1], device=x.device) frames_list: List[torch.Tensor] = [] lens_list: List[int] = [] for i in range(x.shape[0]): - indexes = torch.masked_select( - T_range, - non_blank_mask[i], - ) - frames = x[i][indexes] + frames = x[i][non_blank_mask[i]] frames_list.append(frames) lens_list.append(frames.shape[0]) - x_fr = pad_sequence(frames_list).transpose(0, 1) + x_fr = pad_sequence(frames_list, batch_first=True) x_lens_fr = torch.tensor(lens_list).to(device=x.device) return x_fr, x_lens_fr From 4e249da2c402eb83e6206365c161693d2f5db070 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Mon, 26 Dec 2022 14:30:20 +0800 Subject: [PATCH 084/120] Add zipformer_ctc_blankskip.rst (#784) * Add zipformer_ctc_blankskip.rst * typo fix for zipformer_mmi.rst * fix warning * Update docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst Co-authored-by: yifanyang Co-authored-by: Fangjun Kuang --- .../export-with-torch-jit-script.rst | 2 +- .../aishell/conformer_ctc.rst | 2 +- .../librispeech/conformer_ctc.rst | 8 +- .../Non-streaming-ASR/librispeech/index.rst | 2 +- .../pruned_transducer_stateless.rst | 3 +- .../librispeech/zipformer_ctc_blankskip.rst | 453 ++++++++++++++++++ .../librispeech/zipformer_mmi.rst | 4 +- 7 files changed, 464 insertions(+), 10 deletions(-) create mode 100644 docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst diff --git a/docs/source/model-export/export-with-torch-jit-script.rst b/docs/source/model-export/export-with-torch-jit-script.rst index a041dc1d5..efd7dc2e1 100644 --- a/docs/source/model-export/export-with-torch-jit-script.rst +++ b/docs/source/model-export/export-with-torch-jit-script.rst @@ -1,7 +1,7 @@ .. _export-model-with-torch-jit-script: Export model with torch.jit.script() -=================================== +==================================== In this section, we describe how to export a model via ``torch.jit.script()``. diff --git a/docs/source/recipes/Non-streaming-ASR/aishell/conformer_ctc.rst b/docs/source/recipes/Non-streaming-ASR/aishell/conformer_ctc.rst index 72690e102..6e30ce397 100644 --- a/docs/source/recipes/Non-streaming-ASR/aishell/conformer_ctc.rst +++ b/docs/source/recipes/Non-streaming-ASR/aishell/conformer_ctc.rst @@ -703,7 +703,7 @@ It will show you the following message: HLG decoding -^^^^^^^^^^^^ +~~~~~~~~~~~~ .. code-block:: bash diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/conformer_ctc.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/conformer_ctc.rst index 4656acfd6..b7f89c89f 100644 --- a/docs/source/recipes/Non-streaming-ASR/librispeech/conformer_ctc.rst +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/conformer_ctc.rst @@ -888,7 +888,7 @@ It will show you the following message: CTC decoding -^^^^^^^^^^^^ +~~~~~~~~~~~~ .. code-block:: bash @@ -926,7 +926,7 @@ Its output is: YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION HLG decoding -^^^^^^^^^^^^ +~~~~~~~~~~~~ .. code-block:: bash @@ -966,7 +966,7 @@ The output is: HLG decoding + n-gram LM rescoring -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code-block:: bash @@ -1012,7 +1012,7 @@ The output is: HLG decoding + n-gram LM rescoring + attention decoder rescoring -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code-block:: bash diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/index.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/index.rst index aa97f325d..3ebb36b25 100644 --- a/docs/source/recipes/Non-streaming-ASR/librispeech/index.rst +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/index.rst @@ -7,5 +7,5 @@ LibriSpeech tdnn_lstm_ctc conformer_ctc pruned_transducer_stateless - lstm_pruned_stateless_transducer zipformer_mmi + zipformer_ctc_blankskip diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst index d8569bc5c..86d43c8fe 100644 --- a/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst @@ -499,9 +499,10 @@ can run: Export model using ``torch.jit.script()`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code-block:: bash + ./pruned_transducer_stateless4/export.py \ --exp-dir ./pruned_transducer_stateless4/exp \ --bpe-model data/lang_bpe_500/bpe.model \ diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst new file mode 100644 index 000000000..d85a3c67f --- /dev/null +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst @@ -0,0 +1,453 @@ +Zipformer CTC Blank Skip +======================== + +.. hint:: + + Please scroll down to the bottom of this page to find download links + for pretrained models if you don't want to train a model from scratch. + + +This tutorial shows you how to train a Zipformer model based on the guidance from +a co-trained CTC model using `blank skip method `_ +with the `LibriSpeech `_ dataset. + +.. note:: + + We use both CTC and RNN-T loss to train. During the forward pass, the encoder output + is first used to calculate the CTC posterior probability; then for each output frame, + if its blank posterior is bigger than some threshold, it will be simply discarded + from the encoder output. To prevent information loss, we also put a convolution module + similar to the one used in conformer (referred to as “LConv”) before the frame reduction. + + +Data preparation +---------------- + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./prepare.sh + +The script ``./prepare.sh`` handles the data preparation for you, **automagically**. +All you need to do is to run it. + +.. note:: + + We encourage you to read ``./prepare.sh``. + +The data preparation contains several stages. You can use the following two +options: + + - ``--stage`` + - ``--stop-stage`` + +to control which stage(s) should be run. By default, all stages are executed. + + +For example, + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./prepare.sh --stage 0 --stop-stage 0 + +means to run only stage 0. + +To run stage 2 to stage 5, use: + +.. code-block:: bash + + $ ./prepare.sh --stage 2 --stop-stage 5 + +.. hint:: + + If you have pre-downloaded the `LibriSpeech `_ + dataset and the `musan `_ dataset, say, + they are saved in ``/tmp/LibriSpeech`` and ``/tmp/musan``, you can modify + the ``dl_dir`` variable in ``./prepare.sh`` to point to ``/tmp`` so that + ``./prepare.sh`` won't re-download them. + +.. note:: + + All generated files by ``./prepare.sh``, e.g., features, lexicon, etc, + are saved in ``./data`` directory. + +We provide the following YouTube video showing how to run ``./prepare.sh``. + +.. note:: + + To get the latest news of `next-gen Kaldi `_, please subscribe + the following YouTube channel by `Nadira Povey `_: + + ``_ + +.. youtube:: ofEIoJL-mGM + +Training +-------- + +For stability, it doesn`t use blank skip method until model warm-up. + +Configurable options +~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless7_ctc_bs/train.py --help + +shows you the training options that can be passed from the commandline. +The following options are used quite often: + + - ``--full-libri`` + + If it's True, the training part uses all the training data, i.e., + 960 hours. Otherwise, the training part uses only the subset + ``train-clean-100``, which has 100 hours of training data. + + .. CAUTION:: + + The training set is perturbed by speed with two factors: 0.9 and 1.1. + If ``--full-libri`` is True, each epoch actually processes + ``3x960 == 2880`` hours of data. + + - ``--num-epochs`` + + It is the number of epochs to train. For instance, + ``./pruned_transducer_stateless7_ctc_bs/train.py --num-epochs 30`` trains for 30 epochs + and generates ``epoch-1.pt``, ``epoch-2.pt``, ..., ``epoch-30.pt`` + in the folder ``./pruned_transducer_stateless7_ctc_bs/exp``. + + - ``--start-epoch`` + + It's used to resume training. + ``./pruned_transducer_stateless7_ctc_bs/train.py --start-epoch 10`` loads the + checkpoint ``./pruned_transducer_stateless7_ctc_bs/exp/epoch-9.pt`` and starts + training from epoch 10, based on the state from epoch 9. + + - ``--world-size`` + + It is used for multi-GPU single-machine DDP training. + + - (a) If it is 1, then no DDP training is used. + + - (b) If it is 2, then GPU 0 and GPU 1 are used for DDP training. + + The following shows some use cases with it. + + **Use case 1**: You have 4 GPUs, but you only want to use GPU 0 and + GPU 2 for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ export CUDA_VISIBLE_DEVICES="0,2" + $ ./pruned_transducer_stateless7_ctc_bs/train.py --world-size 2 + + **Use case 2**: You have 4 GPUs and you want to use all of them + for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless7_ctc_bs/train.py --world-size 4 + + **Use case 3**: You have 4 GPUs but you only want to use GPU 3 + for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ export CUDA_VISIBLE_DEVICES="3" + $ ./pruned_transducer_stateless7_ctc_bs/train.py --world-size 1 + + .. caution:: + + Only multi-GPU single-machine DDP training is implemented at present. + Multi-GPU multi-machine DDP training will be added later. + + - ``--max-duration`` + + It specifies the number of seconds over all utterances in a + batch, before **padding**. + If you encounter CUDA OOM, please reduce it. + + .. HINT:: + + Due to padding, the number of seconds of all utterances in a + batch will usually be larger than ``--max-duration``. + + A larger value for ``--max-duration`` may cause OOM during training, + while a smaller value may increase the training time. You have to + tune it. + + +Pre-configured options +~~~~~~~~~~~~~~~~~~~~~~ + +There are some training options, e.g., weight decay, +number of warmup steps, results dir, etc, +that are not passed from the commandline. +They are pre-configured by the function ``get_params()`` in +`pruned_transducer_stateless7_ctc_bs/train.py `_ + +You don't need to change these pre-configured parameters. If you really need to change +them, please modify ``./pruned_transducer_stateless7_ctc_bs/train.py`` directly. + +Training logs +~~~~~~~~~~~~~ + +Training logs and checkpoints are saved in ``pruned_transducer_stateless7_ctc_bs/exp``. +You will find the following files in that directory: + + - ``epoch-1.pt``, ``epoch-2.pt``, ... + + These are checkpoint files saved at the end of each epoch, containing model + ``state_dict`` and optimizer ``state_dict``. + To resume training from some checkpoint, say ``epoch-10.pt``, you can use: + + .. code-block:: bash + + $ ./pruned_transducer_stateless7_ctc_bs/train.py --start-epoch 11 + + - ``checkpoint-436000.pt``, ``checkpoint-438000.pt``, ... + + These are checkpoint files saved every ``--save-every-n`` batches, + containing model ``state_dict`` and optimizer ``state_dict``. + To resume training from some checkpoint, say ``checkpoint-436000``, you can use: + + .. code-block:: bash + + $ ./pruned_transducer_stateless7_ctc_bs/train.py --start-batch 436000 + + - ``tensorboard/`` + + This folder contains tensorBoard logs. Training loss, validation loss, learning + rate, etc, are recorded in these logs. You can visualize them by: + + .. code-block:: bash + + $ cd pruned_transducer_stateless7_ctc_bs/exp/tensorboard + $ tensorboard dev upload --logdir . --description "Zipformer MMI training for LibriSpeech with icefall" + + It will print something like below: + + .. code-block:: + + TensorFlow installation not found - running with reduced feature set. + Upload started and will continue reading any new data as it's added to the logdir. + + To stop uploading, press Ctrl-C. + + New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/xyOZUKpEQm62HBIlUD4uPA/ + + Note there is a URL in the above output. Click it and you will see + tensorboard. + + .. hint:: + + If you don't have access to google, you can use the following command + to view the tensorboard log locally: + + .. code-block:: bash + + cd pruned_transducer_stateless7_ctc_bs/exp/tensorboard + tensorboard --logdir . --port 6008 + + It will print the following message: + + .. code-block:: + + Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all + TensorBoard 2.8.0 at http://localhost:6008/ (Press CTRL+C to quit) + + Now start your browser and go to ``_ to view the tensorboard + logs. + + + - ``log/log-train-xxxx`` + + It is the detailed training log in text format, same as the one + you saw printed to the console during training. + +Usage example +~~~~~~~~~~~~~ + +You can use the following command to start the training using 4 GPUs: + +.. code-block:: bash + + export CUDA_VISIBLE_DEVICES="0,1,2,3" + ./pruned_transducer_stateless7_ctc_bs/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --full-libri 1 \ + --exp-dir pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --use-fp16 1 + +Decoding +-------- + +The decoding part uses checkpoints saved by the training part, so you have +to run the training part first. + +.. hint:: + + There are two kinds of checkpoints: + + - (1) ``epoch-1.pt``, ``epoch-2.pt``, ..., which are saved at the end + of each epoch. You can pass ``--epoch`` to + ``pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py`` to use them. + + - (2) ``checkpoints-436000.pt``, ``epoch-438000.pt``, ..., which are saved + every ``--save-every-n`` batches. You can pass ``--iter`` to + ``pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py`` to use them. + + We suggest that you try both types of checkpoints and choose the one + that produces the lowest WERs. + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py --help + +shows the options for decoding. + +The following shows the example using ``epoch-*.pt``: + +.. code-block:: bash + + for m in greedy_search fast_beam_search modified_beam_search; do + ./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ + --epoch 30 \ + --avg 13 \ + --exp-dir pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method $m + done + +To test CTC branch, you can use the following command: + +.. code-block:: bash + + for m in ctc-decoding 1best; do + ./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ + --epoch 30 \ + --avg 13 \ + --exp-dir pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method $m + done + +Export models +------------- + +`pruned_transducer_stateless7_ctc_bs/export.py `_ supports exporting checkpoints from ``pruned_transducer_stateless7_ctc_bs/exp`` in the following ways. + +Export ``model.state_dict()`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Checkpoints saved by ``pruned_transducer_stateless7_ctc_bs/train.py`` also include +``optimizer.state_dict()``. It is useful for resuming training. But after training, +we are interested only in ``model.state_dict()``. You can use the following +command to extract ``model.state_dict()``. + +.. code-block:: bash + + ./pruned_transducer_stateless7_ctc_bs/export.py \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 13 \ + --jit 0 + +It will generate a file ``./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt``. + +.. hint:: + + To use the generated ``pretrained.pt`` for ``pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py``, + you can run: + + .. code-block:: bash + + cd pruned_transducer_stateless7_ctc_bs/exp + ln -s pretrained epoch-9999.pt + + And then pass ``--epoch 9999 --avg 1 --use-averaged-model 0`` to + ``./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py``. + +To use the exported model with ``./pruned_transducer_stateless7_ctc_bs/pretrained.py``, you +can run: + +.. code-block:: bash + + ./pruned_transducer_stateless7_ctc_bs/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +To test CTC branch using the exported model with ``./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py``: + +.. code-block:: bash + + ./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ + --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ + --bpe-model data/lang_bpe_500/bpe.model \ + --method ctc-decoding \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +Export model using ``torch.jit.script()`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + ./pruned_transducer_stateless7_ctc_bs/export.py \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 13 \ + --jit 1 + +It will generate a file ``cpu_jit.pt`` in the given ``exp_dir``. You can later +load it by ``torch.jit.load("cpu_jit.pt")``. + +Note ``cpu`` in the name ``cpu_jit.pt`` means the parameters when loaded into Python +are on CPU. You can use ``to("cuda")`` to move them to a CUDA device. + +To use the generated files with ``./pruned_transducer_stateless7_ctc_bs/jit_pretrained.py``: + +.. code-block:: bash + + ./pruned_transducer_stateless7_ctc_bs/jit_pretrained.py \ + --nn-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/cpu_jit.pt \ + /path/to/foo.wav \ + /path/to/bar.wav + +To test CTC branch using the generated files with ``./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py``: + +.. code-block:: bash + + ./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ + --model-filename ./pruned_transducer_stateless7_ctc_bs/exp/cpu_jit.pt \ + --bpe-model data/lang_bpe_500/bpe.model \ + --method ctc-decoding \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +Download pretrained models +-------------------------- + +If you don't want to train from scratch, you can download the pretrained models +by visiting the following links: + + - ``_ + + See ``_ + for the details of the above pretrained models diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_mmi.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_mmi.rst index db268dd02..a7b59a992 100644 --- a/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_mmi.rst +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_mmi.rst @@ -272,7 +272,7 @@ You will find the following files in that directory: Usage example ~~~~~~~~~~~~~ -You can use the following command to start the training using 8 GPUs: +You can use the following command to start the training using 4 GPUs: .. code-block:: bash @@ -382,7 +382,7 @@ can run: /path/to/bar.wav Export model using ``torch.jit.script()`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code-block:: bash From dfbcf606e7a7798bc5d9f73da82126914800be0e Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 27 Dec 2022 09:25:42 +0800 Subject: [PATCH 085/120] small fixes to prepare.sh (#789) --- egs/librispeech/ASR/prepare.sh | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 59bed8389..b1d207049 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -123,10 +123,12 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then touch data/fbank/.librispeech.done fi - cat <(gunzip -c data/fbank/librispeech_cuts_train-clean-100.jsonl.gz) \ - <(gunzip -c data/fbank/librispeech_cuts_train-clean-360.jsonl.gz) \ - <(gunzip -c data/fbank/librispeech_cuts_train-other-500.jsonl.gz) | \ - shuf | gzip -c > data/fbank/librispeech_cuts_train-all-shuf.jsonl.gz + if [ ! -f data/fbank/librispeech_cuts_train-all-shuf.jsonl.gz ]; then + cat <(gunzip -c data/fbank/librispeech_cuts_train-clean-100.jsonl.gz) \ + <(gunzip -c data/fbank/librispeech_cuts_train-clean-360.jsonl.gz) \ + <(gunzip -c data/fbank/librispeech_cuts_train-other-500.jsonl.gz) | \ + shuf | gzip -c > data/fbank/librispeech_cuts_train-all-shuf.jsonl.gz + fi if [ ! -e data/fbank/.librispeech-validated.done ]; then log "Validating data/fbank for LibriSpeech" @@ -244,7 +246,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then fi if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - log "Stage 7: Prepare bigram P" + log "Stage 7: Prepare bigram token-level P for MMI training" for vocab_size in ${vocab_sizes[@]}; do lang_dir=data/lang_bpe_${vocab_size} From 88b7895adf03424497619b54cdd9a230e9216b5c Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 27 Dec 2022 13:59:55 +0800 Subject: [PATCH 086/120] fix librispeech.py in multi-dataset setup (#791) --- .../ASR/pruned_transducer_stateless3/librispeech.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/librispeech.py b/egs/librispeech/ASR/pruned_transducer_stateless3/librispeech.py index 6dba8e9fe..9f2cb6225 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/librispeech.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/librispeech.py @@ -72,3 +72,12 @@ class LibriSpeech: f = self.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz" logging.info(f"About to get dev-other cuts from {f}") return load_manifest_lazy(f) + + def train_all_shuf_cuts(self) -> CutSet: + logging.info( + "About to get the shuffled train-clean-100, \ + train-clean-360 and train-other-500 cuts" + ) + return load_manifest_lazy( + self.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" + ) From a24a1cbfa9ffd03a629b988ae19e5d35248e72ec Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Tue, 27 Dec 2022 15:06:53 +0800 Subject: [PATCH 087/120] small fix for zipformer_ctc_blankskip.rst (#792) Co-authored-by: yifanyang --- .../Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst index d85a3c67f..56a420605 100644 --- a/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst @@ -228,7 +228,7 @@ You will find the following files in that directory: .. code-block:: bash $ cd pruned_transducer_stateless7_ctc_bs/exp/tensorboard - $ tensorboard dev upload --logdir . --description "Zipformer MMI training for LibriSpeech with icefall" + $ tensorboard dev upload --logdir . --description "Zipformer-CTC co-training using blank skip for LibriSpeech with icefall" It will print something like below: From 05dfd5e630d525dcc8828feba4d9daf6624af319 Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Tue, 27 Dec 2022 15:26:11 +0800 Subject: [PATCH 088/120] Fix distillation with HuBERT (#790) * update vq huggingface url * remove hard lhotse version requirement * resolve ID mismatch * small fixes * Update egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py Co-authored-by: Fangjun Kuang * update version check Co-authored-by: Fangjun Kuang --- .../ASR/distillation_with_hubert.sh | 12 +++++-- .../pruned_transducer_stateless6/vq_utils.py | 34 ++++++++++++++++--- 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/distillation_with_hubert.sh b/egs/librispeech/ASR/distillation_with_hubert.sh index 2a69d3921..d5d3008aa 100755 --- a/egs/librispeech/ASR/distillation_with_hubert.sh +++ b/egs/librispeech/ASR/distillation_with_hubert.sh @@ -35,7 +35,7 @@ stop_stage=4 # export CUDA_VISIBLE_DEVICES="0" # # Suppose GPU 2,3,4,5 are available. -export CUDA_VISIBLE_DEVICES="0,1,2,3" +# export CUDA_VISIBLE_DEVICES="0,1,2,3" exp_dir=./pruned_transducer_stateless6/exp mkdir -p $exp_dir @@ -49,7 +49,7 @@ full_libri=False # "True" -> stage 0 and stage 1 would be skipped, # and directly download the extracted codebook indexes for distillation # "False" -> start from scratch -use_extracted_codebook=False +use_extracted_codebook=True # teacher_model_id can be one of # "hubert_xtralarge_ll60k_finetune_ls960" -> fine-tuned model, it is the one we currently use. @@ -155,8 +155,14 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then fi log "Downloading extracted codebook indexes to $codebook_download_dir" # Make sure you have git-lfs installed (https://git-lfs.github.com) + # The codebook indexes are generated using lhotse 1.11.0, to avoid + # potential issues, we recommend you to use lhotse version >= 1.11.0 + lhotse_version=$(python3 -c "import lhotse; from packaging import version; print(version.parse(lhotse.version.__version__)>=version.parse('1.11.0'))") + if [ "$lhotse_version" == "False" ]; then + log "Expecting lhotse >= 1.11.0. This may lead to potential ID mismatch." + fi git lfs install - git clone https://huggingface.co/Zengwei/pruned_transducer_stateless6_hubert_xtralarge_ll60k_finetune_ls960 $codebook_download_dir + git clone https://huggingface.co/marcoyang/pruned_transducer_stateless6_hubert_xtralarge_ll60k_finetune_ls960 $codebook_download_dir mkdir -p data/vq_fbank mv $codebook_download_dir/*.jsonl.gz data/vq_fbank/ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py index 97a83b974..bf072d865 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py @@ -244,10 +244,36 @@ class CodebookIndexExtractor: ) cuts_vq = load_manifest(vq_manifest_path) cuts_ori = load_manifest(ori_manifest_path) - cuts_vq = cuts_vq.sort_like(cuts_ori) - for cut_idx, (cut_vq, cut_ori) in enumerate(zip(cuts_vq, cuts_ori)): - assert cut_vq.id == cut_ori.id - cut_ori.codebook_indexes = cut_vq.codebook_indexes + assert len(cuts_vq) == len(cuts_ori), "Cuts should have the same length!" + + if set(cuts_vq.ids) == set(cuts_ori.ids): + # IDs match exactly + cuts_vq = cuts_vq.sort_like(cuts_ori) + for cut_idx, (cut_vq, cut_ori) in enumerate(zip(cuts_vq, cuts_ori)): + assert cut_vq.id == cut_ori.id, (cut_vq.id, cut_ori.id) + cut_ori.codebook_indexes = cut_vq.codebook_indexes + else: + # in case of ID mismatch, remap them + # get the mapping between audio and cut ID + logging + ori_id_map = {} + for id in cuts_ori.ids: + # some text normalization + if "sp" in id: + clean_id = "-".join(id.split("-")[:3]) + "_" + id.split("_")[-1] + else: + clean_id = "-".join(id.split("-")[:3]) + ori_id_map[clean_id] = id + + for id in cuts_vq.ids: + if "sp" in id: + clean_id = "-".join(id.split("-")[:3]) + "_" + id.split("_")[-1] + else: + clean_id = "-".join(id.split("-")[:3]) + assert clean_id in ori_id_map, clean_id + cuts_ori[ori_id_map[clean_id]].codebook_indexes = cuts_vq[ + id + ].codebook_indexes CutSet.from_cuts(cuts_ori).to_jsonl(dst_vq_manifest_path) logging.info(f"Processed {subset}.") From 3c54333b06a87bb2efc665c66d3c25370033d182 Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Wed, 28 Dec 2022 11:20:38 +0800 Subject: [PATCH 089/120] fix bug (#796) --- .../pruned_transducer_stateless5/conformer.py | 38 ++++++++++++------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py index 9bb55d07a..23a877b2f 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py @@ -966,20 +966,32 @@ class RelPositionMultiheadAttention(nn.Module): (batch_size, num_heads, time1, n) = x.shape time2 = time1 + left_context - assert ( - n == left_context + 2 * time1 - 1 - ), f"{n} == {left_context} + 2 * {time1} - 1" + if not torch.jit.is_tracing(): + assert ( + n == left_context + 2 * time1 - 1 + ), f"{n} == {left_context} + 2 * {time1} - 1" - # Note: TorchScript requires explicit arg for stride() - batch_stride = x.stride(0) - head_stride = x.stride(1) - time1_stride = x.stride(2) - n_stride = x.stride(3) - return x.as_strided( - (batch_size, num_heads, time1, time2), - (batch_stride, head_stride, time1_stride - n_stride, n_stride), - storage_offset=n_stride * (time1 - 1), - ) + if torch.jit.is_tracing(): + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(time2) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + + x = x.reshape(-1, n) + x = torch.gather(x, dim=1, index=indexes) + x = x.reshape(batch_size, num_heads, time1, time2) + return x + else: + # Note: TorchScript requires explicit arg for stride() + batch_stride = x.stride(0) + head_stride = x.stride(1) + time1_stride = x.stride(2) + n_stride = x.stride(3) + return x.as_strided( + (batch_size, num_heads, time1, time2), + (batch_stride, head_stride, time1_stride - n_stride, n_stride), + storage_offset=n_stride * (time1 - 1), + ) def multi_head_attention_forward( self, From 1f0408b1031dccfcb13ae3641b576434aec4f983 Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Thu, 29 Dec 2022 10:53:36 +0800 Subject: [PATCH 090/120] Support Transformer LM (#750) * support transformer LM * show number of parameters during training * update docstring * testing files for ppl calculation * add lm wrampper for rnn and transformer LM * apply lm wrapper in lm shallow fusion * small updates * update decode.py to support LM fusion and LODR * add export.py * update CI and workflow * update decoding results * fix CI * remove transformer LM from CI test --- ...h-lstm-transducer-stateless2-2022-09-03.sh | 24 +- ...-lstm-transducer-stateless2-2022-09-03.yml | 11 +- egs/librispeech/ASR/RESULTS.md | 66 +- .../ASR/lstm_transducer_stateless2/decode.py | 195 +++--- .../beam_search.py | 584 +++++++++-------- .../pruned_transducer_stateless3/decode.py | 186 +++--- .../pruned_transducer_stateless5/decode.py | 254 +++++--- .../pruned_transducer_stateless7/decode.py | 178 ++++- icefall/__init__.py | 2 + icefall/lm_wrapper.py | 254 ++++++++ icefall/rnn_lm/model.py | 21 +- icefall/rnn_lm/train.py | 3 + icefall/transformer_lm/attention.py | 510 +++++++++++++++ icefall/transformer_lm/compute_perplexity.py | 195 ++++++ icefall/transformer_lm/dataset.py | 1 + icefall/transformer_lm/encoder.py | 329 ++++++++++ icefall/transformer_lm/export.py | 186 ++++++ icefall/transformer_lm/model.py | 115 ++++ icefall/transformer_lm/scaling.py | 1 + icefall/transformer_lm/train.py | 609 ++++++++++++++++++ 20 files changed, 3086 insertions(+), 638 deletions(-) create mode 100644 icefall/lm_wrapper.py create mode 100644 icefall/transformer_lm/attention.py create mode 100644 icefall/transformer_lm/compute_perplexity.py create mode 120000 icefall/transformer_lm/dataset.py create mode 100644 icefall/transformer_lm/encoder.py create mode 100644 icefall/transformer_lm/export.py create mode 100644 icefall/transformer_lm/model.py create mode 120000 icefall/transformer_lm/scaling.py create mode 100644 icefall/transformer_lm/train.py diff --git a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh index ac5b15979..9b883f889 100755 --- a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh +++ b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh @@ -193,7 +193,7 @@ if [[ x"${GITHUB_EVENT_LABEL_NAME}" == x"shallow-fusion" ]]; then ls -lh data ls -lh lstm_transducer_stateless2/exp - log "Decoding test-clean and test-other" + log "Decoding test-clean and test-other with RNN LM" ./lstm_transducer_stateless2/decode.py \ --use-averaged-model 0 \ @@ -201,12 +201,14 @@ if [[ x"${GITHUB_EVENT_LABEL_NAME}" == x"shallow-fusion" ]]; then --avg 1 \ --exp-dir lstm_transducer_stateless2/exp \ --max-duration 600 \ - --decoding-method modified_beam_search_rnnlm_shallow_fusion \ + --decoding-method modified_beam_search_lm_shallow_fusion \ --beam 4 \ - --rnn-lm-scale 0.3 \ - --rnn-lm-exp-dir $lm_repo/exp \ - --rnn-lm-epoch 88 \ - --rnn-lm-avg 1 \ + --use-shallow-fusion 1 \ + --lm-type rnn \ + --lm-exp-dir $lm_repo/exp \ + --lm-epoch 88 \ + --lm-avg 1 \ + --lm-scale 0.3 \ --rnn-lm-num-layers 3 \ --rnn-lm-tie-weights 1 fi @@ -245,11 +247,13 @@ if [[ x"${GITHUB_EVENT_LABEL_NAME}" == x"LODR" ]]; then --avg 1 \ --exp-dir lstm_transducer_stateless2/exp \ --max-duration 600 \ - --decoding-method modified_beam_search_rnnlm_LODR \ + --decoding-method modified_beam_search_LODR \ --beam 4 \ - --rnn-lm-scale 0.3 \ - --rnn-lm-exp-dir $lm_repo/exp \ - --rnn-lm-epoch 88 \ + --use-shallow-fusion 1 \ + --lm-type rnn \ + --lm-exp-dir $lm_repo/exp \ + --lm-scale 0.4 \ + --lm-epoch 88 \ --rnn-lm-avg 1 \ --rnn-lm-num-layers 3 \ --rnn-lm-tie-weights 1 \ diff --git a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml index f5ee09e16..3752f67e3 100644 --- a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml +++ b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml @@ -139,9 +139,10 @@ jobs: cd egs/librispeech/ASR tree lstm_transducer_stateless2/exp cd lstm_transducer_stateless2/exp - echo "===modified_beam_search_rnnlm_shallow_fusion===" - find modified_beam_search_rnnlm_shallow_fusion -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find modified_beam_search_rnnlm_shallow_fusion -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + echo "===modified_beam_search_lm_shallow_fusion===" + echo "===Using RNNLM===" + find modified_beam_search_lm_shallow_fusion -name "log-*rnn*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find modified_beam_search_lm_shallow_fusion -name "log-*rnn*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - name: Display decoding results for lstm_transducer_stateless2 if: github.event.label.name == 'LODR' @@ -151,8 +152,8 @@ jobs: tree lstm_transducer_stateless2/exp cd lstm_transducer_stateless2/exp echo "===modified_beam_search_rnnlm_LODR===" - find modified_beam_search_rnnlm_LODR -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find modified_beam_search_rnnlm_LODR -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + find modified_beam_search_LODR -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find modified_beam_search_LODR -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - name: Upload decoding results for lstm_transducer_stateless2 uses: actions/upload-artifact@v2 diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 092f77814..007d34a62 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -320,6 +320,10 @@ Number of model parameters: 70369391, i.e., 70.37 M |----------------------|------------|-------------|----------------------------------------| | greedy search | 2.17 | 5.23 | --epoch 39 --avg 6 --max-duration 600 | | modified beam search | 2.15 | 5.20 | --epoch 39 --avg 6 --max-duration 600 | +| modified beam search + RNNLM shallow fusion | 1.99 | 4.73 | --epoch 39 --avg 6 --max-duration 600 | +| modified beam search + TransformerLM shallow fusion | 1.94 | 4.73 | --epoch 39 --avg 6 --max-duration 600 | +| modified beam search + RNNLM + LODR | 1.91 | 4.57 | --epoch 39 --avg 6 --max-duration 600 | +| modified beam search + TransformerLM + LODR | 1.91 | 4.51 | --epoch 39 --avg 6 --max-duration 600 | | fast beam search | 2.15 | 5.22 | --epoch 39 --avg 6 --max-duration 600 | The training commands are: @@ -458,7 +462,9 @@ The WERs are: | greedy search (max sym per frame 1) | 2.78 | 7.36 | --iter 468000 --avg 16 | | modified_beam_search | 2.73 | 7.15 | --iter 468000 --avg 16 | | modified_beam_search + RNNLM shallow fusion | 2.42 | 6.46 | --iter 468000 --avg 16 | -| modified_beam_search + RNNLM shallow fusion | 2.28 | 5.94 | --iter 468000 --avg 16 | +| modified_beam_search + TransformerLM shallow fusion | 2.37 | 6.48 | --iter 468000 --avg 16 | +| modified_beam_search + RNNLM + LODR | 2.24 | 5.89 | --iter 468000 --avg 16 | +| modified_beam_search + TransformerLM + LODR | 2.19 | 5.90 | --iter 468000 --avg 16 | | fast_beam_search | 2.76 | 7.31 | --iter 468000 --avg 16 | | greedy search (max sym per frame 1) | 2.77 | 7.35 | --iter 472000 --avg 18 | | modified_beam_search | 2.75 | 7.08 | --iter 472000 --avg 18 | @@ -513,9 +519,12 @@ for m in greedy_search fast_beam_search modified_beam_search; do done ``` -To decode with RNNLM shallow fusion, use the following decoding command. A well-trained RNNLM -can be found here: +You may also decode using shallow fusion with external neural network LM. To do so you need to +download a well-trained NN LM: +RNN LM: +Transformer LM: +```bash for iter in 472000; do for avg in 8 10 12 14 16 18; do ./lstm_transducer_stateless2/decode.py \ @@ -523,23 +532,24 @@ for iter in 472000; do --avg $avg \ --exp-dir ./lstm_transducer_stateless2/exp \ --max-duration 600 \ - --decoding-method modified_beam_search_rnnlm_shallow_fusion \ - --beam 4 \ - --rnn-lm-scale 0.3 \ - --rnn-lm-exp-dir /path/to/RNNLM \ - --rnn-lm-epoch 99 \ - --rnn-lm-avg 1 \ - --rnn-lm-num-layers 3 \ - --rnn-lm-tie-weights 1 + --decoding-method modified_beam_search_lm_shallow_fusion \ + --use-shallow-fusion 1 \ + --lm-type rnn \ + --lm-exp-dir /ceph-data4/yangxiaoyu/pretrained_models/LM/icefall-librispeech-rnn-lm/exp \ + --lm-epoch 99 \ + --lm-scale $lm_scale \ + --lm-avg 1 \ done done +``` -You may also decode using LODR + RNNLM shallow fusion. This decoding method is proposed in . +You may also decode using LODR + LM shallow fusion. This decoding method is proposed in . It subtracts the internal language model score during shallow fusion, which is approximated by a bi-gram model. The bi-gram can be generated by `generate-lm.sh`, or you may download it from . The decoding command is as follows: +```bash for iter in 472000; do for avg in 8 10 12 14 16 18; do ./lstm_transducer_stateless2/decode.py \ @@ -547,18 +557,22 @@ for iter in 472000; do --avg $avg \ --exp-dir ./lstm_transducer_stateless2/exp \ --max-duration 600 \ - --decoding-method modified_beam_search_rnnlm_LODR \ + --decoding-method modified_beam_search_LODR \ --beam 4 \ - --rnn-lm-scale 0.4 \ - --rnn-lm-exp-dir /path/to/RNNLM \ - --rnn-lm-epoch 99 \ - --rnn-lm-avg 1 \ - --rnn-lm-num-layers 3 \ - --rnn-lm-tie-weights 1 \ - --token-ngram 2 \ + --max-contexts 4 \ + --use-shallow-fusion 1 \ + --lm-type rnn \ + --lm-exp-dir /ceph-data4/yangxiaoyu/pretrained_models/LM/icefall-librispeech-rnn-lm/exp \ + --lm-epoch 99 \ + --lm-scale 0.4 \ + --lm-avg 1 \ + --tokens-ngram 2 \ --ngram-lm-scale -0.16 done done +``` +Note that you can also set `--lm-type transformer` to use transformer LM during LODR. But it will be slower +because it has not been optimized. The pre-trained transformer LM is available at Pretrained models, training logs, decoding logs, and decoding results are available at @@ -1717,6 +1731,9 @@ layers (24 v.s 12) but a narrower model (1536 feedforward dim and 384 encoder di | greedy search (max sym per frame 1) | 2.54 | 5.72 | --epoch 30 --avg 10 --max-duration 600 | | modified beam search | 2.47 | 5.71 | --epoch 30 --avg 10 --max-duration 600 | | modified beam search + RNNLM shallow fusion | 2.27 | 5.24 | --epoch 30 --avg 10 --max-duration 600 | +| modified beam search + RNNLM + LODR | 2.23 | 5.17 | --epoch 30 --avg 10 --max-duration 600 | +| modified beam search + TransformerLM shallow fusion | 2.27 | 5.26 | --epoch 30 --avg 10 --max-duration 600 | +| modified beam search + TransformerLM + LODR | 2.22 | 5.11 | --epoch 30 --avg 10 --max-duration 600 | | fast beam search | 2.5 | 5.72 | --epoch 30 --avg 10 --max-duration 600 | ```bash @@ -2080,7 +2097,8 @@ subset so that the gigaspeech dataloader never exhausts. | greedy search (max sym per frame 1) | 2.03 | 4.70 | --iter 1224000 --avg 14 --max-duration 600 | | modified beam search | 2.00 | 4.63 | --iter 1224000 --avg 14 --max-duration 600 | | modified beam search + rnnlm shallow fusion | 1.94 | 4.2 | --iter 1224000 --avg 14 --max-duration 600 | -| modified beam search + LODR | 1.83 | 4.03 | --iter 1224000 --avg 14 --max-duration 600 | +| modified beam search + rnnlm + LODR | 1.77 | 3.99 | --iter 1224000 --avg 14 --max-duration 600 | +| modified beam search + TransformerLM + LODR | 1.75 | 3.94 | --iter 1224000 --avg 14 --max-duration 600 | | fast beam search | 2.10 | 4.68 | --iter 1224000 --avg 14 --max-duration 600 | The training commands are: @@ -2126,8 +2144,10 @@ for iter in 1224000; do done done ``` -You may also decode using shallow fusion with external RNNLM. To do so you need to -download a well-trained RNNLM from this link +You may also decode using shallow fusion with external neural network LM. To do so you need to +download a well-trained NN LM: +RNN LM: +Transformer LM: ```bash rnn_lm_scale=0.3 diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index fa5bf1825..78be9c01f 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -93,36 +93,37 @@ Usage: --max-contexts 8 \ --max-states 64 -(8) modified beam search (with RNNLM shallow fusion) +(8) modified beam search (with LM shallow fusion) ./lstm_transducer_stateless2/decode.py \ --epoch 35 \ --avg 15 \ --exp-dir ./lstm_transducer_stateless2/exp \ --max-duration 600 \ - --decoding-method modified_beam_search_rnnlm_shallow_fusion \ + --decoding-method modified_beam_search_lm_shallow_fusion \ --beam 4 \ - --rnn-lm-scale 0.3 \ - --rnn-lm-exp-dir /path/to/RNNLM \ + --lm-type rnn \ + --lm-scale 0.3 \ + --lm-exp-dir /path/to/LM \ --rnn-lm-epoch 99 \ --rnn-lm-avg 1 \ --rnn-lm-num-layers 3 \ --rnn-lm-tie-weights 1 -(9) modified beam search with RNNLM shallow fusion + LODR +(9) modified beam search with LM shallow fusion + LODR ./lstm_transducer_stateless2/decode.py \ --epoch 35 \ --avg 15 \ --max-duration 600 \ --exp-dir ./lstm_transducer_stateless2/exp \ - --decoding-method modified_beam_search_rnnlm_LODR \ + --decoding-method modified_beam_search_LODR \ --beam 4 \ - --max-contexts 4 \ - --rnn-lm-scale 0.4 \ - --rnn-lm-exp-dir /path/to/RNNLM/exp \ + --lm-type rnn \ + --lm-scale 0.4 \ + --lm-exp-dir /path/to/LM \ --rnn-lm-epoch 99 \ --rnn-lm-avg 1 \ --rnn-lm-num-layers 3 \ - --rnn-lm-tie-weights 1 \ + --rnn-lm-tie-weights 1 --tokens-ngram 2 \ --ngram-lm-scale -0.16 \ """ @@ -148,14 +149,14 @@ from beam_search import ( greedy_search, greedy_search_batch, modified_beam_search, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, modified_beam_search_ngram_rescoring, - modified_beam_search_rnnlm_LODR, - modified_beam_search_rnnlm_shallow_fusion, ) from librispeech import LibriSpeech from train import add_model_arguments, get_params, get_transducer_model -from icefall import NgramLm +from icefall import LmScorer, NgramLm from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -163,7 +164,6 @@ from icefall.checkpoint import ( load_checkpoint, ) from icefall.lexicon import Lexicon -from icefall.rnn_lm.model import RnnLmModel from icefall.utils import ( AttributeDict, setup_logger, @@ -253,8 +253,8 @@ def get_parser(): - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG - modified_beam_search_ngram_rescoring - - modified_beam_search_rnnlm_shallow_fusion - - modified_beam_search_rnnlm_LODR + - modified_beam_search_lm_shallow_fusion + - modified_beam_search_LODR If you use fast_beam_search_nbest_LG, you have to specify `--lang-dir`, which should contain `LG.pt`. """, @@ -344,67 +344,28 @@ def get_parser(): ) parser.add_argument( - "--rnn-lm-scale", - type=float, - default=0.0, - help="""Used only when --method is modified-beam-search_rnnlm_shallow_fusion. - It specifies the path to RNN LM exp dir. - """, - ) - - parser.add_argument( - "--rnn-lm-exp-dir", - type=str, - default="rnn_lm/exp", - help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. - It specifies the path to RNN LM exp dir. - """, - ) - - parser.add_argument( - "--rnn-lm-epoch", - type=int, - default=7, - help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. - It specifies the checkpoint to use. - """, - ) - - parser.add_argument( - "--rnn-lm-avg", - type=int, - default=2, - help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. - It specifies the number of checkpoints to average. - """, - ) - - parser.add_argument( - "--rnn-lm-embedding-dim", - type=int, - default=2048, - help="Embedding dim of the model", - ) - - parser.add_argument( - "--rnn-lm-hidden-dim", - type=int, - default=2048, - help="Hidden dim of the model", - ) - - parser.add_argument( - "--rnn-lm-num-layers", - type=int, - default=4, - help="Number of RNN layers the model", - ) - parser.add_argument( - "--rnn-lm-tie-weights", + "--use-shallow-fusion", type=str2bool, default=False, - help="""True to share the weights between the input embedding layer and the - last output linear layer + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. """, ) @@ -440,8 +401,7 @@ def decode_one_batch( decoding_graph: Optional[k2.Fsa] = None, ngram_lm: Optional[NgramLm] = None, ngram_lm_scale: float = 1.0, - rnnlm: Optional[RnnLmModel] = None, - rnnlm_scale: float = 1.0, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -470,6 +430,9 @@ def decode_one_batch( The decoding graph. Can be either a `k2.trivial_graph` or LG, Used only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural net LM for shallow fusion. Only used when `--use-shallow-fusion` + set to true. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -581,20 +544,19 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion": - hyp_tokens = modified_beam_search_rnnlm_shallow_fusion( + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, sp=sp, - rnnlm=rnnlm, - rnnlm_scale=rnnlm_scale, + LM=LM, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_rnnlm_LODR": - hyp_tokens = modified_beam_search_rnnlm_LODR( + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, @@ -602,8 +564,7 @@ def decode_one_batch( sp=sp, LODR_lm=ngram_lm, LODR_lm_scale=ngram_lm_scale, - rnnlm=rnnlm, - rnnlm_scale=rnnlm_scale, + LM=LM, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -658,8 +619,7 @@ def decode_dataset( decoding_graph: Optional[k2.Fsa] = None, ngram_lm: Optional[NgramLm] = None, ngram_lm_scale: float = 1.0, - rnnlm: Optional[RnnLmModel] = None, - rnnlm_scale: float = 1.0, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -678,6 +638,8 @@ def decode_dataset( The decoding graph. Can be either a `k2.trivial_graph` or LG, Used only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network LM, used during shallow fusion Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -711,8 +673,7 @@ def decode_dataset( batch=batch, ngram_lm=ngram_lm, ngram_lm_scale=ngram_lm_scale, - rnnlm=rnnlm, - rnnlm_scale=rnnlm_scale, + LM=LM, ) for name, hyps in hyps_dict.items(): @@ -730,6 +691,7 @@ def decode_dataset( batch_str = f"{batch_idx}/{num_batches}" logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -781,6 +743,7 @@ def save_results( def main(): parser = get_parser() AsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -795,9 +758,9 @@ def main(): "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", - "modified_beam_search_rnnlm_LODR", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", "modified_beam_search_ngram_rescoring", - "modified_beam_search_rnnlm_shallow_fusion", ) params.res_dir = params.exp_dir / params.decoding_method @@ -820,12 +783,18 @@ def main(): else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - if "rnnlm" in params.decoding_method: - params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}" + if "ngram" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + if params.use_shallow_fusion: + if params.lm_type == "rnn": + params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}" + elif params.lm_type == "transformer": + params.suffix += f"-transformer-lm-scale-{params.lm_scale}" - if "LODR" in params.decoding_method: - params.suffix += "-LODR" + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) if params.use_averaged_model: params.suffix += "-use-averaged-model" @@ -954,28 +923,19 @@ def main(): ngram_lm = None ngram_lm_scale = None - # only load rnnlm if used - if "rnnlm" in params.decoding_method: - rnn_lm_scale = params.rnn_lm_scale - - rnn_lm_model = RnnLmModel( - vocab_size=params.vocab_size, - embedding_dim=params.rnn_lm_embedding_dim, - hidden_dim=params.rnn_lm_hidden_dim, - num_layers=params.rnn_lm_num_layers, - tie_weights=params.rnn_lm_tie_weights, + # only load the neural network LM if doing shallow fusion + if params.use_shallow_fusion: + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, ) - assert params.rnn_lm_avg == 1 + LM.to(device) + LM.eval() - load_checkpoint( - f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt", - rnn_lm_model, - ) - rnn_lm_model.to(device) - rnn_lm_model.eval() else: - rnn_lm_model = None - rnn_lm_scale = 0.0 + LM = None if "fast_beam_search" in params.decoding_method: if params.decoding_method == "fast_beam_search_nbest_LG": @@ -1003,7 +963,9 @@ def main(): librispeech = LibriSpeech(manifest_dir=args.manifest_dir) test_clean_cuts = librispeech.test_clean_cuts() + # test_clean_cuts = test_clean_cuts.subset(first=500) test_other_cuts = librispeech.test_other_cuts() + # test_other_cuts = test_other_cuts.subset(first=500) test_clean_dl = asr_datamodule.test_dataloaders(test_clean_cuts) test_other_dl = asr_datamodule.test_dataloaders(test_other_cuts) @@ -1021,8 +983,7 @@ def main(): decoding_graph=decoding_graph, ngram_lm=ngram_lm, ngram_lm_scale=ngram_lm_scale, - rnnlm=rnn_lm_model, - rnnlm_scale=rnn_lm_scale, + LM=LM, ) save_results( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index b324cc9b7..7388af389 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -26,7 +26,9 @@ from model import Transducer from icefall import NgramLm, NgramLmStateCost from icefall.decode import Nbest, one_best_decoding +from icefall.lm_wrapper import LmScorer from icefall.rnn_lm.model import RnnLmModel +from icefall.transformer_lm.model import TransformerLM from icefall.utils import ( DecodingResults, add_eos, @@ -1846,254 +1848,14 @@ def modified_beam_search_ngram_rescoring( return ans -def modified_beam_search_rnnlm_shallow_fusion( - model: Transducer, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - sp: spm.SentencePieceProcessor, - rnnlm: RnnLmModel, - rnnlm_scale: float, - beam: int = 4, - return_timestamps: bool = False, -) -> List[List[int]]: - """Modified_beam_search + RNNLM shallow fusion - - Args: - model (Transducer): - The transducer model - encoder_out (torch.Tensor): - Encoder output in (N,T,C) - encoder_out_lens (torch.Tensor): - A 1-D tensor of shape (N,), containing the number of - valid frames in encoder_out before padding. - sp: - Sentence piece generator. - rnnlm (RnnLmModel): - RNNLM - rnnlm_scale (float): - scale of RNNLM in shallow fusion - beam (int, optional): - Beam size. Defaults to 4. - - Returns: - Return a list-of-list of token IDs. ans[i] is the decoding results - for the i-th utterance. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - assert rnnlm is not None - lm_scale = rnnlm_scale - vocab_size = rnnlm.vocab_size - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = model.decoder.blank_id - sos_id = sp.piece_to_id("") - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - # get initial lm score and lm state by scoring the "sos" token - sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) - init_score, init_states = rnnlm.score_token(sos_token) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[blank_id] * context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - state=init_states, - lm_score=init_score.reshape(-1), - timestamp=[], - ) - ) - - rnnlm.clean_cache() - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - for (t, batch_size) in enumerate(batch_size_list): - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] # get batch - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] - ) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - """ - for all hyps with a non-blank new token, score this token. - It is a little confusing here because this for-loop - looks very similar to the one below. Here, we go through all - top-k tokens and only add the non-blanks ones to the token_list. - The RNNLM will score those tokens given the LM states. Note that - the variable `scores` is the LM score after seeing the new - non-blank token. - """ - token_list = [] - hs = [] - cs = [] - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_token = topk_token_indexes[k] - if new_token not in (blank_id, unk_id): - assert new_token != 0, new_token - token_list.append([new_token]) - # store the LSTM states - hs.append(hyp.state[0]) - cs.append(hyp.state[1]) - - # forward RNNLM to get new states and scores - if len(token_list) != 0: - tokens_to_score = ( - torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) - ) - - hs = torch.cat(hs, dim=1).to(device) - cs = torch.cat(cs, dim=1).to(device) - scores, lm_states = rnnlm.score_token(tokens_to_score, (hs, cs)) - - count = 0 # index, used to locate score and lm states - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - ys = hyp.ys[:] - - lm_score = hyp.lm_score - state = hyp.state - - hyp_log_prob = topk_log_probs[k] # get score of current hyp - new_token = topk_token_indexes[k] - new_timestamp = hyp.timestamp[:] - if new_token not in (blank_id, unk_id): - - ys.append(new_token) - new_timestamp.append(t) - hyp_log_prob += lm_score[new_token] * lm_scale # add the lm score - - lm_score = scores[count] - state = ( - lm_states[0][:, count, :].unsqueeze(1), - lm_states[1][:, count, :].unsqueeze(1), - ) - count += 1 - - new_hyp = Hypothesis( - ys=ys, - log_prob=hyp_log_prob, - state=state, - lm_score=lm_score, - timestamp=new_timestamp, - ) - B[i].add(new_hyp) - - B = B + finalized_B - best_hyps = [b.get_most_probable(length_norm=True) for b in B] - - sorted_ans = [h.ys[context_size:] for h in best_hyps] - sorted_timestamps = [h.timestamp for h in best_hyps] - ans = [] - ans_timestamps = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) - - if not return_timestamps: - return ans - else: - return DecodingResults( - tokens=ans, - timestamps=ans_timestamps, - ) - - -def modified_beam_search_rnnlm_LODR( +def modified_beam_search_LODR( model: Transducer, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, sp: spm.SentencePieceProcessor, LODR_lm: NgramLm, LODR_lm_scale: float, - rnnlm: RnnLmModel, - rnnlm_scale: float, + LM: LmScorer, beam: int = 4, ) -> List[List[int]]: """This function implements LODR (https://arxiv.org/abs/2203.16776) with @@ -2113,13 +1875,11 @@ def modified_beam_search_rnnlm_LODR( sp: Sentence piece generator. LODR_lm: - A low order n-gram LM + A low order n-gram LM, whose score will be subtracted during shallow fusion LODR_lm_scale: The scale of the LODR_lm - rnnlm (RnnLmModel): - RNNLM, the external language model - rnnlm_scale (float): - scale of RNNLM in shallow fusion + LM: + A neural net LM, e.g an RNNLM or transformer LM beam (int, optional): Beam size. Defaults to 4. @@ -2130,9 +1890,8 @@ def modified_beam_search_rnnlm_LODR( """ assert encoder_out.ndim == 3, encoder_out.shape assert encoder_out.size(0) >= 1, encoder_out.size(0) - assert rnnlm is not None - lm_scale = rnnlm_scale - vocab_size = rnnlm.vocab_size + assert LM is not None + lm_scale = LM.lm_scale packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( input=encoder_out, @@ -2154,7 +1913,8 @@ def modified_beam_search_rnnlm_LODR( # get initial lm score and lm state by scoring the "sos" token sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) - init_score, init_states = rnnlm.score_token(sos_token) + lens = torch.tensor([1]).to(device) + init_score, init_states = LM.score_token(sos_token, lens) B = [HypothesisList() for _ in range(N)] for i in range(N): @@ -2162,7 +1922,7 @@ def modified_beam_search_rnnlm_LODR( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), - state=init_states, # state of the RNNLM + state=init_states, # state of the NN LM lm_score=init_score.reshape(-1), state_cost=NgramLmStateCost( LODR_lm @@ -2170,7 +1930,6 @@ def modified_beam_search_rnnlm_LODR( ) ) - rnnlm.clean_cache() encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) offset = 0 @@ -2236,7 +1995,7 @@ def modified_beam_search_rnnlm_LODR( It is a little confusing here because this for-loop looks very similar to the one below. Here, we go through all top-k tokens and only add the non-blanks ones to the token_list. - The RNNLM will score those tokens given the LM states. Note that + LM will score those tokens given the LM states. Note that the variable `scores` is the LM score after seeing the new non-blank token. """ @@ -2256,21 +2015,41 @@ def modified_beam_search_rnnlm_LODR( new_token = topk_token_indexes[k] if new_token not in (blank_id, unk_id): - assert new_token != 0, new_token - token_list.append([new_token]) - # store the LSTM states - hs.append(hyp.state[0]) - cs.append(hyp.state[1]) + if LM.lm_type == "rnn": + token_list.append([new_token]) + # store the LSTM states + hs.append(hyp.state[0]) + cs.append(hyp.state[1]) + else: + # for transformer LM + token_list.append( + [sos_id] + hyp.ys[context_size:] + [new_token] + ) - # forward RNNLM to get new states and scores + # forward NN LM to get new states and scores if len(token_list) != 0: - tokens_to_score = ( - torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) - ) + x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device) + if LM.lm_type == "rnn": + tokens_to_score = ( + torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) + ) + hs = torch.cat(hs, dim=1).to(device) + cs = torch.cat(cs, dim=1).to(device) + state = (hs, cs) + else: + # for transformer LM + tokens_list = [torch.tensor(tokens) for tokens in token_list] + tokens_to_score = ( + torch.nn.utils.rnn.pad_sequence( + tokens_list, batch_first=True, padding_value=0.0 + ) + .to(device) + .to(torch.int64) + ) - hs = torch.cat(hs, dim=1).to(device) - cs = torch.cat(cs, dim=1).to(device) - scores, lm_states = rnnlm.score_token(tokens_to_score, (hs, cs)) + state = None + + scores, lm_states = LM.score_token(tokens_to_score, x_lens, state) count = 0 # index, used to locate score and lm states for i in range(batch_size): @@ -2305,18 +2084,19 @@ def modified_beam_search_rnnlm_LODR( state_cost.lm_score, hyp.state_cost.lm_score, ) - # score = score + RNNLM_score - LODR_score - # LODR_LM_scale is a negative number here + # score = score + TDLM_score - LODR_score + # LODR_LM_scale should be a negative number here hyp_log_prob += ( lm_score[new_token] * lm_scale + LODR_lm_scale * current_ngram_score ) # add the lm score lm_score = scores[count] - state = ( - lm_states[0][:, count, :].unsqueeze(1), - lm_states[1][:, count, :].unsqueeze(1), - ) + if LM.lm_type == "rnn": + state = ( + lm_states[0][:, count, :].unsqueeze(1), + lm_states[1][:, count, :].unsqueeze(1), + ) count += 1 else: state_cost = hyp.state_cost @@ -2340,3 +2120,263 @@ def modified_beam_search_rnnlm_LODR( ans.append(sorted_ans[unsorted_indices[i]]) return ans + + +def modified_beam_search_lm_shallow_fusion( + model: Transducer, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + sp: spm.SentencePieceProcessor, + LM: LmScorer, + beam: int = 4, + return_timestamps: bool = False, +) -> List[List[int]]: + """Modified_beam_search + NN LM shallow fusion + + Args: + model (Transducer): + The transducer model + encoder_out (torch.Tensor): + Encoder output in (N,T,C) + encoder_out_lens (torch.Tensor): + A 1-D tensor of shape (N,), containing the number of + valid frames in encoder_out before padding. + sp: + Sentence piece generator. + LM (LmScorer): + A neural net LM, e.g RNN or Transformer + beam (int, optional): + Beam size. Defaults to 4. + + Returns: + Return a list-of-list of token IDs. ans[i] is the decoding results + for the i-th utterance. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + assert LM is not None + lm_scale = LM.lm_scale + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + sos_id = sp.piece_to_id("") + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + # get initial lm score and lm state by scoring the "sos" token + sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) + lens = torch.tensor([1]).to(device) + init_score, init_states = LM.score_token(sos_token, lens) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + state=init_states, + lm_score=init_score.reshape(-1), + timestamp=[], + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for (t, batch_size) in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] # get batch + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) + + lm_scores = torch.cat( + [hyp.lm_score.reshape(1, -1) for hyps in A for hyp in hyps] + ) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + """ + for all hyps with a non-blank new token, score this token. + It is a little confusing here because this for-loop + looks very similar to the one below. Here, we go through all + top-k tokens and only add the non-blanks ones to the token_list. + `LM` will score those tokens given the LM states. Note that + the variable `scores` is the LM score after seeing the new + non-blank token. + """ + token_list = [] # a list of list + hs = [] + cs = [] + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_token = topk_token_indexes[k] + if new_token not in (blank_id, unk_id): + if LM.lm_type == "rnn": + token_list.append([new_token]) + # store the LSTM states + hs.append(hyp.state[0]) + cs.append(hyp.state[1]) + else: + # for transformer LM + token_list.append( + [sos_id] + hyp.ys[context_size:] + [new_token] + ) + + if len(token_list) != 0: + x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device) + if LM.lm_type == "rnn": + tokens_to_score = ( + torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) + ) + hs = torch.cat(hs, dim=1).to(device) + cs = torch.cat(cs, dim=1).to(device) + state = (hs, cs) + else: + # for transformer LM + tokens_list = [torch.tensor(tokens) for tokens in token_list] + tokens_to_score = ( + torch.nn.utils.rnn.pad_sequence( + tokens_list, batch_first=True, padding_value=0.0 + ) + .to(device) + .to(torch.int64) + ) + + state = None + + scores, lm_states = LM.score_token(tokens_to_score, x_lens, state) + + count = 0 # index, used to locate score and lm states + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + ys = hyp.ys[:] + + lm_score = hyp.lm_score + state = hyp.state + + hyp_log_prob = topk_log_probs[k] # get score of current hyp + new_token = topk_token_indexes[k] + new_timestamp = hyp.timestamp[:] + if new_token not in (blank_id, unk_id): + + ys.append(new_token) + new_timestamp.append(t) + + hyp_log_prob += lm_score[new_token] * lm_scale # add the lm score + + lm_score = scores[count] + if LM.lm_type == "rnn": + state = ( + lm_states[0][:, count, :].unsqueeze(1), + lm_states[1][:, count, :].unsqueeze(1), + ) + count += 1 + + new_hyp = Hypothesis( + ys=ys, + log_prob=hyp_log_prob, + state=state, + lm_score=lm_score, + timestamp=new_timestamp, + ) + B[i].add(new_hyp) + + B = B + finalized_B + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + + sorted_ans = [h.ys[context_size:] for h in best_hyps] + sorted_timestamps = [h.timestamp for h in best_hyps] + ans = [] + ans_timestamps = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) + + if not return_timestamps: + return ans + else: + return DecodingResults( + tokens=ans, + timestamps=ans_timestamps, + ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index e00aab34a..109a94a69 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -92,36 +92,37 @@ Usage: --max-contexts 8 \ --max-states 64 -(8) modified beam search (with RNNLM shallow fusion) +(8) modified beam search (with LM shallow fusion) ./pruned_transducer_stateless3/decode.py \ --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless3/exp \ --max-duration 600 \ - --decoding-method modified_beam_search_rnnlm_shallow_fusion \ - --beam 4 \ - --rnn-lm-scale 0.3 \ - --rnn-lm-exp-dir /path/to/RNNLM \ + --decoding-method modified_beam_search_lm_shallow_fusion \ + --beam-size 4 \ + --lm-type rnn \ + --lm-scale 0.3 \ + --lm-exp-dir /path/to/LM \ --rnn-lm-epoch 99 \ --rnn-lm-avg 1 \ --rnn-lm-num-layers 3 \ --rnn-lm-tie-weights 1 -(9) modified beam search with RNNLM shallow fusion + LODR +(9) modified beam search with LM shallow fusion + LODR ./pruned_transducer_stateless3/decode.py \ --epoch 28 \ --avg 15 \ --max-duration 600 \ --exp-dir ./pruned_transducer_stateless3/exp \ - --decoding-method modified_beam_search_rnnlm_LODR \ - --beam 4 \ - --max-contexts 4 \ - --rnn-lm-scale 0.4 \ - --rnn-lm-exp-dir /path/to/RNNLM/exp \ + --decoding-method modified_beam_search_LODR \ + --beam-size 4 \ + --lm-type rnn \ + --lm-scale 0.4 \ + --lm-exp-dir /path/to/LM \ --rnn-lm-epoch 99 \ --rnn-lm-avg 1 \ --rnn-lm-num-layers 3 \ - --rnn-lm-tie-weights 1 \ + --rnn-lm-tie-weights 1 --tokens-ngram 2 \ --ngram-lm-scale -0.16 \ """ @@ -149,14 +150,14 @@ from beam_search import ( greedy_search, greedy_search_batch, modified_beam_search, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, modified_beam_search_ngram_rescoring, - modified_beam_search_rnnlm_LODR, - modified_beam_search_rnnlm_shallow_fusion, ) from librispeech import LibriSpeech from train import add_model_arguments, get_params, get_transducer_model -from icefall import NgramLm +from icefall import LmScorer, NgramLm from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.rnn_lm.model import RnnLmModel @@ -240,8 +241,8 @@ def get_parser(): - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG - modified_beam_search_ngram_rescoring - - modified_beam_search_rnnlm_shallow_fusion - - modified_beam_search_rnnlm_LODR + - modified_beam_search_lm_shallow_fusion + - modified_beam_search_LODR If you use fast_beam_search_nbest_LG, you have to specify `--lang-dir`, which should contain `LG.pt`. """, @@ -392,58 +393,28 @@ def get_parser(): ) parser.add_argument( - "--rnn-lm-exp-dir", - type=str, - default="rnn_lm/exp", - help="""Used only when --method is rnn-lm. - It specifies the path to RNN LM exp dir. - """, - ) - - parser.add_argument( - "--rnn-lm-epoch", - type=int, - default=7, - help="""Used only when --method is rnn-lm. - It specifies the checkpoint to use. - """, - ) - - parser.add_argument( - "--rnn-lm-avg", - type=int, - default=2, - help="""Used only when --method is rnn-lm. - It specifies the number of checkpoints to average. - """, - ) - - parser.add_argument( - "--rnn-lm-embedding-dim", - type=int, - default=2048, - help="Embedding dim of the model", - ) - - parser.add_argument( - "--rnn-lm-hidden-dim", - type=int, - default=2048, - help="Hidden dim of the model", - ) - - parser.add_argument( - "--rnn-lm-num-layers", - type=int, - default=4, - help="Number of RNN layers the model", - ) - parser.add_argument( - "--rnn-lm-tie-weights", + "--use-shallow-fusion", type=str2bool, - default=True, - help="""True to share the weights between the input embedding layer and the - last output linear layer + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. """, ) @@ -481,7 +452,7 @@ def decode_one_batch( ngram_lm: Optional[NgramLm] = None, ngram_lm_scale: float = 1.0, rnn_lm_model: Optional[RnnLmModel] = None, - rnnlm_scale: float = 1.0, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -515,10 +486,9 @@ def decode_one_batch( fast_beam_search_nbest, fast_beam_search_nbest_oracle, or fast_beam_search_with_nbest_rescoring. It an FsaVec containing an acceptor. - rnn_lm_model: - A rnnlm which can be used for rescoring or shallow fusion - rnnlm_scale: - The scale of the rnnlm. + LM: + A neural net LM for shallow fusion. Only used when `--use-shallow-fusion` + set to true. ngram_lm: A ngram lm. Used in LODR decoding. ngram_lm_scale: @@ -697,20 +667,19 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion": - hyp_tokens = modified_beam_search_rnnlm_shallow_fusion( + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, sp=sp, - rnnlm=rnn_lm_model, - rnnlm_scale=rnnlm_scale, + LM=LM, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_rnnlm_LODR": - hyp_tokens = modified_beam_search_rnnlm_LODR( + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, @@ -718,8 +687,7 @@ def decode_one_batch( sp=sp, LODR_lm=ngram_lm, LODR_lm_scale=ngram_lm_scale, - rnnlm=rnn_lm_model, - rnnlm_scale=rnnlm_scale, + LM=LM, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -812,7 +780,7 @@ def decode_dataset( ngram_lm: Optional[NgramLm] = None, ngram_lm_scale: float = 1.0, rnn_lm_model: Optional[RnnLmModel] = None, - rnnlm_scale: float = 1.0, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -836,6 +804,8 @@ def decode_dataset( fast_beam_search_nbest, fast_beam_search_nbest_oracle, or fast_beam_search_with_nbest_rescoring. It's an FsaVec containing an acceptor. + LM: + A neural network LM, used during shallow fusion Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -871,7 +841,7 @@ def decode_dataset( ngram_lm=ngram_lm, ngram_lm_scale=ngram_lm_scale, rnn_lm_model=rnn_lm_model, - rnnlm_scale=rnnlm_scale, + LM=LM, ) for name, hyps in hyps_dict.items(): @@ -1005,6 +975,7 @@ def load_ngram_LM( def main(): parser = get_parser() AsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -1022,9 +993,9 @@ def main(): "modified_beam_search", "fast_beam_search_with_nbest_rescoring", "fast_beam_search_with_nbest_rnn_rescoring", - "modified_beam_search_rnnlm_LODR", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", "modified_beam_search_ngram_rescoring", - "modified_beam_search_rnnlm_shallow_fusion", ) params.res_dir = params.exp_dir / params.decoding_method @@ -1055,12 +1026,18 @@ def main(): params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-temperature-{params.temperature}" - if "rnnlm" in params.decoding_method: - params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}" - if "LODR" in params.decoding_method: - params.suffix += "-LODR" if "ngram" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + if params.use_shallow_fusion: + if params.lm_type == "rnn": + params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}" + elif params.lm_type == "transformer": + params.suffix += f"-transformer-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -1195,28 +1172,19 @@ def main(): ngram_lm = None ngram_lm_scale = None - # only load rnnlm if used - if "rnnlm" in params.decoding_method: - rnn_lm_scale = params.rnn_lm_scale - - rnn_lm_model = RnnLmModel( - vocab_size=params.vocab_size, - embedding_dim=params.rnn_lm_embedding_dim, - hidden_dim=params.rnn_lm_hidden_dim, - num_layers=params.rnn_lm_num_layers, - tie_weights=params.rnn_lm_tie_weights, + # only load the neural network LM if doing shallow fusion + if params.use_shallow_fusion: + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, ) - assert params.rnn_lm_avg == 1 + LM.to(device) + LM.eval() - load_checkpoint( - f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt", - rnn_lm_model, - ) - rnn_lm_model.to(device) - rnn_lm_model.eval() else: - rnn_lm_model = None - rnn_lm_scale = 0.0 + LM = None num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -1247,7 +1215,7 @@ def main(): ngram_lm=ngram_lm, ngram_lm_scale=ngram_lm_scale, rnn_lm_model=rnn_lm_model, - rnnlm_scale=rnn_lm_scale, + LM=LM, ) save_results( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 8b993f638..90b0fcf4b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -87,22 +87,39 @@ Usage: --max-contexts 8 \ --max-states 64 -(8) modified beam search with RNNLM shallow fusion (with LG) +(8) modified beam search with RNNLM shallow fusion ./pruned_transducer_stateless5/decode.py \ --epoch 35 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless5/exp \ --max-duration 600 \ - --decoding-method fast_beam_search_nbest_LG \ - --beam 4 \ - --max-contexts 4 \ - --rnn-lm-scale 0.4 \ - --rnn-lm-exp-dir /path/to/RNNLM/exp \ + --decoding-method modified_beam_search_lm_shallow_fusion \ + --beam-size 4 \ + --lm-type rnn \ + --lm-scale 0.3 \ + --lm-exp-dir /path/to/LM \ --rnn-lm-epoch 99 \ --rnn-lm-avg 1 \ --rnn-lm-num-layers 3 \ --rnn-lm-tie-weights 1 +(9) modified beam search with LM shallow fusion + LODR +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ + --avg 15 \ + --max-duration 600 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --decoding-method modified_beam_search_LODR \ + --beam-size 4 \ + --lm-type rnn \ + --lm-scale 0.4 \ + --lm-exp-dir /path/to/LM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 + --tokens-ngram 2 \ + --ngram-lm-scale -0.16 \ """ @@ -128,10 +145,13 @@ from beam_search import ( greedy_search, greedy_search_batch, modified_beam_search, - modified_beam_search_rnnlm_shallow_fusion, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, + modified_beam_search_ngram_rescoring, ) from train import add_model_arguments, get_params, get_transducer_model +from icefall import LmScorer, NgramLm from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -139,7 +159,6 @@ from icefall.checkpoint import ( load_checkpoint, ) from icefall.lexicon import Lexicon -from icefall.rnn_lm.model import RnnLmModel from icefall.utils import ( AttributeDict, setup_logger, @@ -229,7 +248,8 @@ def get_parser(): - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG - - modified_beam_search_rnnlm_shallow_fusion # for rnn lm shallow fusion + - modified_beam_search_lm_shallow_fusion # for rnn lm shallow fusion + - modified_beam_search_LODR If you use fast_beam_search_nbest_LG, you have to specify `--lang-dir`, which should contain `LG.pt`. """, @@ -342,69 +362,49 @@ def get_parser(): ) parser.add_argument( - "--rnn-lm-scale", - type=float, - default=0.0, - help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. - It specifies the path to RNN LM exp dir. - """, - ) - - parser.add_argument( - "--rnn-lm-exp-dir", - type=str, - default="rnn_lm/exp", - help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. - It specifies the path to RNN LM exp dir. - """, - ) - - parser.add_argument( - "--rnn-lm-epoch", - type=int, - default=7, - help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. - It specifies the checkpoint to use. - """, - ) - - parser.add_argument( - "--rnn-lm-avg", - type=int, - default=2, - help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. - It specifies the number of checkpoints to average. - """, - ) - - parser.add_argument( - "--rnn-lm-embedding-dim", - type=int, - default=2048, - help="Embedding dim of the model", - ) - - parser.add_argument( - "--rnn-lm-hidden-dim", - type=int, - default=2048, - help="Hidden dim of the model", - ) - - parser.add_argument( - "--rnn-lm-num-layers", - type=int, - default=4, - help="Number of RNN layers the model", - ) - parser.add_argument( - "--rnn-lm-tie-weights", + "--use-shallow-fusion", type=str2bool, default=False, - help="""True to share the weights between the input embedding layer and the - last output linear layer + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true """, ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=3, + help="""Token Ngram used for rescoring. + Used only when the decoding method is + modified_beam_search_ngram_rescoring, or LODR + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="""ID of the backoff symbol. + Used only when the decoding method is + modified_beam_search_ngram_rescoring""", + ) add_model_arguments(parser) return parser @@ -417,8 +417,9 @@ def decode_one_batch( batch: dict, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, - rnnlm: Optional[RnnLmModel] = None, - rnnlm_scale: float = 1.0, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -447,6 +448,13 @@ def decode_one_batch( The decoding graph. Can be either a `k2.trivial_graph` or LG, Used only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural net LM for shallow fusion. Only used when `--use-shallow-fusion` + set to true. + ngram_lm: + A ngram lm. Used in LODR decoding. + ngram_lm_scale: + The scale of the ngram language model. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -559,15 +567,38 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion": - hyp_tokens = modified_beam_search_rnnlm_shallow_fusion( + elif params.decoding_method == "modified_beam_search_ngram_rescoring": + hyp_tokens = modified_beam_search_ngram_rescoring( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, sp=sp, - rnnlm=rnnlm, - rnnlm_scale=rnnlm_scale, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + sp=sp, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -620,8 +651,9 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, - rnnlm: Optional[RnnLmModel] = None, - rnnlm_scale: float = 1.0, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -640,6 +672,8 @@ def decode_dataset( The decoding graph. Can be either a `k2.trivial_graph` or LG, Used only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network LM, used during shallow fusion Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -663,7 +697,6 @@ def decode_dataset( for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - logging.info(f"Decoding {batch_idx}-th batch") hyps_dict = decode_one_batch( params=params, @@ -672,8 +705,9 @@ def decode_dataset( decoding_graph=decoding_graph, word_table=word_table, batch=batch, - rnnlm=rnnlm, - rnnlm_scale=rnnlm_scale, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + LM=LM, ) for name, hyps in hyps_dict.items(): @@ -742,6 +776,7 @@ def save_results( def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -757,7 +792,8 @@ def main(): "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", - "modified_beam_search_rnnlm_shallow_fusion", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", ) params.res_dir = params.exp_dir / params.decoding_method @@ -783,7 +819,18 @@ def main(): params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}" + if "ngram" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + if params.use_shallow_fusion: + if params.lm_type == "rnn": + params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}" + elif params.lm_type == "transformer": + params.suffix += f"-transformer-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) if params.use_averaged_model: params.suffix += "-use-averaged-model" @@ -895,24 +942,34 @@ def main(): model.to(device) model.eval() - rnn_lm_model = None - rnn_lm_scale = params.rnn_lm_scale - if params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion": - rnn_lm_model = RnnLmModel( - vocab_size=params.vocab_size, - embedding_dim=params.rnn_lm_embedding_dim, - hidden_dim=params.rnn_lm_hidden_dim, - num_layers=params.rnn_lm_num_layers, - tie_weights=params.rnn_lm_tie_weights, + # only load N-gram LM when needed + if "ngram" in params.decoding_method or "LODR" in params.decoding_method: + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"lm filename: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, ) - assert params.rnn_lm_avg == 1 + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None - load_checkpoint( - f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt", - rnn_lm_model, + # only load the neural network LM if doing shallow fusion + if params.use_shallow_fusion: + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, ) - rnn_lm_model.to(device) - rnn_lm_model.eval() + LM.to(device) + LM.eval() + + else: + LM = None if "fast_beam_search" in params.decoding_method: if "LG" in params.decoding_method: @@ -955,8 +1012,9 @@ def main(): sp=sp, word_table=word_table, decoding_graph=decoding_graph, - rnnlm=rnn_lm_model, - rnnlm_scale=rnn_lm_scale, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + LM=LM, ) save_results( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py index bc15948fc..b9bce465f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 # # Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) +# Zengwei Yao, +# Xiaoyu Yang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -91,6 +92,41 @@ Usage: --beam 20.0 \ --max-contexts 8 \ --max-states 64 + +(8) modified beam search with RNNLM shallow fusion +./pruned_transducer_stateless5/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search_lm_shallow_fusion \ + --beam-size 4 \ + --lm-type rnn \ + --lm-scale 0.3 \ + --lm-exp-dir /path/to/LM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 + +(9) modified beam search with LM shallow fusion + LODR +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ + --avg 15 \ + --max-duration 600 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --decoding-method modified_beam_search_LODR \ + --beam-size 4 \ + --lm-type rnn \ + --lm-scale 0.4 \ + --lm-exp-dir /path/to/LM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 + --tokens-ngram 2 \ + --ngram-lm-scale -0.16 \ + """ @@ -115,9 +151,13 @@ from beam_search import ( greedy_search, greedy_search_batch, modified_beam_search, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, + modified_beam_search_ngram_rescoring, ) from train import add_model_arguments, get_params, get_transducer_model +from icefall import LmScorer, NgramLm from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -213,6 +253,8 @@ def get_parser(): - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG + - modified_beam_search_lm_shallow_fusion # for rnn lm shallow fusion + - modified_beam_search_LODR If you use fast_beam_search_nbest_LG, you have to specify `--lang-dir`, which should contain `LG.pt`. """, @@ -274,6 +316,7 @@ def get_parser(): default=2, help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) + parser.add_argument( "--max-sym-per-frame", type=int, @@ -323,6 +366,50 @@ def get_parser(): help="left context can be seen during decoding (in frames after subsampling)", ) + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=3, + help="""Token Ngram used for rescoring. + Used only when the decoding method is + modified_beam_search_ngram_rescoring, or LODR + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="""ID of the backoff symbol. + Used only when the decoding method is + modified_beam_search_ngram_rescoring""", + ) add_model_arguments(parser) return parser @@ -335,6 +422,9 @@ def decode_one_batch( batch: dict, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -363,6 +453,13 @@ def decode_one_batch( The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural net LM for shallow fusion. Only used when `--use-shallow-fusion` + set to true. + ngram_lm: + A ngram lm. Used in LODR decoding. + ngram_lm_scale: + The scale of the ngram language model. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -468,6 +565,30 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + sp=sp, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + sp=sp, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) else: batch_size = encoder_out.size(0) @@ -517,6 +638,9 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -535,6 +659,8 @@ def decode_dataset( The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network LM, used during shallow fusion Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -566,6 +692,9 @@ def decode_dataset( decoding_graph=decoding_graph, word_table=word_table, batch=batch, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + LM=LM, ) for name, hyps in hyps_dict.items(): @@ -634,6 +763,7 @@ def save_results( def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -648,6 +778,8 @@ def main(): "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", ) params.res_dir = params.exp_dir / params.decoding_method @@ -675,6 +807,19 @@ def main(): params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + if "ngram" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + if params.use_shallow_fusion: + if params.lm_type == "rnn": + params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}" + elif params.lm_type == "transformer": + params.suffix += f"-transformer-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + if params.use_averaged_model: params.suffix += "-use-averaged-model" @@ -785,6 +930,34 @@ def main(): model.to(device) model.eval() + # only load N-gram LM when needed + if "ngram" in params.decoding_method or "LODR" in params.decoding_method: + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"lm filename: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + # only load the neural network LM if doing shallow fusion + if params.use_shallow_fusion: + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + + else: + LM = None if "fast_beam_search" in params.decoding_method: if params.decoding_method == "fast_beam_search_nbest_LG": lexicon = Lexicon(params.lang_dir) @@ -826,6 +999,9 @@ def main(): sp=sp, word_table=word_table, decoding_graph=decoding_graph, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + LM=LM, ) save_results( diff --git a/icefall/__init__.py b/icefall/__init__.py index 27ad74213..82d21706c 100644 --- a/icefall/__init__.py +++ b/icefall/__init__.py @@ -68,3 +68,5 @@ from .utils import ( ) from .ngram_lm import NgramLm, NgramLmStateCost + +from .lm_wrapper import LmScorer diff --git a/icefall/lm_wrapper.py b/icefall/lm_wrapper.py new file mode 100644 index 000000000..0468befd0 --- /dev/null +++ b/icefall/lm_wrapper.py @@ -0,0 +1,254 @@ +# Copyright (c) 2022 Xiaomi Corporation (authors: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging + +import torch + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.rnn_lm.model import RnnLmModel +from icefall.transformer_lm.model import TransformerLM +from icefall.utils import AttributeDict, str2bool + + +class LmScorer(torch.nn.Module): + """This is a wrapper for NN LMs + The language models supported include: + RNN, + Transformer + """ + + def __init__( + self, + lm_type: str, + params: AttributeDict, + device, + lm_scale: float = 0.3, + ): + super(LmScorer, self).__init__() + assert lm_type in ["rnn", "transformer"], f"{lm_type} is not supported" + self.lm_type = lm_type + self.lm = self.get_lm(lm_type, device, params) + self.lm_scale = lm_scale + self.params = params + + @classmethod + def add_arguments(cls, parser): + # LM general arguments + parser.add_argument( + "--vocab-size", + type=int, + default=500, + ) + + parser.add_argument( + "--lm-epoch", + type=int, + default=7, + help="""Which epoch to be used + """, + ) + + parser.add_argument( + "--lm-avg", + type=int, + default=1, + help="""Number of checkpoints to be averaged + """, + ) + + parser.add_argument("--lm-exp-dir", type=str, help="Path to LM experiments") + + # Now RNNLM related arguments + parser.add_argument( + "--rnn-lm-embedding-dim", + type=int, + default=2048, + help="Embedding dim of the model", + ) + + parser.add_argument( + "--rnn-lm-hidden-dim", + type=int, + default=2048, + help="Hidden dim of the model", + ) + + parser.add_argument( + "--rnn-lm-num-layers", + type=int, + default=3, + help="Number of RNN layers the model", + ) + + parser.add_argument( + "--rnn-lm-tie-weights", + type=str2bool, + default=True, + help="""True to share the weights between the input embedding layer and the + last output linear layer + """, + ) + + # Now transformers + parser.add_argument( + "--transformer-lm-exp-dir", type=str, help="Directory of transformer LM exp" + ) + + parser.add_argument( + "--transformer-lm-dim-feedforward", + type=int, + default=2048, + help="Dimension of FFW module in transformer", + ) + + parser.add_argument( + "--transformer-lm-encoder-dim", + type=int, + default=768, + help="Encoder dimension of transformer", + ) + + parser.add_argument( + "--transformer-lm-embedding-dim", + type=int, + default=768, + help="Input embedding dimension of transformer", + ) + + parser.add_argument( + "--transformer-lm-nhead", + type=int, + default=8, + help="Number of attention heads in transformer", + ) + + parser.add_argument( + "--transformer-lm-num-layers", + type=int, + default=16, + help="Number of encoder layers in transformer", + ) + + parser.add_argument( + "--transformer-lm-tie-weights", + type=str2bool, + default=True, + help="If tie weights in transformer LM", + ) + + def get_lm(self, lm_type: str, device, params: AttributeDict) -> torch.nn.Module: + """Return the neural network LM + + Args: + lm_type (str): Type name of NN LM + """ + if lm_type == "rnn": + model = RnnLmModel( + vocab_size=params.vocab_size, + embedding_dim=params.rnn_lm_embedding_dim, + hidden_dim=params.rnn_lm_hidden_dim, + num_layers=params.rnn_lm_num_layers, + tie_weights=params.rnn_lm_tie_weights, + ) + + if params.lm_avg == 1: + load_checkpoint( + f"{params.lm_exp_dir}/epoch-{params.lm_epoch}.pt", model + ) + model.to(device) + else: + start = params.lm_epoch - params.lm_avg + 1 + filenames = [] + for i in range(start, params.lm_epoch + 1): + if start >= 0: + filenames.append(f"{params.lm_exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + elif lm_type == "transformer": + model = TransformerLM( + vocab_size=params.vocab_size, + d_model=params.transformer_lm_encoder_dim, + embedding_dim=params.transformer_lm_embedding_dim, + dim_feedforward=params.transformer_lm_dim_feedforward, + nhead=params.transformer_lm_nhead, + num_layers=params.transformer_lm_num_layers, + tie_weights=params.transformer_lm_tie_weights, + params=params, + ) + + if params.lm_avg == 1: + load_checkpoint( + f"{params.lm_exp_dir}/epoch-{params.lm_epoch}.pt", model + ) + model.to(device) + else: + start = params.lm_epoch - params.lm_avg + 1 + filenames = [] + for i in range(start, params.lm_epoch + 1): + if start >= 0: + filenames.append(f"{params.lm_exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + raise NotImplementedError() + + return model + + def score_token(self, x: torch.Tensor, x_lens: torch.Tensor, state=None): + """Score the input and return the prediction + This requires the lm to have the method `score_token` + Args: + x (torch.Tensor): Input tokens + x_lens (torch.Tensor): Length of the input tokens + state (optional): LM states + + """ + return self.lm.score_token(x, x_lens, state) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + LmScorer.add_arguments(parser) + args = parser.parse_args() + + params = AttributeDict() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + Scorer = LmScorer(params=params, device=device) + Scorer.eval() + + x = ( + torch.tensor([[1, 4, 19, 256, 77], [1, 4, 19, 256, 77]]) + .to(device) + .to(torch.int64) + ) + x_lens = torch.tensor([5, 5]).to(device) + + state = None + + score, state = Scorer.score(x, x_lens) + print(score.shape) + print(score[0]) + print(score[1]) diff --git a/icefall/rnn_lm/model.py b/icefall/rnn_lm/model.py index 3598a4857..08eb753b5 100644 --- a/icefall/rnn_lm/model.py +++ b/icefall/rnn_lm/model.py @@ -153,9 +153,24 @@ class RnnLmModel(torch.nn.Module): def clean_cache(self): self.cache = {} - def score_token(self, tokens: torch.Tensor, state=None): + def score_token(self, x: torch.Tensor, x_lens: torch.Tensor, state=None): + """Score a batch of tokens + + Args: + x (torch.Tensor): + A batch of tokens + x_lens (torch.Tensor): + The length of tokens in the batch before padding + state (_type_, optional): + Either None or a tuple of two torch.Tensor. Each tensor has + the shape of (hidden_dim) + + + Returns: + _type_: _description_ + """ device = next(self.parameters()).device - batch_size = tokens.size(0) + batch_size = x.size(0) if state: h, c = state else: @@ -166,7 +181,7 @@ class RnnLmModel(torch.nn.Module): device ) - embedding = self.input_embedding(tokens) + embedding = self.input_embedding(x) rnn_out, states = self.rnn(embedding, (h, c)) logits = self.output_linear(rnn_out) diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py index 803da99d6..f43e66cd2 100755 --- a/icefall/rnn_lm/train.py +++ b/icefall/rnn_lm/train.py @@ -531,6 +531,9 @@ def run(rank, world_size, args): tie_weights=params.tie_weights, ) + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + checkpoints = load_checkpoint_if_available(params=params, model=model) model.to(device) diff --git a/icefall/transformer_lm/attention.py b/icefall/transformer_lm/attention.py new file mode 100644 index 000000000..5ce83b15e --- /dev/null +++ b/icefall/transformer_lm/attention.py @@ -0,0 +1,510 @@ +# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import List, Optional, Tuple + +import torch +from torch import Tensor, nn + +from icefall.transformer_lm.scaling import ( + ActivationBalancer, + BasicNorm, + DoubleSwish, + ScaledConv1d, + ScaledConv2d, + ScaledLinear, +) +from icefall.utils import is_jit_tracing + + +class RelPositionMultiheadAttention(nn.Module): + r"""Multi-Head Attention layer with relative position encoding + + See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + + Examples:: + + >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + ) -> None: + super(RelPositionMultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True) + self.out_proj = ScaledLinear( + embed_dim, embed_dim, bias=True, initial_scale=0.25 + ) + + # linear transformation for positional encoding. + self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach()) + self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach()) + self._reset_parameters() + + def _pos_bias_u(self): + return self.pos_bias_u * self.pos_bias_u_scale.exp() + + def _pos_bias_v(self): + return self.pos_bias_v * self.pos_bias_v_scale.exp() + + def _reset_parameters(self) -> None: + nn.init.normal_(self.pos_bias_u, std=0.01) + nn.init.normal_(self.pos_bias_v, std=0.01) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = False, + attn_mask: Optional[Tensor] = None, + left_context: int = 0, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask and a value is True, + the corresponding value on the attention layer will be ignored. When given + a byte mask and a value is non-zero, the corresponding value on the attention + layer will be ignored + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + left_context (int): left context (in frames) used during streaming decoding. + this is used only in real streaming decoding, in other circumstances, + it MUST be 0. + + Shape: + - Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the position + with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + - Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + return self.multi_head_attention_forward( + query, + key, + value, + pos_emb, + self.embed_dim, + self.num_heads, + self.in_proj.get_weight(), + self.in_proj.get_bias(), + self.dropout, + self.out_proj.get_weight(), + self.out_proj.get_bias(), + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + left_context=left_context, + ) + + def rel_shift(self, x: Tensor, left_context: int = 0) -> Tensor: + """Compute relative positional encoding. + + Args: + x: Input tensor (batch, head, time1, 2*time1-1+left_context). + time1 means the length of query vector. + left_context (int): left context (in frames) used during streaming decoding. + this is used only in real streaming decoding, in other circumstances, + it MUST be 0. + + Returns: + Tensor: tensor of shape (batch, head, time1, time2) + (note: time2 has the same value as time1, but it is for + the key, while time1 is for the query). + """ + (batch_size, num_heads, time1, n) = x.shape + + time2 = time1 + left_context + if not is_jit_tracing(): + assert ( + n == left_context + 2 * time1 - 1 + ), f"{n} == {left_context} + 2 * {time1} - 1" + + if is_jit_tracing(): + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(time2) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + + x = x.reshape(-1, n) + x = torch.gather(x, dim=1, index=indexes) + x = x.reshape(batch_size, num_heads, time1, time2) + return x + else: + # Note: TorchScript requires explicit arg for stride() + batch_stride = x.stride(0) + head_stride = x.stride(1) + time1_stride = x.stride(2) + n_stride = x.stride(3) + return x.as_strided( + (batch_size, num_heads, time1, time2), + (batch_stride, head_stride, time1_stride - n_stride, n_stride), + storage_offset=n_stride * (time1 - 1), + ) + + def multi_head_attention_forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_emb: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Tensor, + in_proj_bias: Tensor, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = False, + attn_mask: Optional[Tensor] = None, + left_context: int = 0, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + embed_dim_to_check: total dimension of the model. + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + left_context (int): left context (in frames) used during streaming decoding. + this is used only in real streaming decoding, in other circumstances, + it MUST be 0. + + Shape: + Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence + length, N is the batch size, E is the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + + tgt_len, bsz, embed_dim = query.size() + if not is_jit_tracing(): + assert embed_dim == embed_dim_to_check + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + + head_dim = embed_dim // num_heads + if not is_jit_tracing(): + assert ( + head_dim * num_heads == embed_dim + ), "embed_dim must be divisible by num_heads" + + scaling = float(head_dim) ** -0.5 + + if torch.equal(query, key) and torch.equal(key, value): + # self-attention + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) + + elif torch.equal(key, value): + # encoder-decoder attention + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) + + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = embed_dim * 2 + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + k = nn.functional.linear(key, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim * 2 + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + v = nn.functional.linear(value, _w, _b) + + if attn_mask is not None: + assert ( + attn_mask.dtype == torch.float32 + or attn_mask.dtype == torch.float64 + or attn_mask.dtype == torch.float16 + or attn_mask.dtype == torch.uint8 + or attn_mask.dtype == torch.bool + ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( + attn_mask.dtype + ) + if attn_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for attn_mask is deprecated. Use bool tensor instead." + ) + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + raise RuntimeError("The size of the 2D attn_mask is not correct.") + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [ + bsz * num_heads, + query.size(0), + key.size(0), + ]: + raise RuntimeError("The size of the 3D attn_mask is not correct.") + else: + raise RuntimeError( + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + ) + # attn_mask's dim is 3 now. + + # convert ByteTensor key_padding_mask to bool + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + ) + key_padding_mask = key_padding_mask.to(torch.bool) + + q = (q * scaling).contiguous().view(tgt_len, bsz, num_heads, head_dim) + k = k.contiguous().view(-1, bsz, num_heads, head_dim) + v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + + src_len = k.size(0) + + if key_padding_mask is not None and not is_jit_tracing(): + assert key_padding_mask.size(0) == bsz, "{} == {}".format( + key_padding_mask.size(0), bsz + ) + assert key_padding_mask.size(1) == src_len, "{} == {}".format( + key_padding_mask.size(1), src_len + ) + + q = q.transpose(0, 1) # (batch, time1, head, d_k) + + pos_emb_bsz = pos_emb.size(0) + if not is_jit_tracing(): + assert pos_emb_bsz in (1, bsz) # actually it is 1 + + p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) + # (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1) + p = p.permute(0, 2, 3, 1) + + q_with_bias_u = (q + self._pos_bias_u()).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + q_with_bias_v = (q + self._pos_bias_v()).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + # compute attention score + # first compute matrix a and matrix c + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + + # compute matrix b and matrix d + matrix_bd = torch.matmul(q_with_bias_v, p) # (batch, head, time1, 2*time1-1) + matrix_bd = self.rel_shift(matrix_bd, left_context) + + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) + + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + + if not is_jit_tracing(): + assert list(attn_output_weights.size()) == [ + bsz * num_heads, + tgt_len, + src_len, + ] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float("-inf")) + else: + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, src_len + ) + + attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) + + # If we are using dynamic_chunk_training and setting a limited + # num_left_chunks, the attention may only see the padding values which + # will also be masked out by `key_padding_mask`, at this circumstances, + # the whole column of `attn_output_weights` will be `-inf` + # (i.e. be `nan` after softmax), so, we fill `0.0` at the masking + # positions to avoid invalid loss value below. + if ( + attn_mask is not None + and attn_mask.dtype == torch.bool + and key_padding_mask is not None + ): + if attn_mask.size(0) != 1: + attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len) + combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) + else: + # attn_mask.shape == (1, tgt_len, src_len) + combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( + 1 + ).unsqueeze(2) + + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, src_len + ) + + attn_output_weights = nn.functional.dropout( + attn_output_weights, p=dropout_p, training=training + ) + + attn_output = torch.bmm(attn_output_weights, v) + + if not is_jit_tracing(): + assert list(attn_output.size()) == [ + bsz * num_heads, + tgt_len, + head_dim, + ] + + attn_output = ( + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + return attn_output, attn_output_weights.sum(dim=1) / num_heads + else: + return attn_output, None diff --git a/icefall/transformer_lm/compute_perplexity.py b/icefall/transformer_lm/compute_perplexity.py new file mode 100644 index 000000000..72d7c477b --- /dev/null +++ b/icefall/transformer_lm/compute_perplexity.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import math +from pathlib import Path + +import torch +from dataset import get_dataloader +from train import get_params + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.transformer_lm.model import TransformerLM +from icefall.utils import AttributeDict, setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=7, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=1, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="transformer_lm/exp_full_libri_16layer_maxlen200_8gpu", + ) + + parser.add_argument( + "--lm-data", + type=str, + help="Path to the LM test data for computing perplexity", + default="transformer_lm/libri_lm_training_bpe500/sorted_lm_data-test.pt", + ) + + parser.add_argument( + "--vocab-size", + type=int, + default=500, + help="Vocabulary size of the model", + ) + + parser.add_argument( + "--num-layers", + type=int, + default=16, + help="Number of RNN layers the model", + ) + + parser.add_argument( + "--tie-weights", + type=str2bool, + default=False, + help="""True to share the weights between the input embedding layer and the + last output linear layer + """, + ) + + parser.add_argument( + "--batch-size", + type=int, + default=50, + help="Number of RNN layers the model", + ) + + parser.add_argument( + "--max-sent-len", + type=int, + default=100, + help="Number of RNN layers the model", + ) + + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lm_data = Path(args.lm_data) + + params = get_params() + params.update(vars(args)) + + setup_logger(f"{params.exp_dir}/log-ppl/") + logging.info("Computing perplexity started") + logging.info(params) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + logging.info("About to create model") + model = TransformerLM( + vocab_size=params.vocab_size, + d_model=params.encoder_dim, + embedding_dim=params.embedding_dim, + dim_feedforward=params.dim_feedforward, + nhead=params.nhead, + num_layers=params.num_layers, + tie_weights=params.tie_weights, + params=params, + ) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + model.to(device) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + model.eval() + num_param = sum([p.numel() for p in model.parameters()]) + num_param_requires_grad = sum( + [p.numel() for p in model.parameters() if p.requires_grad] + ) + + logging.info(f"Number of model parameters: {num_param}") + logging.info( + f"Number of model parameters (requires_grad): " + f"{num_param_requires_grad} " + f"({num_param_requires_grad/num_param_requires_grad*100}%)" + ) + + logging.info(f"Loading LM test data from {params.lm_data}") + test_dl = get_dataloader( + filename=params.lm_data, + is_distributed=False, + params=params, + ) + + tot_loss = 0.0 + num_tokens = 0 + num_sentences = 0 + for batch_idx, batch in enumerate(test_dl): + x, y, sentence_lengths = batch + x = x.to(device) + y = y.to(device) + sentence_lengths = sentence_lengths.to(device) + + nll = model(x, y, sentence_lengths) + loss = nll.sum().cpu().item() + + tot_loss += loss + num_tokens += sentence_lengths.sum().cpu().item() + num_sentences += x.size(0) + + ppl = math.exp(tot_loss / num_tokens) + logging.info( + f"total nll: {tot_loss}, num tokens: {num_tokens}, " + f"num sentences: {num_sentences}, ppl: {ppl:.3f}" + ) + + +if __name__ == "__main__": + main() diff --git a/icefall/transformer_lm/dataset.py b/icefall/transformer_lm/dataset.py new file mode 120000 index 000000000..5792a6cf0 --- /dev/null +++ b/icefall/transformer_lm/dataset.py @@ -0,0 +1 @@ +../rnn_lm/dataset.py \ No newline at end of file diff --git a/icefall/transformer_lm/encoder.py b/icefall/transformer_lm/encoder.py new file mode 100644 index 000000000..4357b83d7 --- /dev/null +++ b/icefall/transformer_lm/encoder.py @@ -0,0 +1,329 @@ +# Copyright (c) 2021 Xiaomi Corporation (authors: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import math +from typing import List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from icefall.transformer_lm.attention import RelPositionMultiheadAttention +from icefall.transformer_lm.scaling import ( + ActivationBalancer, + BasicNorm, + DoubleSwish, + ScaledConv1d, + ScaledConv2d, + ScaledLinear, +) +from icefall.utils import is_jit_tracing, make_pad_mask + + +class Transformer(torch.nn.Module): + """_summary_ + + Args: + input_dim (int): Input feature dimension + d_mode (int): The dimension of the transformer + dim_feedforward (int ): The dimension of the ffw module + nhead (int): The number of attention heads + dropout_rate (float): dropout rate + att_dropout (float): dropout rate in attention module + """ + + def __init__( + self, + input_dim: int, + d_model: int, + dim_feedforward: int, + nhead: int = 4, + num_layers: int = 6, + dropout_rate: float = 0.1, + att_dropout: float = 0.0, + ): + super().__init__() + + self.encoder_layers = num_layers + self.d_model = d_model + + self.embed = ScaledLinear(input_dim, d_model) + self.norm_before = BasicNorm(d_model, learn_eps=False) + + self.encoder_pos = RelPositionalEncoding(d_model, dropout_rate) + + encoder_layer = TransformerEncoderLayer( + d_model=d_model, + dim_feedforward=dim_feedforward, + nhead=nhead, + dropout_rate=dropout_rate, + ) + + self.encoder = TransformerEncoder(encoder_layer, num_layers) + + def _create_attention_mask(self, x_lens: torch.Tensor): + # create a 2D attention mask to mask out + # the upper right half of the attention matrix + max_len = max(x_lens) + ones = torch.ones(max_len, max_len, device=x_lens.device, dtype=torch.bool) + return torch.triu(ones, diagonal=1) + + def forward( + self, x: torch.Tensor, x_lens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Transformer forward + + Args: + x (torch.Tensor): Input tensor (B,T,input_dim) + x_lens (torch.Tensor): The length of input tensors before padding (B,) + + Returns: + Return a tuple of 2 tensors: + - x: output feature of the transformer (B,T,d_model) + - x_lens: output feature lens of the transformer + """ + + attention_mask = self._create_attention_mask(x_lens) + src_key_padding_mask = make_pad_mask(x_lens) + + x = self.norm_before(self.embed(x)) + + x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) + + x = self.encoder( + x, + pos_emb, + mask=attention_mask, # pass the attention mast + src_key_padding_mask=src_key_padding_mask, + ) # (T, N, C) + + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + return x, x_lens + + +class TransformerEncoder(torch.nn.Module): + def __init__(self, encoder_layer: torch.nn.Module, num_layers: int) -> None: + """TransformerEncoder is a stack of N encoder layers + + Args: + encoder_layer (torch.nn.Module): an instance of the TransformerEncoderLayer() + num_layers (int): Number of layers to be stacked + """ + super().__init__() + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + def forward( + self, + src: torch.Tensor, + pos_emb: torch.Tensor, + src_key_padding_mask: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """_summary_ + + Args: + src: the sequence to the encoder (required). + pos_emb: Positional embedding tensor (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Returns: + output: transformer encoded features + """ + output = src + + for layer_index, mod in enumerate(self.layers): + output = mod( + output, + pos_emb, + src_key_padding_mask=src_key_padding_mask, + src_mask=mask, + ) + + return output + + +class TransformerEncoderLayer(torch.nn.Module): + def __init__( + self, + d_model: int, + dim_feedforward: int, + nhead: int, + dropout_rate: float, + ): + """TransformerEncoderLayer is made up of self-attn and feedforward module + + Args: + d_model (int): The model size + dim_feedforward (int): Dimension of ffw module + nhead (int): Number of heads + dropout_rate (float): Dropout rate + """ + super().__init__() + + self.d_model = d_model + + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.feed_forward = nn.Sequential( + ScaledLinear(d_model, dim_feedforward), + ActivationBalancer(channel_dim=-1), + DoubleSwish(), + nn.Dropout(dropout_rate), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ) + + self.norm_final = BasicNorm(d_model) + + self.balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 + ) + + self.dropout = nn.Dropout(dropout_rate) + + def forward( + self, + src: torch.Tensor, + pos_emb: torch.Tensor, + src_key_padding_mask: Optional[torch.Tensor] = None, + src_mask: Optional[torch.Tensor] = None, + cache=None, + ): + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + src_key_padding_mask: the mask for the src keys per batch (optional). + src_mask: the mask for the src sequence (optional). + """ + src_orig = src + + src_att = self.self_attn( + src, + src, + src, + pos_emb=pos_emb, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + )[0] + + src = src + self.dropout(src_att) + + # feed forward module + src = src + self.dropout(self.feed_forward(src)) + + src = self.norm_final(self.balancer(src)) + + return src + + +class RelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module. + + See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py + + Args: + d_model: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length. + + """ + + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + """Construct an PositionalEncoding object.""" + super(RelPositionalEncoding, self).__init__() + if is_jit_tracing(): + # 10k frames correspond to ~100k ms, e.g., 100 seconds, i.e., + # It assumes that the maximum input won't have more than + # 10k frames. + # + # TODO(fangjun): Use torch.jit.script() for this module + max_len = 10000 + + self.d_model = d_model + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x: torch.Tensor, left_context: int = 0) -> None: + """Reset the positional encodings.""" + x_size_1 = x.size(1) + left_context + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x_size_1 * 2 - 1: + # Note: TorchScript doesn't implement operator== for torch.Device + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vector and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i Tuple[torch.Tensor, torch.Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + left_context (int): left context (in frames) used during streaming decoding. + this is used only in real streaming decoding, in other circumstances, + it MUST be 0. + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). + + """ + self.extend_pe(x, left_context) + x_size_1 = x.size(1) + left_context + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 + - x_size_1 + + 1 : self.pe.size(1) // 2 # noqa E203 + + x.size(1), + ] + return self.dropout(x), self.dropout(pos_emb) diff --git a/icefall/transformer_lm/export.py b/icefall/transformer_lm/export.py new file mode 100644 index 000000000..c08982e37 --- /dev/null +++ b/icefall/transformer_lm/export.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python3 +# Copyright (c) 2022 Xiaomi Corporation (authors: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. + +import argparse +import logging +from pathlib import Path + +import torch +from model import TransformerLM + +from icefall.checkpoint import load_checkpoint +from icefall.utils import AttributeDict, load_averaged_model, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=11, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--avg", + type=int, + default=5, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--vocab-size", + type=int, + default=500, + help="Vocabulary size of the model", + ) + + parser.add_argument( + "--embedding-dim", + type=int, + default=768, + help="Embedding dim of the model", + ) + + parser.add_argument( + "--encoder-dim", + type=int, + default=768, + help="Encoder dim of the model", + ) + + parser.add_argument( + "--dim_feedforward", + type=int, + default=2048, + help="Hidden dim of the model", + ) + + parser.add_argument( + "--nhead", + type=int, + default=8, + help="Number of attention heads", + ) + + parser.add_argument( + "--num-layers", + type=int, + default=16, + help="Number of Transformer layers", + ) + + parser.add_argument( + "--tie-weights", + type=str2bool, + default=True, + help="""True to share the weights between the input embedding layer and the + last output linear layer + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="rnn_lm/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=True, + help="""True to save a model after applying torch.jit.script. + """, + ) + + return parser + + +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = AttributeDict({}) + params.update(vars(args)) + + logging.info(params) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("About to create model") + model = TransformerLM( + vocab_size=params.vocab_size, + d_model=params.encoder_dim, + embedding_dim=params.embedding_dim, + dim_feedforward=params.dim_feedforward, + nhead=params.nhead, + num_layers=params.num_layers, + tie_weights=params.tie_weights, + params=params, + ) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + model.to(device) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + model = load_averaged_model( + params.exp_dir, model, params.epoch, params.avg, device + ) + + model.to("cpu") + model.eval() + + if params.jit: + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torch.jit.script") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/icefall/transformer_lm/model.py b/icefall/transformer_lm/model.py new file mode 100644 index 000000000..79dda3168 --- /dev/null +++ b/icefall/transformer_lm/model.py @@ -0,0 +1,115 @@ +# Copyright (c) 2022 Xiaomi Corporation (authors: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F + +from icefall.transformer_lm.encoder import Transformer +from icefall.utils import AttributeDict, add_eos, add_sos, make_pad_mask + + +class TransformerLM(torch.nn.Module): + def __init__( + self, + vocab_size: int, + embedding_dim: int, + d_model: int, + dim_feedforward: int, + nhead: int = 8, + num_layers: int = 16, + tie_weights: bool = True, + dropout: float = 0.1, + emb_dropout_rate: float = 0.0, + params: AttributeDict = None, + ): + super().__init__() + + self.vocab_size = vocab_size + self.params = params + + self.input_embedding = torch.nn.Embedding( + num_embeddings=vocab_size, + embedding_dim=embedding_dim, + ) + + self.encoder = Transformer( + input_dim=embedding_dim, + d_model=d_model, + dim_feedforward=dim_feedforward, + nhead=nhead, + num_layers=num_layers, + dropout_rate=dropout, + ) + + self.output_linear = torch.nn.Linear( + in_features=d_model, out_features=vocab_size + ) + if tie_weights: + logging.info("Tying weights") + assert d_model == embedding_dim, (d_model, embedding_dim) + self.output_linear.weight = self.input_embedding.weight + else: + logging.info("Not tying weights") + + def forward( + self, + x: torch.Tensor, + y: torch.Tensor, + x_lens: torch.Tensor, + return_logits: bool = False, + ): + """Forward transformer language model + + Args: + x (torch.Tensor): Input tokens (B,L) + y (torch.Tensor): Output tokens (with EOS appended) (B,L) + x_lens (torch.Tensor): Length of input tokens before padding (B,) + return_logits (bool, optional): Return logits instead of NLL + + """ + + x = self.input_embedding(x) + + x, x_lens = self.encoder(x, x_lens) + + logits = self.output_linear(x) + + if return_logits: + return logits + + nll_loss = F.cross_entropy( + logits.reshape(-1, self.vocab_size), y.reshape(-1), reduction="none" + ) + + mask = make_pad_mask(x_lens).reshape(-1) + nll_loss.masked_fill_(mask, 0) + + return nll_loss + + def score_token(self, x: torch.Tensor, x_lens: torch.Tensor, state=None): + + bs = x.size(0) + + state = None + logits = self.forward(x, x, x_lens, return_logits=True) + index = torch.arange(bs) + + last_logits = logits[index, x_lens - 1, :] + + return last_logits.log_softmax(-1), state diff --git a/icefall/transformer_lm/scaling.py b/icefall/transformer_lm/scaling.py new file mode 120000 index 000000000..0876c0704 --- /dev/null +++ b/icefall/transformer_lm/scaling.py @@ -0,0 +1 @@ +../../egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py \ No newline at end of file diff --git a/icefall/transformer_lm/train.py b/icefall/transformer_lm/train.py new file mode 100644 index 000000000..c36abfcdf --- /dev/null +++ b/icefall/transformer_lm/train.py @@ -0,0 +1,609 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +Usage: + ./transformer_lm/train.py \ + --start-epoch 0 \ + --world-size 2 \ + --num-epochs 1 \ + --use-fp16 0 \ + --num-layers 12 \ + --batch-size 400 + +""" + +import argparse +import logging +import math +from pathlib import Path +from shutil import copyfile +from typing import Optional, Tuple + +import torch +import torch.multiprocessing as mp +import torch.nn as nn +import torch.optim as optim +from dataset import get_dataloader +from lhotse.utils import fix_random_seed +from model import TransformerLM +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_norm_ +from torch.utils.tensorboard import SummaryWriter + +from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=0, + help="""Resume training from from this epoch. + If it is positive, it will load checkpoint from + exp_dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="transformer_lm/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, logs, etc, are saved + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=True, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--batch-size", + type=int, + default=400, + ) + + parser.add_argument( + "--lm-data", + type=str, + default="data/lm_training_bpe_500/sorted_lm_data.pt", + help="LM training data", + ) + + parser.add_argument( + "--lm-data-valid", + type=str, + default="data/lm_training_bpe_500/sorted_lm_data-valid.pt", + help="LM validation data", + ) + + parser.add_argument( + "--vocab-size", + type=int, + default=500, + help="Vocabulary size of the model", + ) + + parser.add_argument( + "--num-layers", + type=int, + default=12, + help="Number of Transformer layers in the model", + ) + + parser.add_argument( + "--tie-weights", + type=str2bool, + default=True, + help="""True to share the weights between the input embedding layer and the + last output linear layer + """, + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters.""" + + params = AttributeDict( + { + "max_sent_len": 200, + "sos_id": 1, + "eos_id": 1, + "blank_id": 0, + "lr": 1e-3, + "weight_decay": 1e-6, + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 200, + "reset_interval": 2000, + "valid_interval": 1000, + "nhead": 8, + "embedding_dim": 768, + "encoder_dim": 768, + "dim_feedforward": 2048, + "dropout": 0.1, + "env_info": get_env_info(), + } + ) + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, +) -> None: + """Load checkpoint from file. + + If params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. Otherwise, this function does nothing. + + Apart from loading state dict for `model`, `optimizer` and `scheduler`, + it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + optimizer: + The optimizer that we are using. + scheduler: + The learning rate scheduler we are using. + Returns: + Return None. + """ + if params.start_epoch <= 0: + return + + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + logging.info(f"Loading checkpoint: {filename}") + saved_params = load_checkpoint( + filename, + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + model: nn.Module, + x: torch.Tensor, + y: torch.Tensor, + sentence_lengths: torch.Tensor, + is_training: bool, +) -> Tuple[torch.Tensor, MetricsTracker]: + """Compute the negative log-likelihood loss given a model and its input. + Args: + model: + The NN model, + x: + A 2-D tensor. Each row contains BPE token IDs for a sentence. Also, + each row starts with SOS ID. + y: + A 2-D tensor. Each row is a shifted version of the corresponding row + in `x` but ends with an EOS ID (before padding). + sentence_lengths: + A 1-D tensor containing number of tokens of each sentence + before padding. + is_training: + True for training. False for validation. + """ + with torch.set_grad_enabled(is_training): + device = model.device + x = x.to(device) + y = y.to(device) + sentence_lengths = sentence_lengths.to(device) + + nll = model(x, y, sentence_lengths) + loss = nll.sum() + + num_tokens = sentence_lengths.sum().item() + + loss_info = MetricsTracker() + # Note: Due to how MetricsTracker() is designed, + # we use "frames" instead of "num_tokens" as a key here + loss_info["frames"] = num_tokens + loss_info["loss"] = loss.detach().item() + return loss, loss_info + + +def compute_validation_loss( + params: AttributeDict, + model: nn.Module, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process. The validation loss + is saved in `params.valid_loss`. + """ + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + x, y, sentence_lengths = batch + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + model=model, + x=x, + y=y, + sentence_lengths=sentence_lengths, + is_training=False, + ) + + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all sentences is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + x, y, sentence_lengths = batch + batch_size = x.size(0) + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + model=model, + x=x, + y=y, + sentence_lengths=sentence_lengths, + is_training=True, + ) + + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + optimizer.zero_grad() + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + + if batch_idx % params.log_interval == 0: + # Note: "frames" here means "num_tokens" + this_batch_ppl = math.exp(loss_info["loss"] / loss_info["frames"]) + tot_ppl = math.exp(tot_loss["loss"] / tot_loss["frames"]) + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}, ppl: {this_batch_ppl}] " + f"tot_loss[{tot_loss}, ppl: {tot_ppl}], " + f"batch size: {batch_size}" + ) + + if tb_writer is not None: + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + + tb_writer.add_scalar( + "train/current_ppl", this_batch_ppl, params.batch_idx_train + ) + + tb_writer.add_scalar("train/tot_ppl", tot_ppl, params.batch_idx_train) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + + valid_info = compute_validation_loss( + params=params, + model=model, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + + valid_ppl = math.exp(valid_info["loss"] / valid_info["frames"]) + logging.info( + f"Epoch {params.cur_epoch}, validation: {valid_info}, " + f"ppl: {valid_ppl}" + ) + + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + tb_writer.add_scalar( + "train/valid_ppl", valid_ppl, params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + is_distributed = world_size > 1 + + fix_random_seed(params.seed) + if is_distributed: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + logging.info(params) + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + + logging.info(f"Device: {device}") + + logging.info("About to create model") + model = TransformerLM( + vocab_size=params.vocab_size, + d_model=params.encoder_dim, + embedding_dim=params.embedding_dim, + dim_feedforward=params.dim_feedforward, + nhead=params.nhead, + num_layers=params.num_layers, + tie_weights=params.tie_weights, + params=params, + ) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if is_distributed: + model = DDP(model, device_ids=[rank]) + + model.device = device + + optimizer = optim.Adam( + model.parameters(), + lr=params.lr, + weight_decay=params.weight_decay, + ) + if checkpoints: + logging.info("Load optimizer state_dict from checkpoint") + optimizer.load_state_dict(checkpoints["optimizer"]) + + logging.info(f"Loading LM training data from {params.lm_data}") + train_dl = get_dataloader( + filename=params.lm_data, + is_distributed=is_distributed, + params=params, + ) + + logging.info(f"Loading LM validation data from {params.lm_data_valid}") + valid_dl = get_dataloader( + filename=params.lm_data_valid, + is_distributed=is_distributed, + params=params, + ) + + # Note: No learning rate scheduler is used here + for epoch in range(params.start_epoch, params.num_epochs): + if is_distributed: + train_dl.sampler.set_epoch(epoch) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + optimizer=optimizer, + train_dl=train_dl, + valid_dl=valid_dl, + tb_writer=tb_writer, + world_size=world_size, + ) + + save_checkpoint( + params=params, + model=model, + optimizer=optimizer, + rank=rank, + ) + + logging.info("Done!") + + if is_distributed: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() From aa0fe4e4ac4d9bb4a1082709c103e76f70eb8b6f Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Thu, 29 Dec 2022 11:54:42 +0800 Subject: [PATCH 091/120] Fix typos in RESULTS.md (#797) --- egs/librispeech/ASR/RESULTS.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 007d34a62..05422562c 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -318,13 +318,13 @@ Number of model parameters: 70369391, i.e., 70.37 M | | test-clean | test-other | comment | |----------------------|------------|-------------|----------------------------------------| -| greedy search | 2.17 | 5.23 | --epoch 39 --avg 6 --max-duration 600 | -| modified beam search | 2.15 | 5.20 | --epoch 39 --avg 6 --max-duration 600 | -| modified beam search + RNNLM shallow fusion | 1.99 | 4.73 | --epoch 39 --avg 6 --max-duration 600 | -| modified beam search + TransformerLM shallow fusion | 1.94 | 4.73 | --epoch 39 --avg 6 --max-duration 600 | -| modified beam search + RNNLM + LODR | 1.91 | 4.57 | --epoch 39 --avg 6 --max-duration 600 | -| modified beam search + TransformerLM + LODR | 1.91 | 4.51 | --epoch 39 --avg 6 --max-duration 600 | -| fast beam search | 2.15 | 5.22 | --epoch 39 --avg 6 --max-duration 600 | +| greedy search | 2.17 | 5.23 | --epoch 30 --avg 9 --max-duration 600 | +| modified beam search | 2.15 | 5.20 | --epoch 30 --avg 9 --max-duration 600 | +| modified beam search + RNNLM shallow fusion | 1.99 | 4.73 | --epoch 30 --avg 9 --max-duration 600 | +| modified beam search + TransformerLM shallow fusion | 1.94 | 4.73 | --epoch 30 --avg 9 --max-duration 600 | +| modified beam search + RNNLM + LODR | 1.91 | 4.57 | --epoch 30 --avg 9 --max-duration 600 | +| modified beam search + TransformerLM + LODR | 1.91 | 4.51 | --epoch 30 --avg 9 --max-duration 600 | +| fast beam search | 2.15 | 5.22 | --epoch 30 --avg 9 --max-duration 600 | The training commands are: ```bash From d167aad4abd5d330da9da1aa006478eb4361cd04 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Fri, 30 Dec 2022 10:52:18 +0800 Subject: [PATCH 092/120] Add streaming zipformer (#787) * add streaming zipformer codes * add test_model.py * add export.py, pretrained.py, jit_pretrained.py * add cached_len for pooling module * add jit_trace_export.py and jit_trace_pretrained.py * fix bug in jit.trace * update RESULTS.md * add CI test * minor fix in pruned_transducer_stateless7/zipformer.py * update README.md --- ...nsducer-stateless7-streaming-2022-12-29.sh | 148 + ...speech-2022-12-29-stateless7-streaming.yml | 172 + egs/librispeech/ASR/README.md | 22 +- egs/librispeech/ASR/RESULTS.md | 78 + .../pruned_transducer_stateless7/scaling.py | 6 +- .../pruned_transducer_stateless7/zipformer.py | 2 +- .../asr_datamodule.py | 1 + .../beam_search.py | 1 + .../decode.py | 813 +++++ .../decode_stream.py | 151 + .../decoder.py | 1 + .../encoder_interface.py | 1 + .../export.py | 320 ++ .../jit_pretrained.py | 278 ++ .../jit_trace_export.py | 313 ++ .../jit_trace_pretrained.py | 295 ++ .../joiner.py | 1 + .../model.py | 1 + .../optim.py | 1 + .../pretrained.py | 355 ++ .../scaling.py | 1 + .../scaling_converter.py | 1 + .../streaming_beam_search.py | 1 + .../streaming_decode.py | 615 ++++ .../test_model.py | 150 + .../train.py | 1264 ++++++++ .../zipformer.py | 2881 +++++++++++++++++ 27 files changed, 7867 insertions(+), 6 deletions(-) create mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh create mode 100644 .github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decoder.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/joiner.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/model.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/optim.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/scaling.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/test_model.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh new file mode 100755 index 000000000..afb0dc05a --- /dev/null +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh @@ -0,0 +1,148 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 + +log "Downloading pre-trained model from $repo_url" +git lfs install +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +soxi $repo/test_wavs/*.wav +ls -lh $repo/test_wavs/*.wav + +pushd $repo/exp +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/cpu_jit.pt" +git lfs pull --include "exp/pretrained.pt" +git lfs pull --include "exp/encoder_jit_trace.pt" +git lfs pull --include "exp/decoder_jit_trace.pt" +git lfs pull --include "exp/joiner_jit_trace.pt" +ln -s pretrained.pt epoch-99.pt +ls -lh *.pt +popd + +log "Export to torchscript model" +./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir $repo/exp \ + --use-averaged-model false \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --decode-chunk-len 32 \ + --epoch 99 \ + --avg 1 \ + --jit 1 + +ls -lh $repo/exp/*.pt + +log "Decode with models exported by torch.jit.script()" + +./pruned_transducer_stateless7_streaming/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --nn-model-filename $repo/exp/cpu_jit.pt \ + --decode-chunk-len 32 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +log "Export to torchscript model by torch.jit.trace()" +./pruned_transducer_stateless7_streaming/jit_trace_export.py \ + --exp-dir $repo/exp \ + --use-averaged-model false \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --decode-chunk-len 32 \ + --epoch 99 \ + --avg 1 + +log "Decode with models exported by torch.jit.trace()" + +./pruned_transducer_stateless7_streaming/jit_trace_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --encoder-model-filename $repo/exp/encoder_jit_trace.pt \ + --decoder-model-filename $repo/exp/decoder_jit_trace.pt \ + --joiner-model-filename $repo/exp/joiner_jit_trace.pt \ + --decode-chunk-len 32 \ + $repo/test_wavs/1089-134686-0001.wav + +for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless7_streaming/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --decode-chunk-len 32 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./pruned_transducer_stateless7_streaming/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --decode-chunk-len 32 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" +echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" +if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then + mkdir -p pruned_transducer_stateless7_streaming/exp + ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless7_streaming/exp/epoch-999.pt + ln -s $PWD/$repo/data/lang_bpe_500 data/ + + ls -lh data + ls -lh pruned_transducer_stateless7_streaming/exp + + log "Decoding test-clean and test-other" + + # use a small value for decoding with CPU + max_duration=100 + num_decode_stream=200 + + for method in greedy_search fast_beam_search modified_beam_search; do + log "decoding with $method" + + ./pruned_transducer_stateless7_streaming/decode.py \ + --decoding-method $method \ + --epoch 999 \ + --avg 1 \ + --use-averaged-model 0 \ + --max-duration $max_duration \ + --decode-chunk-len 32 \ + --exp-dir pruned_transducer_stateless7_streaming/exp + done + + for method in greedy_search fast_beam_search modified_beam_search; do + log "Decoding with $method" + + ./pruned_transducer_stateless7_streaming/streaming_decode.py \ + --decoding-method $method \ + --epoch 999 \ + --avg 1 \ + --use-averaged-model 0 \ + --decode-chunk-len 32 \ + --num-decode-streams $num_decode_stream + --exp-dir pruned_transducer_stateless7_streaming/exp + done + + rm pruned_transducer_stateless7_streaming/exp/*.pt +fi diff --git a/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml b/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml new file mode 100644 index 000000000..6dd93946a --- /dev/null +++ b/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml @@ -0,0 +1,172 @@ +# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: run-librispeech-2022-12-29-stateless7-streaming +# zipformer + +on: + push: + branches: + - master + pull_request: + types: [labeled] + + schedule: + # minute (0-59) + # hour (0-23) + # day of the month (1-31) + # month (1-12) + # day of the week (0-6) + # nightly build at 15:50 UTC time every day + - cron: "50 15 * * *" + +concurrency: + group: run_librispeech_2022_12_29_zipformer_streaming-${{ github.ref }} + cancel-in-progress: true + +jobs: + run_librispeech_2022_12_29_zipformer_streaming: + if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event.label.name == 'streaming-zipformer' || github.event_name == 'push' || github.event_name == 'schedule' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: [3.8] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }}-2022-09-25 + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: Cache LibriSpeech test-clean and test-other datasets + id: libri-test-clean-and-test-other-data + uses: actions/cache@v2 + with: + path: | + ~/tmp/download + key: cache-libri-test-clean-and-test-other + + - name: Download LibriSpeech test-clean and test-other + if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh + + - name: Prepare manifests for LibriSpeech test-clean and test-other + shell: bash + run: | + .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh + + - name: Cache LibriSpeech test-clean and test-other fbank features + id: libri-test-clean-and-test-other-fbank + uses: actions/cache@v2 + with: + path: | + ~/tmp/fbank-libri + key: cache-libri-fbank-test-clean-and-test-other-v2 + + - name: Compute fbank for LibriSpeech test-clean and test-other + if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh + + - name: Inference with pre-trained model + shell: bash + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} + run: | + mkdir -p egs/librispeech/ASR/data + ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank + ls -lh egs/librispeech/ASR/data/* + + sudo apt-get -qq install git-lfs tree sox + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + .github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh + + - name: Display decoding results for librispeech pruned_transducer_stateless7_streaming + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + shell: bash + run: | + cd egs/librispeech/ASR/ + tree ./pruned_transducer_stateless7_streaming/exp + + cd pruned_transducer_stateless7_streaming + echo "results for pruned_transducer_stateless7_streaming" + echo "===greedy search===" + find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===fast_beam_search===" + find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===modified beam search===" + find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===streaming greedy search===" + find exp/streaming/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/streaming/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===streaming fast_beam_search===" + find exp/streaming/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/streaming/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===streaming modified beam search===" + find exp/streaming/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/streaming/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + + - name: Upload decoding results for librispeech pruned_transducer_stateless7_streaming + uses: actions/upload-artifact@v2 + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + with: + name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless7-streaming-2022-12-29 + path: egs/librispeech/ASR/pruned_transducer_stateless7_streaming/exp/ diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md index caa23a49f..94cb445a8 100644 --- a/egs/librispeech/ASR/README.md +++ b/egs/librispeech/ASR/README.md @@ -19,18 +19,36 @@ The following table lists the differences among them. | `pruned_transducer_stateless` | Conformer | Embedding + Conv1d | Using k2 pruned RNN-T loss | | `pruned_transducer_stateless2` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss | | `pruned_transducer_stateless3` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss + using GigaSpeech as extra training data | -| `pruned_transducer_stateless4` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless2 + save averaged models periodically during training | +| `pruned_transducer_stateless4` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless2 + save averaged models periodically during training + delay penalty | | `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless4 + more layers + random combiner| | `pruned_transducer_stateless6` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless4 + distillation with hubert| | `pruned_transducer_stateless7` | Zipformer | Embedding + Conv1d | First experiment with Zipformer from Dan| | `pruned_transducer_stateless7_ctc` | Zipformer | Embedding + Conv1d | Same as pruned_transducer_stateless7, but with extra CTC head| +| `pruned_transducer_stateless7_ctc_bs` | Zipformer | Embedding + Conv1d | pruned_transducer_stateless7_ctc + blank skip | +| `pruned_transducer_stateless7_streaming` | Streaming Zipformer | Embedding + Conv1d | streaming version of pruned_transducer_stateless7 | | `pruned_transducer_stateless8` | Zipformer | Embedding + Conv1d | Same as pruned_transducer_stateless7, but using extra data from GigaSpeech| | `pruned_stateless_emformer_rnnt2` | Emformer(from torchaudio) | Embedding + Conv1d | Using Emformer from torchaudio for streaming ASR| | `conv_emformer_transducer_stateless` | ConvEmformer | Embedding + Conv1d | Using ConvEmformer for streaming ASR + mechanisms in reworked model | | `conv_emformer_transducer_stateless2` | ConvEmformer | Embedding + Conv1d | Using ConvEmformer with simplified memory for streaming ASR + mechanisms in reworked model | | `lstm_transducer_stateless` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model | -| `lstm_transducer_stateless2` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gigaspeech (multi-dataset setup) | +| `lstm_transducer_stateless2` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gigaspeech (multi-dataset setup) | +| `lstm_transducer_stateless3` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gradient filter + delay penalty | The decoder in `transducer_stateless` is modified from the paper [Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). We place an additional Conv1d layer right after the input embedding layer. + +# CTC + +| | Encoder | Comment | +|------------------------------|--------------------|------------------------------| +| `conformer-ctc` | Conformer | Use auxiliary attention head | +| `conformer-ctc2` | Reworked Conformer | Use auxiliary attention head | +| `conformer-ctc3` | Reworked Conformer | Streaming version + delay penalty | + +# MMI + +| | Encoder | Comment | +|------------------------------|-----------|---------------------------------------------------| +| `conformer-mmi` | Conformer | | +| `zipformer-mmi` | Zipformer | CTC warmup + use HP as decoding graph for decoding | diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 05422562c..b30cf7c1f 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -1,5 +1,83 @@ ## Results +### Streaming Zipformer-Transducer (Pruned Stateless Transducer + Streaming Zipformer) + +#### [pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming) + +See for more details. + +You can find a pretrained model, training logs, decoding logs, and decoding +results at: + + +Number of model parameters: 70369391, i.e., 70.37 M + +##### training on full librispeech + +The WERs are: + +| decoding method | chunk size | test-clean | test-other | comment | decoding mode | +|----------------------|------------|------------|------------|---------------------|----------------------| +| greedy search | 320ms | 3.15 | 8.09 | --epoch 30 --avg 9 | simulated streaming | +| greedy search | 320ms | 3.17 | 8.24 | --epoch 30 --avg 9 | chunk-wise | +| fast beam search | 320ms | 3.2 | 8.04 | --epoch 30 --avg 9 | simulated streaming | +| fast beam search | 320ms | 3.36 | 8.19 | --epoch 30 --avg 9 | chunk-wise | +| modified beam search | 320ms | 3.11 | 7.93 | --epoch 30 --avg 9 | simulated streaming | +| modified beam search | 320ms | 3.12 | 8.11 | --epoch 30 --avg 9 | chunk-size | +| greedy search | 640ms | 2.97 | 7.5 | --epoch 30 --avg 9 | simulated streaming | +| greedy search | 640ms | 2.98 | 7.67 | --epoch 30 --avg 9 | chunk-wise | +| fast beam search | 640ms | 3.02 | 7.47 | --epoch 30 --avg 9 | simulated streaming | +| fast beam search | 640ms | 2.96 | 7.61 | --epoch 30 --avg 9 | chunk-wise | +| modified beam search | 640ms | 2.94 | 7.36 | --epoch 30 --avg 9 | simulated streaming | +| modified beam search | 640ms | 2.95 | 7.53 | --epoch 30 --avg 9 | chunk-size | + +Note: `simulated streaming` indicates feeding full utterance during decoding using `decode.py`, +while `chunk-size` indicates feeding certain number of frames at each time using `streaming_decode.py`. + +The training command is: + +```bash +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --full-libri 1 \ + --max-duration 750 \ + --master-port 12345 +``` + +The tensorboard log can be found at + + +The simulated streaming decoding command (e.g., chunk-size=320ms) is: +```bash +for $m in greedy_search fast_beam_search modified_beam_search; do + ./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 30 \ + --avg 9 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method $m +done +``` + +The streaming chunk-size decoding command (e.g., chunk-size=320ms) is: +```bash +for m in greedy_search modified_beam_search fast_beam_search; do + ./pruned_transducer_stateless7_streaming/streaming_decode.py \ + --epoch 30 \ + --avg 9 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --decoding-method $m \ + --decode-chunk-len 32 \ + --num-decode-streams 2000 +done +``` + + ### zipformer_mmi (zipformer with mmi loss) See for more details. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 042c9c3e4..1cbde6db0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -298,7 +298,7 @@ class SoftmaxFunction(torch.autograd.Function): def softmax(x: Tensor, dim: int): - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): return x.softmax(dim) return SoftmaxFunction.apply(x, dim) @@ -783,7 +783,7 @@ class WithLoss(torch.autograd.Function): def with_loss(x, y): - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): return x # returns x but adds y.sum() to the loss function. return WithLoss.apply(x, y) @@ -1013,7 +1013,7 @@ class DoubleSwish(torch.nn.Module): """Return double-swish activation function which is an approximation to Swish(Swish(x)), that we approximate closely with x * sigmoid(x-1). """ - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): return x * torch.sigmoid(x - 1.0) return DoubleSwishFunction.apply(x) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index ad3b88df0..d18258085 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -907,7 +907,7 @@ class RelPositionalEncoding(torch.nn.Module): self.d_model = d_model self.dropout = torch.nn.Dropout(dropout_rate) self.pe = None - self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + self.extend_pe(torch.tensor(0.0).expand(max_len)) def extend_pe(self, x: Tensor) -> None: """Reset the positional encodings.""" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py new file mode 120000 index 000000000..a074d6085 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py new file mode 120000 index 000000000..8554e44cc --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py new file mode 100755 index 000000000..aebe2b94b --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py @@ -0,0 +1,813 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + feature_lens += 30 + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, 30), + value=LOG_EPS, + ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, ( + model.encoder.decode_chunk_size, + params.decode_chunk_len, + ) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py new file mode 100644 index 000000000..0d7e86fcf --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py @@ -0,0 +1,151 @@ +# Copyright 2022 Xiaomi Corp. (authors: Wei Kang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple + +import k2 +import torch +from beam_search import Hypothesis, HypothesisList + +from icefall.utils import AttributeDict + + +class DecodeStream(object): + def __init__( + self, + params: AttributeDict, + cut_id: str, + initial_states: List[torch.Tensor], + decoding_graph: Optional[k2.Fsa] = None, + device: torch.device = torch.device("cpu"), + ) -> None: + """ + Args: + initial_states: + Initial decode states of the model, e.g. the return value of + `get_init_state` in conformer.py + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a HLG. + Used only when decoding_method is fast_beam_search. + device: + The device to run this stream. + """ + if params.decoding_method == "fast_beam_search": + assert decoding_graph is not None + assert device == decoding_graph.device + + self.params = params + self.cut_id = cut_id + self.LOG_EPS = math.log(1e-10) + + self.states = initial_states + + # It contains a 2-D tensors representing the feature frames. + self.features: torch.Tensor = None + + self.num_frames: int = 0 + # how many frames have been processed. (before subsampling). + # we only modify this value in `func:get_feature_frames`. + self.num_processed_frames: int = 0 + + self._done: bool = False + + # The transcript of current utterance. + self.ground_truth: str = "" + + # The decoding result (partial or final) of current utterance. + self.hyp: List = [] + + # how many frames have been processed, after subsampling (i.e. a + # cumulative sum of the second return value of + # encoder.streaming_forward + self.done_frames: int = 0 + + # It has two steps of feature subsampling in zipformer: out_lens=((x_lens-7)//2+1)//2 + # 1) feature embedding: out_lens=(x_lens-7)//2 + # 2) output subsampling: out_lens=(out_lens+1)//2 + self.pad_length = 7 + + if params.decoding_method == "greedy_search": + self.hyp = [params.blank_id] * params.context_size + elif params.decoding_method == "modified_beam_search": + self.hyps = HypothesisList() + self.hyps.add( + Hypothesis( + ys=[params.blank_id] * params.context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + elif params.decoding_method == "fast_beam_search": + # The rnnt_decoding_stream for fast_beam_search. + self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( + decoding_graph + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + @property + def done(self) -> bool: + """Return True if all the features are processed.""" + return self._done + + @property + def id(self) -> str: + return self.cut_id + + def set_features( + self, + features: torch.Tensor, + tail_pad_len: int = 0, + ) -> None: + """Set features tensor of current utterance.""" + assert features.dim() == 2, features.dim() + self.features = torch.nn.functional.pad( + features, + (0, 0, 0, self.pad_length + tail_pad_len), + mode="constant", + value=self.LOG_EPS, + ) + self.num_frames = self.features.size(0) + + def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]: + """Consume chunk_size frames of features""" + chunk_length = chunk_size + self.pad_length + + ret_length = min(self.num_frames - self.num_processed_frames, chunk_length) + + ret_features = self.features[ + self.num_processed_frames : self.num_processed_frames + ret_length # noqa + ] + + self.num_processed_frames += chunk_size + if self.num_processed_frames >= self.num_frames: + self._done = True + + return ret_features, ret_length + + def decoding_result(self) -> List[int]: + """Obtain current decoding result.""" + if self.params.decoding_method == "greedy_search": + return self.hyp[self.params.context_size :] # noqa + elif self.params.decoding_method == "modified_beam_search": + best_hyp = self.hyps.get_most_probable(length_norm=True) + return best_hyp.ys[self.params.context_size :] # noqa + else: + assert self.params.decoding_method == "fast_beam_search" + return self.hyp diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decoder.py new file mode 120000 index 000000000..33944d0d2 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decoder.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py new file mode 120000 index 000000000..b9aa0ae08 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py new file mode 100755 index 000000000..5c06cc052 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py @@ -0,0 +1,320 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +(1) Export to torchscript model using torch.jit.script() + +./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("cpu_jit.pt")`. + +Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python +are on CPU. You can use `to("cuda")` to move them to a CUDA device. + +Check +https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +(2) Export `model.state_dict()` + +./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +To use the generated file with `pruned_transducer_stateless7_streaming/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./pruned_transducer_stateless7_streaming/decode.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 + # You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +import torch.nn as nn +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named cpu_jit.pt + + Check ./jit_pretrained.py for how to use it. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + if params.jit is True: + convert_scaled_to_non_scaled(model, inplace=True) + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py new file mode 100755 index 000000000..4fd5e1820 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py @@ -0,0 +1,278 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --jit 1 + +Usage of this script: + +./pruned_transducer_stateless7_streaming/jit_pretrained.py \ + --nn-model-filename ./pruned_transducer_stateless7_streaming/exp/cpu_jit.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model-filename", + type=str, + required=True, + help="Path to the torchscript model cpu_jit.pt", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--decode-chunk-len", + type=int, + default=32, + help="The chunk size for decoding (in frames before subsampling)", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + model: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + Returns: + Return the decoded results for each utterance. + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + device = encoder_out.device + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + context_size = model.decoder.context_size + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ).squeeze(1) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + current_encoder_out = current_encoder_out + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = model.joiner( + current_encoder_out, + decoder_out, + ) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ) + decoder_out = decoder_out.squeeze(1) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + model = torch.jit.load(args.nn_model_filename) + model.encoder.decode_chunk_size = args.decode_chunk_len // 2 + + model.eval() + + model.to(device) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + x=features, + x_lens=feature_lengths, + ) + + hyps = greedy_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = sp.decode(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py new file mode 100755 index 000000000..a164f3f69 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py @@ -0,0 +1,313 @@ +#!/usr/bin/env python3 + +""" +Usage: +./pruned_transducer_stateless7_streaming/jit_trace_export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 10 \ + --use-averaged-model=True \ + --decode-chunk-len 32 +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import AttributeDict, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless2/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + add_model_arguments(parser) + + return parser + + +def export_encoder_model_jit_trace( + encoder_model: torch.nn.Module, + encoder_filename: str, + params: AttributeDict, +) -> None: + """Export the given encoder model with torch.jit.trace() + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported model. + """ + decode_chunk_len = params.decode_chunk_len # before subsampling + pad_length = 7 + s = f"decode_chunk_len: {decode_chunk_len}" + logging.info(s) + assert encoder_model.decode_chunk_size == decode_chunk_len // 2, ( + encoder_model.decode_chunk_size, + decode_chunk_len, + ) + + T = decode_chunk_len + pad_length + + x = torch.zeros(1, T, 80, dtype=torch.float32) + x_lens = torch.full((1,), T, dtype=torch.int32) + states = encoder_model.get_init_state(device=x.device) + + encoder_model.__class__.forward = encoder_model.__class__.streaming_forward + traced_model = torch.jit.trace(encoder_model, (x, x_lens, states)) + traced_model.save(encoder_filename) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_jit_trace( + decoder_model: torch.nn.Module, + decoder_filename: str, +) -> None: + """Export the given decoder model with torch.jit.trace() + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The input decoder model + decoder_filename: + The filename to save the exported model. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_model = torch.jit.trace(decoder_model, (y, need_pad)) + traced_model.save(decoder_filename) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_jit_trace( + joiner_model: torch.nn.Module, + joiner_filename: str, +) -> None: + """Export the given joiner model with torch.jit.trace() + + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + + Args: + joiner_model: + The input joiner model + joiner_filename: + The filename to save the exported model. + + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) + traced_model.save(joiner_filename) + logging.info(f"Saved to {joiner_filename}") + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + logging.info("Using torch.jit.trace()") + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / "encoder_jit_trace.pt" + export_encoder_model_jit_trace(model.encoder, encoder_filename, params) + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / "decoder_jit_trace.pt" + export_decoder_model_jit_trace(model.decoder, decoder_filename) + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / "joiner_jit_trace.pt" + export_joiner_model_jit_trace(model.joiner, joiner_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py new file mode 100755 index 000000000..f2ac1914d --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py @@ -0,0 +1,295 @@ +#!/usr/bin/env python3 +# flake8: noqa +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models exported by `torch.jit.trace()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless7_streaming/jit_trace_export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 10 \ + --use-averaged-model=True \ + --decode-chunk-len 32 + +Usage of this script: + +./pruned_transducer_stateless7_streaming/jit_trace_pretrained.py \ + --encoder-model-filename ./pruned_transducer_stateless7_streaming/exp/encoder_jit_trace.pt \ + --decoder-model-filename ./pruned_transducer_stateless7_streaming/exp/decoder_jit_trace.pt \ + --joiner-model-filename ./pruned_transducer_stateless7_streaming/exp/joiner_jit_trace.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --decode-chunk-len 32 \ + /path/to/foo.wav \ +""" + +import argparse +import logging +import math +from typing import List, Optional + +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder torchscript model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder torchscript model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner torchscript model. ", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--decode-chunk-len", + type=int, + default=32, + help="The chunk size for decoding (in frames before subsampling)", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + decoder: torch.jit.ScriptModule, + joiner: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + decoder_out: Optional[torch.Tensor] = None, + hyp: Optional[List[int]] = None, +): + assert encoder_out.ndim == 2 + context_size = 2 + blank_id = 0 + + if decoder_out is None: + assert hyp is None, hyp + hyp = [blank_id] * context_size + decoder_input = torch.tensor(hyp, dtype=torch.int32).unsqueeze(0) + # decoder_input.shape (1,, 1 context_size) + decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1) + else: + assert decoder_out.ndim == 2 + assert hyp is not None, hyp + + T = encoder_out.size(0) + for i in range(T): + cur_encoder_out = encoder_out[i : i + 1] + joiner_out = joiner(cur_encoder_out, decoder_out).squeeze(0) + y = joiner_out.argmax(dim=0).item() + + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + + decoder_input = torch.tensor(decoder_input, dtype=torch.int32).unsqueeze(0) + decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1) + + return hyp, decoder_out + + +def create_streaming_feature_extractor(sample_rate) -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = sample_rate + opts.mel_opts.num_bins = 80 + return OnlineFbank(opts) + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + encoder = torch.jit.load(args.encoder_model_filename) + decoder = torch.jit.load(args.decoder_model_filename) + joiner = torch.jit.load(args.joiner_model_filename) + + encoder.eval() + decoder.eval() + joiner.eval() + + encoder.to(device) + decoder.to(device) + joiner.to(device) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor(args.sample_rate) + + logging.info(f"Reading sound files: {args.sound_file}") + wave_samples = read_sound_files( + filenames=[args.sound_file], + expected_sample_rate=args.sample_rate, + )[0] + logging.info(wave_samples.shape) + + logging.info("Decoding started") + chunk_length = args.decode_chunk_len + assert encoder.decode_chunk_size == chunk_length // 2, ( + encoder.decode_chunk_size, + chunk_length, + ) + + # we subsample features with ((x_len - 7) // 2 + 1) // 2 + pad_length = 7 + T = chunk_length + pad_length + + logging.info(f"chunk_length: {chunk_length}") + + states = encoder.get_init_state(device) + + tail_padding = torch.zeros(int(0.3 * args.sample_rate), dtype=torch.float32) + + wave_samples = torch.cat([wave_samples, tail_padding]) + + chunk = int(0.25 * args.sample_rate) # 0.2 second + num_processed_frames = 0 + + hyp = None + decoder_out = None + + start = 0 + while start < wave_samples.numel(): + logging.info(f"{start}/{wave_samples.numel()}") + end = min(start + chunk, wave_samples.numel()) + samples = wave_samples[start:end] + start += chunk + online_fbank.accept_waveform( + sampling_rate=args.sample_rate, + waveform=samples, + ) + while online_fbank.num_frames_ready - num_processed_frames >= T: + frames = [] + for i in range(T): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + frames = torch.cat(frames, dim=0).unsqueeze(0) + x_lens = torch.tensor([T], dtype=torch.int32) + encoder_out, out_lens, states = encoder( + x=frames, + x_lens=x_lens, + states=states, + ) + num_processed_frames += chunk_length + + hyp, decoder_out = greedy_search( + decoder, joiner, encoder_out.squeeze(0), decoder_out, hyp + ) + + context_size = 2 + logging.info(args.sound_file) + logging.info(sp.decode(hyp[context_size:])) + + logging.info("Decoding Done") + + +torch.set_num_threads(4) +torch.set_num_interop_threads(1) +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) +torch._C._set_graph_executor_optimize(False) +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/joiner.py new file mode 120000 index 000000000..ecfb6dd8a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/joiner.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/joiner.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/model.py new file mode 120000 index 000000000..e17d4f734 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/model.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/model.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/optim.py new file mode 120000 index 000000000..81ac4a89a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/optim.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/optim.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py new file mode 100755 index 000000000..fb77fdd42 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py @@ -0,0 +1,355 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +Usage of this script: + +(1) greedy search +./pruned_transducer_stateless7_streaming/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) beam search +./pruned_transducer_stateless7_streaming/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) modified beam search +./pruned_transducer_stateless7_streaming/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(4) fast beam search +./pruned_transducer_stateless7_streaming/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method fast_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +You can also use `./pruned_transducer_stateless7_streaming/exp/epoch-xx.pt`. + +Note: ./pruned_transducer_stateless7_streaming/exp/pretrained.pt is generated by +./pruned_transducer_stateless7_streaming/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) + + num_waves = encoder_out.size(0) + hyps = [] + msg = f"Using {params.method}" + if params.method == "beam_search": + msg += f" with beam size {params.beam_size}" + logging.info(msg) + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(sp.decode(hyp).split()) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/scaling.py new file mode 120000 index 000000000..2428b74b9 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/scaling.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py new file mode 120000 index 000000000..b8b8ba432 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py new file mode 120000 index 000000000..3a5f89833 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/streaming_beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py new file mode 100755 index 000000000..7a349ecb2 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py @@ -0,0 +1,615 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +./pruned_transducer_stateless7_streaming/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --decode-chunk-len 32 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --decoding_method greedy_search \ + --num-decode-streams 2000 +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model +from zipformer import stack_states, unstack_states + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless2/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search + """, + ) + + parser.add_argument( + "--num_active_paths", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + + features = [] + feature_lens = [] + states = [] + processed_lens = [] + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames(params.decode_chunk_len) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_lens.append(stream.done_frames) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling + # factor in encoders is 8. + # After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8. + tail_length = 23 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, + ) + + states = stack_states(states) + processed_lens = torch.tensor(processed_lens, device=device) + + encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward( + x=features, + x_lens=feature_lens, + states=states, + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + if params.decoding_method == "greedy_search": + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + elif params.decoding_method == "fast_beam_search": + processed_lens = processed_lens + encoder_out_lens + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + num_active_paths=params.num_active_paths, + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + states = unstack_states(new_states) + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = states[i] + decode_streams[i].done_frames += encoder_out_lens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + log_interval = 50 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + initial_states = model.encoder.get_init_state(device=device) + decode_stream = DecodeStream( + params=params, + cut_id=cut.id, + initial_states=initial_states, + decoding_graph=decoding_graph, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + assert audio.max() <= 1, "Should be normalized to [-1, 1])" + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + feature = fbank(samples.to(device)) + decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len) + decode_stream.ground_truth = cut.supervisions[0].text + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + elif params.decoding_method == "modified_beam_search": + key = f"num_active_paths_{params.num_active_paths}" + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + return {key: decode_results} + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + # for streaming + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_sets = ["test-clean", "test-other"] + test_cuts = [test_clean_cuts, test_other_cuts] + + for test_set, test_cut in zip(test_sets, test_cuts): + results_dict = decode_dataset( + cuts=test_cut, + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/test_model.py new file mode 100755 index 000000000..5400df804 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/test_model.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./pruned_transducer_stateless7_streaming/test_model.py +""" + +import torch +from scaling_converter import convert_scaled_to_non_scaled +from train import get_params, get_transducer_model + + +def test_model(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.num_encoder_layers = "2,4,3,2,4" + params.feedforward_dims = "1024,1024,2048,2048,1024" + params.nhead = "8,8,8,8,8" + params.encoder_dims = "384,384,384,384,384" + params.attention_dims = "192,192,192,192,192" + params.encoder_unmasked_dims = "256,256,256,256,256" + params.zipformer_downsampling_factors = "1,2,4,8,2" + params.cnn_module_kernels = "31,31,31,31,31" + params.decoder_dim = 512 + params.joiner_dim = 512 + params.num_left_chunks = 4 + params.short_chunk_size = 50 + params.decode_chunk_len = 32 + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + # Test jit script + convert_scaled_to_non_scaled(model, inplace=True) + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + print("Using torch.jit.script") + model = torch.jit.script(model) + + +def test_model_jit_trace(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.num_encoder_layers = "2,4,3,2,4" + params.feedforward_dims = "1024,1024,2048,2048,1024" + params.nhead = "8,8,8,8,8" + params.encoder_dims = "384,384,384,384,384" + params.attention_dims = "192,192,192,192,192" + params.encoder_unmasked_dims = "256,256,256,256,256" + params.zipformer_downsampling_factors = "1,2,4,8,2" + params.cnn_module_kernels = "31,31,31,31,31" + params.decoder_dim = 512 + params.joiner_dim = 512 + params.num_left_chunks = 4 + params.short_chunk_size = 50 + params.decode_chunk_len = 32 + model = get_transducer_model(params) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + convert_scaled_to_non_scaled(model, inplace=True) + + # Test encoder + def _test_encoder(): + encoder = model.encoder + assert encoder.decode_chunk_size == params.decode_chunk_len // 2, ( + encoder.decode_chunk_size, + params.decode_chunk_len, + ) + T = params.decode_chunk_len + 7 + + x = torch.zeros(1, T, 80, dtype=torch.float32) + x_lens = torch.full((1,), T, dtype=torch.int32) + states = encoder.get_init_state(device=x.device) + encoder.__class__.forward = encoder.__class__.streaming_forward + traced_encoder = torch.jit.trace(encoder, (x, x_lens, states)) + + states1 = encoder.get_init_state(device=x.device) + states2 = traced_encoder.get_init_state(device=x.device) + for i in range(5): + x = torch.randn(1, T, 80, dtype=torch.float32) + x_lens = torch.full((1,), T, dtype=torch.int32) + y1, _, states1 = encoder.streaming_forward(x, x_lens, states1) + y2, _, states2 = traced_encoder(x, x_lens, states2) + assert torch.allclose(y1, y2, atol=1e-6), (i, (y1 - y2).abs().mean()) + + # Test decoder + def _test_decoder(): + decoder = model.decoder + y = torch.zeros(10, decoder.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_decoder = torch.jit.trace(decoder, (y, need_pad)) + d1 = decoder(y, need_pad) + d2 = traced_decoder(y, need_pad) + assert torch.equal(d1, d2), (d1 - d2).abs().mean() + + # Test joiner + def _test_joiner(): + joiner = model.joiner + encoder_out_dim = joiner.encoder_proj.weight.shape[1] + decoder_out_dim = joiner.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_joiner = torch.jit.trace(joiner, (encoder_out, decoder_out)) + j1 = joiner(encoder_out, decoder_out) + j2 = traced_joiner(encoder_out, decoder_out) + assert torch.equal(j1, j2), (j1 - j2).abs().mean() + + _test_encoder() + _test_decoder() + _test_joiner() + + +def main(): + test_model() + test_model_jit_trace() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py new file mode 100755 index 000000000..2bdc882a5 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py @@ -0,0 +1,1264 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --full-libri 1 \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --full-libri 1 \ + --max-duration 550 +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=50, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + parser.add_argument( + "--decode-chunk-len", + type=int, + default=32, + help="The chunk size for decoding (in frames before subsampling)", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + num_left_chunks=params.num_left_chunks, + short_chunk_size=params.short_chunk_size, + decode_chunk_size=params.decode_chunk_len // 2, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + if params.full_libri is False: + params.valid_interval = 1600 + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeechAsrDataModule(args) + + train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py new file mode 100644 index 000000000..88beb38c1 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py @@ -0,0 +1,2881 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import itertools +import logging +import math +import random +import warnings +from typing import List, Optional, Tuple, Union + +import torch +from encoder_interface import EncoderInterface +from scaling import ( + ScaledLinear, # not as in other dirs.. just scales down initial parameter values. +) +from scaling import ( + ActivationBalancer, + BasicNorm, + DoubleSwish, + Identity, + MaxEig, + ScaledConv1d, + Whiten, + _diag, + penalize_abs_values_gt, + random_clamp, + softmax, +) +from torch import Tensor, nn + +from icefall.dist import get_rank +from icefall.utils import make_pad_mask, subsequent_chunk_mask + + +def stack_states(state_list: List[List[Tensor]]) -> List[Tensor]: + """Stack list of zipformer states that correspond to separate utterances + into a single emformer state, so that it can be used as an input for + zipformer when those utterances are formed into a batch. + + Note: + It is the inverse of :func:`unstack_states`. + + Args: + state_list: + Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. + ``states[i]`` is a list of 7 * num_encoders elements of i-th utterance. + ``states[i][0:num_encoders]`` is the cached numbers of past frames. + ``states[i][num_encoders:2*num_encoders]`` is the cached average tensors. + ``states[i][2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. + ``states[i][3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. + ``states[i][4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. + ``states[i][5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. + ``states[i][6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. + + Returns: + A new state corresponding to a batch of utterances. + See the input argument of :func:`unstack_states` for the meaning + of the returned tensor. + """ + batch_size = len(state_list) + assert len(state_list[0]) % 7 == 0, len(state_list[0]) + num_encoders = len(state_list[0]) // 7 + + cached_len = [] + cached_avg = [] + cached_key = [] + cached_val = [] + cached_val2 = [] + cached_conv1 = [] + cached_conv2 = [] + + # For cached_len + len_list = [state_list[n][0:num_encoders] for n in range(batch_size)] + for i in range(num_encoders): + # len_avg: (num_layers, batch_size) + len_avg = torch.cat([len_list[n][i] for n in range(batch_size)], dim=1) + cached_len.append(len_avg) + + # For cached_avg + avg_list = [ + state_list[n][num_encoders : 2 * num_encoders] for n in range(batch_size) + ] + for i in range(num_encoders): + # avg: (num_layers, batch_size, D) + avg = torch.cat([avg_list[n][i] for n in range(batch_size)], dim=1) + cached_avg.append(avg) + + # For cached_key + key_list = [ + state_list[n][2 * num_encoders : 3 * num_encoders] for n in range(batch_size) + ] + for i in range(num_encoders): + # key: (num_layers, left_context_size, batch_size, D) + key = torch.cat([key_list[n][i] for n in range(batch_size)], dim=2) + cached_key.append(key) + + # For cached_val + val_list = [ + state_list[n][3 * num_encoders : 4 * num_encoders] for n in range(batch_size) + ] + for i in range(num_encoders): + # val: (num_layers, left_context_size, batch_size, D) + val = torch.cat([val_list[n][i] for n in range(batch_size)], dim=2) + cached_val.append(val) + + # For cached_val2 + val2_list = [ + state_list[n][4 * num_encoders : 5 * num_encoders] for n in range(batch_size) + ] + for i in range(num_encoders): + # val2: (num_layers, left_context_size, batch_size, D) + val2 = torch.cat([val2_list[n][i] for n in range(batch_size)], dim=2) + cached_val2.append(val2) + + # For cached_conv1 + conv1_list = [ + state_list[n][5 * num_encoders : 6 * num_encoders] for n in range(batch_size) + ] + for i in range(num_encoders): + # conv1: (num_layers, batch_size, D, kernel-1) + conv1 = torch.cat([conv1_list[n][i] for n in range(batch_size)], dim=1) + cached_conv1.append(conv1) + + # For cached_conv2 + conv2_list = [ + state_list[n][6 * num_encoders : 7 * num_encoders] for n in range(batch_size) + ] + for i in range(num_encoders): + # conv2: (num_layers, batch_size, D, kernel-1) + conv2 = torch.cat([conv2_list[n][i] for n in range(batch_size)], dim=1) + cached_conv2.append(conv2) + + states = ( + cached_len + + cached_avg + + cached_key + + cached_val + + cached_val2 + + cached_conv1 + + cached_conv2 + ) + return states + + +def unstack_states(states: List[Tensor]) -> List[List[Tensor]]: + """Unstack the zipformer state corresponding to a batch of utterances + into a list of states, where the i-th entry is the state from the i-th + utterance in the batch. + + Note: + It is the inverse of :func:`stack_states`. + + Args: + states: + A list of 7 * num_encoders elements: + ``states[0:num_encoders]`` is the cached numbers of past frames. + ``states[num_encoders:2*num_encoders]`` is the cached average tensors. + ``states[2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. + ``states[3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. + ``states[4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. + ``states[5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. + ``states[6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. + + Returns: + A list of states. + ``states[i]`` is a list of 7 * num_encoders elements of i-th utterance. + """ + assert len(states) % 7 == 0, len(states) + num_encoders = len(states) // 7 + ( + cached_len, + cached_avg, + cached_key, + cached_val, + cached_val2, + cached_conv1, + cached_conv2, + ) = (states[i * num_encoders : (i + 1) * num_encoders] for i in range(7)) + + batch_size = cached_len[0].shape[1] + + len_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_len[i]: (num_layers, batch_size) + len_avg = cached_len[i].chunk(chunks=batch_size, dim=1) + for n in range(batch_size): + len_list[n].append(len_avg[n]) + + avg_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_avg[i]: (num_layers, batch_size, D) + avg = cached_avg[i].chunk(chunks=batch_size, dim=1) + for n in range(batch_size): + avg_list[n].append(avg[n]) + + key_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_key[i]: (num_layers, left_context, batch_size, D) + key = cached_key[i].chunk(chunks=batch_size, dim=2) + for n in range(batch_size): + key_list[n].append(key[n]) + + val_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_val[i]: (num_layers, left_context, batch_size, D) + val = cached_val[i].chunk(chunks=batch_size, dim=2) + for n in range(batch_size): + val_list[n].append(val[n]) + + val2_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_val2[i]: (num_layers, left_context, batch_size, D) + val2 = cached_val2[i].chunk(chunks=batch_size, dim=2) + for n in range(batch_size): + val2_list[n].append(val2[n]) + + conv1_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_conv1[i]: (num_layers, batch_size, D, kernel-1) + conv1 = cached_conv1[i].chunk(chunks=batch_size, dim=1) + for n in range(batch_size): + conv1_list[n].append(conv1[n]) + + conv2_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_conv2[i]: (num_layers, batch_size, D, kernel-1) + conv2 = cached_conv2[i].chunk(chunks=batch_size, dim=1) + for n in range(batch_size): + conv2_list[n].append(conv2[n]) + + state_list = [ + ( + len_list[i] + + avg_list[i] + + key_list[i] + + val_list[i] + + val2_list[i] + + conv1_list[i] + + conv2_list[i] + ) + for i in range(batch_size) + ] + return state_list + + +class Zipformer(EncoderInterface): + """ + Args: + num_features (int): Number of input features + d_model: (int,int): embedding dimension of 2 encoder stacks + attention_dim: (int,int): attention dimension of 2 encoder stacks + nhead (int, int): number of heads + dim_feedforward (int, int): feedforward dimension in 2 encoder stacks + num_encoder_layers (int): number of encoder layers + dropout (float): dropout rate + cnn_module_kernel (int): Kernel size of convolution module + vgg_frontend (bool): whether to use vgg frontend. + warmup_batches (float): number of batches to warm up over + """ + + def __init__( + self, + num_features: int, + output_downsampling_factor: int = 2, + encoder_dims: Tuple[int] = (384, 384), + attention_dim: Tuple[int] = (256, 256), + encoder_unmasked_dims: Tuple[int] = (256, 256), + zipformer_downsampling_factors: Tuple[int] = (2, 4), + nhead: Tuple[int] = (8, 8), + feedforward_dim: Tuple[int] = (1536, 2048), + num_encoder_layers: Tuple[int] = (12, 12), + dropout: float = 0.1, + cnn_module_kernels: Tuple[int] = (31, 31), + pos_dim: int = 4, + num_left_chunks: int = 4, + short_chunk_threshold: float = 0.75, + short_chunk_size: int = 50, + decode_chunk_size: int = 16, + warmup_batches: float = 4000.0, + ) -> None: + super(Zipformer, self).__init__() + + self.num_features = num_features + assert 0 < encoder_dims[0] <= encoder_dims[1] + self.encoder_dims = encoder_dims + self.encoder_unmasked_dims = encoder_unmasked_dims + self.zipformer_downsampling_factors = zipformer_downsampling_factors + self.output_downsampling_factor = output_downsampling_factor + + self.num_left_chunks = num_left_chunks + self.short_chunk_threshold = short_chunk_threshold + self.short_chunk_size = short_chunk_size + + # Used in decoding + self.decode_chunk_size = decode_chunk_size + + # will be written to, see set_batch_count() + self.batch_count = 0 + self.warmup_end = warmup_batches + + for u, d in zip(encoder_unmasked_dims, encoder_dims): + assert u <= d, (u, d) + + # self.encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7)//2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7)//2 + # (2) embedding: num_features -> encoder_dims + self.encoder_embed = Conv2dSubsampling( + num_features, encoder_dims[0], dropout=dropout + ) + + # each one will be ZipformerEncoder or DownsampledZipformerEncoder + encoders = [] + + self.num_encoders = len(encoder_dims) + for i in range(self.num_encoders): + encoder_layer = ZipformerEncoderLayer( + encoder_dims[i], + attention_dim[i], + nhead[i], + feedforward_dim[i], + dropout, + cnn_module_kernels[i], + pos_dim, + ) + + # For the segment of the warmup period, we let the Conv2dSubsampling + # layer learn something. Then we start to warm up the other encoders. + encoder = ZipformerEncoder( + encoder_layer, + num_encoder_layers[i], + dropout, + warmup_begin=warmup_batches * (i + 1) / (self.num_encoders + 1), + warmup_end=warmup_batches * (i + 2) / (self.num_encoders + 1), + ) + + if zipformer_downsampling_factors[i] != 1: + encoder = DownsampledZipformerEncoder( + encoder, + input_dim=encoder_dims[i - 1] if i > 0 else encoder_dims[0], + output_dim=encoder_dims[i], + downsample=zipformer_downsampling_factors[i], + ) + encoders.append(encoder) + self.encoders = nn.ModuleList(encoders) + + # initializes self.skip_layers and self.skip_modules + self._init_skip_modules() + + self.downsample_output = AttentionDownsample( + encoder_dims[-1], encoder_dims[-1], downsample=output_downsampling_factor + ) + + def _get_layer_skip_dropout_prob(self): + if not self.training: + return 0.0 + batch_count = self.batch_count + min_dropout_prob = 0.025 + + if batch_count > self.warmup_end: + return min_dropout_prob + else: + return 0.5 - (batch_count / self.warmup_end) * (0.5 - min_dropout_prob) + + def _init_skip_modules(self): + """ + If self.zipformer_downampling_factors = (1, 2, 4, 8, 4, 2), then at the input of layer + indexed 4 (in zero indexing), with has subsapling_factor=4, we combine the output of + layers 2 and 3; and at the input of layer indexed 5, which which has subsampling_factor=2, + we combine the outputs of layers 1 and 5. + """ + skip_layers = [] + skip_modules = [] + z = self.zipformer_downsampling_factors + for i in range(len(z)): + if i <= 1 or z[i - 1] <= z[i]: + skip_layers.append(None) + skip_modules.append(SimpleCombinerIdentity()) + else: + # TEMP + for j in range(i - 2, -1, -1): + if z[j] <= z[i] or j == 0: + # TEMP logging statement. + logging.info( + f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will " + f"combine the outputs of layers {j} and {i-1}, with downsampling_factors={z[j]} and {z[i-1]}." + ) + skip_layers.append(j) + skip_modules.append( + SimpleCombiner( + self.encoder_dims[j], + self.encoder_dims[i - 1], + min_weight=(0.0, 0.25), + ) + ) + break + self.skip_layers = skip_layers + self.skip_modules = nn.ModuleList(skip_modules) + + def get_feature_masks(self, x: torch.Tensor) -> List[float]: + # Note: The actual return type is Union[List[float], List[Tensor]], + # but to make torch.jit.script() work, we use List[float] + """ + In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of + randomized feature masks, one per encoder. + On e.g. 15% of frames, these masks will zero out all enocder dims larger than + some supplied number, e.g. >256, so in effect on those frames we are using + a smaller encoer dim. + + We generate the random masks at this level because we want the 2 masks to 'agree' + all the way up the encoder stack. This will mean that the 1st mask will have + mask values repeated self.zipformer_subsampling_factor times. + + Args: + x: the embeddings (needed for the shape and dtype and device), of shape + (num_frames, batch_size, encoder_dims0) + """ + num_encoders = len(self.encoder_dims) + if torch.jit.is_scripting() or not self.training: + return [1.0] * num_encoders + + (num_frames0, batch_size, _encoder_dims0) = x.shape + + assert self.encoder_dims[0] == _encoder_dims0, ( + self.encoder_dims, + _encoder_dims0, + ) + + max_downsampling_factor = max(self.zipformer_downsampling_factors) + + num_frames_max = num_frames0 + max_downsampling_factor - 1 + + feature_mask_dropout_prob = 0.15 + + # frame_mask_max shape: (num_frames_max, batch_size, 1) + frame_mask_max = ( + torch.rand(num_frames_max, batch_size, 1, device=x.device) + > feature_mask_dropout_prob + ).to(x.dtype) + + feature_masks = [] + for i in range(num_encoders): + ds = self.zipformer_downsampling_factors[i] + upsample_factor = max_downsampling_factor // ds + + frame_mask = ( + frame_mask_max.unsqueeze(1) + .expand(num_frames_max, upsample_factor, batch_size, 1) + .reshape(num_frames_max * upsample_factor, batch_size, 1) + ) + num_frames = (num_frames0 + ds - 1) // ds + frame_mask = frame_mask[:num_frames] + feature_mask = torch.ones( + num_frames, + batch_size, + self.encoder_dims[i], + dtype=x.dtype, + device=x.device, + ) + u = self.encoder_unmasked_dims[i] + feature_mask[:, :, u:] *= frame_mask + feature_masks.append(feature_mask) + + return feature_masks + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + chunk_size: + The chunk size used in evaluation mode. + Returns: + Return a tuple containing 2 tensors: + - embeddings: its shape is (batch_size, output_seq_len, encoder_dims[-1]) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + """ + x = self.encoder_embed(x) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + lengths = (x_lens - 7) >> 1 + assert x.size(0) == lengths.max().item(), (x.shape, lengths, lengths.max()) + mask = make_pad_mask(lengths) + + outputs = [] + feature_masks = self.get_feature_masks(x) + + if self.training: + # Training mode + max_ds = max(self.zipformer_downsampling_factors) + # Generate dynamic chunk-wise attention mask during training + max_len = x.size(0) // max_ds + short_chunk_size = self.short_chunk_size // max_ds + chunk_size = torch.randint(1, max_len, (1,)).item() + if chunk_size > (max_len * self.short_chunk_threshold): + # Full attention + chunk_size = x.size(0) + else: + # Chunk-wise attention + chunk_size = chunk_size % short_chunk_size + 1 + chunk_size *= max_ds + else: + chunk_size = self.decode_chunk_size + # Evaluation mode + for ds in self.zipformer_downsampling_factors: + assert chunk_size % ds == 0, (chunk_size, ds) + + attn_mask = ~subsequent_chunk_mask( + size=x.size(0), + chunk_size=chunk_size, + num_left_chunks=self.num_left_chunks, + device=x.device, + ) + + for i, (module, skip_module) in enumerate( + zip(self.encoders, self.skip_modules) + ): + ds = self.zipformer_downsampling_factors[i] + k = self.skip_layers[i] + if isinstance(k, int): + layer_skip_dropout_prob = self._get_layer_skip_dropout_prob() + if torch.jit.is_scripting(): + x = skip_module(outputs[k], x) + elif (not self.training) or random.random() > layer_skip_dropout_prob: + x = skip_module(outputs[k], x) + x = module( + x, + feature_mask=feature_masks[i], + src_key_padding_mask=None if mask is None else mask[..., ::ds], + attn_mask=attn_mask[::ds, ::ds], + ) + outputs.append(x) + + x = self.downsample_output(x) + # class Downsample has this rounding behavior.. + assert self.output_downsampling_factor == 2, self.output_downsampling_factor + lengths = (lengths + 1) >> 1 + + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return x, lengths + + def streaming_forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + states: List[Tensor], + ) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + seq_len is the input chunk length. + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + states: + A list of 7 * num_encoders elements: + ``states[0:num_encoders]`` is the cached numbers of past frames. + ``states[num_encoders:2*num_encoders]`` is the cached average tensors. + ``states[2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. + ``states[3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. + ``states[4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. + ``states[5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. + ``states[6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. + + Returns: + Return a tuple containing 3 tensors: + - embeddings: its shape is (batch_size, output_seq_len, encoder_dims[-1]) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + - updated states. + """ + assert len(states) == 7 * self.num_encoders, (len(states), self.num_encoders) + + cached_len = states[: self.num_encoders] + cached_avg = states[self.num_encoders : 2 * self.num_encoders] + cached_key = states[2 * self.num_encoders : 3 * self.num_encoders] + cached_val = states[3 * self.num_encoders : 4 * self.num_encoders] + cached_val2 = states[4 * self.num_encoders : 5 * self.num_encoders] + cached_conv1 = states[5 * self.num_encoders : 6 * self.num_encoders] + cached_conv2 = states[6 * self.num_encoders : 7 * self.num_encoders] + + x = self.encoder_embed(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + lengths = (x_lens - 7) >> 1 + assert x.size(0) == lengths.max().item(), (x.shape, lengths, lengths.max()) + + outputs = [] + new_cached_len = [] + new_cached_avg = [] + new_cached_key = [] + new_cached_val = [] + new_cached_val2 = [] + new_cached_conv1 = [] + new_cached_conv2 = [] + + for i, (module, skip_module) in enumerate( + zip(self.encoders, self.skip_modules) + ): + k = self.skip_layers[i] + if isinstance(k, int): + x = skip_module(outputs[k], x) + x, len_avg, avg, key, val, val2, conv1, conv2 = module.streaming_forward( + x, + cached_len=cached_len[i], + cached_avg=cached_avg[i], + cached_key=cached_key[i], + cached_val=cached_val[i], + cached_val2=cached_val2[i], + cached_conv1=cached_conv1[i], + cached_conv2=cached_conv2[i], + ) + outputs.append(x) + # Update caches + new_cached_len.append(len_avg) + new_cached_avg.append(avg) + new_cached_key.append(key) + new_cached_val.append(val) + new_cached_val2.append(val2) + new_cached_conv1.append(conv1) + new_cached_conv2.append(conv2) + + x = self.downsample_output(x) + # class Downsample has this rounding behavior.. + assert self.output_downsampling_factor == 2, self.output_downsampling_factor + lengths = (lengths + 1) >> 1 + + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = ( + new_cached_len + + new_cached_avg + + new_cached_key + + new_cached_val + + new_cached_val2 + + new_cached_conv1 + + new_cached_conv2 + ) + return x, lengths, new_states + + @torch.jit.export + def get_init_state( + self, + device: torch.device = torch.device("cpu"), + ) -> List[Tensor]: + """Get initial states. + A list of 7 * num_encoders elements: + ``states[0:num_encoders]`` is the cached numbers of past frames. + ``states[num_encoders:2*num_encoders]`` is the cached average tensors. + ``states[2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. + ``states[3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. + ``states[4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. + ``states[5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. + ``states[6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. + """ + cached_len = [] + cached_avg = [] + cached_key = [] + cached_val = [] + cached_val2 = [] + cached_conv1 = [] + cached_conv2 = [] + + left_context_len = self.decode_chunk_size * self.num_left_chunks + + for i, encoder in enumerate(self.encoders): + num_layers = encoder.num_layers + ds = self.zipformer_downsampling_factors[i] + + len_avg = torch.zeros(num_layers, 1, dtype=torch.int32, device=device) + cached_len.append(len_avg) + + avg = torch.zeros(num_layers, 1, encoder.d_model, device=device) + cached_avg.append(avg) + + key = torch.zeros( + num_layers, + left_context_len // ds, + 1, + encoder.attention_dim, + device=device, + ) + cached_key.append(key) + + val = torch.zeros( + num_layers, + left_context_len // ds, + 1, + encoder.attention_dim // 2, + device=device, + ) + cached_val.append(val) + + val2 = torch.zeros( + num_layers, + left_context_len // ds, + 1, + encoder.attention_dim // 2, + device=device, + ) + cached_val2.append(val2) + + conv1 = torch.zeros( + num_layers, + 1, + encoder.d_model, + encoder.cnn_module_kernel - 1, + device=device, + ) + cached_conv1.append(conv1) + + conv2 = torch.zeros( + num_layers, + 1, + encoder.d_model, + encoder.cnn_module_kernel - 1, + device=device, + ) + cached_conv2.append(conv2) + + states = ( + cached_len + + cached_avg + + cached_key + + cached_val + + cached_val2 + + cached_conv1 + + cached_conv2 + ) + return states + + +class ZipformerEncoderLayer(nn.Module): + """ + ZipformerEncoderLayer is made up of self-attn, feedforward and convolution networks. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + feedforward_dim: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module. + + Examples:: + >>> encoder_layer = ZipformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + + def __init__( + self, + d_model: int, + attention_dim: int, + nhead: int, + feedforward_dim: int = 2048, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + pos_dim: int = 4, + ) -> None: + super(ZipformerEncoderLayer, self).__init__() + + self.d_model = d_model + self.attention_dim = attention_dim + self.cnn_module_kernel = cnn_module_kernel + + # will be written to, see set_batch_count() + self.batch_count = 0 + + self.self_attn = RelPositionMultiheadAttention( + d_model, + attention_dim, + nhead, + pos_dim, + dropout=0.0, + ) + + self.pooling = PoolingModule(d_model) + + self.feed_forward1 = FeedforwardModule(d_model, feedforward_dim, dropout) + + self.feed_forward2 = FeedforwardModule(d_model, feedforward_dim, dropout) + + self.feed_forward3 = FeedforwardModule(d_model, feedforward_dim, dropout) + + self.conv_module1 = ConvolutionModule(d_model, cnn_module_kernel) + + self.conv_module2 = ConvolutionModule(d_model, cnn_module_kernel) + + self.norm_final = BasicNorm(d_model) + + self.bypass_scale = nn.Parameter(torch.tensor(0.5)) + + # try to ensure the output is close to zero-mean (or at least, zero-median). + self.balancer = ActivationBalancer( + d_model, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, + max_abs=6.0, + ) + self.whiten = Whiten( + num_groups=1, whitening_limit=5.0, prob=(0.025, 0.25), grad_scale=0.01 + ) + + def get_bypass_scale(self): + if torch.jit.is_scripting() or not self.training: + return self.bypass_scale + if random.random() < 0.1: + # ensure we get grads if self.bypass_scale becomes out of range + return self.bypass_scale + # hardcode warmup period for bypass scale + warmup_period = 20000.0 + initial_clamp_min = 0.75 + final_clamp_min = 0.25 + if self.batch_count > warmup_period: + clamp_min = final_clamp_min + else: + clamp_min = initial_clamp_min - (self.batch_count / warmup_period) * ( + initial_clamp_min - final_clamp_min + ) + return self.bypass_scale.clamp(min=clamp_min, max=1.0) + + def get_dynamic_dropout_rate(self): + # return dropout rate for the dynamic modules (self_attn, pooling, convolution); this + # starts at 0.2 and rapidly decreases to 0. Its purpose is to keep the training stable + # at the beginning, by making the network focus on the feedforward modules. + if torch.jit.is_scripting() or not self.training: + return 0.0 + warmup_period = 2000.0 + initial_dropout_rate = 0.2 + final_dropout_rate = 0.0 + if self.batch_count > warmup_period: + return final_dropout_rate + else: + return initial_dropout_rate - ( + initial_dropout_rate * final_dropout_rate + ) * (self.batch_count / warmup_period) + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + batch_split: if not None, this layer will only be applied to + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, N is the batch size, E is the feature number + """ + src_orig = src + + # macaron style feed forward module + src = src + self.feed_forward1(src) + + # dropout rate for submodules that interact with time. + dynamic_dropout = self.get_dynamic_dropout_rate() + + # pooling module + if torch.jit.is_scripting(): + src = src + self.pooling(src, src_key_padding_mask=src_key_padding_mask) + elif random.random() >= dynamic_dropout: + src = src + self.pooling(src, src_key_padding_mask=src_key_padding_mask) + + if torch.jit.is_scripting(): + src_att, attn_weights = self.self_attn( + src, + pos_emb=pos_emb, + attn_mask=attn_mask, + key_padding_mask=src_key_padding_mask, + ) + src = src + src_att + + src = src + self.conv_module1( + src, src_key_padding_mask=src_key_padding_mask + ) + + src = src + self.feed_forward2(src) + + src = src + self.self_attn.forward2(src, attn_weights) + + src = src + self.conv_module2( + src, src_key_padding_mask=src_key_padding_mask + ) + else: + use_self_attn = random.random() >= dynamic_dropout + if use_self_attn: + src_att, attn_weights = self.self_attn( + src, + pos_emb=pos_emb, + attn_mask=attn_mask, + key_padding_mask=src_key_padding_mask, + ) + src = src + src_att + + if random.random() >= dynamic_dropout: + src = src + self.conv_module1( + src, src_key_padding_mask=src_key_padding_mask + ) + + src = src + self.feed_forward2(src) + + if use_self_attn: + src = src + self.self_attn.forward2(src, attn_weights) + + if random.random() >= dynamic_dropout: + src = src + self.conv_module2( + src, src_key_padding_mask=src_key_padding_mask + ) + + src = src + self.feed_forward3(src) + + src = self.norm_final(self.balancer(src)) + + delta = src - src_orig + + src = src_orig + delta * self.get_bypass_scale() + + return self.whiten(src) + + def streaming_forward( + self, + src: Tensor, + pos_emb: Tensor, + cached_len: Tensor, + cached_avg: Tensor, + cached_key: Tensor, + cached_val: Tensor, + cached_val2: Tensor, + cached_conv1: Tensor, + cached_conv2: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + cached_len: processed number of past frames. + cached_avg: cached average of past frames. + cached_key: cached key tensor of left context for the first attention module. + cached_val: cached value tensor of left context for the first attention module. + cached_val2: cached value tensor of left context for the second attention module. + cached_conv1: cached left context for the first convolution module. + cached_conv2: cached left context for the second convolution module. + + Shape: + src: (S, N, E). + pos_emb: (N, left_context_len+2*S-1, E) + cached_len: (N,) + N is the batch size. + cached_avg: (N, C). + N is the batch size, C is the feature dimension. + cached_key: (left_context_len, N, K). + N is the batch size, K is the key dimension. + cached_val: (left_context_len, N, V). + N is the batch size, V is the key dimension. + cached_val2: (left_context_len, N, V). + N is the batch size, V is the key dimension. + cached_conv1: (N, C, kernel_size-1). + N is the batch size, C is the convolution channels. + cached_conv2: (N, C, kernel_size-1). + N is the batch size, C is the convolution channels. + """ + src_orig = src + + # macaron style feed forward module + src = src + self.feed_forward1(src) + + src_pool, cached_len, cached_avg = self.pooling.streaming_forward( + src, + cached_len=cached_len, + cached_avg=cached_avg, + ) + src = src + src_pool + + ( + src_attn, + attn_weights, + cached_key, + cached_val, + ) = self.self_attn.streaming_forward( + src, + pos_emb=pos_emb, + cached_key=cached_key, + cached_val=cached_val, + ) + src = src + src_attn + + src_conv, cached_conv1 = self.conv_module1.streaming_forward( + src, + cache=cached_conv1, + ) + src = src + src_conv + + src = src + self.feed_forward2(src) + + src_attn, cached_val2 = self.self_attn.streaming_forward2( + src, + attn_weights, + cached_val=cached_val2, + ) + src = src + src_attn + + src_conv, cached_conv2 = self.conv_module2.streaming_forward( + src, + cache=cached_conv2, + ) + src = src + src_conv + + src = src + self.feed_forward3(src) + + src = self.norm_final(self.balancer(src)) + + delta = src - src_orig + + src = src_orig + delta * self.bypass_scale + + return ( + src, + cached_len, + cached_avg, + cached_key, + cached_val, + cached_val2, + cached_conv1, + cached_conv2, + ) + + +class ZipformerEncoder(nn.Module): + r"""ZipformerEncoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the ZipformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + + Examples:: + >>> encoder_layer = ZipformerEncoderLayer(d_model=512, nhead=8) + >>> zipformer_encoder = ZipformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = zipformer_encoder(src) + """ + + def __init__( + self, + encoder_layer: nn.Module, + num_layers: int, + dropout: float, + warmup_begin: float, + warmup_end: float, + ) -> None: + super().__init__() + # will be written to, see set_batch_count() Note: in inference time this + # may be zero but should be treated as large, we can check if + # self.training is true. + self.batch_count = 0 + self.warmup_begin = warmup_begin + self.warmup_end = warmup_end + # module_seed is for when we need a random number that is unique to the module but + # shared across jobs. It's used to randomly select how many layers to drop, + # so that we can keep this consistent across worker tasks (for efficiency). + self.module_seed = torch.randint(0, 1000, ()).item() + + self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model, dropout) + + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + self.d_model = encoder_layer.d_model + self.attention_dim = encoder_layer.attention_dim + self.cnn_module_kernel = encoder_layer.cnn_module_kernel + + assert 0 <= warmup_begin <= warmup_end, (warmup_begin, warmup_end) + + delta = (1.0 / num_layers) * (warmup_end - warmup_begin) + cur_begin = warmup_begin + for i in range(num_layers): + self.layers[i].warmup_begin = cur_begin + cur_begin += delta + self.layers[i].warmup_end = cur_begin + + def get_layers_to_drop(self, rnd_seed: int): + ans = set() + if not self.training: + return ans + + batch_count = self.batch_count + num_layers = len(self.layers) + + def get_layerdrop_prob(layer: int) -> float: + layer_warmup_begin = self.layers[layer].warmup_begin + layer_warmup_end = self.layers[layer].warmup_end + + initial_layerdrop_prob = 0.5 + final_layerdrop_prob = 0.05 + + if batch_count == 0: + # As a special case, if batch_count == 0, return 0 (drop no + # layers). This is rather ugly, I'm afraid; it is intended to + # enable our scan_pessimistic_batches_for_oom() code to work correctly + # so if we are going to get OOM it will happen early. + # also search for 'batch_count' with quotes in this file to see + # how we initialize the warmup count to a random number between + # 0 and 10. + return 0.0 + elif batch_count < layer_warmup_begin: + return initial_layerdrop_prob + elif batch_count > layer_warmup_end: + return final_layerdrop_prob + else: + # linearly interpolate + t = (batch_count - layer_warmup_begin) / layer_warmup_end + assert 0.0 <= t < 1.001, t + return initial_layerdrop_prob + t * ( + final_layerdrop_prob - initial_layerdrop_prob + ) + + shared_rng = random.Random(batch_count + self.module_seed) + independent_rng = random.Random(rnd_seed) + + layerdrop_probs = [get_layerdrop_prob(i) for i in range(num_layers)] + tot = sum(layerdrop_probs) + # Instead of drawing the samples independently, we first randomly decide + # how many layers to drop out, using the same random number generator between + # jobs so that all jobs drop out the same number (this is for speed). + # Then we use an approximate approach to drop out the individual layers + # with their specified probs while reaching this exact target. + num_to_drop = int(tot) + int(shared_rng.random() < (tot - int(tot))) + + layers = list(range(num_layers)) + independent_rng.shuffle(layers) + + # go through the shuffled layers until we get the required number of samples. + if num_to_drop > 0: + for layer in itertools.cycle(layers): + if independent_rng.random() < layerdrop_probs[layer]: + ans.add(layer) + if len(ans) == num_to_drop: + break + if shared_rng.random() < 0.005 or __name__ == "__main__": + logging.info( + f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, " + f"batch_count={batch_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}" + ) + return ans + + def forward( + self, + src: Tensor, + # Note: The type of feature_mask should be Union[float, Tensor], + # but to make torch.jit.script() work, we use `float` here + feature_mask: float = 1.0, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer. + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + Returns: (x, x_no_combine), both of shape (S, N, E) + """ + pos_emb = self.encoder_pos(src) + output = src + + if torch.jit.is_scripting(): + layers_to_drop = [] + else: + rnd_seed = src.numel() + random.randint(0, 1000) + layers_to_drop = self.get_layers_to_drop(rnd_seed) + + output = output * feature_mask + + for i, mod in enumerate(self.layers): + if not torch.jit.is_scripting(): + if i in layers_to_drop: + continue + output = mod( + output, + pos_emb, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + ) + + output = output * feature_mask + + return output + + @torch.jit.export + def streaming_forward( + self, + src: Tensor, + cached_len: Tensor, + cached_avg: Tensor, + cached_key: Tensor, + cached_val: Tensor, + cached_val2: Tensor, + cached_conv1: Tensor, + cached_conv2: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + cached_len: number of past frames. + cached_avg: cached average of past frames. + cached_key: cached key tensor for first attention module. + cached_val: cached value tensor for first attention module. + cached_val2: cached value tensor for second attention module. + cached_conv1: cached left contexts for the first convolution module. + cached_conv2: cached left contexts for the second convolution module. + + Shape: + src: (S, N, E). + cached_len: (N,) + N is the batch size. + cached_avg: (num_layers, N, C). + N is the batch size, C is the feature dimension. + cached_key: (num_layers, left_context_len, N, K). + N is the batch size, K is the key dimension. + cached_val: (num_layers, left_context_len, N, V). + N is the batch size, V is the key dimension. + cached_val2: (num_layers, left_context_len, N, V). + N is the batch size, V is the key dimension. + cached_conv1: (num_layers, N, C, kernel_size-1). + N is the batch size, C is the convolution channels. + cached_conv2: (num_layers, N, C, kernel_size-1). + N is the batch size, C is the convolution channels. + + Returns: A tuple of 8 tensors: + - output tensor + - updated cached number of past frmaes. + - updated cached average of past frmaes. + - updated cached key tensor of of the first attention module. + - updated cached value tensor of of the first attention module. + - updated cached value tensor of of the second attention module. + - updated cached left contexts of the first convolution module. + - updated cached left contexts of the second convolution module. + """ + assert cached_len.size(0) == self.num_layers, ( + cached_len.size(0), + self.num_layers, + ) + assert cached_avg.size(0) == self.num_layers, ( + cached_avg.size(0), + self.num_layers, + ) + assert cached_key.size(0) == self.num_layers, ( + cached_key.size(0), + self.num_layers, + ) + assert cached_val.size(0) == self.num_layers, ( + cached_val.size(0), + self.num_layers, + ) + assert cached_val2.size(0) == self.num_layers, ( + cached_val2.size(0), + self.num_layers, + ) + assert cached_conv1.size(0) == self.num_layers, ( + cached_conv1.size(0), + self.num_layers, + ) + assert cached_conv2.size(0) == self.num_layers, ( + cached_conv2.size(0), + self.num_layers, + ) + + left_context_len = cached_key.shape[1] + pos_emb = self.encoder_pos(src, left_context_len) + output = src + + new_cached_len = [] + new_cached_avg = [] + new_cached_key = [] + new_cached_val = [] + new_cached_val2 = [] + new_cached_conv1 = [] + new_cached_conv2 = [] + for i, mod in enumerate(self.layers): + output, len_avg, avg, key, val, val2, conv1, conv2 = mod.streaming_forward( + output, + pos_emb, + cached_len=cached_len[i], + cached_avg=cached_avg[i], + cached_key=cached_key[i], + cached_val=cached_val[i], + cached_val2=cached_val2[i], + cached_conv1=cached_conv1[i], + cached_conv2=cached_conv2[i], + ) + # Update caches + new_cached_len.append(len_avg) + new_cached_avg.append(avg) + new_cached_key.append(key) + new_cached_val.append(val) + new_cached_val2.append(val2) + new_cached_conv1.append(conv1) + new_cached_conv2.append(conv2) + + return ( + output, + torch.stack(new_cached_len, dim=0), + torch.stack(new_cached_avg, dim=0), + torch.stack(new_cached_key, dim=0), + torch.stack(new_cached_val, dim=0), + torch.stack(new_cached_val2, dim=0), + torch.stack(new_cached_conv1, dim=0), + torch.stack(new_cached_conv2, dim=0), + ) + + +class DownsampledZipformerEncoder(nn.Module): + r""" + DownsampledZipformerEncoder is a zipformer encoder evaluated at a reduced frame rate, + after convolutional downsampling, and then upsampled again at the output, and combined + with the origin input, so that the output has the same shape as the input. + """ + + def __init__( + self, encoder: nn.Module, input_dim: int, output_dim: int, downsample: int + ): + super(DownsampledZipformerEncoder, self).__init__() + self.downsample_factor = downsample + self.downsample = AttentionDownsample(input_dim, output_dim, downsample) + self.encoder = encoder + self.num_layers = encoder.num_layers + self.d_model = encoder.d_model + self.attention_dim = encoder.attention_dim + self.cnn_module_kernel = encoder.cnn_module_kernel + self.upsample = SimpleUpsample(output_dim, downsample) + self.out_combiner = SimpleCombiner( + input_dim, output_dim, min_weight=(0.0, 0.25) + ) + + def forward( + self, + src: Tensor, + # Note: the type of feature_mask should be Unino[float, Tensor], + # but to make torch.jit.script() happ, we use float here + feature_mask: float = 1.0, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Downsample, go through encoder, upsample. + + Args: + src: the sequence to the encoder (required). + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer. feature_mask is expected to be already downsampled by + self.downsample_factor. + attn_mask: attention mask (optional). Should be downsampled already. + src_key_padding_mask: the mask for the src keys per batch (optional). Should be downsampled already. + + Shape: + src: (S, N, E). + attn_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + Returns: output of shape (S, N, F) where F is the number of output features + (output_dim to constructor) + """ + src_orig = src + src = self.downsample(src) + + src = self.encoder( + src, + feature_mask=feature_mask, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + ) + src = self.upsample(src) + # remove any extra frames that are not a multiple of downsample_factor + src = src[: src_orig.shape[0]] + + return self.out_combiner(src_orig, src) + + def streaming_forward( + self, + src: Tensor, + cached_len: Tensor, + cached_avg: Tensor, + cached_key: Tensor, + cached_val: Tensor, + cached_val2: Tensor, + cached_conv1: Tensor, + cached_conv2: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + r"""Downsample, go through encoder, upsample. + + Args: + src: the sequence to the encoder (required). + cached_avg: cached average value of past frames. + cached_len: length of past frames. + cached_key: cached key tensor for the first attention module. + cached_val: cached value tensor for the first attention module. + cached_val2: cached value tensor for the second attention module. + cached_conv1: cached left context for the first convolution module. + cached_conv2: cached left context for the second convolution module. + + Shape: + src: (S, N, E). + cached_len: (N,) + N is the batch size. + cached_avg: (num_layers, N, C). + N is the batch size, C is the feature dimension. + cached_key: (num_layers, left_context_len, N, K). + N is the batch size, K is the key dimension. + cached_val: (num_layers, left_context_len, N, V). + N is the batch size, V is the key dimension. + cached_val2: (num_layers, left_context_len, N, V). + N is the batch size, V is the key dimension. + cached_conv1: (num_layers, N, C, kernel_size-1). + N is the batch size, C is the convolution channels. + cached_conv2: (num_layers, N, C, kernel_size-1). + N is the batch size, C is the convolution channels. + Returns: output of shape (S, N, F) where F is the number of output features + (output_dim to constructor) + """ + src_orig = src + src = self.downsample(src) + + ( + src, + cached_len, + cached_avg, + cached_key, + cached_val, + cached_val2, + cached_conv1, + cached_conv2, + ) = self.encoder.streaming_forward( + src, + cached_len=cached_len, + cached_avg=cached_avg, + cached_key=cached_key, + cached_val=cached_val, + cached_val2=cached_val2, + cached_conv1=cached_conv1, + cached_conv2=cached_conv2, + ) + src = self.upsample(src) + # remove any extra frames that are not a multiple of downsample_factor + src = src[: src_orig.shape[0]] + + return ( + self.out_combiner(src_orig, src), + cached_len, + cached_avg, + cached_key, + cached_val, + cached_val2, + cached_conv1, + cached_conv2, + ) + + +class AttentionDownsample(torch.nn.Module): + """ + Does downsampling with attention, by weighted sum, and a projection.. + """ + + def __init__(self, in_channels: int, out_channels: int, downsample: int): + """ + Require out_channels > in_channels. + """ + super(AttentionDownsample, self).__init__() + self.query = nn.Parameter(torch.randn(in_channels) * (in_channels**-0.5)) + + # fill in the extra dimensions with a projection of the input + if out_channels > in_channels: + self.extra_proj = nn.Linear( + in_channels * downsample, out_channels - in_channels, bias=False + ) + else: + self.extra_proj = None + self.downsample = downsample + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, 1, in_channels) + Returns a tensor of shape + ( (seq_len+downsample-1)//downsample, batch_size, out_channels) + """ + (seq_len, batch_size, in_channels) = src.shape + ds = self.downsample + d_seq_len = (seq_len + ds - 1) // ds + + # Pad to an exact multiple of self.downsample + if seq_len != d_seq_len * ds: + # right-pad src, repeating the last element. + pad = d_seq_len * ds - seq_len + src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) + src = torch.cat((src, src_extra), dim=0) + assert src.shape[0] == d_seq_len * ds, (src.shape[0], d_seq_len, ds) + + src = src.reshape(d_seq_len, ds, batch_size, in_channels) + scores = (src * self.query).sum(dim=-1, keepdim=True) + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) + + weights = scores.softmax(dim=1) + + # ans1 is the first `in_channels` channels of the output + ans = (src * weights).sum(dim=1) + src = src.permute(0, 2, 1, 3).reshape(d_seq_len, batch_size, ds * in_channels) + + if self.extra_proj is not None: + ans2 = self.extra_proj(src) + ans = torch.cat((ans, ans2), dim=2) + return ans + + +class SimpleUpsample(torch.nn.Module): + """ + A very simple form of upsampling that mostly just repeats the input, but + also adds a position-specific bias. + """ + + def __init__(self, num_channels: int, upsample: int): + super(SimpleUpsample, self).__init__() + self.bias = nn.Parameter(torch.randn(upsample, num_channels) * 0.01) + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, num_channels) + Returns a tensor of shape + ( (seq_len*upsample), batch_size, num_channels) + """ + upsample = self.bias.shape[0] + (seq_len, batch_size, num_channels) = src.shape + src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) + src = src + self.bias.unsqueeze(1) + src = src.reshape(seq_len * upsample, batch_size, num_channels) + return src + + +class SimpleCombinerIdentity(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + + def forward(self, src1: Tensor, src2: Tensor) -> Tensor: + return src1 + + +class SimpleCombiner(torch.nn.Module): + """ + A very simple way of combining 2 vectors of 2 different dims, via a + learned weighted combination in the shared part of the dim. + Args: + dim1: the dimension of the first input, e.g. 256 + dim2: the dimension of the second input, e.g. 384. + The output will have the same dimension as dim2. + """ + + def __init__(self, dim1: int, dim2: int, min_weight: Tuple[float] = (0.0, 0.0)): + super(SimpleCombiner, self).__init__() + assert dim2 >= dim1, (dim2, dim1) + self.weight1 = nn.Parameter(torch.zeros(())) + self.min_weight = min_weight + + def forward(self, src1: Tensor, src2: Tensor) -> Tensor: + """ + src1: (*, dim1) + src2: (*, dim2) + + Returns: a tensor of shape (*, dim2) + """ + assert src1.shape[:-1] == src2.shape[:-1], (src1.shape, src2.shape) + + weight1 = self.weight1 + if not torch.jit.is_scripting(): + if ( + self.training + and random.random() < 0.25 + and self.min_weight != (0.0, 0.0) + ): + weight1 = weight1.clamp( + min=self.min_weight[0], max=1.0 - self.min_weight[1] + ) + + src1 = src1 * weight1 + src2 = src2 * (1.0 - weight1) + + src1_dim = src1.shape[-1] + src2_dim = src2.shape[-1] + if src1_dim != src2_dim: + if src1_dim < src2_dim: + src1 = torch.nn.functional.pad(src1, (0, src2_dim - src1_dim)) + else: + src1 = src1[:src2_dim] + + return src1 + src2 + + +class RelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module. + + See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py + + Args: + d_model: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length. + + """ + + def __init__( + self, + d_model: int, + dropout_rate: float, + max_len: int = 5000, + ) -> None: + """Construct a PositionalEncoding object.""" + super(RelPositionalEncoding, self).__init__() + self.d_model = d_model + self.dropout = torch.nn.Dropout(dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(max_len)) + + def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: + """Reset the positional encodings.""" + x_size_left = x.size(0) + left_context_len + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x_size_left * 2 - 1: + # Note: TorchScript doesn't implement operator== for torch.Device + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vecotr and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i Tensor: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (time, batch, `*`). + left_context_len: (int): Length of cached left context. + + Returns: + torch.Tensor: Encoded tensor (batch, left_context_len + 2*time-1, `*`). + + """ + self.extend_pe(x, left_context_len) + x_size_left = x.size(0) + left_context_len + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 + - x_size_left + + 1 : self.pe.size(1) // 2 # noqa E203 + + x.size(0), + ] + return self.dropout(pos_emb) + + +class RelPositionMultiheadAttention(nn.Module): + r"""Multi-Head Attention layer with relative position encoding + + This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", + we have to write up the differences. + + + Args: + embed_dim: total dimension of the model. + attention_dim: dimension in the attention module, may be less or more than embed_dim + but must be a multiple of num_heads. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + + Examples:: + + >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) + """ + + def __init__( + self, + embed_dim: int, + attention_dim: int, + num_heads: int, + pos_dim: int, + dropout: float = 0.0, + ) -> None: + super(RelPositionMultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.attention_dim = attention_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = attention_dim // num_heads + self.pos_dim = pos_dim + assert self.head_dim % 2 == 0, self.head_dim + assert self.head_dim * num_heads == attention_dim, ( + self.head_dim, + num_heads, + attention_dim, + ) + + # the initial_scale is supposed to take over the "scaling" factor of + # head_dim ** -0.5, dividing it between the query and key. + in_proj_dim = ( + 2 * attention_dim + + attention_dim // 2 # query, key + + pos_dim * num_heads # value + ) # positional encoding query + + self.in_proj = ScaledLinear( + embed_dim, in_proj_dim, bias=True, initial_scale=self.head_dim**-0.25 + ) + + # self.whiten_values is applied on the values in forward(); + # it just copies the keys but prevents low-rank distribution by modifying grads. + self.whiten_values = Whiten( + num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025, + ) + self.whiten_keys = Whiten( + num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025, + ) + + # linear transformation for positional encoding. + self.linear_pos = ScaledLinear( + embed_dim, num_heads * pos_dim, bias=False, initial_scale=0.05 + ) + + # the following are for diagnosics only, see --print-diagnostics option. + # they only copy their inputs. + self.copy_pos_query = Identity() + self.copy_query = Identity() + + self.out_proj = ScaledLinear( + attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 + ) + + self.in_proj2 = nn.Linear(embed_dim, attention_dim // 2, bias=False) + self.out_proj2 = ScaledLinear( + attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 + ) + # self.whiten_values2 is applied on the values in forward2() + self.whiten_values2 = Whiten( + num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025, + ) + + def forward( + self, + x: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + r""" + Args: + x: input to be projected to query, key, value + pos_emb: Positional embedding tensor + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask and a value is True, + the corresponding value on the attention layer will be ignored. When given + a byte mask and a value is non-zero, the corresponding value on the attention + layer will be ignored + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + - Inputs: + - x: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the position + with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + - Returns: (attn_output, attn_weights) + + - attn_output: :math:`(S, N, E)` where S is the sequence length, N is the batch size, + E is the embedding dimension. + - attn_weights: :math:`(N * N, S, S)` where N is the batch size, H is the num-heads + and S is the sequence length. + """ + x, weights = self.multi_head_attention_forward( + self.in_proj(x), + self.linear_pos(pos_emb), + self.attention_dim, + self.num_heads, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + attn_mask=attn_mask, + ) + return x, weights + + def streaming_forward( + self, + x: Tensor, + pos_emb: Tensor, + cached_key: Tensor, + cached_val: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + r""" + Args: + x: input to be projected to query, key, value + pos_emb: Positional embedding tensor + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + - Inputs: + - x: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + - cached_key: :math:`(left_context_len, N, K)`, where N is the batch size, K is the key dimension. + - cached_val: :math:`(left_context_len, N, V)`, where N is the batch size, V is the value dimension. + + - Returns: (attn_output, attn_weights, cached_key, cached_val) + + - attn_output: :math:`(S, N, E)` where S is the sequence length, N is the batch size, + E is the embedding dimension. + - attn_weights: :math:`(N * N, S, S)` where N is the batch size, H is the num-heads + and S is the sequence length. + - cached_key: :math:`(left_context_len, N, K)`, updated cached attention key tensor of + left context + - cached_val: :math:`(left_context_len, N, K)`, updated cached attention value tensor of + """ + ( + x, + weights, + cached_key, + cached_val, + ) = self.streaming_multi_head_attention_forward( + self.in_proj(x), + self.linear_pos(pos_emb), + self.attention_dim, + self.num_heads, + self.out_proj.weight, + self.out_proj.bias, + cached_key=cached_key, + cached_val=cached_val, + ) + return x, weights, cached_key, cached_val + + def multi_head_attention_forward( + self, + x_proj: Tensor, + pos: Tensor, + attention_dim: int, + num_heads: int, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + r""" + Args: + x_proj: the projected input, to be split into query, key, value. + pos: head-specific biases arising from the positional embeddings. + attention_dim: dimension inside attention mechanism + num_heads: parallel attention heads. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + Inputs: + - x: :math:`(L, N, 7 * A // 2)` where L is the target sequence length, N is the batch size, A is + the attention dimension. Will be split into (query, key, value, pos). + - pos: :math:`(N, 2*L-1, A//2)` or :math:`(1, 2*L-1, A//2)` where L is the sequence + length, N is the batch size, and A is the attention dim. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_weights: :math:`(N * H, S, S)` where N is the batch size, + H is the num-heads, S is the sequence length. + """ + + seq_len, bsz, _ = x_proj.size() + + head_dim = attention_dim // num_heads + pos_dim = self.pos_dim # positional-encoding dim per head + assert ( + head_dim * num_heads == attention_dim + ), f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}, {attention_dim}" + + # self-attention + q = x_proj[..., 0:attention_dim] + k = x_proj[..., attention_dim : 2 * attention_dim] + value_dim = attention_dim // 2 + v = x_proj[..., 2 * attention_dim : 2 * attention_dim + value_dim] + # p is the position-encoding query, its dimension is num_heads*pos_dim.. + p = x_proj[..., 2 * attention_dim + value_dim :] + + k = self.whiten_keys(k) # does nothing in the forward pass. + v = self.whiten_values(v) # does nothing in the forward pass. + q = self.copy_query(q) # for diagnostics only, does nothing. + p = self.copy_pos_query(p) # for diagnostics only, does nothing. + + if attn_mask is not None: + assert ( + attn_mask.dtype == torch.float32 + or attn_mask.dtype == torch.float64 + or attn_mask.dtype == torch.float16 + or attn_mask.dtype == torch.uint8 + or attn_mask.dtype == torch.bool + ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( + attn_mask.dtype + ) + if attn_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for attn_mask is deprecated. Use bool tensor instead." + ) + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, seq_len, seq_len]: + raise RuntimeError("The size of the 2D attn_mask is not correct.") + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [ + bsz * num_heads, + seq_len, + seq_len, + ]: + raise RuntimeError("The size of the 3D attn_mask is not correct.") + else: + raise RuntimeError( + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + ) + # attn_mask's dim is 3 now. + + # convert ByteTensor key_padding_mask to bool + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + ) + key_padding_mask = key_padding_mask.to(torch.bool) + + q = q.reshape(seq_len, bsz, num_heads, head_dim) + p = p.reshape(seq_len, bsz, num_heads, pos_dim) + k = k.reshape(seq_len, bsz, num_heads, head_dim) + v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz, "{} == {}".format( + key_padding_mask.size(0), bsz + ) + assert key_padding_mask.size(1) == seq_len, "{} == {}".format( + key_padding_mask.size(1), seq_len + ) + + q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim) + p = p.permute(1, 2, 0, 3) # (batch, head, time1, pos_dim) + k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) + + seq_len2 = 2 * seq_len - 1 + pos = pos.reshape(1, seq_len2, num_heads, pos_dim).permute(0, 2, 3, 1) + # pos shape now: (batch, head, pos_dim, seq_len2) + + # (batch, head, time1, pos_dim) x (1, head, pos_dim, seq_len2) -> (batch, head, time1, seq_len2) + # [where seq_len2 represents relative position.] + pos_weights = torch.matmul(p, pos) + # the following .as_strided() expression converts the last axis of pos_weights from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + pos_weights = pos_weights.as_strided( + (bsz, num_heads, seq_len, seq_len), + ( + pos_weights.stride(0), + pos_weights.stride(1), + pos_weights.stride(2) - pos_weights.stride(3), + pos_weights.stride(3), + ), + storage_offset=pos_weights.stride(3) * (seq_len - 1), + ) + + # caution: they are really scores at this point. + attn_output_weights = torch.matmul(q, k) + pos_weights + + if not torch.jit.is_scripting(): + if training and random.random() < 0.1: + # This is a harder way of limiting the attention scores to not be too large. + # It incurs a penalty if any of them has an absolute value greater than 50.0. + # this should be outside the normal range of the attention scores. We use + # this mechanism instead of, say, a limit on entropy, because once the entropy + # gets very small gradients through the softmax can become very small, and + # some mechanisms like that become ineffective. + attn_output_weights = penalize_abs_values_gt( + attn_output_weights, limit=25.0, penalty=1.0e-04 + ) + + # attn_output_weights: (batch, head, time1, time2) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, seq_len, seq_len + ) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights = attn_output_weights.masked_fill( + attn_mask, float("-inf") + ) + else: + attn_output_weights = attn_output_weights + attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view( + bsz, num_heads, seq_len, seq_len + ) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, seq_len, seq_len + ) + + # Using this version of softmax, defined in scaling.py, + # should save a little of the memory used in backprop by, if + # we are in automatic mixed precision mode (amp) == autocast, + # only storing the half-precision output for backprop purposes. + attn_output_weights = softmax(attn_output_weights, dim=-1) + + # If we are using chunk-wise attention mask and setting a limited + # num_left_chunks, the attention may only see the padding values which + # will also be masked out by `key_padding_mask`. At this circumstances, + # the whole column of `attn_output_weights` will be `-inf` + # (i.e. be `nan` after softmax). So we fill `0.0` at the masking + # positions to avoid invalid loss value below. + if ( + attn_mask is not None + and attn_mask.dtype == torch.bool + and key_padding_mask is not None + ): + if attn_mask.size(0) != 1: + attn_mask = attn_mask.view(bsz, num_heads, seq_len, seq_len) + combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) + else: + # attn_mask.shape == (1, tgt_len, src_len) + combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( + 1 + ).unsqueeze(2) + + attn_output_weights = attn_output_weights.view( + bsz, num_heads, seq_len, seq_len + ) + attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, seq_len, seq_len + ) + + attn_output_weights = nn.functional.dropout( + attn_output_weights, p=dropout_p, training=training + ) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim // 2] + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(seq_len, bsz, attention_dim // 2) + ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) + + return attn_output, attn_output_weights + + def streaming_multi_head_attention_forward( + self, + x_proj: Tensor, + pos: Tensor, + attention_dim: int, + num_heads: int, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + cached_key: Tensor, + cached_val: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + r""" + Args: + x_proj: the projected input, to be split into query, key, value. + pos: head-specific biases arising from the positional embeddings. + attention_dim: dimension inside attention mechanism + num_heads: parallel attention heads. + out_proj_weight, out_proj_bias: the output projection weight and bias. + cached_key: cached attention key tensor of left context. + cached_val: cached attention value tensor of left context. + + Shape: + Inputs: + - x: :math:`(L, N, 7 * A // 2)` where L is the target sequence length, N is the batch size, A is + the attention dimension. Will be split into (query, key, value, pos). + - pos: :math:`(N, 2*L-1, A//2)` or :math:`(1, 2*L-1, A//2)` where L is the sequence + length, N is the batch size, and A is the attention dim. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_weights: :math:`(N * H, S, S)` where N is the batch size, + H is the num-heads, S is the sequence length. + - cached_key: :math:`(left_context_len, N, K)`, updated cached attention key tensor of left context. + - cached_val: :math:`(left_context_len, N, K)`, updated cached attention value tensor of left context. + """ + + seq_len, bsz, _ = x_proj.size() + + head_dim = attention_dim // num_heads + pos_dim = self.pos_dim # positional-encoding dim per head + assert ( + head_dim * num_heads == attention_dim + ), f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}, {attention_dim}" + + # self-attention + q = x_proj[..., 0:attention_dim] + k = x_proj[..., attention_dim : 2 * attention_dim] + value_dim = attention_dim // 2 + v = x_proj[..., 2 * attention_dim : 2 * attention_dim + value_dim] + # p is the position-encoding query, its dimension is num_heads*pos_dim.. + p = x_proj[..., 2 * attention_dim + value_dim :] + + left_context_len = cached_key.shape[0] + assert left_context_len > 0, left_context_len + assert cached_key.shape[0] == cached_val.shape[0], ( + cached_key.shape, + cached_val.shape, + ) + # Pad cached left contexts + k = torch.cat([cached_key, k], dim=0) + v = torch.cat([cached_val, v], dim=0) + # Update cached left contexts + cached_key = k[-left_context_len:, ...] + cached_val = v[-left_context_len:, ...] + + # The length of key and value + kv_len = k.shape[0] + + q = q.reshape(seq_len, bsz, num_heads, head_dim) + p = p.reshape(seq_len, bsz, num_heads, pos_dim) + k = k.reshape(kv_len, bsz, num_heads, head_dim) + v = v.reshape(kv_len, bsz * num_heads, head_dim // 2).transpose(0, 1) + + q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim) + p = p.permute(1, 2, 0, 3) # (batch, head, time1, pos_dim) + k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) + + seq_len2 = 2 * seq_len - 1 + left_context_len + pos = pos.reshape(1, seq_len2, num_heads, pos_dim).permute(0, 2, 3, 1) + # pos shape now: (batch, head, pos_dim, seq_len2) + + # (batch, head, time1, pos_dim) x (1, head, pos_dim, seq_len2) -> (batch, head, time1, seq_len2) + # [where seq_len2 represents relative position.] + pos_weights = torch.matmul(p, pos) + # the following .as_strided() expression converts the last axis of pos_weights from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + pos_weights = pos_weights.as_strided( + (bsz, num_heads, seq_len, kv_len), + ( + pos_weights.stride(0), + pos_weights.stride(1), + pos_weights.stride(2) - pos_weights.stride(3), + pos_weights.stride(3), + ), + storage_offset=pos_weights.stride(3) * (seq_len - 1), + ) + + # caution: they are really scores at this point. + attn_output_weights = torch.matmul(q, k) + pos_weights + + # attn_output_weights: (batch, head, time1, time2) + attn_output_weights = attn_output_weights.view(bsz * num_heads, seq_len, kv_len) + + # Using this version of softmax, defined in scaling.py, + # should save a little of the memory used in backprop by, if + # we are in automatic mixed precision mode (amp) == autocast, + # only storing the half-precision output for backprop purposes. + attn_output_weights = softmax(attn_output_weights, dim=-1) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim // 2] + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(seq_len, bsz, attention_dim // 2) + ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) + + return attn_output, attn_output_weights, cached_key, cached_val + + def forward2( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """ + Second forward function, where we re-use the attn_weights returned by the first forward function + but with different input. + Args: + x: input, of shape (seq_len, batch_size, embed_dim) + attn_weights: attention weights returned by forward(), of shape (batch_size * num_heads, seq_len, seq_len) + Returns: + output of the same shape as x, i.e. (seq_len, batch_size, embed_dim) + """ + num_heads = self.num_heads + (seq_len, bsz, embed_dim) = x.shape + head_dim = self.attention_dim // num_heads + # v: (tgt_len, bsz, embed_dim // 2) + v = self.in_proj2(x) + v = self.whiten_values2(v) # does nothing in the forward pass. + v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) + + # now v: (bsz * num_heads, seq_len, head_dim // 2) + attn_output = torch.bmm(attn_weights, v) + + if not torch.jit.is_scripting(): + if random.random() < 0.001 or __name__ == "__main__": + self._print_attn_stats(attn_weights, attn_output) + + # attn_output: (bsz * num_heads, seq_len, head_dim) + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(seq_len, bsz, self.attention_dim // 2) + ) + # returned value is of shape (seq_len, bsz, embed_dim), like x. + return self.out_proj2(attn_output) + + def streaming_forward2( + self, + x: Tensor, + attn_weights: Tensor, + cached_val: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Second forward function, where we re-use the attn_weights returned by the first forward function + but with different input. + Args: + x: input, of shape (seq_len, batch_size, embed_dim) + attn_weights: attention weights returned by forward(), of shape (batch_size * num_heads, seq_len, seq_len) + cached_val: cached attention value tensor of left context. + Returns: + - output of the same shape as x, i.e. (seq_len, batch_size, embed_dim) + - updated cached attention value tensor of left context. + """ + num_heads = self.num_heads + (seq_len, bsz, embed_dim) = x.shape + head_dim = self.attention_dim // num_heads + # v: (tgt_len, bsz, embed_dim // 2) + v = self.in_proj2(x) + + left_context_len = cached_val.shape[0] + assert left_context_len > 0, left_context_len + v = torch.cat([cached_val, v], dim=0) + cached_val = v[-left_context_len:] + + seq_len2 = left_context_len + seq_len + v = v.reshape(seq_len2, bsz * num_heads, head_dim // 2).transpose(0, 1) + + # now v: (bsz * num_heads, seq_len, head_dim // 2) + attn_output = torch.bmm(attn_weights, v) + + # attn_output: (bsz * num_heads, seq_len, head_dim) + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(seq_len, bsz, self.attention_dim // 2) + ) + # returned value is of shape (seq_len, bsz, embed_dim), like x. + return self.out_proj2(attn_output), cached_val + + def _print_attn_stats(self, attn_weights: Tensor, attn_output: Tensor): + # attn_weights: (batch_size * num_heads, seq_len, seq_len) + # attn_output: (bsz * num_heads, seq_len, head_dim) + (n, seq_len, head_dim) = attn_output.shape + num_heads = self.num_heads + bsz = n // num_heads + + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + attn_weights = attn_weights.to(torch.float32) + attn_output = attn_output.to(torch.float32) + attn_weights_entropy = ( + -((attn_weights + 1.0e-20).log() * attn_weights) + .sum(dim=-1) + .reshape(bsz, num_heads, seq_len) + .mean(dim=(0, 2)) + ) + attn_output = attn_output.reshape(bsz, num_heads, seq_len, head_dim) + attn_output = attn_output.permute(1, 0, 2, 3).reshape( + num_heads, bsz * seq_len, head_dim + ) + attn_output_mean = attn_output.mean(dim=1, keepdim=True) + attn_output = attn_output - attn_output_mean + attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / ( + bsz * seq_len + ) + # attn_covar: (num_heads, head_dim, head_dim) + # eigs, _ = torch.symeig(attn_covar) + # logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") + + attn_covar = _diag(attn_covar).mean(dim=1) # (num_heads,) + embed_dim = self.in_proj2.weight.shape[1] + in_proj_covar = ( + self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2 + ).mean(dim=(1, 2)) + out_proj_covar = ( + self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2 + ).mean(dim=(0, 2)) + logging.info( + f"attn_weights_entropy = {attn_weights_entropy}, covar={attn_covar}, in_proj_covar={in_proj_covar}, out_proj_covar={out_proj_covar}" + ) + + +class PoolingModule(nn.Module): + """ + Averages the input over the time dimension and project with a square matrix. + """ + + def __init__(self, d_model: int): + super().__init__() + self.proj = ScaledLinear(d_model, d_model, initial_scale=0.1, bias=False) + + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Args: + x: a Tensor of shape (T, N, C) + src_key_padding_mask: a Tensor of bool, of shape (N, T), with True in masked + positions. + + Returns: + - output, a Tensor of shape (T, N, C). + """ + if src_key_padding_mask is not None: + # False in padding positions + padding_mask = src_key_padding_mask.logical_not().to(x.dtype) # (N, T) + # Cumulated numbers of frames from start + cum_mask = padding_mask.cumsum(dim=1) # (N, T) + x = x.cumsum(dim=0) # (T, N, C) + pooling_mask = padding_mask / cum_mask + pooling_mask = pooling_mask.transpose(0, 1).contiguous().unsqueeze(-1) + # now pooling_mask: (T, N, 1) + x = x * pooling_mask # (T, N, C) + else: + num_frames = x.shape[0] + cum_mask = torch.arange(1, num_frames + 1).unsqueeze(1) # (T, 1) + x = x.cumsum(dim=0) # (T, N, C) + pooling_mask = (1.0 / cum_mask).unsqueeze(2) + # now pooling_mask: (T, N, 1) + x = x * pooling_mask + + x = self.proj(x) + return x + + def streaming_forward( + self, + x: Tensor, + cached_len: Tensor, + cached_avg: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor]: + """ + Args: + x: a Tensor of shape (T, N, C) + cached_len: a Tensor of int, of shape (N,), containing the number of + past frames in batch. + cached_avg: a Tensor of shape (N, C), the average over all past frames + in batch. + + Returns: + A tuple of 2 tensors: + - output, a Tensor of shape (T, N, C). + - updated cached_avg, a Tensor of shape (N, C). + """ + x = x.cumsum(dim=0) # (T, N, C) + x = x + (cached_avg * cached_len.unsqueeze(1)).unsqueeze(0) + # Cumulated numbers of frames from start + cum_mask = torch.arange(1, x.size(0) + 1, device=x.device) + cum_mask = cum_mask.unsqueeze(1) + cached_len.unsqueeze(0) # (T, N) + pooling_mask = (1.0 / cum_mask).unsqueeze(2) + # now pooling_mask: (T, N, 1) + x = x * pooling_mask # (T, N, C) + + cached_len = cached_len + x.size(0) + cached_avg = x[-1] + + x = self.proj(x) + return x, cached_len, cached_avg + + +class FeedforwardModule(nn.Module): + """Feedforward module in Zipformer model.""" + + def __init__(self, d_model: int, feedforward_dim: int, dropout: float): + super(FeedforwardModule, self).__init__() + self.in_proj = nn.Linear(d_model, feedforward_dim) + self.balancer = ActivationBalancer( + feedforward_dim, channel_dim=-1, max_abs=10.0, min_prob=0.25 + ) + self.activation = DoubleSwish() + self.dropout = nn.Dropout(dropout) + self.out_proj = ScaledLinear(feedforward_dim, d_model, initial_scale=0.01) + + def forward(self, x: Tensor): + x = self.in_proj(x) + x = self.balancer(x) + x = self.activation(x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Zipformer model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: + """Construct an ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0, kernel_size + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + + # after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu). + # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, + # but sometimes, for some reason, for layer 0 the rms ends up being very large, + # between 50 and 100 for different channels. This will cause very peaky and + # sparse derivatives for the sigmoid gating function, which will tend to make + # the loss function not learn effectively. (for most layers the average absolute values + # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, + # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different + # layers, which likely breaks down as 0.5 for the "linear" half and + # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we + # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, + # it will be in a better position to start learning something, i.e. to latch onto + # the correct range. + self.deriv_balancer1 = ActivationBalancer( + 2 * channels, + channel_dim=1, + max_abs=10.0, + min_positive=0.05, + max_positive=1.0, + ) + + # Will pad cached left context + self.lorder = kernel_size - 1 + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=0, + groups=channels, + bias=bias, + ) + + self.deriv_balancer2 = ActivationBalancer( + channels, + channel_dim=1, + min_positive=0.05, + max_positive=1.0, + max_abs=20.0, + ) + + self.activation = DoubleSwish() + + self.pointwise_conv2 = ScaledConv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + initial_scale=0.05, + ) + + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains bool in masked positions. + + Returns: + - Output tensor (#time, batch, channels). + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + + x = self.deriv_balancer1(x) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + if src_key_padding_mask is not None: + x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + # 1D Depthwise Conv + # Make depthwise_conv causal by + # manualy padding self.lorder zeros to the left + x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) + x = self.depthwise_conv(x) + + x = self.deriv_balancer2(x) + x = self.activation(x) + + x = self.pointwise_conv2(x) # (batch, channel, time) + + return x.permute(2, 0, 1) + + def streaming_forward( + self, + x: Tensor, + cache: Tensor, + ) -> Tuple[Tensor, Tensor]: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch: + (batch, #time), contains bool in masked positions. + cache: Cached left context for depthwise_conv, with shape of + (batch, channels, #kernel_size-1). Only used in real streaming decoding. + + Returns: + A tuple of 2 tensors: + - Output tensor (#time, batch, channels). + - New cached left context, with shape of (batch, channels, #kernel_size-1). + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + + x = self.deriv_balancer1(x) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + # 1D Depthwise Conv + assert cache.shape == (x.size(0), x.size(1), self.lorder), ( + cache.shape, + (x.size(0), x.size(1), self.lorder), + ) + x = torch.cat([cache, x], dim=2) + # Update cache + cache = x[:, :, -self.lorder :] + x = self.depthwise_conv(x) + + x = self.deriv_balancer2(x) + x = self.activation(x) + + x = self.pointwise_conv2(x) # (batch, channel, time) + + return x.permute(2, 0, 1), cache + + +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = (T-3)//2 - 2 == (T-7)//2 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + layer1_channels: int = 8, + layer2_channels: int = 32, + layer3_channels: int = 128, + dropout: float = 0.1, + ) -> None: + """ + Args: + in_channels: + Number of channels in. The input shape is (N, T, in_channels). + Caution: It requires: T >=7, in_channels >=7 + out_channels + Output dim. The output shape is (N, (T-7)//2, out_channels) + layer1_channels: + Number of channels in layer1 + layer2_channels: + Number of channels in layer2 + layer3_channels: + Number of channels in layer3 + """ + assert in_channels >= 7, in_channels + super().__init__() + + self.conv = nn.Sequential( + nn.Conv2d( + in_channels=1, + out_channels=layer1_channels, + kernel_size=3, + padding=(0, 1), # (time, freq) + ), + ActivationBalancer(layer1_channels, channel_dim=1), + DoubleSwish(), + nn.Conv2d( + in_channels=layer1_channels, + out_channels=layer2_channels, + kernel_size=3, + stride=2, + padding=0, + ), + ActivationBalancer(layer2_channels, channel_dim=1), + DoubleSwish(), + nn.Conv2d( + in_channels=layer2_channels, + out_channels=layer3_channels, + kernel_size=3, + stride=(1, 2), # (time, freq) + ), + ActivationBalancer(layer3_channels, channel_dim=1), + DoubleSwish(), + ) + out_height = (((in_channels - 1) // 2) - 1) // 2 + self.out = ScaledLinear(out_height * layer3_channels, out_channels) + self.dropout = nn.Dropout(dropout) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + + Returns: + Return a tensor of shape (N, (T-7)//2, odim) + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + x = self.conv(x) + # Now x is of shape (N, odim, (T-7)//2, ((idim-1)//2 - 1)//2) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).reshape(b, t, c * f)) + # Now x is of shape (N, (T-7)//2, odim) + x = self.dropout(x) + return x + + +def _test_zipformer_main(): + feature_dim = 50 + batch_size = 5 + seq_len = 47 + feature_dim = 50 + # Just make sure the forward pass runs. + + c = Zipformer( + num_features=feature_dim, + encoder_dims=(64, 96), + encoder_unmasked_dims=(48, 64), + nhead=(4, 4), + decode_chunk_size=4, + ) + # Just make sure the forward pass runs. + f = c( + torch.randn(batch_size, seq_len, feature_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + assert ((seq_len - 7) // 2 + 1) // 2 == f[0].shape[1], (seq_len, f.shape[1]) + f[0].sum().backward() + c.eval() + f = c( + torch.randn(batch_size, seq_len, feature_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + f # to remove flake8 warnings + + +def _test_conv2d_subsampling(): + num_features = 80 + encoder_dims = 384 + dropout = 0.1 + encoder_embed = Conv2dSubsampling(num_features, encoder_dims, dropout=dropout) + for i in range(20, 40): + x = torch.rand(2, i, num_features) + y = encoder_embed(x) + assert (x.shape[1] - 7) // 2 == y.shape[1], (x.shape[1], y.shape[1]) + + +def _test_pooling_module(): + N, S, C = 2, 12, 32 + chunk_len = 4 + m = PoolingModule(d_model=C) + + # test chunk-wise forward with padding_mask + x = torch.randn(S, N, C) + y = m(x) + cached_len = torch.zeros(N, dtype=torch.int32) + cached_avg = torch.zeros(N, C) + for i in range(S // chunk_len): + start = i * chunk_len + end = start + chunk_len + x_chunk = x[start:end] + y_chunk, cached_len, cached_avg = m.streaming_forward( + x_chunk, + cached_len=cached_len, + cached_avg=cached_avg, + ) + assert torch.allclose(y_chunk, y[start:end]), (y_chunk, y[start:end]) + + +def _test_state_stack_unstack(): + m = Zipformer( + num_features=80, + encoder_dims=(64, 96), + encoder_unmasked_dims=(48, 64), + nhead=(4, 4), + zipformer_downsampling_factors=(4, 8), + num_left_chunks=2, + decode_chunk_size=8, + ) + s1 = m.get_init_state() + s2 = m.get_init_state() + states = stack_states([s1, s2]) + new_s1, new_s2 = unstack_states(states) + for i in range(m.num_encoders * 7): + for x, y in zip(s1[i], new_s1[i]): + assert torch.equal(x, y) + for x, y in zip(s2[i], new_s2[i]): + assert torch.equal(x, y) + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_zipformer_main() + _test_conv2d_subsampling() + _test_pooling_module() + _test_state_stack_unstack() From a54b748a02550fe523b05c6ccc5845170ba99114 Mon Sep 17 00:00:00 2001 From: behnamasefi Date: Fri, 30 Dec 2022 03:06:09 +0000 Subject: [PATCH 093/120] check for utterance len (#795) Co-authored-by: behnam --- .../pruned_transducer_stateless7_ctc/train.py | 28 ++++++++++++++++++- .../train.py | 28 ++++++++++++++++++- 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py index 162ad8412..5a05e1836 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py @@ -1086,7 +1086,33 @@ def run(rank, world_size, args): # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - return 1.0 <= c.duration <= 20.0 + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True train_cuts = train_cuts.filter(remove_short_and_long_utt) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py index 63e9d6e90..522ecc974 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py @@ -1077,7 +1077,33 @@ def run(rank, world_size, args): # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - return 1.0 <= c.duration <= 20.0 + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True train_cuts = train_cuts.filter(remove_short_and_long_utt) From 67ae5fdf2bf2b09d2ce9e5acb7dab12b2d2fc441 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Fri, 30 Dec 2022 15:21:18 +0800 Subject: [PATCH 094/120] Doc streaming zipformer (#798) * add doc for streaming_zipformer * update README.md --- .../Streaming-ASR/librispeech/index.rst | 2 + .../librispeech/zipformer_transducer.rst | 654 ++++++++++++++++++ .../README.md | 3 + egs/librispeech/ASR/zipformer_mmi/README.md | 2 +- 4 files changed, 660 insertions(+), 1 deletion(-) create mode 100644 docs/source/recipes/Streaming-ASR/librispeech/zipformer_transducer.rst create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/README.md diff --git a/docs/source/recipes/Streaming-ASR/librispeech/index.rst b/docs/source/recipes/Streaming-ASR/librispeech/index.rst index 546ce168b..d52e08058 100644 --- a/docs/source/recipes/Streaming-ASR/librispeech/index.rst +++ b/docs/source/recipes/Streaming-ASR/librispeech/index.rst @@ -7,3 +7,5 @@ LibriSpeech pruned_transducer_stateless lstm_pruned_stateless_transducer + + zipformer_transducer diff --git a/docs/source/recipes/Streaming-ASR/librispeech/zipformer_transducer.rst b/docs/source/recipes/Streaming-ASR/librispeech/zipformer_transducer.rst new file mode 100644 index 000000000..f0e8961d7 --- /dev/null +++ b/docs/source/recipes/Streaming-ASR/librispeech/zipformer_transducer.rst @@ -0,0 +1,654 @@ +Zipformer Transducer +==================== + +This tutorial shows you how to run a **streaming** zipformer transducer model +with the `LibriSpeech `_ dataset. + +.. Note:: + + The tutorial is suitable for `pruned_transducer_stateless7_streaming `_, + +.. HINT:: + + We assume you have read the page :ref:`install icefall` and have setup + the environment for ``icefall``. + +.. HINT:: + + We recommend you to use a GPU or several GPUs to run this recipe. + +.. hint:: + + Please scroll down to the bottom of this page to find download links + for pretrained models if you don't want to train a model from scratch. + + +We use pruned RNN-T to compute the loss. + +.. note:: + + You can find the paper about pruned RNN-T at the following address: + + ``_ + +The transducer model consists of 3 parts: + + - Encoder, a.k.a, the transcription network. We use a Zipformer model (proposed by Daniel Povey) + - Decoder, a.k.a, the prediction network. We use a stateless model consisting of + ``nn.Embedding`` and ``nn.Conv1d`` + - Joiner, a.k.a, the joint network. + +.. caution:: + + Contrary to the conventional RNN-T models, we use a stateless decoder. + That is, it has no recurrent connections. + + +Data preparation +---------------- + +.. hint:: + + The data preparation is the same as other recipes on LibriSpeech dataset, + if you have finished this step, you can skip to ``Training`` directly. + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./prepare.sh + +The script ``./prepare.sh`` handles the data preparation for you, **automagically**. +All you need to do is to run it. + +The data preparation contains several stages, you can use the following two +options: + + - ``--stage`` + - ``--stop-stage`` + +to control which stage(s) should be run. By default, all stages are executed. + + +For example, + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./prepare.sh --stage 0 --stop-stage 0 + +means to run only stage 0. + +To run stage 2 to stage 5, use: + +.. code-block:: bash + + $ ./prepare.sh --stage 2 --stop-stage 5 + +.. HINT:: + + If you have pre-downloaded the `LibriSpeech `_ + dataset and the `musan `_ dataset, say, + they are saved in ``/tmp/LibriSpeech`` and ``/tmp/musan``, you can modify + the ``dl_dir`` variable in ``./prepare.sh`` to point to ``/tmp`` so that + ``./prepare.sh`` won't re-download them. + +.. NOTE:: + + All generated files by ``./prepare.sh``, e.g., features, lexicon, etc, + are saved in ``./data`` directory. + +We provide the following YouTube video showing how to run ``./prepare.sh``. + +.. note:: + + To get the latest news of `next-gen Kaldi `_, please subscribe + the following YouTube channel by `Nadira Povey `_: + + ``_ + +.. youtube:: ofEIoJL-mGM + + +Training +-------- + +Configurable options +~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless7_streaming/train.py --help + + +shows you the training options that can be passed from the commandline. +The following options are used quite often: + + - ``--exp-dir`` + + The directory to save checkpoints, training logs and tensorboard. + + - ``--full-libri`` + + If it's True, the training part uses all the training data, i.e., + 960 hours. Otherwise, the training part uses only the subset + ``train-clean-100``, which has 100 hours of training data. + + .. CAUTION:: + The training set is perturbed by speed with two factors: 0.9 and 1.1. + If ``--full-libri`` is True, each epoch actually processes + ``3x960 == 2880`` hours of data. + + - ``--num-epochs`` + + It is the number of epochs to train. For instance, + ``./pruned_transducer_stateless7_streaming/train.py --num-epochs 30`` trains for 30 epochs + and generates ``epoch-1.pt``, ``epoch-2.pt``, ..., ``epoch-30.pt`` + in the folder ``./pruned_transducer_stateless7_streaming/exp``. + + - ``--start-epoch`` + + It's used to resume training. + ``./pruned_transducer_stateless7_streaming/train.py --start-epoch 10`` loads the + checkpoint ``./pruned_transducer_stateless7_streaming/exp/epoch-9.pt`` and starts + training from epoch 10, based on the state from epoch 9. + + - ``--world-size`` + + It is used for multi-GPU single-machine DDP training. + + - (a) If it is 1, then no DDP training is used. + + - (b) If it is 2, then GPU 0 and GPU 1 are used for DDP training. + + The following shows some use cases with it. + + **Use case 1**: You have 4 GPUs, but you only want to use GPU 0 and + GPU 2 for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ export CUDA_VISIBLE_DEVICES="0,2" + $ ./pruned_transducer_stateless7_streaming/train.py --world-size 2 + + **Use case 2**: You have 4 GPUs and you want to use all of them + for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless7_streaming/train.py --world-size 4 + + **Use case 3**: You have 4 GPUs but you only want to use GPU 3 + for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ export CUDA_VISIBLE_DEVICES="3" + $ ./pruned_transducer_stateless7_streaming/train.py --world-size 1 + + .. caution:: + + Only multi-GPU single-machine DDP training is implemented at present. + Multi-GPU multi-machine DDP training will be added later. + + - ``--max-duration`` + + It specifies the number of seconds over all utterances in a + batch, before **padding**. + If you encounter CUDA OOM, please reduce it. + + .. HINT:: + + Due to padding, the number of seconds of all utterances in a + batch will usually be larger than ``--max-duration``. + + A larger value for ``--max-duration`` may cause OOM during training, + while a smaller value may increase the training time. You have to + tune it. + + - ``--use-fp16`` + + If it is True, the model will train with half precision, from our experiment + results, by using half precision you can train with two times larger ``--max-duration`` + so as to get almost 2X speed up. + + We recommend using ``--use-fp16 True``. + + - ``--short-chunk-size`` + + When training a streaming attention model with chunk masking, the chunk size + would be either max sequence length of current batch or uniformly sampled from + (1, short_chunk_size). The default value is 50, you don't have to change it most of the time. + + - ``--num-left-chunks`` + + It indicates how many left context (in chunks) that can be seen when calculating attention. + The default value is 4, you don't have to change it most of the time. + + + - ``--decode-chunk-len`` + + The chunk size for decoding (in frames before subsampling). It is used for validation. + The default value is 32 (i.e., 320ms). + + +Pre-configured options +~~~~~~~~~~~~~~~~~~~~~~ + +There are some training options, e.g., number of encoder layers, +encoder dimension, decoder dimension, number of warmup steps etc, +that are not passed from the commandline. +They are pre-configured by the function ``get_params()`` in +`pruned_transducer_stateless7_streaming/train.py `_ + +You don't need to change these pre-configured parameters. If you really need to change +them, please modify ``./pruned_transducer_stateless7_streaming/train.py`` directly. + + +Training logs +~~~~~~~~~~~~~ + +Training logs and checkpoints are saved in ``--exp-dir`` (e.g. ``pruned_transducer_stateless7_streaming/exp``. +You will find the following files in that directory: + + - ``epoch-1.pt``, ``epoch-2.pt``, ... + + These are checkpoint files saved at the end of each epoch, containing model + ``state_dict`` and optimizer ``state_dict``. + To resume training from some checkpoint, say ``epoch-10.pt``, you can use: + + .. code-block:: bash + + $ ./pruned_transducer_stateless7_streaming/train.py --start-epoch 11 + + - ``checkpoint-436000.pt``, ``checkpoint-438000.pt``, ... + + These are checkpoint files saved every ``--save-every-n`` batches, + containing model ``state_dict`` and optimizer ``state_dict``. + To resume training from some checkpoint, say ``checkpoint-436000``, you can use: + + .. code-block:: bash + + $ ./pruned_transducer_stateless7_streaming/train.py --start-batch 436000 + + - ``tensorboard/`` + + This folder contains tensorBoard logs. Training loss, validation loss, learning + rate, etc, are recorded in these logs. You can visualize them by: + + .. code-block:: bash + + $ cd pruned_transducer_stateless7_streaming/exp/tensorboard + $ tensorboard dev upload --logdir . --description "pruned transducer training for LibriSpeech with icefall" + + .. hint:: + + If you don't have access to google, you can use the following command + to view the tensorboard log locally: + + .. code-block:: bash + + cd pruned_transducer_stateless7_streaming/exp/tensorboard + tensorboard --logdir . --port 6008 + + It will print the following message: + + .. code-block:: + + Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all + TensorBoard 2.8.0 at http://localhost:6008/ (Press CTRL+C to quit) + + Now start your browser and go to ``_ to view the tensorboard + logs. + + + - ``log/log-train-xxxx`` + + It is the detailed training log in text format, same as the one + you saw printed to the console during training. + +Usage example +~~~~~~~~~~~~~ + +You can use the following command to start the training using 4 GPUs: + +.. code-block:: bash + + export CUDA_VISIBLE_DEVICES="0,1,2,3" + ./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --full-libri 1 \ + --max-duration 550 + +Decoding +-------- + +The decoding part uses checkpoints saved by the training part, so you have +to run the training part first. + +.. hint:: + + There are two kinds of checkpoints: + + - (1) ``epoch-1.pt``, ``epoch-2.pt``, ..., which are saved at the end + of each epoch. You can pass ``--epoch`` to + ``pruned_transducer_stateless7_streaming/decode.py`` to use them. + + - (2) ``checkpoints-436000.pt``, ``epoch-438000.pt``, ..., which are saved + every ``--save-every-n`` batches. You can pass ``--iter`` to + ``pruned_transducer_stateless7_streaming/decode.py`` to use them. + + We suggest that you try both types of checkpoints and choose the one + that produces the lowest WERs. + +.. tip:: + + To decode a streaming model, you can use either ``simulate streaming decoding`` in ``decode.py`` or + ``real chunk-wise streaming decoding`` in ``streaming_decode.py``. The difference between ``decode.py`` and + ``streaming_decode.py`` is that, ``decode.py`` processes the whole acoustic frames at one time with masking (i.e. same as training), + but ``streaming_decode.py`` processes the acoustic frames chunk by chunk. + +.. NOTE:: + + ``simulate streaming decoding`` in ``decode.py`` and ``real chunk-size streaming decoding`` in ``streaming_decode.py`` should + produce almost the same results given the same ``--decode-chunk-len``. + + +Simulate streaming decoding +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless7_streaming/decode.py --help + +shows the options for decoding. +The following options are important for streaming models: + + ``--decode-chunk-len`` + + It is same as in ``train.py``, which specifies the chunk size for decoding (in frames before subsampling). + The default value is 32 (i.e., 320ms). + + +The following shows two examples (for the two types of checkpoints): + +.. code-block:: bash + + for m in greedy_search fast_beam_search modified_beam_search; do + for epoch in 30; do + for avg in 12 11 10 9 8; do + ./pruned_transducer_stateless7_streaming/decode.py \ + --epoch $epoch \ + --avg $avg \ + --decode-chunk-len 32 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decoding-method $m + done + done + done + + +.. code-block:: bash + + for m in greedy_search fast_beam_search modified_beam_search; do + for iter in 474000; do + for avg in 8 10 12 14 16 18; do + ./pruned_transducer_stateless7_streaming/decode.py \ + --iter $iter \ + --avg $avg \ + --decode-chunk-len 32 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decoding-method $m + done + done + done + + +Real streaming decoding +~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless7_streaming/streaming_decode.py --help + +shows the options for decoding. +The following options are important for streaming models: + + ``--decode-chunk-len`` + + It is same as in ``train.py``, which specifies the chunk size for decoding (in frames before subsampling). + The default value is 32 (i.e., 320ms). + For ``real streaming decoding``, we will process ``decode-chunk-len`` acoustic frames at each time. + + ``--num-decode-streams`` + + The number of decoding streams that can be run in parallel (very similar to the ``bath size``). + For ``real streaming decoding``, the batches will be packed dynamically, for example, if the + ``num-decode-streams`` equals to 10, then, sequence 1 to 10 will be decoded at first, after a while, + suppose sequence 1 and 2 are done, so, sequence 3 to 12 will be processed parallelly in a batch. + + +The following shows two examples (for the two types of checkpoints): + +.. code-block:: bash + + for m in greedy_search fast_beam_search modified_beam_search; do + for epoch in 30; do + for avg in 12 11 10 9 8; do + ./pruned_transducer_stateless7_streaming/decode.py \ + --epoch $epoch \ + --avg $avg \ + --decode-chunk-len 32 \ + --num-decode-streams 100 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --decoding-method $m + done + done + done + + +.. code-block:: bash + + for m in greedy_search fast_beam_search modified_beam_search; do + for iter in 474000; do + for avg in 8 10 12 14 16 18; do + ./pruned_transducer_stateless7_streaming/decode.py \ + --iter $iter \ + --avg $avg \ + --decode-chunk-len 16 \ + --num-decode-streams 100 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --decoding-method $m + done + done + done + + +.. tip:: + + Supporting decoding methods are as follows: + + - ``greedy_search`` : It takes the symbol with largest posterior probability + of each frame as the decoding result. + + - ``beam_search`` : It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf and + `espnet/nets/beam_search_transducer.py `_ + is used as a reference. Basicly, it keeps topk states for each frame, and expands the kept states with their own contexts to + next frame. + + - ``modified_beam_search`` : It implements the same algorithm as ``beam_search`` above, but it + runs in batch mode with ``--max-sym-per-frame=1`` being hardcoded. + + - ``fast_beam_search`` : It implements graph composition between the output ``log_probs`` and + given ``FSAs``. It is hard to describe the details in several lines of texts, you can read + our paper in https://arxiv.org/pdf/2211.00484.pdf or our `rnnt decode code in k2 `_. ``fast_beam_search`` can decode with ``FSAs`` on GPU efficiently. + + - ``fast_beam_search_LG`` : The same as ``fast_beam_search`` above, ``fast_beam_search`` uses + an trivial graph that has only one state, while ``fast_beam_search_LG`` uses an LG graph + (with N-gram LM). + + - ``fast_beam_search_nbest`` : It produces the decoding results as follows: + + - (1) Use ``fast_beam_search`` to get a lattice + - (2) Select ``num_paths`` paths from the lattice using ``k2.random_paths()`` + - (3) Unique the selected paths + - (4) Intersect the selected paths with the lattice and compute the + shortest path from the intersection result + - (5) The path with the largest score is used as the decoding output. + + - ``fast_beam_search_nbest_LG`` : It implements same logic as ``fast_beam_search_nbest``, the + only difference is that it uses ``fast_beam_search_LG`` to generate the lattice. + +.. NOTE:: + + The supporting decoding methods in ``streaming_decode.py`` might be less than that in ``decode.py``, if needed, + you can implement them by yourself or file a issue in `icefall `_ . + + +Export Model +------------ + +Currently it supports exporting checkpoints from ``pruned_transducer_stateless7_streaming/exp`` in the following ways. + +Export ``model.state_dict()`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Checkpoints saved by ``pruned_transducer_stateless7_streaming/train.py`` also include +``optimizer.state_dict()``. It is useful for resuming training. But after training, +we are interested only in ``model.state_dict()``. You can use the following +command to extract ``model.state_dict()``. + +.. code-block:: bash + + # Assume that --epoch 30 --avg 9 produces the smallest WER + # (You can get such information after running ./pruned_transducer_stateless7_streaming/decode.py) + + epoch=30 + avg=9 + + ./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch $epoch \ + --avg $avg \ + --use-averaged-model=True \ + --decode-chunk-len 32 + +It will generate a file ``./pruned_transducer_stateless7_streaming/exp/pretrained.pt``. + +.. hint:: + + To use the generated ``pretrained.pt`` for ``pruned_transducer_stateless7_streaming/decode.py``, + you can run: + + .. code-block:: bash + + cd pruned_transducer_stateless7_streaming/exp + ln -s pretrained.pt epoch-999.pt + + And then pass ``--epoch 999 --avg 1 --use-averaged-model 0`` to + ``./pruned_transducer_stateless7_streaming/decode.py``. + +To use the exported model with ``./pruned_transducer_stateless7_streaming/pretrained.py``, you +can run: + +.. code-block:: bash + + ./pruned_transducer_stateless7_streaming/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + --decode-chunk-len 32 \ + /path/to/foo.wav \ + /path/to/bar.wav + + +Export model using ``torch.jit.script()`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + ./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 9 \ + --decode-chunk-len 32 \ + --jit 1 + +.. caution:: + + ``--decode-chunk-len`` is required to export a ScriptModule. + +It will generate a file ``cpu_jit.pt`` in the given ``exp_dir``. You can later +load it by ``torch.jit.load("cpu_jit.pt")``. + +Note ``cpu`` in the name ``cpu_jit.pt`` means the parameters when loaded into Python +are on CPU. You can use ``to("cuda")`` to move them to a CUDA device. + +Export model using ``torch.jit.trace()`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + epoch=30 + avg=9 + + ./pruned_transducer_stateless7_streaming/jit_trace_export.py \ + --bpe-model data/lang_bpe_500/bpe.model \ + --use-averaged-model=True \ + --decode-chunk-len 32 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --epoch $epoch \ + --avg $avg + +.. caution:: + + ``--decode-chunk-len`` is required to export a ScriptModule. + +It will generate 3 files: + + - ``./pruned_transducer_stateless7_streaming/exp/encoder_jit_trace.pt`` + - ``./pruned_transducer_stateless7_streaming/exp/decoder_jit_trace.pt`` + - ``./pruned_transducer_stateless7_streaming/exp/joiner_jit_trace.pt`` + +To use the generated files with ``./pruned_transducer_stateless7_streaming/jit_trace_pretrained.py``: + +.. code-block:: bash + + ./pruned_transducer_stateless7_streaming/jit_trace_pretrained.py \ + --encoder-model-filename ./pruned_transducer_stateless7_streaming/exp/encoder_jit_trace.pt \ + --decoder-model-filename ./pruned_transducer_stateless7_streaming/exp/decoder_jit_trace.pt \ + --joiner-model-filename ./pruned_transducer_stateless7_streaming/exp/joiner_jit_trace.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --decode-chunk-len 32 \ + /path/to/foo.wav + + +Download pretrained models +-------------------------- + +If you don't want to train from scratch, you can download the pretrained models +by visiting the following links: + + - `pruned_transducer_stateless7_streaming `_ + + See ``_ + for the details of the above pretrained models + +Deploy with Sherpa +------------------ + +Please see ``_ +for how to deploy the models in ``sherpa``. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/README.md b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/README.md new file mode 100644 index 000000000..6e461e196 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/README.md @@ -0,0 +1,3 @@ +This recipe implements Streaming Zipformer-Transducer model. + +See https://k2-fsa.github.io/icefall/recipes/Streaming-ASR/librispeech/zipformer_transducer.html for detailed tutorials. diff --git a/egs/librispeech/ASR/zipformer_mmi/README.md b/egs/librispeech/ASR/zipformer_mmi/README.md index 8ca844180..e9a37a52a 100644 --- a/egs/librispeech/ASR/zipformer_mmi/README.md +++ b/egs/librispeech/ASR/zipformer_mmi/README.md @@ -1,6 +1,6 @@ This recipe implements Zipformer-MMI model. -See https://k2-fsa.github.io/icefall/recipes/librispeech/zipformer_mmi.html for detailed tutorials. +See https://k2-fsa.github.io/icefall/recipes/Non-streaming-ASR/librispeech/zipformer_mmi.html for detailed tutorials. It uses **CTC loss for warm-up** and then switches to MMI loss during training. From 2fd970b6821d47dacb2e6513321520db21fff67b Mon Sep 17 00:00:00 2001 From: Daniil Date: Sun, 1 Jan 2023 19:08:32 -0500 Subject: [PATCH 095/120] not removing result_dir in tedlium conformer ctc2 + add lm stem to compile_hlg_using_openfst.py + add MASTER_ADDR to be prvided to setup_dist (#801) --- .../ASR/local/compile_hlg_using_openfst.py | 19 ++++++++++++++----- egs/tedlium3/ASR/conformer_ctc2/decode.py | 7 ++----- icefall/dist.py | 8 ++++++-- 3 files changed, 22 insertions(+), 12 deletions(-) diff --git a/egs/librispeech/ASR/local/compile_hlg_using_openfst.py b/egs/librispeech/ASR/local/compile_hlg_using_openfst.py index 9e5e3df69..15fc47ef1 100755 --- a/egs/librispeech/ASR/local/compile_hlg_using_openfst.py +++ b/egs/librispeech/ASR/local/compile_hlg_using_openfst.py @@ -24,7 +24,7 @@ This script takes as input lang_dir and generates HLG from Caution: We use a lexicon that contains disambiguation symbols - - G, the LM, built from data/lm/G_3_gram.fst.txt + - G, the LM, built from data/lm/G_n_gram.fst.txt The generated HLG is saved in $lang_dir/HLG_fst.pt @@ -46,6 +46,13 @@ from icefall.lexicon import Lexicon def get_args(): parser = argparse.ArgumentParser() + parser.add_argument( + "--lm", + type=str, + default="G_3_gram", + help="""Stem name for LM used in HLG compiling. + """, + ) parser.add_argument( "--lang-dir", type=str, @@ -56,11 +63,13 @@ def get_args(): return parser.parse_args() -def compile_HLG(lang_dir: str) -> kaldifst.StdVectorFst: +def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> kaldifst.StdVectorFst: """ Args: lang_dir: The language directory, e.g., data/lang_phone or data/lang_bpe_5000. + lm: + The language stem base name. Return: An FST representing HLG. @@ -71,8 +80,8 @@ def compile_HLG(lang_dir: str) -> kaldifst.StdVectorFst: kaldifst.arcsort(L, sort_type="olabel") logging.info(f"L: #states {L.num_states}") - G_filename_txt = "data/lm/G_3_gram.fst.txt" - G_filename_binary = "data/lm/G_3_gram.fst" + G_filename_txt = f"data/lm/{lm}.fst.txt" + G_filename_binary = f"data/lm/{lm}.fst" if Path(G_filename_binary).is_file(): logging.info(f"Loading {G_filename_binary}") G = kaldifst.StdVectorFst.read(G_filename_binary) @@ -171,7 +180,7 @@ def main(): logging.info(f"{filename} already exists - skipping") return - HLG = compile_HLG(lang_dir) + HLG = compile_HLG(lang_dir, args.lm) logging.info(f"Saving HLG to {filename}") torch.save(HLG.as_dict(), filename) diff --git a/egs/tedlium3/ASR/conformer_ctc2/decode.py b/egs/tedlium3/ASR/conformer_ctc2/decode.py index ce4dcd142..28d39de70 100755 --- a/egs/tedlium3/ASR/conformer_ctc2/decode.py +++ b/egs/tedlium3/ASR/conformer_ctc2/decode.py @@ -20,7 +20,6 @@ import argparse import logging -import shutil from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -183,7 +182,7 @@ def get_parser() -> argparse.ArgumentParser: parser.add_argument( "--result-dir", type=str, - default="conformer_ctc2/exp", + default="conformer_ctc2/exp/results", help="Directory to store results.", ) @@ -635,9 +634,7 @@ def main() -> None: args.lm_path = Path(args.lm_path) args.result_dir = Path(args.result_dir) - if args.result_dir.is_dir(): - shutil.rmtree(args.result_dir) - args.result_dir.mkdir() + args.result_dir.mkdir(exist_ok=True) params = get_params() params.update(vars(args)) diff --git a/icefall/dist.py b/icefall/dist.py index 9df1c5bd1..672948623 100644 --- a/icefall/dist.py +++ b/icefall/dist.py @@ -21,12 +21,16 @@ import torch from torch import distributed as dist -def setup_dist(rank, world_size, master_port=None, use_ddp_launch=False): +def setup_dist( + rank, world_size, master_addr=None, master_port=None, use_ddp_launch=False +): """ rank and world_size are used only if use_ddp_launch is False. """ if "MASTER_ADDR" not in os.environ: - os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_ADDR"] = ( + "localhost" if master_addr is None else str(master_addr) + ) if "MASTER_PORT" not in os.environ: os.environ["MASTER_PORT"] = "12354" if master_port is None else str(master_port) From 80cce141b4235c9bf0d6a903f202e1217d56c18b Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Tue, 3 Jan 2023 15:40:53 +0800 Subject: [PATCH 096/120] Full libri fix manifest (#804) * modify the name of the directory of vq manifest * fix missing manifest in full libri training --- .../ASR/distillation_with_hubert.sh | 22 +++++++++++++++---- .../pruned_transducer_stateless6/vq_utils.py | 15 ++++++++++--- 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/distillation_with_hubert.sh b/egs/librispeech/ASR/distillation_with_hubert.sh index d5d3008aa..a38cf590c 100755 --- a/egs/librispeech/ASR/distillation_with_hubert.sh +++ b/egs/librispeech/ASR/distillation_with_hubert.sh @@ -43,7 +43,7 @@ mkdir -p $exp_dir # full_libri can be "True" or "False" # "True" -> use full librispeech dataset for distillation # "False" -> use train-clean-100 subset for distillation -full_libri=False +full_libri=True # use_extracted_codebook can be "True" or "False" # "True" -> stage 0 and stage 1 would be skipped, @@ -145,8 +145,12 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then log "Currently we only uploaded codebook indexes from teacher model hubert_xtralarge_ll60k_finetune_ls960" exit 1 fi + # The codebook indexes to be downloaded are generated using the following setup: + embedding_layer=36 + num_codebooks=8 + mkdir -p $exp_dir/vq - codebook_dir=$exp_dir/vq/$teacher_model_id + codebook_dir=$exp_dir/vq/${teacher_model_id}_layer${embedding_layer}_cb${num_codebooks} mkdir -p codebook_dir codebook_download_dir=$exp_dir/download_codebook if [ -d $codebook_download_dir ]; then @@ -164,8 +168,9 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then git lfs install git clone https://huggingface.co/marcoyang/pruned_transducer_stateless6_hubert_xtralarge_ll60k_finetune_ls960 $codebook_download_dir - mkdir -p data/vq_fbank - mv $codebook_download_dir/*.jsonl.gz data/vq_fbank/ + vq_fbank=data/vq_fbank_layer${embedding_layer}_cb${num_codebooks}/ + mkdir -p $vq_fbank + mv $codebook_download_dir/*.jsonl.gz $vq_fbank mkdir -p $codebook_dir/splits4 mv $codebook_download_dir/*.h5 $codebook_dir/splits4/ log "Remove $codebook_download_dir" @@ -181,6 +186,15 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then --max-duration 100 \ --teacher-model-id $teacher_model_id \ --use-extracted-codebook $use_extracted_codebook + + if [ "$full_libri" == "True" ]; then + # Merge the 3 subsets and create a full one + rm ${vq_fbank}/librispeech_cuts_train-all-shuf.jsonl.gz + cat <(gunzip -c ${vq_fbank}/librispeech_cuts_train-clean-100.jsonl.gz) \ + <(gunzip -c ${vq_fbank}/librispeech_cuts_train-clean-360.jsonl.gz) \ + <(gunzip -c ${vq_fbank}/librispeech_cuts_train-other-500.jsonl.gz) | \ + shuf | gzip -c > ${vq_fbank}/librispeech_cuts_train-all-shuf.jsonl.gz + fi fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py index bf072d865..14ff86f23 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py @@ -68,7 +68,10 @@ class CodebookIndexExtractor: def init_dirs(self): # vq_dir is the root dir for quantization, containing: # training data, trained quantizer, and extracted codebook indexes - self.vq_dir = self.params.exp_dir / f"vq/{self.params.teacher_model_id}/" + self.vq_dir = ( + self.params.exp_dir + / f"vq/{self.params.teacher_model_id}_layer{self.params.embedding_layer}_cb{self.params.num_codebooks}/" + ) self.vq_dir.mkdir(parents=True, exist_ok=True) # manifest_dir contains: @@ -79,7 +82,10 @@ class CodebookIndexExtractor: # It's doesn't matter whether ori_manifest_dir is str or Path. # Set it to Path to be consistent. self.ori_manifest_dir = Path("./data/fbank/") - self.dst_manifest_dir = Path("./data/vq_fbank/") + self.dst_manifest_dir = Path( + f"./data/vq_fbank_layer" + + f"{self.params.embedding_layer}_cb{self.params.num_codebooks}/" + ) self.dst_manifest_dir.mkdir(parents=True, exist_ok=True) @@ -284,7 +290,10 @@ class CodebookIndexExtractor: Merge generated vq included manfiests and storage to self.dst_manifest_dir. """ for subset in self.params.subsets: - vq_manifests = f"{self.manifest_dir}/with_codebook_indexes-librispeech-cuts_train-{subset}*.jsonl.gz" + vq_manifests = ( + f"{self.manifest_dir}/" + + f"with_codebook_indexes-librispeech-cuts_train-{subset}*.jsonl.gz" + ) dst_vq_manifest = ( self.dst_manifest_dir / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" ) From 0f26edfde96d48406f2f227ed87584c6a94f3f68 Mon Sep 17 00:00:00 2001 From: Yunusemre Date: Tue, 3 Jan 2023 08:59:44 +0000 Subject: [PATCH 097/120] Add Zipformer Onnx Support (#778) * add export script * add zipformer onnx pretrained script * add onnx zipformer test * fix style * add zipformer onnx to workflow * replace is_in_onnx_export with is_tracing * add github.event.label.name == 'onnx' * add is_tracing to necessary conditions * fix pooling_mask * add onnx_check * add onnx_check to scripts * add is_tracing to scaling.py --- ...pruned-transducer-stateless7-2022-11-11.sh | 30 ++ .../run-librispeech-2022-11-11-stateless7.yml | 2 +- .../pruned_transducer_stateless7/export.py | 267 +++++++++++- .../onnx_check.py | 286 +++++++++++++ .../onnx_pretrained.py | 388 ++++++++++++++++++ .../pruned_transducer_stateless7/scaling.py | 7 +- .../pruned_transducer_stateless7/test_onnx.py | 374 +++++++++++++++++ .../pruned_transducer_stateless7/zipformer.py | 57 ++- 8 files changed, 1383 insertions(+), 28 deletions(-) create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7/onnx_check.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7/onnx_pretrained.py create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless7/test_onnx.py diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh index 8e485d2e6..999841b80 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh @@ -30,6 +30,15 @@ ln -s pretrained.pt epoch-99.pt ls -lh *.pt popd +log "Test exporting to ONNX format" +./pruned_transducer_stateless7/export.py \ + --exp-dir $repo/exp \ + --use-averaged-model false \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --onnx 1 + log "Export to torchscript model" ./pruned_transducer_stateless7/export.py \ --exp-dir $repo/exp \ @@ -41,6 +50,27 @@ log "Export to torchscript model" ls -lh $repo/exp/*.pt +log "Decode with ONNX models" + +./pruned_transducer_stateless7/onnx_check.py \ + --jit-filename $repo/exp/cpu_jit.pt \ + --onnx-encoder-filename $repo/exp/encoder.onnx \ + --onnx-decoder-filename $repo/exp/decoder.onnx \ + --onnx-joiner-filename $repo/exp/joiner.onnx \ + --onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj.onnx \ + --onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj.onnx + +./pruned_transducer_stateless7/onnx_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --encoder-model-filename $repo/exp/encoder.onnx \ + --decoder-model-filename $repo/exp/decoder.onnx \ + --joiner-model-filename $repo/exp/joiner.onnx \ + --joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \ + --joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + log "Decode with models exported by torch.jit.script()" ./pruned_transducer_stateless7/jit_pretrained.py \ diff --git a/.github/workflows/run-librispeech-2022-11-11-stateless7.yml b/.github/workflows/run-librispeech-2022-11-11-stateless7.yml index 365e2761a..7694e8bf5 100644 --- a/.github/workflows/run-librispeech-2022-11-11-stateless7.yml +++ b/.github/workflows/run-librispeech-2022-11-11-stateless7.yml @@ -39,7 +39,7 @@ concurrency: jobs: run_librispeech_2022_11_11_zipformer: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' + if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' runs-on: ${{ matrix.os }} strategy: matrix: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py index 3e3160e7e..db8b5eb2b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py @@ -41,7 +41,31 @@ Check https://github.com/k2-fsa/sherpa for how to use the exported models outside of icefall. -(2) Export `model.state_dict()` +(2) Export to ONNX format + +./pruned_transducer_stateless7/export.py \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --onnx 1 + +It will generate the following files in the given `exp_dir`. +Check `onnx_check.py` for how to use them. + + - encoder.onnx + - decoder.onnx + - joiner.onnx + - joiner_encoder_proj.onnx + - joiner_decoder_proj.onnx + +Please see ./onnx_pretrained.py for usage of the generated files + +Check +https://github.com/k2-fsa/sherpa-onnx +for how to use the exported models outside of icefall. + +(3) Export `model.state_dict()` ./pruned_transducer_stateless7/export.py \ --exp-dir ./pruned_transducer_stateless7/exp \ @@ -172,6 +196,23 @@ def get_parser(): """, ) + parser.add_argument( + "--onnx", + type=str2bool, + default=False, + help="""If True, --jit is ignored and it exports the model + to onnx format. It will generate the following files: + + - encoder.onnx + - decoder.onnx + - joiner.onnx + - joiner_encoder_proj.onnx + - joiner_decoder_proj.onnx + + Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them. + """, + ) + parser.add_argument( "--context-size", type=int, @@ -184,6 +225,204 @@ def get_parser(): return parser +def export_encoder_model_onnx( + encoder_model: nn.Module, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T, C) + - encoder_out_lens, a tensor of shape (N,) + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(1, 101, 80, dtype=torch.float32) + x_lens = torch.tensor([101], dtype=torch.int64) + + # encoder_model = torch.jit.script(encoder_model) + # It throws the following error for the above statement + # + # RuntimeError: Exporting the operator __is_ to ONNX opset version + # 11 is not supported. Please feel free to request support or + # submit a pull request on PyTorch GitHub. + # + # I cannot find which statement causes the above error. + # torch.onnx.export() will use torch.jit.trace() internally, which + # works well for the current reworked model + torch.onnx.export( + encoder_model, + (x, x_lens), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["encoder_out", "encoder_out_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + }, + ) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_onnx( + decoder_model: nn.Module, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, 1, C) + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = False # Always False, so we can use torch.jit.trace() here + # Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script() + # in this case + torch.onnx.export( + decoder_model, + (y, need_pad), + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y", "need_pad"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - projected_encoder_out: a tensor of shape (N, joiner_dim) + - projected_decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + + The exported encoder_proj model has one input: + + - encoder_out: a tensor of shape (N, encoder_out_dim) + + and produces one output: + + - projected_encoder_out: a tensor of shape (N, joiner_dim) + + The exported decoder_proj model has one input: + + - decoder_out: a tensor of shape (N, decoder_out_dim) + + and produces one output: + + - projected_decoder_out: a tensor of shape (N, joiner_dim) + """ + encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") + decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") + + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + joiner_dim = joiner_model.decoder_proj.weight.shape[0] + + projected_encoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32) + + project_input = False + # Note: It uses torch.jit.trace() internally + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out, project_input), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + "project_input", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + logging.info(f"Saved to {joiner_filename}") + + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + torch.onnx.export( + joiner_model.encoder_proj, + encoder_out, + encoder_proj_filename, + verbose=False, + opset_version=opset_version, + input_names=["encoder_out"], + output_names=["projected_encoder_out"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "projected_encoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {encoder_proj_filename}") + + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + torch.onnx.export( + joiner_model.decoder_proj, + decoder_out, + decoder_proj_filename, + verbose=False, + opset_version=opset_version, + input_names=["decoder_out"], + output_names=["projected_decoder_out"], + dynamic_axes={ + "decoder_out": {0: "N"}, + "projected_decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_proj_filename}") + + @torch.no_grad() def main(): args = get_parser().parse_args() @@ -292,7 +531,31 @@ def main(): model.to("cpu") model.eval() - if params.jit is True: + if params.onnx is True: + convert_scaled_to_non_scaled(model, inplace=True) + opset_version = 13 + logging.info("Exporting to onnx format") + encoder_filename = params.exp_dir / "encoder.onnx" + export_encoder_model_onnx( + model.encoder, + encoder_filename, + opset_version=opset_version, + ) + + decoder_filename = params.exp_dir / "decoder.onnx" + export_decoder_model_onnx( + model.decoder, + decoder_filename, + opset_version=opset_version, + ) + + joiner_filename = params.exp_dir / "joiner.onnx" + export_joiner_model_onnx( + model.joiner, + joiner_filename, + opset_version=opset_version, + ) + elif params.jit is True: convert_scaled_to_non_scaled(model, inplace=True) # We won't use the forward() method of the model in C++, so just ignore # it here. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_check.py new file mode 100755 index 000000000..63acc0922 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_check.py @@ -0,0 +1,286 @@ +#!/usr/bin/env python3 +# +# Copyright 2022 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script checks that exported onnx models produce the same output +with the given torchscript model for the same input. +""" + +import argparse +import logging + +import onnxruntime as ort +import torch + +from icefall import is_module_available + +if not is_module_available("onnxruntime"): + raise ValueError("Please 'pip install onnxruntime' first.") + + +ort.set_default_logger_severity(3) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--jit-filename", + required=True, + type=str, + help="Path to the torchscript model", + ) + + parser.add_argument( + "--onnx-encoder-filename", + required=True, + type=str, + help="Path to the onnx encoder model", + ) + + parser.add_argument( + "--onnx-decoder-filename", + required=True, + type=str, + help="Path to the onnx decoder model", + ) + + parser.add_argument( + "--onnx-joiner-filename", + required=True, + type=str, + help="Path to the onnx joiner model", + ) + + parser.add_argument( + "--onnx-joiner-encoder-proj-filename", + required=True, + type=str, + help="Path to the onnx joiner encoder projection model", + ) + + parser.add_argument( + "--onnx-joiner-decoder-proj-filename", + required=True, + type=str, + help="Path to the onnx joiner decoder projection model", + ) + + return parser + + +def test_encoder( + model: torch.jit.ScriptModule, + encoder_session: ort.InferenceSession, +): + inputs = encoder_session.get_inputs() + outputs = encoder_session.get_outputs() + input_names = [n.name for n in inputs] + output_names = [n.name for n in outputs] + + assert inputs[0].shape == ["N", "T", 80] + assert inputs[1].shape == ["N"] + + for N in [1, 5]: + for T in [12, 50]: + print("N, T", N, T) + x = torch.rand(N, T, 80, dtype=torch.float32) + x_lens = torch.randint(low=10, high=T + 1, size=(N,)) + x_lens[0] = T + + encoder_inputs = { + input_names[0]: x.numpy(), + input_names[1]: x_lens.numpy(), + } + + torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens) + + encoder_out, encoder_out_lens = encoder_session.run( + output_names, + encoder_inputs, + ) + + torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens) + + encoder_out = torch.from_numpy(encoder_out) + assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), ( + (encoder_out - torch_encoder_out).abs().max(), + encoder_out.shape, + torch_encoder_out.shape, + ) + + +def test_decoder( + model: torch.jit.ScriptModule, + decoder_session: ort.InferenceSession, +): + inputs = decoder_session.get_inputs() + outputs = decoder_session.get_outputs() + input_names = [n.name for n in inputs] + output_names = [n.name for n in outputs] + + assert inputs[0].shape == ["N", 2] + for N in [1, 5, 10]: + y = torch.randint(low=1, high=500, size=(10, 2)) + + decoder_inputs = {input_names[0]: y.numpy()} + decoder_out = decoder_session.run( + output_names, + decoder_inputs, + )[0] + decoder_out = torch.from_numpy(decoder_out) + + torch_decoder_out = model.decoder(y, need_pad=False) + assert torch.allclose(decoder_out, torch_decoder_out, atol=1e-5), ( + (decoder_out - torch_decoder_out).abs().max() + ) + + +def test_joiner( + model: torch.jit.ScriptModule, + joiner_session: ort.InferenceSession, + joiner_encoder_proj_session: ort.InferenceSession, + joiner_decoder_proj_session: ort.InferenceSession, +): + joiner_inputs = joiner_session.get_inputs() + joiner_outputs = joiner_session.get_outputs() + joiner_input_names = [n.name for n in joiner_inputs] + joiner_output_names = [n.name for n in joiner_outputs] + + assert joiner_inputs[0].shape == ["N", 1, 1, 512] + assert joiner_inputs[1].shape == ["N", 1, 1, 512] + + joiner_encoder_proj_inputs = joiner_encoder_proj_session.get_inputs() + encoder_proj_input_name = joiner_encoder_proj_inputs[0].name + + assert joiner_encoder_proj_inputs[0].shape == ["N", 384] + + joiner_encoder_proj_outputs = joiner_encoder_proj_session.get_outputs() + encoder_proj_output_name = joiner_encoder_proj_outputs[0].name + + joiner_decoder_proj_inputs = joiner_decoder_proj_session.get_inputs() + decoder_proj_input_name = joiner_decoder_proj_inputs[0].name + + assert joiner_decoder_proj_inputs[0].shape == ["N", 512] + + joiner_decoder_proj_outputs = joiner_decoder_proj_session.get_outputs() + decoder_proj_output_name = joiner_decoder_proj_outputs[0].name + + for N in [1, 5, 10]: + encoder_out = torch.rand(N, 384) + decoder_out = torch.rand(N, 512) + + projected_encoder_out = torch.rand(N, 1, 1, 512) + projected_decoder_out = torch.rand(N, 1, 1, 512) + + joiner_inputs = { + joiner_input_names[0]: projected_encoder_out.numpy(), + joiner_input_names[1]: projected_decoder_out.numpy(), + } + joiner_out = joiner_session.run(joiner_output_names, joiner_inputs)[0] + joiner_out = torch.from_numpy(joiner_out) + + torch_joiner_out = model.joiner( + projected_encoder_out, + projected_decoder_out, + project_input=False, + ) + assert torch.allclose(joiner_out, torch_joiner_out, atol=1e-5), ( + (joiner_out - torch_joiner_out).abs().max() + ) + + # Now test encoder_proj + joiner_encoder_proj_inputs = {encoder_proj_input_name: encoder_out.numpy()} + joiner_encoder_proj_out = joiner_encoder_proj_session.run( + [encoder_proj_output_name], joiner_encoder_proj_inputs + )[0] + joiner_encoder_proj_out = torch.from_numpy(joiner_encoder_proj_out) + + torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out) + assert torch.allclose( + joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5 + ), ((joiner_encoder_proj_out - torch_joiner_encoder_proj_out).abs().max()) + + # Now test decoder_proj + joiner_decoder_proj_inputs = {decoder_proj_input_name: decoder_out.numpy()} + joiner_decoder_proj_out = joiner_decoder_proj_session.run( + [decoder_proj_output_name], joiner_decoder_proj_inputs + )[0] + joiner_decoder_proj_out = torch.from_numpy(joiner_decoder_proj_out) + + torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out) + assert torch.allclose( + joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5 + ), ((joiner_decoder_proj_out - torch_joiner_decoder_proj_out).abs().max()) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + logging.info(vars(args)) + + model = torch.jit.load(args.jit_filename) + + options = ort.SessionOptions() + options.inter_op_num_threads = 1 + options.intra_op_num_threads = 1 + + logging.info("Test encoder") + encoder_session = ort.InferenceSession( + args.onnx_encoder_filename, + sess_options=options, + ) + test_encoder(model, encoder_session) + + logging.info("Test decoder") + decoder_session = ort.InferenceSession( + args.onnx_decoder_filename, + sess_options=options, + ) + test_decoder(model, decoder_session) + + logging.info("Test joiner") + joiner_session = ort.InferenceSession( + args.onnx_joiner_filename, + sess_options=options, + ) + joiner_encoder_proj_session = ort.InferenceSession( + args.onnx_joiner_encoder_proj_filename, + sess_options=options, + ) + joiner_decoder_proj_session = ort.InferenceSession( + args.onnx_joiner_decoder_proj_filename, + sess_options=options, + ) + test_joiner( + model, + joiner_session, + joiner_encoder_proj_session, + joiner_decoder_proj_session, + ) + logging.info("Finished checking ONNX models") + + +if __name__ == "__main__": + torch.manual_seed(20220727) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_pretrained.py new file mode 100755 index 000000000..3a06ee293 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_pretrained.py @@ -0,0 +1,388 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX models and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless7/export.py \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --onnx 1 + +Usage of this script: + +./pruned_transducer_stateless7/onnx_pretrained.py \ + --encoder-model-filename ./pruned_transducer_stateless7/exp/encoder.onnx \ + --decoder-model-filename ./pruned_transducer_stateless7/exp/decoder.onnx \ + --joiner-model-filename ./pruned_transducer_stateless7/exp/joiner.onnx \ + --joiner-encoder-proj-model-filename ./pruned_transducer_stateless7/exp/joiner_encoder_proj.onnx \ + --joiner-decoder-proj-model-filename ./pruned_transducer_stateless7/exp/joiner_decoder_proj.onnx \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import numpy as np +import onnxruntime as ort +import sentencepiece as spm +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--joiner-encoder-proj-model-filename", + type=str, + required=True, + help="Path to the joiner encoder_proj onnx model. ", + ) + + parser.add_argument( + "--joiner-decoder-proj-model-filename", + type=str, + required=True, + help="Path to the joiner decoder_proj onnx model. ", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="Context size of the decoder model", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + decoder: ort.InferenceSession, + joiner: ort.InferenceSession, + joiner_encoder_proj: ort.InferenceSession, + joiner_decoder_proj: ort.InferenceSession, + encoder_out: np.ndarray, + encoder_out_lens: np.ndarray, + context_size: int, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + decoder: + The decoder model. + joiner: + The joiner model. + joiner_encoder_proj: + The joiner encoder projection model. + joiner_decoder_proj: + The joiner decoder projection model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + context_size: + The context size of the decoder model. + Returns: + Return the decoded results for each utterance. + """ + encoder_out = torch.from_numpy(encoder_out) + encoder_out_lens = torch.from_numpy(encoder_out_lens) + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + projected_encoder_out = joiner_encoder_proj.run( + [joiner_encoder_proj.get_outputs()[0].name], + {joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()}, + )[0] + + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input_nodes = decoder.get_inputs() + decoder_output_nodes = decoder.get_outputs() + + joiner_input_nodes = joiner.get_inputs() + joiner_output_nodes = joiner.get_outputs() + + decoder_input = torch.tensor( + hyps, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = decoder.run( + [decoder_output_nodes[0].name], + { + decoder_input_nodes[0].name: decoder_input.numpy(), + }, + )[0].squeeze(1) + projected_decoder_out = joiner_decoder_proj.run( + [joiner_decoder_proj.get_outputs()[0].name], + {joiner_decoder_proj.get_inputs()[0].name: decoder_out}, + )[0] + + projected_decoder_out = torch.from_numpy(projected_decoder_out) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = projected_encoder_out[start:end] + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + projected_decoder_out = projected_decoder_out[:batch_size] + + logits = joiner.run( + [joiner_output_nodes[0].name], + { + joiner_input_nodes[0].name: np.expand_dims( + np.expand_dims(current_encoder_out, axis=1), axis=1 + ), + joiner_input_nodes[1] + .name: projected_decoder_out.unsqueeze(1) + .unsqueeze(1) + .numpy(), + }, + )[0] + logits = torch.from_numpy(logits).squeeze(1).squeeze(1) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + dtype=torch.int64, + ) + decoder_out = decoder.run( + [decoder_output_nodes[0].name], + { + decoder_input_nodes[0].name: decoder_input.numpy(), + }, + )[0].squeeze(1) + projected_decoder_out = joiner_decoder_proj.run( + [joiner_decoder_proj.get_outputs()[0].name], + {joiner_decoder_proj.get_inputs()[0].name: decoder_out}, + )[0] + projected_decoder_out = torch.from_numpy(projected_decoder_out) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + encoder = ort.InferenceSession( + args.encoder_model_filename, + sess_options=session_opts, + ) + + decoder = ort.InferenceSession( + args.decoder_model_filename, + sess_options=session_opts, + ) + + joiner = ort.InferenceSession( + args.joiner_model_filename, + sess_options=session_opts, + ) + + joiner_encoder_proj = ort.InferenceSession( + args.joiner_encoder_proj_model_filename, + sess_options=session_opts, + ) + + joiner_decoder_proj = ort.InferenceSession( + args.joiner_decoder_proj_model_filename, + sess_options=session_opts, + ) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) + + encoder_input_nodes = encoder.get_inputs() + encoder_out_nodes = encoder.get_outputs() + encoder_out, encoder_out_lens = encoder.run( + [encoder_out_nodes[0].name, encoder_out_nodes[1].name], + { + encoder_input_nodes[0].name: features.numpy(), + encoder_input_nodes[1].name: feature_lengths.numpy(), + }, + ) + + hyps = greedy_search( + decoder=decoder, + joiner=joiner, + joiner_encoder_proj=joiner_encoder_proj, + joiner_decoder_proj=joiner_decoder_proj, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + context_size=args.context_size, + ) + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = sp.decode(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 1cbde6db0..156b91f09 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -261,7 +261,7 @@ class RandomGrad(torch.nn.Module): self.min_abs = min_abs def forward(self, x: Tensor): - if torch.jit.is_scripting() or not self.training: + if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing(): return x else: return RandomGradFunction.apply(x, self.min_abs) @@ -530,7 +530,7 @@ class ActivationBalancer(torch.nn.Module): self.register_buffer("count", torch.tensor(0, dtype=torch.int64)) def forward(self, x: Tensor) -> Tensor: - if torch.jit.is_scripting() or not x.requires_grad: + if torch.jit.is_scripting() or not x.requires_grad or torch.jit.is_tracing(): return _no_op(x) count = self.cpu_count @@ -790,7 +790,7 @@ def with_loss(x, y): def _no_op(x: Tensor) -> Tensor: - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): return x else: # a no-op function that will have a node in the autograd graph, @@ -862,6 +862,7 @@ class MaxEig(torch.nn.Module): torch.jit.is_scripting() or self.max_var_per_eig <= 0 or random.random() > self.cur_prob + or torch.jit.is_tracing() ): return _no_op(x) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/test_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7/test_onnx.py new file mode 100644 index 000000000..2440d267c --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/test_onnx.py @@ -0,0 +1,374 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file is to test that models can be exported to onnx. +""" +import os + +from icefall import is_module_available + +if not is_module_available("onnxruntime"): + raise ValueError("Please 'pip install onnxruntime' first.") + +import onnxruntime as ort +import torch +from scaling_converter import convert_scaled_to_non_scaled +from zipformer import ( + Conv2dSubsampling, + RelPositionalEncoding, + Zipformer, + ZipformerEncoder, + ZipformerEncoderLayer, +) + +ort.set_default_logger_severity(3) + + +def test_conv2d_subsampling(): + filename = "conv2d_subsampling.onnx" + opset_version = 13 + N = 30 + T = 50 + num_features = 80 + d_model = 512 + x = torch.rand(N, T, num_features) + + encoder_embed = Conv2dSubsampling(num_features, d_model) + encoder_embed.eval() + encoder_embed = convert_scaled_to_non_scaled(encoder_embed, inplace=True) + + torch.onnx.export( + encoder_embed, + x, + filename, + verbose=False, + opset_version=opset_version, + input_names=["x"], + output_names=["y"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "y": {0: "N", 1: "T"}, + }, + ) + + options = ort.SessionOptions() + options.inter_op_num_threads = 1 + options.intra_op_num_threads = 1 + + session = ort.InferenceSession( + filename, + sess_options=options, + ) + + input_nodes = session.get_inputs() + assert input_nodes[0].name == "x" + assert input_nodes[0].shape == ["N", "T", num_features] + + inputs = {input_nodes[0].name: x.numpy()} + + onnx_y = session.run(["y"], inputs)[0] + + onnx_y = torch.from_numpy(onnx_y) + torch_y = encoder_embed(x) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() + + os.remove(filename) + + +def test_rel_pos(): + filename = "rel_pos.onnx" + + opset_version = 13 + N = 30 + T = 50 + num_features = 80 + d_model = 512 + x = torch.rand(N, T, num_features) + + encoder_pos = RelPositionalEncoding(d_model, dropout_rate=0.1) + encoder_pos.eval() + encoder_pos = convert_scaled_to_non_scaled(encoder_pos, inplace=True) + + x = x.permute(1, 0, 2) + + torch.onnx.export( + encoder_pos, + x, + filename, + verbose=False, + opset_version=opset_version, + input_names=["x"], + output_names=["pos_emb"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "pos_emb": {0: "N", 1: "T"}, + }, + ) + + options = ort.SessionOptions() + options.inter_op_num_threads = 1 + options.intra_op_num_threads = 1 + + session = ort.InferenceSession( + filename, + sess_options=options, + ) + + input_nodes = session.get_inputs() + assert input_nodes[0].name == "x" + assert input_nodes[0].shape == ["N", "T", num_features] + + inputs = {input_nodes[0].name: x.numpy()} + onnx_pos_emb = session.run(["pos_emb"], inputs) + onnx_pos_emb = torch.from_numpy(onnx_pos_emb[0]) + + torch_pos_emb = encoder_pos(x) + assert torch.allclose(onnx_pos_emb, torch_pos_emb, atol=1e-05), ( + (onnx_pos_emb - torch_pos_emb).abs().max() + ) + print(onnx_pos_emb.abs().sum(), torch_pos_emb.abs().sum()) + + os.remove(filename) + + +def test_zipformer_encoder_layer(): + filename = "zipformer_encoder_layer.onnx" + opset_version = 13 + N = 30 + T = 50 + + d_model = 384 + attention_dim = 192 + nhead = 8 + feedforward_dim = 1024 + dropout = 0.1 + cnn_module_kernel = 31 + pos_dim = 4 + + x = torch.rand(N, T, d_model) + + encoder_pos = RelPositionalEncoding(d_model, dropout) + encoder_pos.eval() + encoder_pos = convert_scaled_to_non_scaled(encoder_pos, inplace=True) + + x = x.permute(1, 0, 2) + pos_emb = encoder_pos(x) + + encoder_layer = ZipformerEncoderLayer( + d_model, + attention_dim, + nhead, + feedforward_dim, + dropout, + cnn_module_kernel, + pos_dim, + ) + encoder_layer.eval() + encoder_layer = convert_scaled_to_non_scaled(encoder_layer, inplace=True) + + torch.onnx.export( + encoder_layer, + (x, pos_emb), + filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "pos_emb"], + output_names=["y"], + dynamic_axes={ + "x": {0: "T", 1: "N"}, + "pos_emb": {0: "N", 1: "T"}, + "y": {0: "T", 1: "N"}, + }, + ) + + options = ort.SessionOptions() + options.inter_op_num_threads = 1 + options.intra_op_num_threads = 1 + + session = ort.InferenceSession( + filename, + sess_options=options, + ) + + input_nodes = session.get_inputs() + inputs = { + input_nodes[0].name: x.numpy(), + input_nodes[1].name: pos_emb.numpy(), + } + onnx_y = session.run(["y"], inputs)[0] + onnx_y = torch.from_numpy(onnx_y) + + torch_y = encoder_layer(x, pos_emb) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() + + print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) + + os.remove(filename) + + +def test_zipformer_encoder(): + filename = "zipformer_encoder.onnx" + + opset_version = 13 + N = 3 + T = 15 + + d_model = 512 + attention_dim = 192 + nhead = 8 + feedforward_dim = 1024 + dropout = 0.1 + cnn_module_kernel = 31 + pos_dim = 4 + num_encoder_layers = 12 + + warmup_batches = 4000.0 + warmup_begin = warmup_batches / (num_encoder_layers + 1) + warmup_end = warmup_batches / (num_encoder_layers + 1) + + x = torch.rand(N, T, d_model) + + encoder_layer = ZipformerEncoderLayer( + d_model, + attention_dim, + nhead, + feedforward_dim, + dropout, + cnn_module_kernel, + pos_dim, + ) + encoder = ZipformerEncoder( + encoder_layer, num_encoder_layers, dropout, warmup_begin, warmup_end + ) + encoder.eval() + encoder = convert_scaled_to_non_scaled(encoder, inplace=True) + + # jit_model = torch.jit.trace(encoder, (pos_emb)) + + torch_y = encoder(x) + + torch.onnx.export( + encoder, + (x), + filename, + verbose=False, + opset_version=opset_version, + input_names=["x"], + output_names=["y"], + dynamic_axes={ + "x": {0: "T", 1: "N"}, + "y": {0: "T", 1: "N"}, + }, + ) + + options = ort.SessionOptions() + options.inter_op_num_threads = 1 + options.intra_op_num_threads = 1 + + session = ort.InferenceSession( + filename, + sess_options=options, + ) + + input_nodes = session.get_inputs() + inputs = { + input_nodes[0].name: x.numpy(), + } + onnx_y = session.run(["y"], inputs)[0] + onnx_y = torch.from_numpy(onnx_y) + + torch_y = encoder(x) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() + + print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) + + os.remove(filename) + + +def test_zipformer(): + filename = "zipformer.onnx" + opset_version = 11 + N = 3 + T = 15 + num_features = 80 + x = torch.rand(N, T, num_features) + x_lens = torch.full((N,), fill_value=T, dtype=torch.int64) + + zipformer = Zipformer(num_features=num_features) + zipformer.eval() + zipformer = convert_scaled_to_non_scaled(zipformer, inplace=True) + + # jit_model = torch.jit.trace(zipformer, (x, x_lens)) + torch.onnx.export( + zipformer, + (x, x_lens), + filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["y", "y_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "y": {0: "N", 1: "T"}, + "y_lens": {0: "N"}, + }, + ) + options = ort.SessionOptions() + options.inter_op_num_threads = 1 + options.intra_op_num_threads = 1 + + session = ort.InferenceSession( + filename, + sess_options=options, + ) + + input_nodes = session.get_inputs() + inputs = { + input_nodes[0].name: x.numpy(), + input_nodes[1].name: x_lens.numpy(), + } + onnx_y, onnx_y_lens = session.run(["y", "y_lens"], inputs) + onnx_y = torch.from_numpy(onnx_y) + onnx_y_lens = torch.from_numpy(onnx_y_lens) + + torch_y, torch_y_lens = zipformer(x, x_lens) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() + + assert torch.allclose(onnx_y_lens, torch_y_lens, atol=1e-05), ( + (onnx_y_lens - torch_y_lens).abs().max() + ) + print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) + print(onnx_y_lens, torch_y_lens) + + os.remove(filename) + + +@torch.no_grad() +def main(): + test_conv2d_subsampling() + test_rel_pos() + test_zipformer_encoder_layer() + test_zipformer_encoder() + test_zipformer() + + +if __name__ == "__main__": + torch.manual_seed(20221011) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index d18258085..b1717ec64 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -210,7 +210,7 @@ class Zipformer(EncoderInterface): (num_frames, batch_size, encoder_dims0) """ num_encoders = len(self.encoder_dims) - if torch.jit.is_scripting() or not self.training: + if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing(): return [1.0] * num_encoders (num_frames0, batch_size, _encoder_dims0) = x.shape @@ -293,7 +293,7 @@ class Zipformer(EncoderInterface): k = self.skip_layers[i] if isinstance(k, int): layer_skip_dropout_prob = self._get_layer_skip_dropout_prob() - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): x = skip_module(outputs[k], x) elif (not self.training) or random.random() > layer_skip_dropout_prob: x = skip_module(outputs[k], x) @@ -386,7 +386,7 @@ class ZipformerEncoderLayer(nn.Module): ) def get_bypass_scale(self): - if torch.jit.is_scripting() or not self.training: + if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing(): return self.bypass_scale if random.random() < 0.1: # ensure we get grads if self.bypass_scale becomes out of range @@ -407,7 +407,7 @@ class ZipformerEncoderLayer(nn.Module): # return dropout rate for the dynamic modules (self_attn, pooling, convolution); this # starts at 0.2 and rapidly decreases to 0. Its purpose is to keep the training stable # at the beginning, by making the network focus on the feedforward modules. - if torch.jit.is_scripting() or not self.training: + if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing(): return 0.0 warmup_period = 2000.0 initial_dropout_rate = 0.2 @@ -452,12 +452,12 @@ class ZipformerEncoderLayer(nn.Module): dynamic_dropout = self.get_dynamic_dropout_rate() # pooling module - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): src = src + self.pooling(src, key_padding_mask=src_key_padding_mask) elif random.random() >= dynamic_dropout: src = src + self.pooling(src, key_padding_mask=src_key_padding_mask) - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): src_att, attn_weights = self.self_attn( src, pos_emb=pos_emb, @@ -658,7 +658,7 @@ class ZipformerEncoder(nn.Module): pos_emb = self.encoder_pos(src) output = src - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): layers_to_drop = [] else: rnd_seed = src.numel() + random.randint(0, 1000) @@ -667,7 +667,7 @@ class ZipformerEncoder(nn.Module): output = output * feature_mask for i, mod in enumerate(self.layers): - if not torch.jit.is_scripting(): + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): if i in layers_to_drop: continue output = mod( @@ -864,7 +864,7 @@ class SimpleCombiner(torch.nn.Module): assert src1.shape[:-1] == src2.shape[:-1], (src1.shape, src2.shape) weight1 = self.weight1 - if not torch.jit.is_scripting(): + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): if ( self.training and random.random() < 0.25 @@ -1258,21 +1258,31 @@ class RelPositionMultiheadAttention(nn.Module): # the following .as_strided() expression converts the last axis of pos_weights from relative # to absolute position. I don't know whether I might have got the time-offsets backwards or # not, but let this code define which way round it is supposed to be. - pos_weights = pos_weights.as_strided( - (bsz, num_heads, seq_len, seq_len), - ( - pos_weights.stride(0), - pos_weights.stride(1), - pos_weights.stride(2) - pos_weights.stride(3), - pos_weights.stride(3), - ), - storage_offset=pos_weights.stride(3) * (seq_len - 1), - ) + if torch.jit.is_tracing(): + (batch_size, num_heads, time1, n) = pos_weights.shape + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(seq_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + pos_weights = pos_weights.reshape(-1, n) + pos_weights = torch.gather(pos_weights, dim=1, index=indexes) + pos_weights = pos_weights.reshape(batch_size, num_heads, time1, seq_len) + else: + pos_weights = pos_weights.as_strided( + (bsz, num_heads, seq_len, seq_len), + ( + pos_weights.stride(0), + pos_weights.stride(1), + pos_weights.stride(2) - pos_weights.stride(3), + pos_weights.stride(3), + ), + storage_offset=pos_weights.stride(3) * (seq_len - 1), + ) # caution: they are really scores at this point. attn_output_weights = torch.matmul(q, k) + pos_weights - if not torch.jit.is_scripting(): + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): if training and random.random() < 0.1: # This is a harder way of limiting the attention scores to not be too large. # It incurs a penalty if any of them has an absolute value greater than 50.0. @@ -1383,7 +1393,7 @@ class RelPositionMultiheadAttention(nn.Module): # now v: (bsz * num_heads, seq_len, head_dim // 2) attn_output = torch.bmm(attn_weights, v) - if not torch.jit.is_scripting(): + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): if random.random() < 0.001 or __name__ == "__main__": self._print_attn_stats(attn_weights, attn_output) @@ -1458,7 +1468,10 @@ class PoolingModule(nn.Module): a Tensor of shape (1, N, C) """ if key_padding_mask is not None: - pooling_mask = key_padding_mask.logical_not().to(x.dtype) # (N, T) + if torch.jit.is_tracing(): + pooling_mask = (~key_padding_mask).to(x.dtype) + else: + pooling_mask = key_padding_mask.logical_not().to(x.dtype) # (N, T) pooling_mask = pooling_mask / pooling_mask.sum(dim=1, keepdim=True) pooling_mask = pooling_mask.transpose(0, 1).contiguous().unsqueeze(-1) # now pooling_mask: (T, N, 1) From 8642dbc0bd4174acb6612b6510f971f98a16f7d3 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 4 Jan 2023 12:21:19 +0800 Subject: [PATCH 098/120] Fix setup_dist (#806) --- icefall/dist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/icefall/dist.py b/icefall/dist.py index 672948623..922f31a2f 100644 --- a/icefall/dist.py +++ b/icefall/dist.py @@ -22,7 +22,7 @@ from torch import distributed as dist def setup_dist( - rank, world_size, master_addr=None, master_port=None, use_ddp_launch=False + rank, world_size, master_port=None, use_ddp_launch=False, master_addr=None ): """ rank and world_size are used only if use_ddp_launch is False. From b9626f2e0684dd59d761c6bd6e9b6127a387d11c Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Thu, 5 Jan 2023 17:18:43 +0800 Subject: [PATCH 099/120] fix typo for ctc-decode.py (#815) Co-authored-by: yifanyang --- .../ASR/pruned_transducer_stateless7_ctc/ctc_decode.py | 2 +- .../ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py index 9c23e7d66..4b373e4c7 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py @@ -44,7 +44,7 @@ Usage: --exp-dir ./pruned_transducer_stateless7_ctc/exp \ --max-duration 600 \ --hlg-scale 0.8 \ - --decoding-method 1best + --decoding-method nbest (4) nbest-rescoring ./pruned_transducer_stateless7_ctc/ctc_decode.py \ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py index 0ef733226..f137485b2 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py @@ -42,7 +42,7 @@ Usage: --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ --max-duration 600 \ --hlg-scale 0.8 \ - --decoding-method 1best + --decoding-method nbest (4) nbest-rescoring ./pruned_transducer_stateless7_ctc_bs/ctc_decode.py \ --epoch 30 \ From 9a9c5a0f9b083a729ee00d439df1054f517e1b6d Mon Sep 17 00:00:00 2001 From: kobenaxie <572745565@qq.com> Date: Fri, 6 Jan 2023 11:16:22 +0800 Subject: [PATCH 100/120] remove unused codes. (#821) --- .../emformer2.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py index 188059044..f0c92a9b4 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py @@ -1512,24 +1512,6 @@ class EmformerEncoder(nn.Module): ) return states - attn_caches = [ - [ - torch.zeros(self.memory_size, self.d_model, device=device), - torch.zeros(self.left_context_length, self.d_model, device=device), - torch.zeros(self.left_context_length, self.d_model, device=device), - ] - for _ in range(self.num_encoder_layers) - ] - conv_caches = [ - torch.zeros(self.d_model, self.cnn_module_kernel - 1, device=device) - for _ in range(self.num_encoder_layers) - ] - states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]] = ( - attn_caches, - conv_caches, - ) - return states - class Emformer(EncoderInterface): def __init__( From 9453eb1c709140becd3373666bb51e996e8f7260 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 6 Jan 2023 17:00:27 +0800 Subject: [PATCH 101/120] Fix doc for building ncnn (#822) --- docs/README.md | 24 ++++++++++++++ .../lstm_pruned_stateless_transducer.rst | 31 ++++++++++++++++--- 2 files changed, 51 insertions(+), 4 deletions(-) create mode 100644 docs/README.md diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 000000000..3abb38f8b --- /dev/null +++ b/docs/README.md @@ -0,0 +1,24 @@ + +## Usage + +```bash +cd /path/to/icefall/docs +pip install -r requirements.txt +make clean +make html +cd build/html +python3 -m http.server 8000 +``` + +It prints: + +``` +Serving HTTP on 0.0.0.0 port 8000 (http://0.0.0.0:8000/) ... +``` + +Open your browser and go to to view the generated +documentation. + +Done! + +**Hint**: You can change the port number when starting the server. diff --git a/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst b/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst index 643855cc2..d09421eb5 100644 --- a/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst +++ b/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst @@ -531,16 +531,36 @@ First, let us install a modified version of ``ncnn``: git clone https://github.com/csukuangfj/ncnn cd ncnn git submodule update --recursive --init - python3 setup.py bdist_wheel - ls -lh dist/ - pip install ./dist/*.whl + + # Note: We don't use "python setup.py install" or "pip install ." here + + mkdir -p build-wheel + cd build-wheel + + cmake \ + -DCMAKE_BUILD_TYPE=Release \ + -DNCNN_PYTHON=ON \ + -DNCNN_BUILD_BENCHMARK=OFF \ + -DNCNN_BUILD_EXAMPLES=OFF \ + -DNCNN_BUILD_TOOLS=OFF \ + .. + + make -j4 + + cd .. + + # Note: $PWD here is /path/to/ncnn + + export PYTHONPATH=$PWD/python:$PYTHONPATH + export PATH=$PWD/tools/pnnx/build/src:$PATH + export PATH=$PWD/build/tools/quantize:$PATH # now build pnnx cd tools/pnnx mkdir build cd build + cmake .. make -j4 - export PATH=$PWD/src:$PATH ./src/pnnx @@ -549,6 +569,9 @@ First, let us install a modified version of ``ncnn``: We assume that you have added the path to the binary ``pnnx`` to the environment variable ``PATH``. + We also assume that you have added ``build/tools/quantize`` to the environment + variable ``PATH`` so that you are able to use ``ncnn2int8`` later. + Second, let us export the model using ``torch.jit.trace()`` that is suitable for ``pnnx``: From 42cc10117eed5960e7219bbb9501a0beda602cfa Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Mon, 9 Jan 2023 15:08:39 +0800 Subject: [PATCH 102/120] Fix ncnn install (#824) * add README to docs * fix ncnn installation --- .../librispeech/lstm_pruned_stateless_transducer.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst b/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst index d09421eb5..22addd1d2 100644 --- a/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst +++ b/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst @@ -542,7 +542,7 @@ First, let us install a modified version of ``ncnn``: -DNCNN_PYTHON=ON \ -DNCNN_BUILD_BENCHMARK=OFF \ -DNCNN_BUILD_EXAMPLES=OFF \ - -DNCNN_BUILD_TOOLS=OFF \ + -DNCNN_BUILD_TOOLS=ON \ .. make -j4 @@ -553,7 +553,7 @@ First, let us install a modified version of ``ncnn``: export PYTHONPATH=$PWD/python:$PYTHONPATH export PATH=$PWD/tools/pnnx/build/src:$PATH - export PATH=$PWD/build/tools/quantize:$PATH + export PATH=$PWD/build-wheel/tools/quantize:$PATH # now build pnnx cd tools/pnnx From fcffa593f011bd3213af5af044eb3ce2ede666c1 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 10 Jan 2023 15:38:33 +0800 Subject: [PATCH 103/120] Add FAQs to doc (#827) * Add FAQs * small fixes --- docs/source/faqs.rst | 67 +++++++++++++++++++++++++++++++++++++++++++ docs/source/index.rst | 1 + 2 files changed, 68 insertions(+) create mode 100644 docs/source/faqs.rst diff --git a/docs/source/faqs.rst b/docs/source/faqs.rst new file mode 100644 index 000000000..c70ded431 --- /dev/null +++ b/docs/source/faqs.rst @@ -0,0 +1,67 @@ +Frequently Asked Questions (FAQs) +================================= + +In this section, we collect issues reported by users and post the corresponding +solutions. + + +OSError: libtorch_hip.so: cannot open shared object file: no such file or directory +----------------------------------------------------------------------------------- + +One user is using the following code to install ``torch`` and ``torchaudio``: + +.. code-block:: bash + + pip install \ + torch==1.10.0+cu111 \ + torchvision==0.11.0+cu111 \ + torchaudio==0.10.0 \ + -f https://download.pytorch.org/whl/torch_stable.html + +and it throws the following error when running ``tdnn/train.py``: + +.. code-block:: + + OSError: libtorch_hip.so: cannot open shared object file: no such file or directory + +The fix is to specify the CUDA version while installing ``torchaudio``. That +is, change ``torchaudio==0.10.0`` to ``torchaudio==0.10.0+cu11```. Therefore, +the correct command is: + +.. code-block:: bash + + pip install \ + torch==1.10.0+cu111 \ + torchvision==0.11.0+cu111 \ + torchaudio==0.10.0+cu111 \ + -f https://download.pytorch.org/whl/torch_stable.html + +AttributeError: module 'distutils' has no attribute 'version' +------------------------------------------------------------- + +The error log is: + +.. code-block:: + + Traceback (most recent call last): + File "./tdnn/train.py", line 14, in + from asr_datamodule import YesNoAsrDataModule + File "/home/xxx/code/next-gen-kaldi/icefall/egs/yesno/ASR/tdnn/asr_datamodule.py", line 34, in + from icefall.dataset.datamodule import DataModule + File "/home/xxx/code/next-gen-kaldi/icefall/icefall/__init__.py", line 3, in + from . import ( + File "/home/xxx/code/next-gen-kaldi/icefall/icefall/decode.py", line 23, in + from icefall.utils import add_eos, add_sos, get_texts + File "/home/xxx/code/next-gen-kaldi/icefall/icefall/utils.py", line 39, in + from torch.utils.tensorboard import SummaryWriter + File "/home/xxx/tool/miniconda3/envs/yyy/lib/python3.8/site-packages/torch/utils/tensorboard/__init__.py", line 4, in + LooseVersion = distutils.version.LooseVersion + AttributeError: module 'distutils' has no attribute 'version' + +The fix is: + +.. code-block:: bash + + pip uninstall setuptools + + pip install setuptools==58.0.4 diff --git a/docs/source/index.rst b/docs/source/index.rst index 4ea446259..8d76eb68b 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -21,6 +21,7 @@ speech recognition recipes using `k2 `_. :caption: Contents: installation/index + faqs model-export/index .. toctree:: From c05f5d76df6e9cc208b99308a8e426e54e9be69e Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 10 Jan 2023 20:52:13 +0800 Subject: [PATCH 104/120] fix decoding for ncnn (#828) --- .../streaming-ncnn-decode.py | 8 +++++--- .../ASR/lstm_transducer_stateless2/ncnn-decode.py | 8 +++++--- .../lstm_transducer_stateless2/streaming-ncnn-decode.py | 8 +++++--- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py index b21fe5c7e..e4104a5bb 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py @@ -131,6 +131,8 @@ class Model: encoder_net = ncnn.Net() encoder_net.opt.use_packing_layout = False encoder_net.opt.use_fp16_storage = False + encoder_net.opt.num_threads = 4 + encoder_param = args.encoder_param_filename encoder_model = args.encoder_bin_filename @@ -144,6 +146,7 @@ class Model: decoder_model = args.decoder_bin_filename decoder_net = ncnn.Net() + decoder_net.opt.num_threads = 4 decoder_net.load_param(decoder_param) decoder_net.load_model(decoder_model) @@ -154,6 +157,8 @@ class Model: joiner_param = args.joiner_param_filename joiner_model = args.joiner_bin_filename joiner_net = ncnn.Net() + joiner_net.opt.num_threads = 4 + joiner_net.load_param(joiner_param) joiner_net.load_model(joiner_model) @@ -176,7 +181,6 @@ class Model: - next_states, a list of tensors containing the next states """ with self.encoder_net.create_extractor() as ex: - ex.set_num_threads(4) ex.input("in0", ncnn.Mat(x.numpy()).clone()) # layer0 in2-in5 @@ -220,7 +224,6 @@ class Model: assert decoder_input.dtype == torch.int32 with self.decoder_net.create_extractor() as ex: - ex.set_num_threads(4) ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone()) ret, ncnn_out0 = ex.extract("out0") assert ret == 0, ret @@ -229,7 +232,6 @@ class Model: def run_joiner(self, encoder_out, decoder_out): with self.joiner_net.create_extractor() as ex: - ex.set_num_threads(4) ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone()) ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone()) ret, ncnn_out0 = ex.extract("out0") diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py index 3b471fa85..3bd1b0a09 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py @@ -104,6 +104,8 @@ class Model: encoder_net = ncnn.Net() encoder_net.opt.use_packing_layout = False encoder_net.opt.use_fp16_storage = False + encoder_net.opt.num_threads = 4 + encoder_param = args.encoder_param_filename encoder_model = args.encoder_bin_filename @@ -118,6 +120,7 @@ class Model: decoder_net = ncnn.Net() decoder_net.opt.use_packing_layout = False + decoder_net.opt.num_threads = 4 decoder_net.load_param(decoder_param) decoder_net.load_model(decoder_model) @@ -129,6 +132,8 @@ class Model: joiner_model = args.joiner_bin_filename joiner_net = ncnn.Net() joiner_net.opt.use_packing_layout = False + joiner_net.opt.num_threads = 4 + joiner_net.load_param(joiner_param) joiner_net.load_model(joiner_model) @@ -136,7 +141,6 @@ class Model: def run_encoder(self, x, states): with self.encoder_net.create_extractor() as ex: - ex.set_num_threads(10) ex.input("in0", ncnn.Mat(x.numpy()).clone()) x_lens = torch.tensor([x.size(0)], dtype=torch.float32) ex.input("in1", ncnn.Mat(x_lens.numpy()).clone()) @@ -165,7 +169,6 @@ class Model: assert decoder_input.dtype == torch.int32 with self.decoder_net.create_extractor() as ex: - ex.set_num_threads(10) ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone()) ret, ncnn_out0 = ex.extract("out0") assert ret == 0, ret @@ -174,7 +177,6 @@ class Model: def run_joiner(self, encoder_out, decoder_out): with self.joiner_net.create_extractor() as ex: - ex.set_num_threads(10) ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone()) ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone()) ret, ncnn_out0 = ex.extract("out0") diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py index baff15ea6..02ed16a8c 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py @@ -92,6 +92,8 @@ class Model: encoder_net = ncnn.Net() encoder_net.opt.use_packing_layout = False encoder_net.opt.use_fp16_storage = False + encoder_net.opt.num_threads = 4 + encoder_param = args.encoder_param_filename encoder_model = args.encoder_bin_filename @@ -106,6 +108,7 @@ class Model: decoder_net = ncnn.Net() decoder_net.opt.use_packing_layout = False + decoder_net.opt.num_threads = 4 decoder_net.load_param(decoder_param) decoder_net.load_model(decoder_model) @@ -117,6 +120,8 @@ class Model: joiner_model = args.joiner_bin_filename joiner_net = ncnn.Net() joiner_net.opt.use_packing_layout = False + joiner_net.opt.num_threads = 4 + joiner_net.load_param(joiner_param) joiner_net.load_model(joiner_model) @@ -124,7 +129,6 @@ class Model: def run_encoder(self, x, states): with self.encoder_net.create_extractor() as ex: - # ex.set_num_threads(10) ex.input("in0", ncnn.Mat(x.numpy()).clone()) x_lens = torch.tensor([x.size(0)], dtype=torch.float32) ex.input("in1", ncnn.Mat(x_lens.numpy()).clone()) @@ -153,7 +157,6 @@ class Model: assert decoder_input.dtype == torch.int32 with self.decoder_net.create_extractor() as ex: - # ex.set_num_threads(10) ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone()) ret, ncnn_out0 = ex.extract("out0") assert ret == 0, ret @@ -162,7 +165,6 @@ class Model: def run_joiner(self, encoder_out, decoder_out): with self.joiner_net.create_extractor() as ex: - # ex.set_num_threads(10) ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone()) ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone()) ret, ncnn_out0 = ex.extract("out0") From 8582b6e41acbd1258633492212c06589d1370960 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 11 Jan 2023 15:34:30 +0800 Subject: [PATCH 105/120] Add doc about converting conv-emformer to sherpa-ncnn (#830) --- docs/source/conf.py | 6 + ...nv-emformer-transducer-for-ncnn-output.txt | 21 + ...-decode-conv-emformer-transducer-libri.txt | 7 + docs/source/model-export/export-ncnn.rst | 492 +++++++++++++++++- .../lstm_pruned_stateless_transducer.rst | 9 +- 5 files changed, 526 insertions(+), 9 deletions(-) create mode 100644 docs/source/model-export/code/export-conv-emformer-transducer-for-ncnn-output.txt create mode 100644 docs/source/model-export/code/test-stremaing-ncnn-decode-conv-emformer-transducer-libri.txt diff --git a/docs/source/conf.py b/docs/source/conf.py index 221d9d734..33429f74c 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -78,3 +78,9 @@ html_context = { } todo_include_todos = True + +rst_epilog = """ +.. _sherpa-ncnn: https://github.com/k2-fsa/sherpa-ncnn +.. _git-lfs: https://git-lfs.com/ +.. _ncnn: https://github.com/tencent/ncnn +""" diff --git a/docs/source/model-export/code/export-conv-emformer-transducer-for-ncnn-output.txt b/docs/source/model-export/code/export-conv-emformer-transducer-for-ncnn-output.txt new file mode 100644 index 000000000..ecbdd4b31 --- /dev/null +++ b/docs/source/model-export/code/export-conv-emformer-transducer-for-ncnn-output.txt @@ -0,0 +1,21 @@ +2023-01-11 12:15:38,677 INFO [export-for-ncnn.py:220] device: cpu +2023-01-11 12:15:38,681 INFO [export-for-ncnn.py:229] {'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_v +alid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 50, 'reset_interval': 200, 'valid_interval': 3000, 'feature_dim': 80, 'subsampl +ing_factor': 4, 'decoder_dim': 512, 'joiner_dim': 512, 'model_warm_step': 3000, 'env_info': {'k2-version': '1.23.2', 'k2-build-type': +'Release', 'k2-with-cuda': True, 'k2-git-sha1': 'a34171ed85605b0926eebbd0463d059431f4f74a', 'k2-git-date': 'Wed Dec 14 00:06:38 2022', + 'lhotse-version': '1.12.0.dev+missing.version.file', 'torch-version': '1.10.0+cu102', 'torch-cuda-available': False, 'torch-cuda-vers +ion': '10.2', 'python-version': '3.8', 'icefall-git-branch': 'fix-stateless3-train-2022-12-27', 'icefall-git-sha1': '530e8a1-dirty', ' +icefall-git-date': 'Tue Dec 27 13:59:18 2022', 'icefall-path': '/star-fj/fangjun/open-source/icefall', 'k2-path': '/star-fj/fangjun/op +en-source/k2/k2/python/k2/__init__.py', 'lhotse-path': '/star-fj/fangjun/open-source/lhotse/lhotse/__init__.py', 'hostname': 'de-74279 +-k2-train-3-1220120619-7695ff496b-s9n4w', 'IP address': '127.0.0.1'}, 'epoch': 30, 'iter': 0, 'avg': 1, 'exp_dir': PosixPath('icefa +ll-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp'), 'bpe_model': './icefall-asr-librispeech-conv-emformer-transdu +cer-stateless2-2022-07-05//data/lang_bpe_500/bpe.model', 'jit': False, 'context_size': 2, 'use_averaged_model': False, 'encoder_dim': +512, 'nhead': 8, 'dim_feedforward': 2048, 'num_encoder_layers': 12, 'cnn_module_kernel': 31, 'left_context_length': 32, 'chunk_length' +: 32, 'right_context_length': 8, 'memory_size': 32, 'blank_id': 0, 'vocab_size': 500} +2023-01-11 12:15:38,681 INFO [export-for-ncnn.py:231] About to create model +2023-01-11 12:15:40,053 INFO [checkpoint.py:112] Loading checkpoint from icefall-asr-librispeech-conv-emformer-transducer-stateless2-2 +022-07-05/exp/epoch-30.pt +2023-01-11 12:15:40,708 INFO [export-for-ncnn.py:315] Number of model parameters: 75490012 +2023-01-11 12:15:41,681 INFO [export-for-ncnn.py:318] Using torch.jit.trace() +2023-01-11 12:15:41,681 INFO [export-for-ncnn.py:320] Exporting encoder +2023-01-11 12:15:41,682 INFO [export-for-ncnn.py:149] chunk_length: 32, right_context_length: 8 diff --git a/docs/source/model-export/code/test-stremaing-ncnn-decode-conv-emformer-transducer-libri.txt b/docs/source/model-export/code/test-stremaing-ncnn-decode-conv-emformer-transducer-libri.txt new file mode 100644 index 000000000..114fe7342 --- /dev/null +++ b/docs/source/model-export/code/test-stremaing-ncnn-decode-conv-emformer-transducer-libri.txt @@ -0,0 +1,7 @@ +2023-01-11 14:02:12,216 INFO [streaming-ncnn-decode.py:320] {'tokens': './icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/data/lang_bpe_500/tokens.txt', 'encoder_param_filename': './icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.param', 'encoder_bin_filename': './icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.bin', 'decoder_param_filename': './icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.param', 'decoder_bin_filename': './icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.bin', 'joiner_param_filename': './icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.param', 'joiner_bin_filename': './icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.bin', 'sound_filename': './icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/test_wavs/1089-134686-0001.wav'} +T 51 32 +2023-01-11 14:02:13,141 INFO [streaming-ncnn-decode.py:328] Constructing Fbank computer +2023-01-11 14:02:13,151 INFO [streaming-ncnn-decode.py:331] Reading sound files: ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/test_wavs/1089-134686-0001.wav +2023-01-11 14:02:13,176 INFO [streaming-ncnn-decode.py:336] torch.Size([106000]) +2023-01-11 14:02:17,581 INFO [streaming-ncnn-decode.py:380] ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/test_wavs/1089-134686-0001.wav +2023-01-11 14:02:17,581 INFO [streaming-ncnn-decode.py:381] AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS diff --git a/docs/source/model-export/export-ncnn.rst b/docs/source/model-export/export-ncnn.rst index 3dbb8b514..11471d611 100644 --- a/docs/source/model-export/export-ncnn.rst +++ b/docs/source/model-export/export-ncnn.rst @@ -1,12 +1,492 @@ Export to ncnn ============== -We support exporting LSTM transducer models to `ncnn `_. - -Please refer to :ref:`export-model-for-ncnn` for details. +We support exporting both +`LSTM transducer models `_ +and +`ConvEmformer transducer models `_ +to `ncnn `_. We also provide ``_ performing speech recognition using ``ncnn`` with exported models. -It has been tested on Linux, macOS, Windows, and Raspberry Pi. The project is -self-contained and can be statically linked to produce a binary containing -everything needed. +It has been tested on Linux, macOS, Windows, ``Android``, and ``Raspberry Pi``. + +`sherpa-ncnn`_ is self-contained and can be statically linked to produce +a binary containing everything needed. Please refer +to its documentation for details: + + - ``_ + + +Export LSTM transducer models +----------------------------- + +Please refer to :ref:`export-lstm-transducer-model-for-ncnn` for details. + + + +Export ConvEmformer transducer models +------------------------------------- + +We use the pre-trained model from the following repository as an example: + + - ``_ + +We will show you step by step how to export it to `ncnn`_ and run it with `sherpa-ncnn`_. + +.. hint:: + + We use ``Ubuntu 18.04``, ``torch 1.10``, and ``Python 3.8`` for testing. + +.. caution:: + + Please use a more recent version of PyTorch. For instance, ``torch 1.8`` + may ``not`` work. + +1. Download the pre-trained model +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. hint:: + + You can also refer to ``_ to download the pre-trained model. + + You have to install `git-lfs`_ before you continue. + +.. code-block:: bash + + cd egs/librispeech/ASR + + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05 + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05 + + git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt" + git lfs pull --include "data/lang_bpe_500/bpe.model" + + cd .. + +.. note:: + + We download ``exp/pretrained-xxx.pt``, not ``exp/cpu-jit_xxx.pt``. + + +In the above code, we download the pre-trained model into the directory +``egs/librispeech/ASR/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05``. + +2. Install ncnn and pnnx +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: bash + + # We put ncnn into $HOME/open-source/ncnn + # You can change it to anywhere you like + + cd $HOME + mkdir -p open-source + cd open-source + + git clone https://github.com/csukuangfj/ncnn + cd ncnn + git submodule update --recursive --init + + # Note: We don't use "python setup.py install" or "pip install ." here + + mkdir -p build-wheel + cd build-wheel + + cmake \ + -DCMAKE_BUILD_TYPE=Release \ + -DNCNN_PYTHON=ON \ + -DNCNN_BUILD_BENCHMARK=OFF \ + -DNCNN_BUILD_EXAMPLES=OFF \ + -DNCNN_BUILD_TOOLS=ON \ + .. + + make -j4 + + cd .. + + # Note: $PWD here is $HOME/open-source/ncnn + + export PYTHONPATH=$PWD/python:$PYTHONPATH + export PATH=$PWD/tools/pnnx/build/src:$PATH + export PATH=$PWD/build-wheel/tools/quantize:$PATH + + # Now build pnnx + cd tools/pnnx + mkdir build + cd build + cmake .. + make -j4 + + ./src/pnnx + +Congratulations! You have successfully installed the following components: + + - ``pnxx``, which is an executable located in + ``$HOME/open-source/ncnn/tools/pnnx/build/src``. We will use + it to convert models exported by ``torch.jit.trace()``. + - ``ncnn2int8``, which is an executable located in + ``$HOME/open-source/ncnn/build-wheel/tools/quantize``. We will use + it to quantize our models to ``int8``. + - ``ncnn.cpython-38-x86_64-linux-gnu.so``, which is a Python module located + in ``$HOME/open-source/ncnn/python/ncnn``. + + .. note:: + + I am using ``Python 3.8``, so it + is ``ncnn.cpython-38-x86_64-linux-gnu.so``. If you use a different + version, say, ``Python 3.9``, the name would be + ``ncnn.cpython-39-x86_64-linux-gnu.so``. + + Also, if you are not using Linux, the file name would also be different. + But that does not matter. As long as you can compile it, it should work. + +We have set up ``PYTHONPATH`` so that you can use ``import ncnn`` in your +Python code. We have also set up ``PATH`` so that you can use +``pnnx`` and ``ncnn2int8`` later in your terminal. + +.. caution:: + + Please don't use ``_. + We have made some modifications to the offical `ncnn`_. + + We will synchronize ``_ periodically + with the official one. + +3. Export the model via torch.jit.trace() +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +First, let us rename our pre-trained model: + +.. code-block:: + + cd egs/librispeech/ASR + + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp + + ln -s pretrained-epoch-30-avg-10-averaged.pt epoch-30.pt + + cd ../.. + +Next, we use the following code to export our model: + +.. code-block:: bash + + dir=./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/ + + ./conv_emformer_transducer_stateless2/export-for-ncnn.py \ + --exp-dir $dir/exp \ + --bpe-model $dir/data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 1 \ + --use-averaged-model 0 \ + \ + --num-encoder-layers 12 \ + --chunk-length 32 \ + --cnn-module-kernel 31 \ + --left-context-length 32 \ + --right-context-length 8 \ + --memory-size 32 \ + --encoder-dim 512 + +.. hint:: + + We have renamed our model to ``epoch-30.pt`` so that we can use ``--epoch 30``. + There is only one pre-trained model, so we use ``--avg 1 --use-averaged-model 0``. + + If you have trained a model by yourself and if you have all checkpoints + available, please first use ``decode.py`` to tune ``--epoch --avg`` + and select the best combination with with ``--use-averaged-model 1``. + +.. note:: + + You will see the following log output: + + .. literalinclude:: ./code/export-conv-emformer-transducer-for-ncnn-output.txt + + The log shows the model has ``75490012`` number of parameters, i.e., ``~75 M``. + + .. code-block:: + + ls -lh icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/pretrained-epoch-30-avg-10-averaged.pt + + -rw-r--r-- 1 kuangfangjun root 289M Jan 11 12:05 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/pretrained-epoch-30-avg-10-averaged.pt + + You can see that the file size of the pre-trained model is ``289 MB``, which + is roughly ``4 x 75 M``. + +After running ``conv_emformer_transducer_stateless2/export-for-ncnn.py``, +we will get the following files: + +.. code-block:: bash + + ls -lh icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/*pnnx* + + -rw-r--r-- 1 kuangfangjun root 1010K Jan 11 12:15 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.pt + -rw-r--r-- 1 kuangfangjun root 283M Jan 11 12:15 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.pt + -rw-r--r-- 1 kuangfangjun root 3.0M Jan 11 12:15 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.pt + + +.. _conv-emformer-step-3-export-torchscript-model-via-pnnx: + +3. Export torchscript model via pnnx +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. hint:: + + Make sure you have set up the ``PATH`` environment variable. Otherwise, + it will throw an error saying that ``pnnx`` could not be found. + +Now, it's time to export our models to `ncnn`_ via ``pnnx``. + +.. code-block:: + + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ + + pnnx ./encoder_jit_trace-pnnx.pt + pnnx ./decoder_jit_trace-pnnx.pt + pnnx ./joiner_jit_trace-pnnx.pt + +It will generate the following files: + +.. code-block:: bash + + ls -lh icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/*ncnn*{bin,param} + + -rw-r--r-- 1 kuangfangjun root 503K Jan 11 12:38 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 437 Jan 11 12:38 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 142M Jan 11 12:36 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 79K Jan 11 12:36 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 1.5M Jan 11 12:38 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 488 Jan 11 12:38 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.param + +There are two types of files: + +- ``param``: It is a text file containing the model architectures. You can + use a text editor to view its content. +- ``bin``: It is a binary file containing the model parameters. + +We compare the file sizes of the models below before and after converting via ``pnnx``: + +.. see https://tableconvert.com/restructuredtext-generator + ++----------------------------------+------------+ +| File name | File size | ++==================================+============+ +| encoder_jit_trace-pnnx.pt | 283 MB | ++----------------------------------+------------+ +| decoder_jit_trace-pnnx.pt | 1010 KB | ++----------------------------------+------------+ +| joiner_jit_trace-pnnx.pt | 3.0 MB | ++----------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin | 142 MB | ++----------------------------------+------------+ +| decoder_jit_trace-pnnx.ncnn.bin | 503 KB | ++----------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin | 1.5 MB | ++----------------------------------+------------+ + +You can see that the file size of the models after converting is about one half +of the models before converting: + + - encoder: 283 MB vs 142 MB + - decoder: 1010 KB vs 503 KB + - joiner: 3.0 MB vs 1.5 MB + +The reason is that by default ``pnnx`` converts ``float32`` parameters +to ``float16``. A ``float32`` parameter occupies 4 bytes, while it is 2 bytes +for ``float16``. Thus, it is ``twice smaller`` after conversion. + +.. hint:: + + If you use ``pnnx ./encoder_jit_trace-pnnx.pt fp16=0``, then ``pnnx`` + won't convert ``float32`` to ``float16``. + +4. Test the exported models in icefall +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. note:: + + We assume you have set up the environment variable ``PYTHONPATH`` when + building `ncnn`_. + +Now we have successfully converted our pre-trained model to `ncnn`_ format. +The generated 6 files are what we need. You can use the following code to +test the converted models: + +.. code-block:: bash + + ./conv_emformer_transducer_stateless2/streaming-ncnn-decode.py \ + --tokens ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/data/lang_bpe_500/tokens.txt \ + --encoder-param-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.bin \ + ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/test_wavs/1089-134686-0001.wav + +.. hint:: + + `ncnn`_ supports only ``batch size == 1``, so ``streaming-ncnn-decode.py`` accepts + only 1 wave file as input. + +The output is given below: + +.. literalinclude:: ./code/test-stremaing-ncnn-decode-conv-emformer-transducer-libri.txt + +Congratulations! You have successfully exported a model from PyTorch to `ncnn`_! + + +5. Modify the exported encoder for sherpa-ncnn +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In order to use the exported models in `sherpa-ncnn`_, we have to modify +``encoder_jit_trace-pnnx.ncnn.param``. + +Let us have a look at the first few lines of ``encoder_jit_trace-pnnx.ncnn.param``: + +.. code-block:: + + 7767517 + 1060 1342 + Input in0 0 1 in0 + +**Explanation** of the above three lines: + + 1. ``7767517``, it is a magic number and should not be changed. + 2. ``1060 1342``, the first number ``1060`` specifies the number of layers + in this file, while ``1342`` specifies the number intermediate outputs of + this file + 3. ``Input in0 0 1 in0``, ``Input`` is the layer type of this layer; ``in0`` + is the layer name of this layer; ``0`` means this layer has no input; + ``1`` means this layer has one output. ``in0`` is the output name of + this layer. + +We need to add 1 extra line and the result looks like below: + +.. code-block:: bash + + 7767517 + 1061 1342 + SherpaMetaData sherpa_meta_data1 0 0 0=1 1=12 2=32 3=31 4=8 5=32 6=8 7=512 + Input in0 0 1 in0 + +**Explanation** + + 1. ``7767517``, it is still the same + 2. ``1061 1342``, we have added an extra layer, so we need to update ``1060`` to ``1061``. + We don't need to change ``1342`` since the newly added layer has no inputs and outputs. + 3. ``SherpaMetaData sherpa_meta_data1 0 0 0=1 1=12 2=32 3=31 4=8 5=32 6=8 7=512`` + This line is newly added. Its explanation is given below: + + - ``SherpaMetaData`` is the type of this layer. Must be ``SherpaMetaData``. + - ``sherpa_meta_data1`` is the name of this layer. Must be ``sherpa_meta_data1``. + - ``0 0`` means this layer has no inputs and output. Must be ``0 0`` + - ``0=1``, 0 is the key and 1 is the value. MUST be ``0=1`` + - ``1=12``, 1 is the key and 12 is the value of the + parameter ``--num-encoder-layers`` that you provided when running + ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. + - ``2=32``, 2 is the key and 32 is the value of the + parameter ``--memory-size`` that you provided when running + ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. + - ``3=31``, 3 is the key and 31 is the value of the + parameter ``--cnn-module-kernel`` that you provided when running + ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. + - ``4=8``, 4 is the key and 8 is the value of the + parameter ``--left-context-length`` that you provided when running + ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. + - ``5=32``, 5 is the key and 32 is the value of the + parameter ``--chunk-length`` that you provided when running + ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. + - ``6=8``, 6 is the key and 8 is the value of the + parameter ``--right-context-length`` that you provided when running + ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. + - ``7=512``, 7 is the key and 512 is the value of the + parameter ``--encoder-dim`` that you provided when running + ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. + + For ease of reference, we list the key-value pairs that you need to add + in the following table. If your model has a different setting, please + change the values for ``SherpaMetaData`` accordingly. Otherwise, you + will be ``SAD``. + + +------+-----------------------------+ + | key | value | + +======+=============================+ + | 0 | 1 (fixed) | + +------+-----------------------------+ + | 1 | ``--num-encoder-layers`` | + +------+-----------------------------+ + | 2 | ``--memory-size`` | + +------+-----------------------------+ + | 3 | ``--cnn-module-kernel`` | + +------+-----------------------------+ + | 4 | ``--left-context-length`` | + +------+-----------------------------+ + | 5 | ``--chunk-length`` | + +------+-----------------------------+ + | 6 | ``--right-context-length`` | + +------+-----------------------------+ + | 7 | ``--encoder-dim`` | + +------+-----------------------------+ + + 4. ``Input in0 0 1 in0``. No need to change it. + +.. caution:: + + When you add a new layer ``SherpaMetaData``, please remember to update the + number of layers. In our case, update ``1060`` to ``1061``. Otherwise, + you will be SAD later. + +.. hint:: + + After adding the new layer ``SherpaMetaData``, you cannot use this model + with ``streaming-ncnn-decode.py`` anymore since ``SherpaMetaData`` is + supported only in `sherpa-ncnn`_. + +.. hint:: + + `ncnn`_ is very flexible. You can add new layers to it just by text-editing + the ``param`` file! You don't need to change the ``bin`` file. + +Now you can use this model in `sherpa-ncnn`_. +Please refer to the following documentation: + + - Linux/macOS/Windows/arm/aarch64: ``_ + - Android: ``_ + - Python: ``_ + +We have a list of pre-trained models that have been exported for `sherpa-ncnn`_: + + - ``_ + + You can find more usages there. + +6. (Optional) int8 quantization with sherpa-ncnn +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +This step is optional. + +In this step, we describe how to quantize our model with ``int8``. + +Change :ref:`conv-emformer-step-3-export-torchscript-model-via-pnnx` to +disable ``fp16`` when using ``pnnx``: + +.. code-block:: + + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ + + pnnx ./encoder_jit_trace-pnnx.pt fp16=0 + pnnx ./decoder_jit_trace-pnnx.pt + pnnx ./joiner_jit_trace-pnnx.pt fp16=0 + +.. note:: + + We add ``fp16=0`` when exporting the encoder and joiner. ``ncnn`` does not + support quantizing the decoder model yet. We will update this documentation + once ``ncnn`` supports it. (Maybe in this year, 2023). + +TODO(fangjun): Finish it. + +Have fun with `sherpa-ncnn`_! diff --git a/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst b/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst index 22addd1d2..ce8ba1453 100644 --- a/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst +++ b/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst @@ -515,10 +515,10 @@ To use the generated files with ``./lstm_transducer_stateless2/jit_pretrained``: Please see ``_ for how to use the exported models in ``sherpa``. -.. _export-model-for-ncnn: +.. _export-lstm-transducer-model-for-ncnn: -Export model for ncnn -~~~~~~~~~~~~~~~~~~~~~ +Export LSTM transducer models for ncnn +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ We support exporting pretrained LSTM transducer models to `ncnn `_ using @@ -657,3 +657,6 @@ by visiting the following links: You can find more usages of the pretrained models in ``_ + +Export ConvEmformer transducer models for ncnn +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ From 142420b3afa7b07c95f733c2e72ee80078364a44 Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Wed, 11 Jan 2023 16:45:24 +0800 Subject: [PATCH 106/120] Add docs for distillation (#812) * add README to docs * update documents for distillation * upload png files --- .../librispeech/distillation.rst | 220 ++++++++++++++++++ .../images/distillation_codebook.png | Bin 0 -> 57170 bytes .../images/distillation_directory.png | Bin 0 -> 43816 bytes .../Non-streaming-ASR/librispeech/index.rst | 1 + .../ASR/distillation_with_hubert.sh | 6 +- 5 files changed, 224 insertions(+), 3 deletions(-) create mode 100644 docs/source/recipes/Non-streaming-ASR/librispeech/distillation.rst create mode 100644 docs/source/recipes/Non-streaming-ASR/librispeech/images/distillation_codebook.png create mode 100644 docs/source/recipes/Non-streaming-ASR/librispeech/images/distillation_directory.png diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/distillation.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/distillation.rst new file mode 100644 index 000000000..aa379c3f8 --- /dev/null +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/distillation.rst @@ -0,0 +1,220 @@ +Distillation with HuBERT +======================== + +This totorial shows you how to perform knowledge distillation in ``icefall`` +with the `LibriSpeech `_ dataset. The distillation method +used here is called "Multi Vector Quantization Knowledge Distillation" (MVQ-KD). +Please have a look at our paper `Predicting Multi-Codebook Vector Quantization Indexes for Knowledge Distillation `_ +for more details about MVQ-KD. + +.. note:: + + This tutorial is based on recipe + `pruned_transducer_stateless4 `_. + Currently, we only implement MVQ-KD in this recipe. However, MVQ-KD is theoretically applicable to all recipes + with only minor changes needed. Feel free to try out MVQ-KD in different recipes. If you + encounter any problems, please open an issue here `icefall `_. + +.. note:: + + We assume you have read the page :ref:`install icefall` and have setup + the environment for ``icefall``. + +.. HINT:: + + We recommend you to use a GPU or several GPUs to run this recipe. + +Data preparation +---------------- + +We first prepare necessary training data for ``LibriSpeech``. +This is the same as in `Pruned_transducer_statelessX <./pruned_transducer_stateless.rst>`_. + +.. hint:: + + The data preparation is the same as other recipes on LibriSpeech dataset, + if you have finished this step, you can skip to ``Codebook index preparation`` directly. + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./prepare.sh + +The script ``./prepare.sh`` handles the data preparation for you, **automagically**. +All you need to do is to run it. + +The data preparation contains several stages, you can use the following two +options: + + - ``--stage`` + - ``--stop-stage`` + +to control which stage(s) should be run. By default, all stages are executed. + +For example, + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./prepare.sh --stage 0 --stop-stage 0 # run only stage 0 + $ ./prepare.sh --stage 2 --stop-stage 5 # run from stage 2 to stage 5 + +.. HINT:: + + If you have pre-downloaded the `LibriSpeech `_ + dataset and the `musan `_ dataset, say, + they are saved in ``/tmp/LibriSpeech`` and ``/tmp/musan``, you can modify + the ``dl_dir`` variable in ``./prepare.sh`` to point to ``/tmp`` so that + ``./prepare.sh`` won't re-download them. + +.. NOTE:: + + All generated files by ``./prepare.sh``, e.g., features, lexicon, etc, + are saved in ``./data`` directory. + +We provide the following YouTube video showing how to run ``./prepare.sh``. + +.. note:: + + To get the latest news of `next-gen Kaldi `_, please subscribe + the following YouTube channel by `Nadira Povey `_: + + ``_ + +.. youtube:: ofEIoJL-mGM + + +Codebook index preparation +-------------------------- + +Here, we prepare necessary data for MVQ-KD. This requires the generation +of codebook indexes (please read our `paper `_. +if you are interested in details). In this tutorial, we use the pre-computed +codebook indexes for convenience. The only thing you need to do is to +run ``./distillation_with_hubert.sh``. + +.. note:: + There are 5 stages in total, the first and second stage will be automatically skipped + when choosing to downloaded codebook indexes prepared by `icefall`_. + Of course, you can extract and compute the codebook indexes by yourself. This + will require you downloading a HuBERT-XL model and it can take a while for + the extraction of codebook indexes. + + +As usual, you can control the stages you want to run by specifying the following +two options: + + - ``--stage`` + - ``--stop-stage`` + +For example, + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./distillation_with_hubert.sh --stage 0 --stop-stage 0 # run only stage 0 + $ ./distillation_with_hubert.sh --stage 2 --stop-stage 4 # run from stage 2 to stage 5 + +Here are a few options in ``./distillation_with_hubert.sh`` +you need to know before you proceed. + +- ``--full_libri`` If True, use full 960h data. Otherwise only ``train-clean-100`` will be used +- ``--use_extracted_codebook`` If True, the first two stages will be skipped and the codebook + indexes uploaded by us will be downloaded. + +Since we are using the pre-computed codebook indexes, we set +``use_extracted_codebook=True``. If you want to do full `LibriSpeech`_ +experiments, please set ``full_libri=True``. + +The following command downloads the pre-computed codebook indexes +and prepares MVQ-augmented training manifests. + +.. code-block:: bash + + $ ./distillation_with_hubert.sh --stage 2 --stop-stage 2 # run only stage 2 + +Please see the +following screenshot for the output of an example execution. + +.. figure:: ./images/distillation_codebook.png + :width: 800 + :alt: Downloading codebook indexes and preparing training manifest. + :align: center + + Downloading codebook indexes and preparing training manifest. + +.. hint:: + + The codebook indexes we prepared for you in this tutorial + are extracted from the 36-th layer of a fine-tuned HuBERT-XL model + with 8 codebooks. If you want to try other configurations, please + set ``use_extracted_codebook=False`` and set ``embedding_layer`` and + ``num_codebooks`` by yourself. + +Now, you should see the following files under the direcory ``./data/vq_fbank_layer36_cb8``. + +.. figure:: ./images/distillation_directory.png + :width: 800 + :alt: MVQ-augmented training manifests + :align: center + + MVQ-augmented training manifests. + +Whola! You are ready to perform knowledge distillation training now! + +Training +-------- + +To perform training, please run stage 3 by executing the following command. + +.. code-block:: bash + + $ ./prepare.sh --stage 3 --stop-stage 3 # run MVQ training + +Here is the code snippet for training: + +.. code-block:: bash + + WORLD_SIZE=$(echo ${CUDA_VISIBLE_DEVICES} | awk '{n=split($1, _, ","); print n}') + + ./pruned_transducer_stateless6/train.py \ + --manifest-dir ./data/vq_fbank_layer36_cb8 \ + --master-port 12359 \ + --full-libri $full_libri \ + --spec-aug-time-warp-factor -1 \ + --max-duration 300 \ + --world-size ${WORLD_SIZE} \ + --num-epochs 30 \ + --exp-dir $exp_dir \ + --enable-distillation True \ + --codebook-loss-scale 0.01 + +There are a few training arguments in the following +training commands that should be paid attention to. + - ``--enable-distillation`` If True, knowledge distillation training is enabled. + - ``--codebook-loss-scale`` The scale of the knowledge distillation loss. + - ``--manifest-dir`` The path to the MVQ-augmented manifest. + + +Decoding +-------- + +After training finished, you can test the performance on using +the following command. + +.. code-block:: bash + + export CUDA_VISIBLE_DEVICES=0 + ./pruned_transducer_stateless6/train.py \ + --decoding-method "modified_beam_search" \ + --epoch 30 \ + --avg 10 \ + --max-duration 200 \ + --exp-dir $exp_dir \ + --enable-distillation True + +You should get similar results as `here `_. + +That's all! Feel free to experiment with your own setups and report your results. +If you encounter any problems during training, please open up an issue `here `_. + diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/images/distillation_codebook.png b/docs/source/recipes/Non-streaming-ASR/librispeech/images/distillation_codebook.png new file mode 100644 index 0000000000000000000000000000000000000000..1a40d6c6ec3caf9c551b62f6c1ad6c1058823048 GIT binary patch literal 57170 zcmbrm2UJtrw?2x6!x02*5Rj%=K)^&m=_G<8DhPr|krL@D1f;hV6cMG0s1Yeq5fJGh zp(hmSB_N#uK`EiPKp>>O*v|Rg`+xU8?i+890fxWh);pE`j1gAzh04}EpeIQR%0{4Y zAuv&bz<%l?xNyMK;=lv^}?fWL54 z;8~OX8jHuK;&TCy*{PbjfQBa;Ai&B7?RqoOYpq9cROFf{rtJ_FBoqyt(qxv-nj-JL z^=7S~*Iu8e(oi=NZFw(NOse#*hEz;WyfG?4QE2YqyUTtru=ns-*8*$hFZK29s)Ggd zXzLZfM51e>A0++#Mt<*x7(G}bo;}!x&PFO?NoB0vU;h3n7{jR7Sg!ZN%)lnum35HY zV!XasebBTp6BE>sui6=sk%$-oiQ$P<>S`A7>X(S%C4_BG^Z~$9kHZJP2t)8U;&Dgj zPP@%0_p%HFFtDd3cAlb-g86+@xoKHwcr&@fVRAoBu9QHn7AyqnIzeAmra;=_z6_K( z;q1t(-1Fz~mLXTQtc++H&9!^~h-Cg4+S?ofm8%h;3clO#{c%ws6aM!SzKtxAQ(a#+ zBZ7)Jqw&_FXcwyw8@X7V%g1gmEX>_%rA$N`ZR3m2B0k=})3ASk%Nnfgqk)D;e(v^F zcSUd-*4yj+b^z(yxpLL0c%FG@vQysnltv-2?i1pgv9)33wbPysPz>SZUd13TG^41T zWe(H-ZGKPZGP7s(vnKAiwgB9c>E0loVKJc4-{6(M&{D;_f{#;7#faU$_GTEE?E$DQ z(DJTUn^t5Dfv{^1A}krNKzbZi@nM2GY=jR~JQ6A8s`$D@zpG5cLJZXh;t7+G16vUv z zg0IAQy8qs+u@RV~cERW!d`h&~6iI+&7hhs~HP*|6=xPnH{qWj#fUo~R&jo?mHC7P( z8cVe>#)+n4(|1CwbikbNJP3WZSfqWJJfl@MhYl&MkX#!{SFAXqO}_@$4kMp|9Hod0 z|B&LIz|vy;ll?yh*wT(h^STzE;+dIpcp_M&j|!;~5G?b)oyX&=-W<#0hIm5VuN@H< z;N3-VS}NI4Nd{j>Q2h$SA7Pj8)m~MC;32JSLWbh{K|X(*<87tiJ^_c+4jKTbKS37N zMlEo$(^Nj}V^8SO*7;ZyGy*@umm8%;Sv2rgwJE;1G{3ih0wa_ig1pNt8XU%DANj46 z%`TwOy&aiD>M#Ck07#i0QxEP;6Wr*F*fc_aZGwqSCV{)ll8AFhtH7Ccqr;Wid@=D3 z4shTEm9Jk38guPw-^hRqe84{m_^s@X{+D;l^dXhGA%Kq;eZjQ0H2-sXkalNZqv1lO zW+G4r^EGuUj|o)khO2cZ($QzsZAE!!L%c>o#Tx-&$ZO&?n}Ml-rr*2luSSg7TV0ge zie0!s)>a;X3&OixInG$f&v!o6yAj?E*`>Jir&>XE_p{EEM6KydVqD`Ebr79iK+c*3 zWu|G7%?t9)3KiM$K=A7-+YRJIg5THuJvMTGK68y9C^O^Qe1H@~tZgpLqFXmmH$O0Rn<`oS zaHz{>DXRk%**@$sv7c(|Y=O?kKU*1^NbNB0rkLkV0%L-h&_lGcjj_N^JDw-q198qU zoc_W=X-Y+IDS5mY!Mw3`3Q6N%y&%vT)Tp^6M~H>0kkb5Us;)6jMM_XqRt<_=rS|PD z>Qt*iho7lZ>YIGZ zhfe1r!G{egC3++8+kBq-t-{KJWs%e}^`L0~@|y$&nXA^{k|(|g<73>)+uZKB?qyrG zyNG$=EvIa(ma|9P@4WT5RAN#`JL$mmX50ria{|l}N?jv*kZ$fvGTii+&}ymsT?XPAHfbxpHKGYTUe~ zELeUO-b7wg9fklU9KyT=pIvG3JvIKWogXUZV7p+i2d_pRxuv-S1Y^8iV^?deA4>Tc=9V#JJeJ=;9&ZC?+hlr|8Z!>Sf*6TP)Uo5H9%xoeNF$8>Af90 z{XThMkz#F&ijNw50f+~d^ZPMUfwZ!W8;c}3ILUZ8L?S(RIDb)ep?~k{${7m3iWBk_ zY*eeM*y&mVKMQ$x|4`3{k=q(2ZSekNZ)Mdw-6jqD@M|W;+I3fqRkQI}WeQocbF2r_ z<0ElMyEqxzS3H%9Tq4N83JH&~VKh6{o}QVy>+E4kd*3V{4S>p|{u>{x`TKLzoddL~ z(21ep+~n9E$OC>mV(n|b9kIZ#*!`PfD!{KdQk&=(RiM2mmGo1ZMzA)7T0X2y&`Iww zFFRr$?9C02>3Qp;L&4XkSrDVefCc~Rvfe%{aEZBdiN&H=Rhk`fUcZ>>ZE_~;c(J#%FnQv3ZnLt} z#L}sR)Tte^wZ1C59rB{cZcvwxQGkR#HHodLJolYfdGe)*oUG?m=~;iycFil%$lnG7*>=n3E6swtCdvi4 zDy}_&za(1F7xM1Tp_(?K@>QIgpAFL)C3>4 zZVO+N#kZqoCY$t!3CM%m$LLkG52%G=Foc;*+;EjC7B4b~56?TeJD|`saZ*U|0J;0H zfzK5})1vk)`~g40TEX#1Z^MiYctE{~=Aa$5_Z)OSsiSoGq#JI-dLlzy0mD>D3wZ82 zq=vH1f9K~WQvInKKrU!!4iVd+n)4rHtwU6)D$cA0qO=t4BI20XTJ@3AqfsI5Va+QD zf4C7f)h=VAxCW;&cC7fpo#0JEeyXCfm|UvSP+9xqGT0)cDBGYa;xhhD%*|;^i)$ib zJHOp)ePi(z)dV1~FXH@#a)nLoJPiC6-0WAb1P#opshe|<4xgF#E{o?Xy>L%0_SDr^ zkF3G&!1kuvTb3)Y+()#X*NPofdrtHxCj!&*9UW{ZL%)8P#tckeb&q)MJ+d9sU~s9w zW$WU)xN3T6kdl<(Y*C!URiCQMLaC|-uz!NP&U1O_1r9J>QhU# zRkW5RNK?kA55&I&naLhVbL^R#+`RA9@NE_N9`eV3r#jm){$_ za-D@gO*L|#*P4Y}5%1RXkQ+y)<8#HU?u4Ye($pj!i2yIkkyY|b_26U?eJ>C-K6Y4N z1?wUH6|0{kbaYN^Xq@KL@%Z>s7QYXMhMiU~bHtqd{H|awWD!s&r3ytodg<4;%|{%B zL=d7otAUO==I;BLi;~})54#6!JC$FY$OWcDN`;8G6fSQLXuCxppFdj6{0?7AuX}c0 zYc*bWJ5l-D8IKx~CZ)?B>CvN!1ceXRGwLKtyIv_u$zX|(+Fs8kZfn%fcyCv3MS(%H zSEAZHgn6P`npM;&*GZq^B$N&2E>EXpkD>*B;J@a44_v2!@F#qDV{!X2VGM?+7d29F z%nHX3@VnBqSdxNVbvS4q?bEphp^7 z5>GpU$E^3~zVzkSLw*ZZD4}cJT3T=U1-X@I;O&j1!08R@O0MgyLs>0@mt* z-PE%L;l^(94VqRSzp9{%+qSUQkljcYG!fL^dxOq|3FX>?wZktwK<5`&aiCc-TQ2y%5f6W zGpR$%#kD%Sx@ZVQ$(bx+yN^H|Qk+(A5TxJZ3Ztv5Z?`u@ zK2I1fAQvZG5m?%iby*JwsZy+zJiH|M!FJ*@PV>g4-2*h*)F*zWPF!~K;HPu_jU8l7 z`U(1@!GO;an5<=EwV#W>XX-nAG<~Q1vZS3}marPlCcJ4v;rp3^vRF8dmyrGZ?#mzg zRuVtrR=HJw(!L>~g<5sLmTk4DP927lc1jsxy^RCx6Kzg|dpN*;e`1Fu z3>d3$?Da?Qh6+IwpIDzKe4#XrHzC*-8dd0?VGmr&uZ8(M zA7U(3q_a7IDQ7PQ%aC5Fzd?P3R+_27?sLM$;;q8bWqbMHG0*Z>JAbP3A zorfjAs#4n`f!UNS1MdAbd^dbfel|46sMfBkX+G}S(fo8;%vq`Pufcd*n z!X3PQeiWmHM_Cz2ptP#V#K~*S(VSvwl}D@{v!IXt<(n{|&ku%mp%#gXg+Q+n3Qb!3 zAk&p{%^R6a)U>#B$@Y!SUA}>JA!25ba5QCofuU3%4QX=0rH zMbg|m#rm&Ka$uH@OZ`Vw@#2x?p5RSHAl{wGF#pdGVTPv0Yjdl_?a-sC|JobkU3Dz*~JlA7?{8}G$!nVdu zr~JYDO}pl}kbC3STy{yA>nr0DI6*j7W8Bn!vSjVx|2~!bli1aq5C#cJuC^&&Jjr5+ z{JFcLT$}zY^ZaoK_~1YD#X1|g%kLm8L|bsyTef5dt8!?1*$Ih&@i&*XnUN$9H&YguoHNJvxa%U>)QUAw*Gi}yw3#9Yc`mKIS|O6 z#3+0OaE+o+Z?S18W~taFb(ML}O%o&pzW?se$Xtx`!|6K#^!1*?><^nfN7#kzb(=iC)^fz{tOq2zQo*&SWDPN zUN-9*Gzj-UdWHpl)ThGqqe<{nGY|iHj%%gJ7yq$PczYWUnc)jv4U({6dtZ20`TEK5 zwn4ZG4@>?RvB!CS<=~nBum+dgxamKW9!jhmST^7;PHPLN9X5TWD(1cl z9HE=)&=+1T(kcdDVpB9wdjY(=+e#h(8H2}*n7*^5BNFbBxjZQEg616)`JQ3GUHC0I zY3ehk>Z8`8={x$e08c~S=iK2MmWZFsWx`y(HsYTFw1?9l-pCo)saj?BT9`(W>W=%t z-d-WztIg+I`Z$B+{F*#JVKr>-^juP*jXE!a5d5Q+kp%y&XRJ*w>m&0IOQd8P@<=BE zJN0$)BmI{=@ygI?N+cZ)CR)VXn#p$7kwU;TiSagazAQIH34y^`vgOD3_|4Ik`qYw?Njnd zjr#^NzLw9^Aud~z*yPNe>tL&c$Vo=CxvGy@iAk$6=in7!GGR( zguHAcr-siykBoF+L8bq>e`YNJC?wFxW`*+oLqd(qMg8Yq(dJN1f3BJn|D`1GUv|-8 z6=x5yK0xnpZLBh!z)!@YZj+w09Pd8I8B747|(Y2E2xyGhuEL-u8RVNDjrht zKRnT>B%m10YqG5zr(-`dfAlO(Pq1Mdno-7{I`2W>;Nzn1_;;pNvZLt_87+EDS78zk z0v_NUh_I0><1S$YwOPmoOWl0b1P)aASMCDcO0z)v7V9ul>0LK-GE~mD8Y5eP(9N@v zGhA$Ul{o+x%?P z2dKm6agMMkR(5zv+cd|?;Z=1?mpe6wkLXd4kOqVLsCSZGTya)Oj81MQ$#Q`l_C4gO zAL;SG9zcE)k~MyaX`)DfW)ObmXlF0qvzNGV6Dp+4tP7y}vVrCT`(V5oC0||s_rIc- zJ)mj|eXJ>Er;p3n$cf)4osVsjr4ONb*DA%xA z3XQ2S+c#6$TOZdv@ZJm}Qz6~Az9;gVK!sV?+7r{V$J`~VF3x$qj~7h8eOjcI-Kyfl zkfaPKwlK;f1SIg;IF9v!L+1ZkHSHq*e&*sb|Nrm4|D7&Z&52haA&w8_`0i!MH_GUu zPt4}B@pIbGHBCx#C7@TS%vHZ62noWbv$g|$h;~mg?_C=`3%+%0d0@&?T$?~i%iR@P z#{}{x36B@DFw1ttn-^Wcn~NdpKPFIP?EGV+xJPueaX84>wm&$BRcjF5B>vV@>j}r# zqL`aB=2A(uh$btuJauHkfuIZ#_kfM6%+M)~$D zmW>bjax6J|8WDeA`~{f)onGEhH^1_F*Lth{xyKHx?H8D*7q*s%8d~M@KMh1@2a_C@ zd#HPWV`RO_;(=Cooa@xrV-qV28wDyPe6?9%2IQx@?YRB2j*u3u)sy2QdRdy`i;Y87XC!S@={dd{m@E9bClh?y8#JDZ6T={md=oc)zzuzn7IFJa@7bR9|X8O|Z9F z9ipni0_<0K5@IN@*Cp|UX>gHS_}{*#(1$cE?h3rQ(dbleX7xHh=|Y>Ptk*|cm(yAG z?20X~6>@?~nrpxRnK{6|ya;|tg^}maiM0--M)y80C%~;8G0uu>3q6li4w)nJnYvpO-ePmg@@^`? z;-l`4?A|)#RDQ6pLy1G!4BNg@qv$VU3fKzR9hhw}OMo?Ch55cR=eJbbYA3CfNaLh1 zAGk|cu*&nCh<2(kP7P@vec43?+V4NS;?-X1p!~}|-@$0!(RbT5;*D5A+^pJNO*+M9 zETGnbI-NjR+M!X1m~Kj1nf7zH$0i* z%RJtg!f!OPaaa>lN#QPku}(&nFE(cp2w`uastdWwOf6E^R{LKY?#HbJm5FII4zqLN zl~|gbauUcbyS#?7Q0?|}53zP~;E8zeNIj-b`@@@hD8ipMt|^+yx7Pt#usDEdXy=7iLAO8sXI*(UDuE+Jpd1xE5xjI6cTmwQ!T9Sx82ur7KB_mEZ3;tv69w!r-h zCemJa18^hlKiQm@wM|hIk4#kYF*p^y)I>6eV^~UY9BR}oyy=z`lSBxh8c_EYF5l4v z-ym%gCfuG29U|D-2@*0W8{3!(WwGAkcqeVA3MWeXYLw%PPIQS-;#pPtvE-kjdff9p z(j-lMczaim*XuW|4W5vFFrG8Zp{2en9gXQct>>`QJ($ISB2%Itp06N9Uuz-x9kc(j zU1=HD84Zfg_G>hZz>tPjOUWw*NHjTWK&G~Eo$fPaO@wO6JMDv?K}upum_D3WG^5-i z>a*yj(p@A>B4IKS0xW%}v>gabJ~Ozf`6uHg2 zrFmu{4QtFw*?$%ilZLU}SDgz?B~-f>)7~dHyUQ|neehgB3)ybdK47FE^XNaMg8KA= z@zK7xvDewKpLPOW49Gga^5Cg}k|FA~2Z4|L&Pc#%!bg-W1y@cSQp-;35(wKrK=x_E zz~UEG4~Wa^SIea#hGiy;yY&6?lt`6NO%~8jowx2B)JGjwjWcQ3^mExYf0OQW6zlN+ zG26Y`?uTne_8{YG&t$9{kPQYd;83X{=Udbli>IoVCL2S{-yBrUnYGKKS{I=*D0+p> zjxUx6g2Wrk0c`n;{_6Kjex`ln;ha!uCkeijxE#D%j9)?7&UT3no#L^zb>dnHXz6dg zS#*y2g!XgxT!9GWSo%)X<|n+_v*m|D5Ig$M%qWiAZSSuK+6WxFM}$qUhvKgvLF(2z zBC7_!GG&$ajTaB>K=U4JHvaKkD84hU>OUYOp65T%{ZFNPdyc0{;?#OACWSX05b;zr)Dh4=Fy=hM)$3gi>$a!E@qfmeyg|B1)cB&354W z2Os3uU@FbF_`?V1iL0Kh^)U)Vmx@53s$LS7F(fuK6O87I&>xQb1FUM~yKLopb*C05 z+Gql9v^?)1B#Y@K_Yn-aN;I*oMlX-2k`OAWrYfxgAKzrufY$5f56z=nnloNU&7&^n z&XYIkL;Zj!Ihq}sCE8f~wt*<{{IDk8iMs#frD(b57Z`&omqB&~TeP09 zi_~luQ8=pYICc!~!QRzBv3_BpSg2YWzasv?b6E$_=6%X>qOTzsAfswztvb=wG`hLr zto=LDLkVnhr~p^P(=Qav7P53d#DKy}EIz#6*^;81xcus&%xBPP-1DFBEClk5S>E{Y z*irD!(#FsTt{6g)V4tnIa=_XXH|)X}(x29P1+HR;jCMl4!^(juqreh>AKesbpVAH)F|yKlzwT2BcDCr&G-|Qh%zc)DlV_)5Z?gT z)FtHyS>*_6W5BD#g@ohSrSykueYAzpjgN#I70@_1|@2Iwg_YJ zK8k%%;*j;}HYC_lqp##E`BA$qZ1Cv?(YA_7oC|j2H=V*b$6=|Sm_E85Z!(U%Vt<<3 z;?$trh5e3`sf)F6av4+iS(N69{Fe5}{0X1ac~9#5H-}X?+uw7N5bR5KDt0vBEOd zK^`7&MLiwV!_`YN57p;?V13yEJoI}N%|}9iXr?zXTAT| zPHqXPRf_79PP;llZleScH_k$GcUhRWz=Y&O&Mb5T~4rChQMVlnv(U^m_TRas+DA^c#)ad2-&tqT}gOzR4CWa|vF4 zBwUp(+L?lPs&%}^O2kJiHhvh~hn12|878|a4Q235pZue;EgFNvTLYh0=q||~hDcd= z;gR1y2Xh+hQs)4;%i9dr{Ya;;X|{h|eqg&1IXt6I&lX<6KEp_Zmlrc#m=)^$OeDV(@$bM>eh2GbxcB&Tl=gz49=+2NFfg*}-8 zy7NBT`b}0}xG5=z%IhhA&q@KV<5F)5YRCGKgWS3^nm#StxCOm@?>R6ww6E=jf0|4h zZ^N@h#2EbKU!!W7qmA$BSTh07)@$rCzTjA7Q+6Z)Hd=nBMTC-NEBB(z+=$)v+Fjewd*3hlSFUq1 zY{Z!XJ(1GHx;VavrfQ~M(i5{apoVJ$#xXrSsKY(OV&CAd7*nSl;r%(TX5!AltpgQX z4)~8)^WM1NOVpcha<&!FLp0}anjH^0KK0x1Z3L-{c803#Q-Gh?sk+uSBiQCLB4o~f zR|YCK40>MTr+$&3>X$h6=!pMScV}$`+)P4TTfiJqdaSt7FV)e<-R>M@VbYbP+M)AH z$eMalkdW`un>d40+)r*8N+A4J`yzLrviz`5d7?c4Ocd#pNnGxQ0gTZ(i=SnncBi@b1sf^CCP`F1oP&0!$Z_?;w-%Ft)cQoa*V}Go zDH>Z}3I0;LnsG+VsYJKthwJq&92WvsobFNv-78X|X~{RcvtQCWSveEBCY5(FF(%;~ zAHAQ8#Gs{)B|<;uwzk1b4Kg9M>N*rz>ul(Am88mRq?de;%RMa#PIjA#0LO+CiQ%tM zOv-8TAy)w#ThvJAwq5+zZ)bcTssa>hbOpiY*$v5Tyc?b_MX^<8y+nU9%E|?wt^3Ezcdm`F7(g3ESDTfyE=^+2x+^=i5OKGux0O@)OAoe>o;lqez!DUOcpEa4Ok~-F`}N zUj^M?l1r*2T(Prx6(sF1h(4l~NhR!kYW?z9@LkkBOtja^faTNfx5WG> z*HlCr414;6cExuEDQiCceAo~uGxg-iAPa$^5@2f*m z;I|2c>sVvvftjpDDW5!V+vzKT_9jihVZ&a&awWj*DC0HRwW=|ipT~Bu*KfwG=E5we zQWG^H6v$|28?9GAjj||+I8?{vZoK(X1ZBfmpHTw|p0ew|%rj01GYFSz+#*Ry2EMzF@~wJ^F)ahl;RVVU z*@r4F1~)qYCY;7)IBmnf{txH-|A!7e4xGhrvW-zY-ZEOg5CZbA9b>Lo>=x(+HMpbf z9sA-Ud6;WW3&sX{lr1PLQ0fne$!1Ut!ZQIoHVEGo^O-Wtywx8d7kd2?4(Q*fHP5>D znB+g#iRLp6q+onUXi;Kgn5?s?+6zq=S>jn=8$M%x`TTmwjUbsnfvd)cb_308x-#>A zUUmh{c3x##TUDEtfs9Yf(!A#=^!iiJuf+L_9svm6%>NjZT8zK9!;_D8>N{l&8Jw|} z;^%?6tDUozy&O>4uHQJD^^?isC}RdBH21(J(vxm`>C~;x<@5b@Pa-4J&Y-brWq}}4 zgBD-z+UUsDpgi)78eiURT!M5!pyW~JgzmtAa&f+pWw5*aZV;`}5rBe<)oeCH~d$KK@PpK4SsMH9*M?wD|&i=lVjoi$I&%oC|z%)D% zq$epPq?9T^HhtpD>{UZt&JPb@=3pu*DTu2&1;OL4+G)BcyM1cGwxMp-=_q>DBw-j- zAn%qPjmSx5!c&0hWhsEq)zLlzuN=w6RF%)hjYt=PSUsEZuY(!Jj!VTIHl~)$i0`;b z!KrKg{Zl8jNxdhEs`VE}%DF5zM*GuspU!JFO%GPxLk_kOaGr0%eyzCB*i};7^a1wU#(0EIxO|88+KIdW zxPEtVB=mBeE?W7E%@VF%#p8{m@>SeyS*o`xC1C?Rz-h3Y$iV#exgx)v|8CLPFf8~! z$N4`w>{N20Qqxt5G!<5D;WUFiPw2&6S%o)XZt@(q##oD5iFNaS*jA8#|6y@~P;vTD zNg_cD|3YOR^#gh{q`??KSsvcN`XykaT$d9JRx^WMd8Wt8ttB%DO5V z`WZ*p1N5f}2o2^QY^}}iL&RbgNhsbNd4%R;>q9DtO7OIQ|X9oa>o3t>C zmE(D-S%|OsLNe9K1}yRG6$;px@Z>&P03m(D`KR702@ui(7gd2pP z%pX{N4oM*=8h^zW+YU3NTemMh+33Gm7~@_%5hZEs^(rnmJZ(TF`X+W>@fxAk1E~IQ zWU#rbXjQdxiWI&~f3syKpsg>qZLpKGlLwqr9|45@81FlF6K$t8*H2U@Nl_Tl{q=Vu zo*w-JX4QT~2cV4+DmZJnD9vaM!E3+L-u5>8zDE%HuMHW2KgI2uI8F^%b(PaC@S{q$ zQo|R%oi-uFcu|KiVcRyr*qL1?zA6K=e}@Oc|F&TKUn&mwvsTVe=F3)%gN`RNl<*72 zth;t?IQ)hImxZG*;2DVhB!rdNHSjH~obTs*Bs+hk1;aavSuGV0O6tA~YnQ(sKJ~sG zdRc^B@rZG%u}s_k8$n=c`0!A`^5FBTd&RD*^Q-!-{iX7_yV(vQ5W{erA76B-r&1{KYFt-Da(} zK>YFS5xI%kx2L$FKVg4o@K)|GTE$rXc0j~{ke1cJ2AsZS#sf9Ra(08Xf2$6ZvF+79 zBrr42-Tfw?j-;r4p^knpK9%s$0&vN~o_I6%bS{r8auK5dP=ap43OzM*;yUa6PP@KW zSF+t!fnbeW(9XoBK-&wGGZ?wAxXZBHquV=!7i8!)y-Pa@%m^RAPZMdyaiW3LWNX0c zutZ$$%7=JW)fhf>-s z&zxyAdLK8TB@mA&-D=l&fA(D(xZ8r9qPXSntetbn#a|(z+0gOBw)MT7dt(KAt&$vf zxY54RitOo#wT4x+V@UeYkqHLPjo9>VGc^{a8SvBl*C>0j>0!vmWFF035{LV0%VW^> zn2?Vu5rzSsPT=KGT!haP{3n@jv-fyY)1{m3)xLH{(n;KDx@W#}S&Y{NaYDG_cmUox+` z^858s6ZuL|j|=0jmxk|%Z$ejKQ=Vkt&vyr&bVKt_70gNk`*iYSGIz-^jkrTzdqhlX zqzdNN>nJ8JUvx5#ih5uYF(2GoM%mZN^UWl4>-N$y?nl+v_cb@&dc{os)>K-mGo8mY z|GtkH5Ji|4QI~m7S*hwQHhE!lOHd`?4&o7Y7MIwra0#2A(X0hiskpYq*?{*dweZ@| z9Pw`J={DFtsbIaJeC7NNJp?QG_ZlY`24X~0LX&k@- z?D)S>GTy9|PR!q^(gFFM^8r`A-W3S0rF-)I>58aMq0($(w9moJ#b$K16 zf+0-e!%r^h;I+|1F|2+qkUn7V!3*a*=bk_CGsm<|)hHPG?(h2V zT+R8Kuj#t-1?vGWZls)C>;RdET3M>4?}paB()js(R;zG1!**Q5Wos!gBMNC}67KLs zG+_8hFcKF+5v{EBGOA=6RZpNhJv#nnub-u%x`VblH35Cq%l9xrNi&-pt9Fcm?Q2Kz zzR2}`n%N^1X~&t&cZya(n`bT9VG$eW_DnCtq5nq1aGm7r;p^W_qJXp<5SOo)5x2$Z z>zuBNo9{B3Zqn?IX40?d^zwfE3-HPI=!@B*`I5C1m)aQ{@@;j0uyRD5@Du`~tUm*I z(3nT7H;ruj92eJzl3lWLU2BL3T_HY$CRJR-ZqD#=ZEvqk#YecVx^Jj)5O{3+;envo zj>4cmMnM7wK|0Ix8`|D}kZZtgD_$b)?k*O2HJnrPRV-Ny<17EzCCXQ8Xl#23i|dhY zr5MlzeD74=xLqWfX?z*bG@wt9w3|(TsJy8?vjO?22u7`iPQP@}bo&Qhms#djm`j+Y z`YyKrz}Hmxf8gt4sGy5YOcL_P5#S%!p14Z0k42qzrr(5VhvQE%C+do4uVXqkwCKfl ze0sSScTVUfS#TY&HCduPQXBTc|84cIwE-3%;G>v5vo537cs8!6Wc$S3Y)m%)|IDO} zrq^9Urf%QI%!TNE|Baow3O*Zf=gZ1*w_Pn96AC7)o!aT=l@zNPuX0%RkZ)w+?Y8v= zJ`abindzNFs|MFUm+Ot~m^Z=jaDRM8(^KQSnwhJ(%@Y0%mwU@(U|Kk2F5~qUXDuwqg@K__KDD>r$~mc4RB6Di=^6AYFXBA(*VMTN-`UFTC2_xcc#-%02Q|@Q zUqv@3@1s74*3+BO4?E=D*GHz5&I3M&XIy|R;q+c$Bfzf6*^T}J$xQyQ`2(Zt*L)Y9 z^#UuNobIj?{hvt_mh8Wz$voU?PUCfB!Pop_-3Ig#V+P8LtzlSYPp@bWwukhgsA}u&TG@! z15%h9)Jxift?bYXR)tK{L~|o*ABO=S{{tfpIhi<6Q5Xn%YB9d5%+~DQJPEX`m3;2` z|H_kaZOt-L47mSyJP8+L=MP+1``?HXvt+!xy%G8{p)BV&PacYDllTwUx&rv)k$)KS zzjm$ufXSg!JD5f$&SLU!%o)qUoR#9T>pb<0`)v~`Kl*CUpjhd3 z#~5oi&^G6qM3C(Q6MF-Kvw2ph*_!lh#rW!WaB$F~aOTxaD&viz2kN3nzd_>7p}7N1 zbFwzp`qfMo;nemu!{VSFSIUdwS+^I!SsK@v-_Il@=0$``-^jNrDrUExRozY?2QHl$ zaHJmnIh&P8&dJ*(9>le?l01eO{jI4brCFJsu@5IXL`(UW`ow8HqZG}1X=Iz4K3&(! z@%TDefns;Uns z#@CXNS6=HNY609faP1HAQo$i!ic_q^r2+QX&}-`oglTR8+NF6oPa+X;Qaw{LPtg3? zRutZcuA4k~&(i#9J~~9BzcO5AgVg#`KnzOs*TX(foP^I=L{5>L3I(pFien(8q!A)> zb_hmAS5AI>8AOzvZM7LI{W*sku#{t|r>t(a)0a zB@#g;Kt9W}zS+@m; zhaQ~0N8!`a*sZDOqW+rlz*OpM%JOzBrKR@9{K)p;g{`@x=VqJ2!wTQcB+g4?0+<|r zba6!sd*?l2%KH-INH~W#ij<@Wg?A&{-nt^p*Ahj&&!R)FW!%pX^!4u5C}b)5$fxD*g=pA6Fc@bfSe^whnn zl_g>2v7-UAH3~oEwVIX~5Bi*IqwvgQZX60T9-bqAl3~gtJfrpOyL@UUcPO4R=P)&c zTQ+rbs@(%&8(pgfu&A{WTr4rLAT4WAG^p=Hbw;*`H~}ZC^eG8;#PyBgblo|^ zapU2N%6jCt!JwrJsRDkO8FyhL&2r6>~zxa$vW8veU2^OZ?czTdmlH3zQ)j_cp#TTl|Q0f^3C9LFtmV$PKkVay5vl zA~qnL(7-3PtciJ=>SztU>_KKcY)>dC$esYeunJJQ*CRP)Dxa}0cfu?U5)G;OX>Hz; zyk|2znWq|LCSDJMG8av154RoiBZWk0kH&N5; zFjLg3OUPg8|4mMGdq|P+VC4<=EenUCngq-CejizjEZUZ^J>*X0Bolj%Kz=22%9!?e$@X8uPMybt;F3yyw64gO2V^nb4%-Jji<8xfM#2#V7YLuCa0!RV3|6H~szMA;faF!G4VP(h%d2w&&E< zrLQ6k@2{d0;H@{D8_LKq{@f6IQxSl>nfNBZ9fsvelg4-?Mva|kX1F@ z?L3+W%4AAOU9y%>(=Dn@4^2KW|DAY$Lyx`S<+CQT=$)NvB;Gz2xd;{S@3$_G$Ap)h>u_n@Q>!VXpr=5rgg{ zyPUrAegE31(wFZ31*4DGv%P`3r{Zut2R6PL{BBfQCd3{Zlt^mJKt1)m(2}2ND_|p` zO2#Qx=n!I<;l8hI#K~?j=4%Uk0;wP81l{JgK}s=p`F{^2o%4s--OT^}R|SNNCE&Q% z%kKzw^2M~&mj-ds^~D4A!|=Gk&8T+QVP39%&)KXaJX&@OM(F0hvyRG zoL>YvN@x{%#XUPkSc{2a*hf7o2ab0++Vofe(hg z)<(lTqBEii&b%Y8*&Q>c3su+92QaO(ns1^Qc|GKZsX^(4PKwGP(xqrX zN<;)yI!K2g(jgG3frKU{p|?OFB*44CeBU>}_kQa&|8Tijl6CLRz4x5G&))lVKAa;l z4tY7e^6`md?iZA+jxBah9*b`m_O3FrN19_c98#EiYl`%a_iXk>u+Z4>gxE0cbkT`e zNYZe*uRY;9df8%3&HKyNEVKj(_Irp*XFL_|p}XgIP4%jCQ!BW^@5-R5l_DrR(p~;k z0;J39%Dkkyd5(R^K;Ql!1kwcc)qfoIN_?|QI_allDA-?=D|FhGoOMV4#Nvu8>jBLT zOfZ4m1V@PW7hV6OJXZH$Qsz~zwoAa+{;L#$)Rr`=j zdUBQfPg3nG87`;9r>ah|v?W!A1kkf%@u}?=v%hhy*#MC~;FDExL>>w9hX8?~x5x68 zL|wj1E+6VI^>6tEXNp^ZD}G%#TUtKJ{W}R4LD0gKqa@Zroy%6nDn6Ti5bKIHy<=_0 z%?%61RVd`wfLExy<{kAd7eSBb@8cnSihG`Wu=))Fq+rfebY!bX!Po}717>Posz0IH zYQ$6xjnWy<9Ke%lKi)s=E33yB7_|OYK5QDWYOztY)NQzuJfR@F@uZ%?7#Q5<(t}yU zPOjDA?}C_8?sdkG6`kNpuql6T1cpHmHJ?MYC8)HQ)U?bxHOv_g)Bip)@+hHDOpyocgC%V{rw2t-3d=`!s=o10J7W z7o3|L2JeeQCB5XI)tcADd{Lb4GW;iD{379Y5tRvX=9Yi1ZBwMKfDJU}?9YJx(#2+;cBN`3GYz>~i@{h=% zw3^LYg_(sfquh@Z3=)F=;h3L+2*lI2`lbXL)e8-g6S=Z}y)$}rk}_RQ6=(jV*<~R< zKQqDqDXxip6eBZF0?gD+uMC2Iyyuy4Fw}llI{?RL`x?ylWgKw2n6+WS3S4Z=o+x^& zAc*ZhT%3>L#hps>H*a{9-6yuuY6vRIkI1U=$}oJVdv$2~-Uh-nKfC^=hYL0d%_|O= z!}o8CB>r1@Tbx0!U*fjDA-7{e?s^qzJ#-=gxPO0C8J)B^Sou7%4F*MkqbvuW_vby@ z-<%zruS)j>y(SZ*zPL6@f3qbt2G21rM9VgvueQOR%5WAH|TMOYL;*GW@G_%no3?z40mzig2ckC(WbY%7w#&Q|-dHKwZ;Jwb83i}e_= zpiBvX_TN||eNI6@5$T zpd5X4tbM=4ykB?{ZjkxX?Dht|dH{?MHC`s8z+FT-PdL%>4pBRL^-7`odtz6Vl1-WH z*WJteo?DGLs-?y;B6A3Pu<}cPU{h$1ntUi(m^R|{Y0uK`EX8dzCUT~*z^39yOE>6Q zDC_8Fsg1&doDBMPu@)-CKc$TpC6Fkh=r3z0w?=L#bK1!jpf!M2`=?W^Z}3Vl1POIc zY$_2dW484jccKN1J)0Zmih9^9|JBxx@g7tTDzhIH-+B_VU>{)zZO&MTYon7f#m=m9 zVz{1@DgPP)LWB+$f+eOf)5^_-ri+_LXh;|AGkAg0d8c@k0mbjo7+^q`;*w_p{^sUe zPJQ(L=JOI>$m8-GKd`E>lKIm5N)cqGlEFz&I%U$K5Et*8LlzqwQhw)DdCat^fn3T& z!mDjCM_3WXWAdv#3!^=iBFXx+nN?^b9Q+x6uGz22qnWBZCeU~>3FIIt_tcr8qol|& zG6XnI3OK4le>3^yB<~yw7KKOY{~XLyS!P~V%k|U9iVpYxOYfE9!yTJC{?X<8@wW6i zih8`m*71kd=hs1BE%sB~kX~Z~eKKKw?a1in0S9^X#if5*)i;y(tm@lFPd7VGZe@zu z{UeA9!eN9VyOrl6h1=9sFYbEPV@~>&kE5eb>TGClloFHF17_z~%JeU&6u=2xq#JlY zJ%&_gPYRhsSD&PKuN2KRa0VUGZ+5>71jTSnZ*dV9w#RCmjraoeP4r9<$bA!EOe;;p#YIKRM?Na zYhVM8-TH0MT)v6yBu3&^TQABd5M8=pSRn*Mg5&!pV9YJJ$+XQ>7D$!q-l&wO`XmZ# zPU91k89>Ogr&ZK`haH!DM|z}O2}XYQbngF z{O=@^UL=2YdVkC@)LU%zq-{tQ>h1?mOH*-1y+hV${p&xeGaG-Dmc?c#AakzfO3Yl0 zxGM)Qg(;fIYCFe+3N7CS;CWKF+KK}RuYSUxMUc@yS0Yg0VShDoun+3C|7;=B%rPL9 z+s)QY&Q4z^A@kie>`_D#=(qq`(`qvfrUEM$`M8-zFkX8ejUdATJk%V%R(Msy`B7Ju z*<6W@i#{#XDys3~a4TIY4B-!V<@m0C0CV>oheIEa5XYi%*+cuyhm$`Ag>7tyXJw2y z)(dSMCNB4%_OC+%#^MryTSveqMbDjepY{Bn_T$gxV;qZ0{||^kenoH? zv$1W;<^gjowmb>n#VKq9;1iswhOo<%)oqn*fh#Q_-HrRzefQ8$`L_l}3?~>cKQAKc zBTID}wB1^7VO<4>uuQfFCWg0vKw3626>I8exuQth{NoB7BSpJlVc$ER^$){YHv15Y zlQu$5DzuYb?_vmzW~da~v3bct{}dqd3paTWh0k&rF+5fg!y`I-f?GTC_v*N>GAg~( zxwn0`96?vld23=uAJ?3dEmv_iZiFL-_XX6^j&Qxv5pv9FLH;F^`=ZVTN)q&U$jUkr ze*gGl@`xHndcJzdrMZv*H)b<8lrjGEL$_C>Dr*T7@-RKh%7SXp#zJ*p9Z>S4MPaaM!>%Y$Ncd9C_m*fPU&)Kr%_@v`B+~s7#yrty zIG8nk$?PrjPjica0@{G<^FdSV0-d6*t@rqQQtbWC|K*TF7)MZpLliJ^8{9M_{D#Ew zoOBil@_KgPRc4=0j2AcAQ9h*@&{5bx-lc@<8UHKEi=B33E01ep2Om6e^&ZP6%&lje zvyCQg3$~YCLkRaib=kHskrVf7=(Tn=nC4gRO!@Y^2gO46rFd8iP9v;!Q5rK4eTj30 z38ZhH?sKI~jW1ny5iQ|!p{ubpr}^lNbd-=lVKyAqUe>{S!FQuE5fKleyY&pctYE0g zVLDZM+(Fd8io)sTMWHww0?q&GaAK8$INf72X|(|@YE@{-x;L-lRGSDfw>aeISeN_F ziB=f{#}lbe?|`Ty8mm}0K?Lz(HgwG}N`U4pB~Pafq1ES6-PpCHnqe(my%K@lU;ia0 zKC-n`rfWlsT#Pv2i)vvh`|@zC&t;ik?)x~W<}uBa^|7HJEGc2-1{qp;IDV?YZ7uc| zk^E}rE_icR%igm?;Ex=raD3?2d7t(+>~B@q27EG_TEEbVzDHd9?yNp2*C}TbW7ItQ z@?X%bw@-!e0v5faqVo%VCblP&m|6qP${Tu)5*q?OP#mhPi>4+acz2k9e`GjlXq9RU zXSBUuA+XOnT+2RJ$nVAewA9!y^oA(OXlI|=ylP^{Du9(%S<&Rj_&@lP)(*foRos@@ zG5y}s-VzYgC$UK^M6t+CQJ1Yuf8uS^8=rJ%3p0yN{X-Hvu`~pO=myC)m64~zuZqr* zs^*r<+U{<2KJw01URaDJ;Co$GX2_*sWrQKP^|r`N#=+%sl$|nE4!G5;lYpjMF|r^k zG0>dJixZlitXFgv8dOSmM*iqbad(4M^?Qy5U}turVYPRV1CCtS}@Rgijst84!XI`#?R_AJ>8sBvjFzrDK(9ZdPNmba;k$ zxc|FjPZJ8o4I7vq1IA**S-G;E_t&sQ=OGUj!dOnnm^EyndFA9PVe{0?Su>M_{=&D# zuLoDX`@K%P;1ayG1!6k+b51-oXi?R2x)!% zZ%CXwFDtH5i_3=+L|yJQYNz0o1@%gwCO3)uIA50v7=PH(m&Baqz7qYPhFLXXPArBU zjfyC($8rA*P4mx}wse2-Q>BA6E@^vihPF<`7=TJvxubQ#d9X(IX@qN7MFxHS<@_%! zDoP+X<;O31P(z{eRmk7W@!nSC@Rl-E#+Jx zd%?`SQ?0_DWg{F2S?axgO3^$a7)n1m&dMovLFwP9^MBE6d4a3S?UbY()+&>1g)4T` zQ*=qLFfsirGc<$iMI{7k?3-l4@f=#?`sI7V2lW-YomFzYeycDo?^_v&ppz$6;U@D2 zXw(QT(Q^kS@4NK9^VkPlhCjSB{>FkeHXhhlRauVZFA`P3TvgZHZnOBnZC3DtDoi~L z1yl!BRd5EUa4NNF)`Q$jd%cmFCn zaB{PM!qJ4QLb=1$M%>A50M(>8=9dcWEATDCh9h$lZmG4OO)x1rP(Dt7;4>61@Z^Hv+Mn@s6KTX>(8d%e_8@GEY3L}SCEfrpe5c}aZtc)%d0+l3({K8{P3PzPo&~R%(KC9r zR(X~k?O^;RS<*j97wG*DD0I)`3pg+QgZMJ&7<0qTk)j;b+0+T-tv%R6qp`%f3C z3)jR>=yHCK+>R9WQ~YTmcZi9cuvi_<(L8qu>N&!Z4s< zNW(EyzpyqyQT!QnA}_*;a79P`PHCC*_&ep2S=H6pVzBodOW~Ap>4yvnEDZ$dSQ4*Z zakv_-eC{-y(R{%-9`-vWIzM%x3M@EG7PtOJTwYZwC~2sJp)rl1x|c3F6V5X7;$8QF zc&By`HybPi80tJtNS6OWs*h16UeLiBMEVHN&EinikL+fNj`)3r3(-Q;lB_Ua_#xKa zj*LRpbA`kuiJAg9IykQNh6Q)i%#bL@^I`q<(9^#?3beQg-K<{{bbNG1{mOIEI-ln# z&#jiz`$o#l2D#BgA^>kZ;aTaG{kB<>@2civR*FEuKC$u20^ z#zXMK174BK-25dCY;JMC>1ySzVIbo=p^IknSNS^F0|Vde1U2uOx7I!a(m@~!Lxs<) z_}Zl)mjy9wSn73}kTuyoni-eL#o8r*xrYBu{_?P&T6q`z3aGN`oxD3zee;fIM}E3h z>f|W1aPz{x-Hz_TLHEwFX&Ah(BxOAAdSta0_I0M5XiDgvWjje&&ipar&Nnq5PbXQ& zPDQicul@126o-AkJMDMjoF4Dj>#90)^1>%6zw~+i+@BK~JzYCaemgbFj3BBt($=_y zFM9Jb`b)W*KJ8mb(_|P4T>1+FjYNl*r*B@d~3{xvzeq*;~+qzoW+pSWu4{PlaAoKSSeSPix_NE4Et*PwL zNNu9&&#yo*^N+Oaa-yw2wWwBilOh2u>-$c$pjotdJ!!@rV$A8SIMgx_F-)#Ew9Spn zzJWdH9`6o;l^HFv^GrU^lJ9U|>L z<&YIcu}qa8lhCkGTdxg|@)*7443UaAnAi%Os20wety({^zL6-u>YDDcSAHOc%Dt4-Qi7MEf8ZPO^Mfgme)f2AaAam}rKw>x<2-5E}O2@D)!{{3*2UZARTinKMv zz%i1yQ6xamQ^^aeGcZcz{2bROZ5L!F_RU~LfqA>XX+wF7j$1H)GfO_qf5$>E#2iR{ zteyBVJMviQV3rOJmm*V?n;w=hh3SEK54?mLPcEbHQsWSOeRnxE~jW#v;AQ$gMO|NTeCUKEZOsM?_`%872vZVTR z^5j-6`=a?{bhkcTr#`-W+ixy1!R)dtD}}DpCBK4;5vx{c6u8*+HuetSD5N1`*`KEIYB zSL0=@YLyL}XH=XYKC&nRbzPv-4pz-7#H%%!Sd+J{MS}u$IfqFO^gxx;Kz7_d*;JSx z)Y+A&8z|_cghJNunu(r6L&}Zb?3-OQU!=P};B}brRefDQHha>lYBG2y->*yaiQ2rf zT$L_%RCVDFGD6a)92jgKt}!HD6W!J z2>!7y@y*+&b-&Dbx7xfJUPdSgVaDvrNRwXUcI(004-jyO1}mlLDO%1q-2?_EIX7_p z-U~)(AE_T1pZv_Co zyjQDXUz*uBrrh%D#L~<3x65?n@2oJg_5AkyP`(3($?&%@3JY{8 z3f_KDU=UJeUO@{W!|e3Wi#)tC^uB!H;c$6DxZVV==Fa5 zl|ps|WI*`W?6B(7l$dMSFFZxWpT#bkP?4Qe#>5u^4&Fx+>(`D6%lx|p2BrSH1Xezg z3Zl#|!aX|z!)FY2s2`yRlI4F&Sf3P_zO^Atgnk9?MOpXnb8J70KZFWq)Sk6Bi34S^ zdjEDs*qDzxUw8*%Lr%&HaXu!lIT(4k$ad5E+0xxnG!O?IfzHRfBD}vh1ur`)({zR# z!WbqWl`4gqLk%m+1oR!R!iBf_S*9(q0$q1UEd2L#fcAA4aj zo;&^piYJhSe}=lmCKtJ!Plb1X!CA{RGH^w~L_c$y628z)iB|T`!6u{qZM$<4yj8hc zWFcf-d%0-?DVGe?y9>3Gq`Vv6I=*T~@tI`K;L578&nHI-_iyx~*7GqswNsD#l>eK0T+<%SuLaa& zP3GItXF}MiDgL>v;b(Wx1!{Qb%{AO-oxyGSYvqGhDq`-dme+A=NLtv>7e=1*|NG2L z=53u`)~$ymS4EntPEl49zN@^iyhV4sL_45E<9P(S&=pxb37w;x%-YQznO9nV(dM$y z|5y9GR=jmGFnQ9d&+os;x;;Nr#Jn!7rdq`ZGfW$J zwK#0f(u5=gW>ePV7G^T1>Q%~@=YH!a7QW6B`=DG68RGNZ$Qz0_N|XF}d1vx6BYbeIuREW5SXJ2g4-6 zzuSn_2ZHyvwu3D(ey$R{6+@{%t~Ey>y}o28=uXAduRi&v&mN`BGP5Fy&2QGYV$|rg zSconwI>NkuqkiMVHw16;SPL#x+KDc-Xxkp#a0(~W>JV1%dcJG&XHg`cNP4DG-*0jG zr-&d+B>~iRziozpdF)X~bChzQ9Wa2+EiJbJ6J{%F_IrbNPZ@TWtrgB;Ke-s8X-!-X zy!J!grjqyG^Kni;$t^YmUieWyPwK?>&VS_?F$5aXH}|VmL%9P1o5dq4rwYuV%1Aau zHxML|(lG=xA?vn*yPwUGxPm=M zRYb0FhOq6t*RCanr3SG77C}vvl(=kznY&n+^QPtCR z%U-+nMA-AoHLkYZ0N41CvTZV4)mB-|6yXkq@ARk3k~;?M8cMLi@}Zyj$LmDdWXo8 zUt4cPuHu@(Tonnk8R z$yW8n(q@;>)AJ9R+mx&L>m*V(Vq4DXmlTeg-p!ERPE=ky|R>9)}7x5e}d| zvFDTaU98)_uf%)ZGvsZ^j<9FAr`8obpDZ_JVo?`+{p>{{b90ZoBOCJA<~TQ_Ysq+9 zrfU5ZT0`b$R1cF@`KLW+Oz+#xUwso$!{wirr2R`KbVPWMz40aSGFWf#I92&rAwU|4 zTEL^ZNVA!88y|~)$QhdYE8lg5JH-?5`VbHkoS!l}6-!zBGCoy2T8+S?2M#KLX^bUCm-l%;7{!KhEjR;3yY0uljzXnj*Eoywg5HMl z?5w`FKLYRmB~d6D=J?``TyU+4*?4^{*B;be`Zv_Ak7$>WIE#6ZYcBX1Lt}lcSnoJ) z&+W7t@6vkOAKa(Q$uy$@^0mi@W03OBud|YVcRW{j=cO5%BVdcT+w%TU?4pAC_| zDlP55w)~P+K&1gm=l~usYv|iFky)Xai5&K9+CZuK1*1 z)BpVaUlFGEt#OgZC=2(bZs4#}V?R&4sb8ayX673c9wv_JH^$+`VV4B^MI*M^H2=WP z93IOT)AgVF-WK%;nM)#33`eo;NzYEAe>47USI-O93SIU#f13%8QZ`t8BYoB3t?ZhU zRe)7fSv<-@H4!@Yy6-*ak1R_8K>1a=Vg2XRV zbFXSw&i)7FdRp4{UwR#UAKOF{)Y?B)q`JLw?*)ag3i5bBb6dGA-}K@ffWSW|FSHb9 zWkXDN53VVcrpr*~TyB1-c0+NP@2hmo)VFNlJJ*mMw|X{gZzu>YGkWx#HN%FRvwa9R z6I?o=%Xlwj-p-YaBmbDq1BxQ^`Q|UOZ73LEo4uS~b9d`u%aV=hk$r_46VJQQZTOGX zre?a;MOMn8IOkRy1|!%aijds*sv$0Pro3B00%e;RQF~D}h~=$(-$NOv3_(y@*1rO= z40}`{h8u|Xh<$ae~m zU#_&wj$U&&q72LfEHHp?#~eT>zDjH1D*LACd6LomrNe2zTlU__K=f*#-2RIZp+>v+ zqRHva$I2qOgE2WVQOVX0cydzpX{)lYRzjSnmz-QZ&KxzLK~_~R!^M0P!#doW=#@t& z-E#)Tv8=d2yI0C&1cZ>b-uS2H!m9NjMRd`eSU zpU_%XN3cDYr5^M3tHK$O-BP<=0dV{(S5do3cSf60a?4r!s9QE4Gxht`gm3`of3ZoEIz zp9Bk42j=3VFR2#`#;h8E5`&-w`tzHWxQEK*qQ1DQ6MK7yFP_t8r1W}&VCb`Jge7cK zq0Pb+<08GinY?@Dw^W)jm^)ndG`Z(|ubnmQ`G$(Kv18GMMmdlC$?1`(9nBEHv#J3+ z>&5sTo;47mh<@`l1=3T!;qYR@MRd#7BG_M?ZlBPHJr`(3Gj12SbiAD=i9MLS9g_7H z*X_!kbMKqSSOq-oaN5Ze)cWQ}>kkvwVHajT^-+qrWLGz>SIw3Vgr}71%6ROc)0Wa} zhL#G%@fk@x0Ph51c@#vy5lvxpG5Zp{r80^XtMr9bhHYfY_PVw?S4R! zD{4)Jv8mNV$AH4*-QLUEzO$svwLIf<$b=& z$N9LNYI?LpUCl96xl{5H)!fFY)@9>RDSgqk6Y(UZbf7lP_?;m_SjEV6CGT*dcSpug^}E)W z$Qn2=4xOim%IISlzoG)H+xejnR86#D-yZsjuDW|!KW?@m$Ni+0@$ntq*_Kp>!re}G zmfTXM%`|4;u|~|R+Zi7Bw~rux)31))`|d9P0y=-jw8mGFuPrI&y!So6i||DOV`mYL zz_eP{D6+Q3HR@l>k&X#1pWDHn!y;by#od(nT6`@J%QkF2IZJ<^oAmo=rO{#aKJY}s z8D6gL4EjWH6Ri0@1{t3#85^>Bn@CoS?vS7ECdy*1SwwOjyh1BxOjR%lA%fYC)9}in z#`k>~ohJjx6EMU>jH~&|-};228?Z&$goVb4@dP(b3zlpM#VWppO3Q+YV$PQ2t{x+= zGyURfHdoa=Ef~K6u>iZ(&X{p~=hU7Lzw-ZbMd;?#9Z!+;7I_+nN|@l4AR_(E!%NoQ z)IzSGH~+J6Xn0~o5M_}bscC%Z#4wRNc>l~(cIDn#Y18(j^dS%NzkG$NSwvE3VBSc< zhLPYZ{f9vZgP_a#()v;F9^|yjMc%DWrvVgvR>oL$(Li`FE*0_Tgf=}mNiyt3OFi9W zN$c%YI}x?PJb6JfG>Y=BsWKTC%oB2kXvgrKQ$rVsymG73hhv0eqE*d@M4?lPBjz^$@_xB8j^?c+zP|9#p;{ZQn zMdB}!`CJk@R#yj?<>a!k_s&WJ=CsJ41wj(pyl%xxpSkcpi5H-;ih!7K@2&zi-XO0O z0g&=z?V?e-e_osI!OOf67DhKc+;v4&1!idVC@pG@0!$T%-u*H7~I!D^?eNqs{&jG&Y8M z*Bg@PQqz);lH@nmO-7MTbiG-6h{F`NF@3bI$*I-Gq7B%x)DQ=sDNkfJ{JYkwmQj*- zE3J?Bo_&hV)17u*`KKvC<1564>+6nQFH3kRcz4vQPN*MCKG#*@aS2c7CDTfQXs_k8 z9~TQlHKp%K1`t%xQg15U{2VVLYVo`nAsVOXr0`$}p>kw$Zt`H@f&R3YqCxxSu&L76 zl+Rk+ZhfmFC+u%~P0mO<4VvA?$vBUdc7T^cT?%Q15Qv25(1cp3jC$HI8GpX~1`4~J zV#U6s|Pjs)Sw(RKQc3&wR13MV6^iP4NiN=&%i}E*Ni#lvc(1X9=b(|Yhnb@+#X;j3& z#-%y`Sc>s~NaA{)kg^Th?%Fm~*`t(zM;x^yi`&o1bXe?i6hpx#Krl5^Y(|^$wjZ#$ zPtCs$R&hAqc-EJ-{%m724tCRzZ(e{aNf&{62fIVw{LyXdxnzSs8rj4guA`5xd6O!% zmzO%WUIA(xV*1`t^=raUjqRf{f+;1pE{sinM?}avDK{etJpz{CdAwYOe%0SxwU#+u zDYnc2Qa#+$cj`RL-HaP~F$X;L=>A zyDzM-URtSf6|GuA>ueQ9&Q*j!hr3)%gc`st%0%d0nK>r?i1N~HwjfI+AFl2visT37 z55*uJRu65^;k}HI%7LMS9P33pgIusK`UquT?{6+f&KjA3KARiByhtfcf%qI=ksVkN z3um^zVq5ZsBFx)qUq|K}_~fL!Gd7Gd(s}G~?BOb#H9kj!GxtKo{F(t;OXh?c^tegp zFwam?ipcyUiDhpr{lfmY28z}Cty4ZmLWv5!=a!ULKZ9cf9??ki@ux^rE@WQ2p!_nJ z91Lz#zFti9Jb>)A?4{&D`RpL7OEv<&xCOFHxuMNIb&0AfH@Wnw2dus_Yj%?AL;}luM_{rdV$f!g|N2DfH@fQm|YrM4K&7>6BmXAY1ZNX)5e*xS@Y@%C44w_O^ zoGWnCwipeB@*7^w$Z2@RX>KP+uLJ|~=SvZy;muZ8YgSDha`BzS{pwsOBT6Se7O?cn zJrbCRSz*V2OByo|+e+fGs3_K;c)3QH77ng?zKV~0y6u^R83Jzd@56S0na1L<-##_i z$sua?-;z9hQ5$j>x~>rLdj2g`VaN3xnT<5+9_;oIQwBPlHH+gU-ht(p)9hMAj7SjA z18ABgt2VzClRTbMo5M3&NTuz}o0rd0(I%!b2{b)IGLP+i-bJPnv1lk2RTC+vpvWz{u$Sn@RgB*PgbVHM>pvFW1H(myAZgh^arbBn1@hE)myqBXVOJ=>x_C9uK$g^s$^rD0&W@u8L1Wo33hNet` zBMO~hqM~NApqY%5D20!#mCeK(7LSW$1UvHGqA-w2%nY{#f_-)aI{sL*S@7KpVLj(1 zoYzdOW8daddJXNza;N)8b_fX<^R+DOtgH#8y{x&t8CT`gfn@jw&wZWq-x4$ZM!1-4 zPzlU?MAq(E_Ko;hXA0W;@(W9KhBr9gp*5K8_H3Y}|I)MB3hef5-eQ;VIN^g=uUb_1 z^}1R-RB6~7#BaH>Y0zQ+js@~z;rW5M;}r$P?^9C-ICJV#?vV6Xtv9o#$VQ1@^?}Vp z_6rTi|8?!sL*YR!xpx+tm5gE6XEb5nTb1+j$u5h5Srsydv+Nw6vC`uCmZ65tHD1yQ zZ=Rs??&6Q~{r&2S)~`pqJh#7W-XcC&O6KX?S1um8Z_1Z?u(4xO7}z37mfS2Zlaa@1 z(es5b?AL!K4{Pp`LO-RX=gMzib&iQ=wygp8q21|oM3)A}+r-I@Lvk~Zn~=RHsk&Aa z#`UfS;6gs#m?<*qH6>vux?K*fJ3Vx6xI%l8aBN5=w_J)^v?ir(!`KpB^XT;Cw+5bW zYk)`LX>vY0)L;cx^-iqq++T7w~``LoW888T71iWbo%iz~~^)4bOw zY%kI8u~Gy!o7|Q(CacHmSr-LASNp%FBaM@7omNdQdcSaRv+qV`a%pDkEU>)dQfTmM z0?mLDk?lOpi7Ee7PH>D|!K*|(TBAh}k+5%N<>;=`$T3I3@G&?LX7LC%H!j8;?{+Ad z_hyr*=gbKr%JDj`$fTkkrMy|BVRFChhP*?lXNtOKJ-i?JGJ-~6e9!$tTtljpu%-Ic z$2sA?$w!*QXz#OAC;QI<7s)0qzs<4qD?EUOHKV~hHd=erMXJAFfd>_kvi#H4Op8z? z(OBILN9FMoZPjP^fUNvUnG-(rd_nPZ3C5H^CJ=*;qi+2DQgBdWKk_ZoI-(gErHa>r z&H+xBbL0;ha)VeCnfrI#TvtZ1EwA`KZ>y#AOQbWNxb>~A_8`A+xz61KSpvHXcccq5 zK}y-JX2eXUj#;RtmdBCpuj~)S(dWo(pJ=2;maMj@&$qCRtj<;&RaMydI=bxinjThx za?bpR8@PlA4|So(`-@3KR90t!F!v{oCs0PFJr3U6Vn7nHqX<|3HgB~jB3Py>OS{rp zdXw!K6a5sbbCXb6D&cmq_!-O_k~DsF=j!QtC_{4Br4kH(%FQ819R{hbEo3ib)yLc` zF}@ptq{q2}kx4oG&cUyrmyhW9p`vnfbLmf5<8`#8{`>3bSr5L)&bfog&)(U;jFo&5F@_B0jx zQX*?Y@k434OXelNTqggfobK@5c>85g-Qz>qY0|+)Zy(DOrXdifWeTzO9mrv}|B9-8 zfa0 zXz6%{L1f}kgbF6>oHZ$f95ZIt!Y(7Fd-gA#I@abrok+x%(5&TQEC2>*t;30Lcq=PUJ!V_i4P9Le7I%4N12j&Z_qa+m>28vdp6Sl=AEm!g69X%0}11Ex2N z*kJUJX--#kxjy>#c_s+pGVTE6Tj{Aitq2-=eC$`Fr8MZiJyhEeulHeTo2}!Bxh-T7 zZVarxOItsKMJudUbF!cNB(r7??1ZZSf3Mgpy}8C64oF7iu(03lJYt8Z%~rO7M5COtFk%LL_FNUt4(YSpPY6{)eVYE8?(u$y#VYds-Q^YLouZoUP9E z^mz+QXMR-!p)o9V^C%?pTlFso4(HTrqbt=Z#h8>>tvTt^e^YeS>Gb6Aid=$oz~w9o zv?FZ;d6%TdYnu~OKg>w@H$S)jyphrhS)3bC=bXL{2$JWQobSEHz7JnT9yv`R!(0f=q<$ z{58%rrS;*ooJ%l6?~f%Dj7N_sLup6+d#xEYw8jO`p{}U8XnESkj7J;!;l(lB_yQJs z?6<{_0MqyQ8*EtAWkMcPl8&O*0!k8_`%&R4fb#`J|Il_AaK6Z3bk_p-x7x*H!u_A? zk>g~l2R2C)edXJ(h85!bZ9hr&$3GF{3w5-s&YZgT-l2IOfP;I+dIq)wC}~vpgT9pzR`$) zv_^8tW!Aqgd36%CXi{0{XV~t%={G~1HK9l@yZcLWK>q>;QB6N+*bsFop~a5~?(2YR zWoDvucN8V{pHY;c1hZBIxjTybHSo_-6h-hWLDG$ABjR+7<(^pUbLo#sE$=;9rm;;n zQm0!zD;Ddli1!!Yz$r4g=-3&b@0sAG>f61y@=F}}tfpCw(`BQ_PHVbh%_bZ4H?POK zy-5D?Ecq6ajk+nn0j}%E`h^jnInPmVz^lV+)nInYZX0CQiaXxl#2BmW#eH8k5vk9F{mlZIH&8zw%6lNJ&;3u?7#XJ4W3k#$$I3-)($PE|iwTb^j6Sp(fT##M9eh@yz1T$gNGLtqN$=dA zW1dm}^FH|F6S})+hv2svqM5jLLr3fCigasr!PKf7EXucyRB+21ZOgn$O3{3q;W+Lq zgjpxXuob#BZ*3!ma%q;;pOsOk`16&cZxP0CKwgszwq#d1$-IZ#hB`pK0=i`m<9UY^ ziJym48g&L$jjAI$N@P<}?@O$67-qC71A|Go)W}J?8tjj17tr>%ORHsXrjAk|j2=S6 zEzyKf)vYjE@8Y_=;c=;q^`^EN!hgU3zf1W1_=n<@wRjO@vH>jCw+dC^pPQ8fX&>EET+_uuNV1oVN(NCF*I=j9^<%w*~0s|d`FxNZ~v+Hj>B)b zzZ>qX1=G4G;5vA`#VV@pgV1S$cprcO&hifg+LI2+bac1HZ=t-ym&YUPvGuMijqbt)7 zIZk&{DsoDM1?nYMXyFjV=CKLSk-~I}QhPB9g%nyE{%I(?@$AOh6l-QxfF_oGVi{KA ztRfmJ><4o~MZj-YYeeX-hSDMfl@{Th^%T>@r4Rzn+B64PK~j?tiR*uKr3MKM$^CDV zi)k|iXgI}2+d*LmY*BBAUAD{;S9Yr)o=(j|4i=-Y3t}D+_pg!Y-M%yB@G6_gW^#kV zX=~az(W25G!m6zE>S3s3-bg>TJ8mdp!NHaTtO=I+%V)C_G<^B1&Ua~9^Ha*2jdmSD z`%gC;m5%X~<)~wIj0(qr4w~(Tk_U)1%IfDfsP#Jy&26frI>2jt%QRpyk#AjmL>d>i z=&M+Bk4L9gzJRA`fp@fn`j8V>9tI=d6uBL0Re3*iGtBSy>|4+@X+`a)#LHC=HjNrT z=CBr@LBA}u5LZpU3w0<5PXx(G5ca@C@3NLF>#L_51VKF&H#ORyGAcN&eu_4&Ymb_o zX~&UJy%RIRG3UI{q*soEw-Z)|Rn3eCJO!ukQ9|1NHhuS?M6mvV@jtP|)I4XAORR2S zmAZRQD*@Om@2C^35%Y^+?5+?BBNtEYU1K{meN#>%I=-Qao{*2aAW;qpAe8gV;S^pm zqx*`hetuh8&-8o0;zf zMW#@mLizJVVH^|xnyyk}DFmn33qcc@!!;YvB)FO-GAvAZMP}D?+jPpavY?@3Gg;XU zFKk)<8c8m8jg!PI>E1v_`3amY{eVicJY`>XI%Jr>Y6cqx$2*jPslM9h<&zzXBR;|- z7fr0+ylns%{`#8CGY@l5!f63%J_&y8-)j?vTPzCaQzoySG+Jrpt^md`SuboufMFMgZLD@OW11Hmp6$$y6WWZ@D`aVVbK@bg zqJhrF_WBLqo%gS68_7Us5jlOqeSKmkvzh8^hWRXT1jrAOItz(yN`#3n1gUMh&(LA) z5;^&gZ{Nul*#ipRRLan9?RKaw zZYIf7G%-GY?TmM6i~%nW^g#fN@Sk;AK3#aEo90y00LZ(qgQ$BhDDyW`y{Lk%&fN~n2CXltp~tXeZ^tE#Ob2x3gCh8RMq62w$f5OW9_ zy;o?vy7ylDx$phF&+{Jl>kr3aJd%qmzww6TLEE_|(d>KU(oFx? z#zpL(8W+Cd|DtgTsKN;@@6<@5Gb@Ne<{i1YFMDSxC}+?CC>Frd?;0F*DZ z%RB#?2`Z}ETt6#lDXQ67V7kM7i{^=h>C)oWI)zQEPuHb_9j?jUOKOEWjVD)P7Stl- zrg!=GAX&FMP$uod8VKfC{z;NpLDwEi zdP}zX$;un77(Hw=b5IU`#5(&i^E^MD+%LiXg2)e5iGo0^P1Gi#G#-K z89|@kf+b;F)VQ{Hi%mC7&_UXS9Z7zo{(co9L`!+6;Sb~-zc;p5gYvRgKT;tboeljF zU$gg)qW59~l$PWnfqt>Da90`RyfUispcq!3)((;<0cb7mn+bRDuT2Hq_qnifY*jCa zx{0O%YqGx0KS3|$mpu@yq@yvo01$z@|H+OPH+f_GT*~$#S^gVVXEf`G>WeF3Jz;4p zFDI6lGdZFH^w+D0WJhzP>6JI7El_{C(L~5_1pAE3Lhjf;*?nr$1Ig2SMnms^cgMwT zc+n1;2&`V45gatVe&z4R_tt4FYrZn`@yh#4TqD|oYuU%PpqmNg{^N#bPT`<>cRameZkk3~)8Whx|R+w`I-I$R@`bnfkHe z*Ck#hNhPx0BGC8v$YIqY?RgJ39j@M=+Vz??=|7-!7e-R7G4Jke-ne*)DDU@H;U(0M zGh%^hHkmfu*p<)n=~S)-+dp1NJ~QUm?YQd>CqzLXv?}+HDhSVjW?Xb(U3B_m2{90?Ldxlm+iAkd!XPF-& zFn4qF@f}^3275`SW|AmipOGF0*kS6cHb#S=^efL@ol8ZEqxI1aaq!m^V-V%r>T#vt z0Px!ba{ZZ1>?ig>*)7hb&TIR`jZc>_7I9uv-gSil56dnth{wi;r6aGtEh>qaNH67QRGyyQ_axI>x#XOZ_ki}SkTYy;q3UFG*y7#o(#{06Sy z`wo!5dexz!SQnz+S}n)3gO9pPx>>{QFgJ4iLV~m}LqF>p#$=~m`887I>b>cGs27cy zo6+JWpQG;!iT|l4x0j;m4gHnxmTWJ+a?G^L2$8!F0}4l`g#!UE}U1H=>jL*294F%*pk& z=GafcG$3|m>_p)YZ7j{3SKtcN5#Zx3K(}e^;FBnuL&Cl!KgMhDS5U=X|(!%f@n~PM#`LK`+!=2`1&g*%YvKxP3EpYZ-Sg`;y5Abr@aAp+K|M5=}uF$ zxbm>rXzMmlx0o9Dj+?M5nZa76NQA>tsjb;_+OZ4Eo!hU_tQRcI@g}B=e#KU$kZP|} z=!wPNgzY_?BMYC)K2Tkt>_`Qh>E7<<1VjU1of})k&uX%I2q9zab+hh}z%$kguh#?b z?OLEi<#j$r&u@^Xblr*5*$WKXx+upW^6L9A0rBzj@Y9;q7* zUV&ru(T(iXOGYXoOnN5zSM-oOt4Hfp$nm;&r{xUlKcRbT*G`&B)eYonK6yVUo=pxj zy!(4y;poeQ%FBEO$rYVKwzGH509(g@?oX3puzp+V%Ki4t{=42Z>*uEY)#9DvmcP3v zMbMB!ah)oT3Q0>noyA4a774(dMx-7Y%I|M~$eviO0YqG6b=9r=^%TFm5$PQX*1J-U zCe`X2j0V4xp@`oM|A@J^M>PYSBNy*ZY_JwSi&PP~Bl69icHgC$^;)k30rwppN<`He zb{fRD3+28TxKx$epJ=3zq$IiIH(L{2`#3D?rpsW~vn*2BS}OoLX!Z;cs8i}Ts|XxJ zxvU$)S4aK?+(X@{ZFf|Qnyo*Ge!?Fv6v2Jj@-pAk7o3Vb*h?~Z;^I)^5c?{UhGN9! z8*(_OLpf*R87kf9FM|e;c^tsh7dT-CoK*#| z@$r2XFaCWdYGEZ>*p3{Na)uYz*(2`RXTcVRTG2-Dc5NrD7M8rAv%6yfLw19c zmadzbG_{MBx=a}{9Njns!2PNUzUG3K;p9LZKK-7#e1KDz7F0=q+xrW}-cPIS{1QFK_8@G*xGnPX{ zZ8gQ#(wyFmrXjN@lv*bzM9-JC!)(agZC|=Cz>Hwm?aQeih(M5TW0TEjiG%`@4B>jW zQ5)XvoKS0X*PA+{`S7TjWz&}I;_v^G!1A4j>i)gAVGk4tc?sN zc+^cYs18`TpE7wKNJ>l=)Zl&xjnty~`!7%;rqlKD|Cv+90=wFGoWyXrBy^qT#?YA+y@ERnKA z+#kR(Fqeop7`%3qsTW1;-XBN4gL158MSPD~^`#Ss)oWTt-$7^cHyFWqIX=~2c zjY!j9RmO7TPjaGw(UaS#SAXf1CB7G=HY`XF551} z%NRs$StA;wf}!s0!yd|}N$E#__NkGY{Jp4`_E(i2nz5a(zq*Tg$aT*V1SvHpP$s?$ z%WEDXhH>+H!7&nAh(lpjoWh8~QY4F@KRZ>A5;|4H7d7?(yUp$}#tPo6$h(sil8mW} zs8{miJ>{|xj2l=uYvgp3xB4Kx079skLx$c z?qjKle&KugIllSYm)9XcV;sHxjD-4PPiBdWY^EVv_n-RWUIc^wQqObp1iQ=qn&(r7 z?3hm7DoSwvg{0J?qhLmkOq^d)66;u%O^QEEMAYKR1)>^Q4#|3IaZt+EQ*`mHr7*dP zmV<``8*ZGs547~T$93Dh&&++RSrg38KT1P^z`rz2_g$$a{ry`=!E42hla`x@K9m(m z-Yv9!OX;nAJf9~fH97DSJ?Xnm{K1bj6H8^O6G$`d;yei&%M6>3C_SeIUsg!CSw;A< z!}9EpL;t^qKr1e-jR8ov%b^=lJ070=b)pEjbO=n|@n$ZYH5Lw`EEdtHz^>MqC`k0G zlh6@rm?pOdz=~7(F3(l7hhKU)&Jq^+u2g^%{j+5^#H6(pUrQEe)}QzX2-G_`58#q| zzS-5i_NYyOKiMB$QhIMED(2fN|5ipSDY^_m;c%-0^n1Xs4sP?%aX1)iIpE@Uu22ql zn_x=)1D723dUy)~an<-A{AxPlKEp|%?t*1OY|i{G3l)MEcCf0X;>f~`?Mccg#!WB@ zz4|ASrUhEPVNNa1$IjK`hpmVXF>>0nbsTB;3id)*>imY&P8G;H97G9#v;wsPs`Nl< zw&^>A?WJnGm5Ss$@Rxlax425-Nax#+k1asfZ-tNBUEOv zQgd3Qs>ljYQc)6EjmNJ1`6uSvEjNUC{i@HdH?1d(n2c^efiO$J+#aKkISzUPjxvge z4fw@c{?xSM)(jnRU1qGpPh71_ZW_?KcCRqU8C9MAka2AsVP zCqJ~#T(YWdu_jik<2*bpsCnCCSq*?(l+b1_C6KlAYiaWa2Tprg6|v8B)@Mc38~QF( z{h#S0`*VCaC|DTxvyI zB3+$Mn9rF(tRvg8sA?yo(p?h25_AQgW<74(8N=u-Qkt?Lha0Miq=oN6NMVaKY<&UU}`W!7WktxwixhN+RwWqjbpG7Z(02zQ4 z@bgjAMp9R|K`9ue5d}$w3@^QyyaCYy@c0FV9y#rZycs=U1YJYb79=?_Z_HpnPWjI< zUJ%ORyTlMwRLY*NbQ`%93zTO3Oh?_}?tuM;a2;8mp_M(BC+F*LC|8xo#_!NoYAakw z7P_Py^_d*Gaqh6pi4tQ`1;|=L@8#+4 zDVoK6R`Xlgsw#><$?dwqV@_o|BrixqtS|6<2NamvV?ZF^Ig_p*l9P{4YH*QJW6(hP z3yw|bfY}|yNYZE5n2{UoVY=;MepZsF@E~f|_u@FAK!zx>k~GwUBRESZT7i0{3N?>G z0g~1@q8_S{QG9@DZ*a^*LifdKqqx-4en{kXBaDlg)~iThh23FWpDyP6n=>_P6gFf> z*Y%$IpL}T)&Y=cD8O5-yJ#suB5jY?_r=&8_J@v~S01y-InISLorA&T3zCNud^#elp zLa2LjeD_a+^mKZP{DpC92!xZac;7h5T zDoi8}p!Uc1FkSmPE)z{7Ht_^S1GiIuHxgi50wQY@rDMr$k7a%$LCKD_W0CU^}N$D}9aGg-$xYPeyb?lh7K?y;G=b zvzyVF2`5Zv@GIbD8m^A_EPF=pb0igX0vlKj6K*9W&BP8E<91OveG5r*;Y!9t*8?kV zsn85Jmh*TJFC{xo^cW)5|7sUIYONRcruhY&xZu@gH1e@9Z3U71mu~dsKh%wmLd`1& zYKl#}n&{4(3Hze1(~`Xi&mq|*bBtj3_ZFX2{4P?#^K!s#+hK8Xe%w}BUsnlP@>>(R zyr*_Vy3n-MQUygnNv{_LOdr5l2dv8->lA4RDEQSF|25XR9O|bpW1v4%J?!fTho6O; zC=+FucaIYSF?Q3v=h$$1lB|a&a$MX^hJRoCMV8ZI#%LZ*bCSNi5E#+Udr&3=l)a?T z)7mPeiUH@{NNwq(FU682z#Y&w~OMPM$5yRj*6iv>*lnQxo+`R z#I)6g1WTi%mgxAvz>>?2MTn!v94!0dRQ8$`Bh?WG?oao#lZqwEMG?lZ_!g0TJsKNf^luL}n84w@FDjt3&jJsW?4D!Ls|IPd-t8JwpZ1F5M-G8R*teNBb!4)cn@Kw26Lr{l@CDA1d z;%ur>_MSpvO9-2J@&GyK`?!o1c>jIdA#mkrwWvN|KwE9Tpk=+hSw*W~QcL0)D*@>B zj{_BW-|yQOi9iwsD;~BmKawAxaH2>&LkcX}wi?7aTB0SvkF5V^`G&w=-YNua9s3VT z@r_OZc)vk2dczFSm1FxPHOp;MRBLkel4Y!wfwy+8*%Tob`D9O!EyvK0wYY&29ClOH zcs+4pdKcE@nw@0^=`6L%CbE% zzBcObc30^@nMgs#b%3Z2?y?F!Ds~yI zK$4_cVd?kNK^h91yI$Vb#A>I5&JK#p}m>Py8q~Fz{f>>}o*U z%1bwZwC0F%pf-2Nw`{+dSkrWt%SQsg+#7c{=!EJU_u}S?%wmP6Ck6CHnD;Pi zf(3gko3QfPLIvVnT^4sVyDwc&4(-ySZUtq}!~i8bg|_+PbG-l#?1Dk%?zMUh_o)r= zBI*n+=h%0aT7!{w&GZ8g#Lq7dir;c`RF2+a>FFY$xG(cdSF_vYb)J&B>>sFaOI_lC zSP7_1zc-aCRQ01^2VJ060@k!=WifJO$ZeLgsc@@xDmhLskxFmMe5C9XI~Ew)Jao`V zG>AsHutu%tSaICN1bjOsImV4@CurU?+WPP{6HQ~-dE4W&61>Z$@4-(x`*JbYRE>L! zWvIpHBQQJM7Rzy~IpbX}2A(vM3Z7~ozb6ZEs+l79qqg5ewg@_iCz@2pS&YIlPS{CM z?2aW369Ieuj%-?i2TnSdq&{lZPxh2%FMq~4LfPB$U9O5R#7#`#ULjKE*^0DmUv>p+ zrdPqj9Z{!{(ub-FH+-(ws&dH~J7Yy6^1V?o!%k6KrIK$N+txgiRRpNPDfBOxy}0Ol zJNe`xMozIYcoGtv9t`>bbp-c5tVb8`&07hfzKs6&rTvglK9H$RNwdV0$H~wOe^vI0 zy5B1M@K;&R7plk4ax44CtJm3I>kcI}-d*#TH;s&wjJu2|&SOZRG!vXZr-2E8+U9+U zmzFUizy&)!_CG#}h#pziIHOCLY|eku602_m3>UPks`muus^JnXKEH33VH<;B)0 zlvjR$esL`^dh&nIlBQWfYBnY!wqWHd&pNOtyN71C?I1viy+)pn2u*|^wN6bM?-S)Q zEL4=s?2TwnY>4QR(JuG*sO9D&+ko=HcHh1oxci_>JwCK`>vt;3u}4^F6pyIVk&A zaD@@>5!1Dujx~9ZLY9DVikX9laG8ucD(<}Ks^aAEj7;nQSN%bt@NDL%d=tOe zu9lY?;g-)PU(1Kj^x2)XpKgVDZ;AW?x3X8X9C#<16eZ9Hxv4K|5F`}TPTrxpO+xXh zINC8~ntY=5aW4EEIBxceEVI1;7lh#yhV!2a26ckil{=dzcBY_&zH%q=IgtX}B;>}` zc4wqN<7*OHcJeaEkwzzO$H5r?B_F8bCo z;H_vRdLY&B7l(3R?{}+@{SE8O7`yL?Gs>xCFN7D}e|Yps+FvGa3&BGotue<+pmJxR8k`U5g0VHA zQTexEKO>@z6po_x(CZ9^_`h#Wt2^Ft!dK3@^}<`e;+e{#Q(5?9feZ;NQ_}s;@Ei~Y zc5VjOw`A*E=kqGJu^Cc_Sa|iOs#FcER@f2T(#X1~61Rt8^;xpv6?r`t&79?n_ww`@ zY*Ykk+gncL=tNfOJ@LY0^Z^-_$uMe>AitFr(rnE@jZ zJ>XACc6npHiDcneyl-BQT39;p@(`(Qbby3^ob4Y~49a-fsm?l`q2nW)}!fv9(f~kf&^3s zLS4-*GO?W(S)r%aUEFJZAz_3X2IM!k%5=0!VY;B0oE#{Qc#5%IF)0fyTZ!BH@1io! ztyBJT{^|1fhmDu@YzS(PXV!x<&=y*oczQaj%Kp73TI>M=Sa=r`(vVVF@4R){&n$mI zMg4JGxavkLB4qbN{C1%xbH(pS*LCqbeKte6UVprMU>Gu_Guz~#*6PzFrZh7gd_+0O znno9ynv4+_tpXPw1WZ(jn2J6C+-?V14 z-y;;$4eyx^LL9+`9>Phf5D1B^ymr}y5U(9w>naJp8K(I58MNjg1lgu2|DoT`SKb`Sg#asrfy(}pQAj+V&N22~^jc$fS~ zm%P8wZN-EuK%JI5$v4IcRZ6+dQG~?1n44r6rczTph2Bf(r%8WG@GI)4XqLPS9csi6 z4pwnSl3k$NI=#VwE3>_{psM!_${I{?%UfO^K<@|yAI~wxq2OOK5pw=G`x149iUQt; z>Cg4N`|hx8l6Y7bxM7+HkCm@F`-}CwvU?HG;44-yMEkXB+BopTq-oB?@96~g;;2t< zftg?N^2+al!7YHv42Z%GkBEOv*gSqD8Nn`w3-Xwsd?37Bq4h209yl)%0H#d4wi4z{8E5<>n5#3>HVfL)gha?o9{O) zN!-pQ;X6Es`@z2$)rK4<*5cSpYjOM{+^+*~R^q^ONYU2j92mNBIFv;#9XuF0*t*Wb z<>5%O*KKL+O{V-UWLM=*?9;Qtmw4$O5suA3`P<~m>E}SZdPc4uFNH*-R$ySi#Ae5CP}XcE8aq8Df_49dgEpg0I6)_p?Uywap!h#!5x(-+EW zSy*=7aYKuEu;Q_*u+VYj!}~_ZldK!8U@Ap+E$lsNe2koXRdO%yLjcU6e7$e zV3d-Ft;s()#kpaR61{iS^i7R%@3_U*jT`NB%{n6u)|SWHwuo0R0qA&@ZSSJw;a9y+a4&N%hgx;s zk@z_O+%n5}x~cbG)MIy{Zq!58P3j+QRLeqSgQwRJhz*N)cG07Y?CpYusgndhP(WXnxZIfCzw*=8LQl1$j%7dZOyQh2}wHst($AhT+HxZz^TxAAGw z$JBg?j{Bzj_ZvZ3`JkS#(1~%THXx&?l}HR)Mzv4hI9$SxPjq6?xsWtcKXv2L9O6RK zo&%7yu?j14HA?Q+5x4lt7{~RM9DJc)o%^`Bmqaud_Y(d%Ngu13aN?L-=`}0BA-4}; zfUl}~(i!)Z3aj!gKD3Q)Oe=Q^J^D7REP&`d+X7JgPBALB6)iPW3!l{)qU{v|m=}jw`#@}!1YOUMx+{G63G9Gd;FTn z`M;*?CgGQOM56BUz=!|FdDdiqmTNrwZ2mM>qk}U|J=>+Y;`QeX$VvjHDgjX}_Wh;Z zc_-(!E{{=Sje-(JMtc|KYAaU6hb?JFVqF*&P`=GJeHKs9#eClYSO*tacy zixrd-7V{%$`GbXas&_kucB%Y1Iiz#T+6~crqQ*0Z)Q|};v0n7JBK?{bGg~z6=o}); zT#moRRd@Eo3~#hdDXqZ^KnDXTw1n-B>C39L!63xoTqYD44IvlM&4_yVLKkYk(gR81zumYwA6MY49sX6O% z{}3{6!+aKYP4+^c&FNE~?|;D2O3h<{qm7H<1bttsv@{ZP7fi|UfeF1SEIG>f{eCb54&z{nHo`N`~%$` z5VHppl#_l6^h*A@@lO+`Ul&b!0s)v9JW8?Yak4-!ITuN5eAuEM-_e^gAjzNRTDTut zl&N667&kvdSv0KiXfj94`A;XOmmzIGr8W$xn$2l(|D7ISalODn?z_BcN`r^8#j5Ua z*=YJcp^W-N9n`^j>4e5^c}ClGOTnK}>Tq}i^PQVwL4_me#2Gbgy3i(Ezb?rM(wEt6wyA4cjBNwI&K#xE+)Z>_WLkr}bd}5( zpb>E*&ms1K?U|DVkcRtxe9v&!J*no2@UAOBP`lMd|5TE`$R0-JR81RZ3)o=Nr#F3a zxJVU~K6BLCpd0?g_B8IQiWe_Ju}}=A)T&i)TWm|oi83A0 zFm~EO0QV_j$X>9=i%f1HlDB zIPqx-24FXI8tvJl;`AS!Gynh!;85cQ!Ha&}8C+L=h*@KCSyb&iDV#R*2cR1l1mVn^ z$L6&&JUsbd_1y&(c!bE$yH~EhZt^6P*H6Y5bq!EGJIhu=E#_o%Lfq)`k5h5%LZg_g zc=Pu|F(~}&4Q>9}x9^6xJoj_pe>STU`tHG`33o>Z?et&Z3nP020XF|_Aq!sdiuS?Z9RQr^Y7V-DIChh1&fK55S!yiWu5O7e5u35rU_n>y1Z~3MA*oy z-B`73N3FOyNx_N3GJpxqK$Ur6r`DjsC34K@=AlCo53_Fi#kXQ(!~G&|nMqRbefzT< zLrRs%O%|WiLd5A^nstEo46OT9=Zyg-MPIP`(NHZj(Jwj{fRkIiNrVpDm%j3;T6TS- zmAHQH%y4v3x1Jr~QPDsAG}y;Pwwy%P2-@kRw&Wx>O^%zXDU0JP;ZfzbJV&=@P_Ro$yCY+J0^=SV_$<9Qu-1{C638G24Jh=4<4*EZGNTRUw9RXy$xhricRpue8jgB8sMB2?5PRSq@m1p#vadNQ=S);Qhm z9}YJ1)s(3j1Ue#R`83FO@F*?k(p)g% zBl;^{zKQlVuxgo1$eG#aR|d4(^G=sJ%;1wIX-A$bpG2?Xw*z*NZwo^j%w->Z6Sxv~ z0O$q5J>0fd=(j#x)H(-zr{V3|6h#OnmzilV&C~I)XR*>_ZbTs!D z{38s!z-@nbVYR8P=5EBz=zGC&!E8mvypF*lF+9Jm2i6YL+TghWMNoFH=G5d=aiuFj z&1;lh1PvAsdLkNr@i*`f7_s}NU9CE)w~aj(#S?SiMt2lz=6x(b3$HbNogp3?=<;cW z(2X2ko?ApaHG4L~1)H7S2{c*Iby?LBQ44lSFcd5o)8x$lKu|~AHL!@>!B>vA8~1~b z-(=np)jAMYcU&i0)udMoBCnuy+=58-6j)9$NiLZNuxXJ2RrU3(mm;>L;u(yVv288G ztKf$(4M{)1Gq21gDkOXu;2|XBtUS9u;+(P$EAMpxJzI~6jbn-o*12gv_25F8%D6wB zcwPLoy*R3JoB2gP@Lv)ChlCi}?*Eu6EvzE_Xi0UH2x-)*0P;4&eo60b^MoQ)XHHQw zuP6<3oa)cCKnr|^h4zPnS`Y&jgomT0>%4?lt<}uo(h^OCg_Bu(CY%NvZxN#8pB_|g zeF|wLPME2fcN2s%00~4iv}UT|3c>@$7_@fMXSx^*~@?xw-e`oDSPOpu89i zN;r!smYLz*mOV1Dc4eH<2YCT-)5w(XD%FEQ?y@iuz5b;s-8)MAd(Jg$H&;4Wxas$= zZuRjflx{^c7W2Q?S4k~9e4jfCbS(N^9aQ1pi=tk%11AF$g3MBH6}gYp8`_Px7{7C* zy*(KS?$%T8SR)M)Hk*Ugw303Ysv#LC!2WU<;XrlAJ_EqR3mz5;&!g)xyS1F83T(4U zKGg`>AM%r`FoUX#DB~=6-e>-H6%7pTnT*vT_FYr0x{2&v{aJI!!5wb=pYKSl`B97* z01?5}@n47t_Ao;E$i5^kT#0!?ALNm;xE89w>Jj}eqJb7u1SM?Psv^4XaoH327yEC2 zGP&w*(+D}~+6vJX8hq@GG3)Y-sD0RF{HoW;p6Ep={6)sQRou67LhH#m|n+) zA{70Sn#jY$(co%O)eH4T%E0KNr?e(=q4ozntwV0l&!Y=UO28rN($#`UkZZ% z(`y1PA6QN?>Pj9{3qQ-hMK;Q23IIsI&LGUn>%rn`t^H!Yob^!MC7m)p0BrRtFJ%hH zFzqOQuW~!3Qogo*Ao9d_o1eZAf&T^g)j*uv*WS;JHe#90Z$1pT(v0vQKVj{4bT=dP2K`h5l#g1N~w_ zsO1Oq*91&nBkXctmXq3Nk8|T>dXqneAt+L)1kMv6Ij}~@FANhpD{1L<$FcJ8>y0T;Tkggxo-<5F1!wH@K`WuJ$eUx8d~DZ9)4dcKI)1 zTtky&*Ax)y&~HA!XHyC&&Ckp!6`Ov9ItW;OH$KIM&J_dPckqtj;CpQ>?iyOGr|w5oQJV8kQ&debVOhc@ z4p{hOn-004FjD3~IqvV>4v{phf^e^rk9mhM8^)mn{QFi%!Sqw288-2&9%ps++-}Bb zEvw$2m-$~3)Yu6{k_hB8SG+ve#Aml^Z^y+Oh%D*99ZYZD`%J~$SNpTx-Gev7_uLHd zRrxXVnUcUG*s#X#;=`Z@yK?+;WpYq$vC6FrLJP>o=w}sw4^OizOQ&~8{nrjPok<6N zWz*#z4gNx7u|UfBZVP|YNZRp7ZjuN8iq6c!6{M@DG#rG=sR5*O^CGBQB>jao!v$wwGe~+Nn;0Z2t zQm9eUVBG+gwvMP*=7#~e{`s+nCBN=P`$$a@2VUX7W>aQLUv%Y$u!2(@I$>MvAl@DHQf#S0<}*6QWicNAvQ4ysm(d2viG)#n)&OX13%(>;(^DTO17 zlF4&lR1tEn1e^ng=7H7>pSM22sQv_Omp+)#8~i|i`1 zfZ?aWnC!#Gm`r+amoupc)Z4|BUd)PUri~gnDANwL#jzZ5Pa}!{Xj9{z|JkNS?%l*d zyK7-NqJYL@E5G$4*oOURyz%$J?)L;vFNlw)6aW=}k8pCC0_THa3(t5xPIIYhKPfw~ zO*V1|fRNBXz}Busgo=k2SL(lF@7JU&Jgi@%iWjg?xcQ#79J}%L#c-%Oqw~v};QW?h zzgeVI=`|jnoc})Qc%x$g-mh;r7&_7~cgBAb+iV{ATVnh5z-*89H3IucK1V2n{TzJ% zvg*o7>{(V%xds%(*)efec|Ac=_lTPy3G(S5kmFZN80Rd}2mtMPWt8~!c3{$Y_Tu&h zrk?z6Gx3=@EL?K)kT-WLpU6FC%LnOuuGJDF<)501#*R_tjg^F0X;t-1c}2T{13vyw?c^#b)W+~4suRFe`?e3T z#lo&^@T)?poWGPVTd{quvOp^lYmTeY`EQIBOCu0tFHiq_q8dlxe+W_Scg=At$Fd2GOPTc7094XGNbapc5k3^Zj-xXs+!b0C8Kf3}v`$l!RyBBt>W7 zQe?{0{;SkR$$t8au~?Svz)8JEdo4B2Jix2oy!6w(ZxliaXg*xuRMB8Mma%Gl@}4x36*pkF~MZQKf3b z_4yQnAMLjwCFC*iv{UYH0~^WB!|6S;Z1_8bli|-;WiL@SxRI5_%-P*cn}mr?YyRGY zIV|)5p$qf|;csMz%1=D6JCdTUdM0r_yAp14Unm)1DtAV z&%#GkuHjp+c4Bx0Cx_L5sHM5KQb<`^BeHx>z2GAr}D@VXA%*|JNLDjaY7dyIRSzPRdI$z-xQzcP`a*? zZ2|al_TRym!~bLWay;zq;yBC!=(8KfE9|f0{t&$j-}eD}+uLGAT91~P12xZ(*Fb+Z z^>p(zV3i#>I%oQO&ue>Kr;ks6h__K!PFf!c5A1*X z6Mz}%dh6jg_46=5#(lZq&(HMR)a9`|n=wtJr|vK#Tmr{wJvZ+nEX1zC*!DtVM! zz!`prpSK@Y0;@*cJ|b=hh0Px?QzFLhUmjJt3;u$f^m}Qzl`^?HhZj%Jo|szgg^y2G zZf{l}M|yuDB(>31;yZ?}c2PM8@fo%=Sg)*G#-PyI5`V{VUFcd?-X z7u#T50->o{Hc%cdWGp0p$y-`S@h`mEweOkfygD;0vHnzd+HW9@W2z=viczg1`Ie#Z z^k{!<9=W!0jC*bCwEAHKih3WK@w7kYvPYhH1FF ze}Gh%o80HEa6wSeXM20Lspz*hWRPzPrl_t9cq|HJi-@`FG~$KSo`b7rm+9$9GGjNo z(Y(1>o}hRnX~7)jwO6xZLNP9DKoLx|!>H#?Mg)Q&YpEtbiq`lenK#H@X42c+l4tCEuIlXKUc^{`A2{`=${-6PaLfLoKn?CC8hyR+h zCQ!f5fBqNF+W$o>n_DMDEq=+bh=4wvV%Yn-%y`k`3VOT=mKD|eI43vZMI9(CmMHag*TuOZOeM1XnzuT_a2>yo;sc0G3u$|+3v zI&1AM^OH%OaQ#Dpw60k-OV@}1`Y^?RsFnN`80Xt`7li}k{MAQR{_1zqTN=T5OHqDc zA5oDEwvX}+wYJEbVNTyLQ%?At72dDn^4H?b%W1~Un_3VJ#Fr8b;WNgLte~JuXe}}1 zI33HA2|oCAP#@DKsz?8HfhL0BXKX!Ij)$&zHP*YCeJ~p$WIII+70`GRYsrU7tns-H zaYbiUy`p-?Snv35!cIpAuR1SXa<-kuy6I!1^aHXk8Mt0b)9Q*Xpxm6<<|b>;>2vBR zpI(i4I6u3FUnmp9W@kAm}yMIyG(1pl%jV;?GZsXZyl1_&GiVJ zsdTjIpp?>VX3MxCcbIi_TXUe1ie}utv}$S@IGvSu*83E)gp*$FnlH@UyGqIdw`mAt z_MR4MXH6q#P&|OvfPXj0t`7OcBfC$U8Z6htyu-Uiv@1xx^w!ZXP1)vyuD!48`tED* z{n?*y^fha9>D`LErKF^!uUx)(LrUrwu$0ucWB=H(^^5V>;Oko-+b}oGFGv;BmAPAm zpONRS&Pz#^ChQU4`DLrT>+WR-jFi;AhCkoin$h`gQc~~gu3S8SGZZ#k2Y#>7o;WYT z9n8?3#{4m$V1O*7*-N`Ve*@@r-v4oY_qenCesC=D{?1~pJcaX0&HH^@Ojw52O6;e8 z-3e#hDnjbkclFu|52hw29!PY!g(`AE&DJkw+uic0_2_mAEUMcek;LPgnp+L|&sViv z=km}0S(1_(I_!!4>2Kw^XZ}n+ZJpUHgSIJQuke_IGHNzzBa}IgbQlJ^xhUbr83kL2 zd0gJ+{3f_=Z65v`!frBBytFLyh0ubwhI?V+$LE{K74{_NfQj5v+J+GRvh>QvS`E%} zCq6_p9eEO?-KO9gdPL0Pk~UbIAW1uV(?Y?C!KxRnbKwstEu%J@YYb5>J9|Q7etlKV z8Vo4%^&X49}=92hD4HYQ4p?!EJUogmwn#fsL(VtaPZt;g?Sc=Btc4E9n4%KZ)uBXl5uX|d)qCn4 zs5=*1aW^O~8A&>x&FiY&0^3RdS}a))y4k+;dsyz|ICS%_$U5S4Rul_kprRd{Mi5Jd z(G5F@aSH*TfTJwzRQ=;mua*x8vVfVlxyr&HT$2!7I1U61XPF~za6lz8y7?<6@QR?8 zbpxQNm|W>(e}MFYAtIgV75-)_;u0IY;su>Eg^Z2?3?EkfZFy+e?OEMJRPCXcOT&|E z_Et8>RxC2BxeC2KCA3QX>crjheM|N_(-YnmKl>fX;})otkK9+L+CS9WK<`GsI?;I_ zy0hThH-lepWM}CaH*S@_*X>qJ0ZU2QP3C_2&lNP}Z-?E$MyDy2J=NRFdN%X{L- z_ck|K4!xjpqA}0aPB6NMvZ4wiMP%igM$>f-5E^DO^RH_p{2^z>c&wmy9l$9dfbm3Q zN*xHKNmy_7~|ZX@w?rQ1d2q=htW9S@H9D0hEuQklzrV)g&UvBg-4j ze)G*cVU1JW$7q-HzB_z@sjIhE+$DES$u-+!@JKDJRoUDVYtTaYzGUIYX-}`(C$aZw z`AtZ6Jaa$2(GP0?M`@1Vx6BABdS5I_X#4t1TWlA(Po}$2>%f%a(dF+nH4696Cc%fu zOAOeBT_~1W@17Be<%M<(BztTbxYd9Jz|ie8u=aDp1pcj*`k4ocSf0oJj6IpfG~`Y%FJ z#RT;%w6KM}+NUi^cWQLA+yWo@7)qz0)9b z#vH!%5I3~?Ui=5l?F`%}cz-v$pk3L?qenX(0zw*<7D2(^qSF*CxAZo^){5>ROb|LU=-7j6KD}yNB zM28ZO^_!5})O?Mf>bA+{oR+y0NxuY^(Q#vz^}#E{WWB;hV8i)dlw2CAor(Dbnf%cj zT{@)nG!A4?e{HTHklgE6Et3SWj8;8ghDPD`q|Fu1OO&ROKd&jDdh!9fR@ADIIxR;_ ztn^*Z&HK?lK-8)0{xFFu0ICiK?G8v+%KTO}ZPRAEbR8bm8+?vo1C^1A?_ciT@-&5%k)L8- z1{Mo0HGQJ2_IIQN{=hH&2xZ(IC&mIMi@G|qS^|FOis!+1oA_$inLU?fo< z`FQ*dK0^96|bEDqJ&wX(&)le%h8(CNo(9r$j z$uBWt@jAC=%BE*@&{f9sl66c;7w7TwPqNE@7~$=al)V8S&3@d;P=q3Vz`F6XTC!G1 zv2*78b*}IQ6p-VL6uK80{K12Gd#_P0XW(M{J#jwg>sn^-%XFW$O!k;3DFp z`~cCPY;rpwan_s0y@qj*{7}_cw`R)+^>^x{>?Uoa@B7H3Q@ugk?;-TSwu=tZ?K0@m zeKKSAuctSsb4tmw_d@%$WttsrA3&AMc6uWUr=JhH@k0s@{F=T`{N+K?7X*=RD($rkuSW+Fl#@c9qK$f?+4iPC49*j4uVi#fW%ZSfBA{Bm zwZT`j&-_@+a;4=-I;>%XBic?;A8nZLOTte9U43!8>W}zjWH@_$|3*=M_*{SD2~QiB zS}vDpuzN7zh*w~nAltBE5mK5myr>}`cWO~%U4a&y$Ba43@d~ZA$gxh@gOqp{^5fbQ ztzm#Kqf>|_U?zX5I<@8>;rFy;E_lJi?s=tu==h9etS>T`QwK)l)whaS7M>7d-uM)_;n+(=7L*3Pb(oGX+vf&4zO2EpX<0uK?oMYtiY_$z1i-5@z;=uy0E`rThGi*Ia>C?q{}r)=A5hY$_gK+xGgf z%Q00Pbys$*%)ZH=(8m@TcM45T>yAAv=HIiDnbqqB-vm@!8DxW8g<1Mqg|E`qmgU<_ z#iYn2weLLBh|tdYSuLvv@yZrJZ8j&YV_lKt<3GQ;0hCK}kh>nx ztb_AMXeyl2CPNWD7gMuaIr7~DE|pQJat+EwZy&Exc_ZBn=gJ-+%PFiUvuf@@6X8mD z^l509!mAdC2fSon)$@b{L6$GrEG10Vk(TPrZ2xfqJ|rL7=!DyRXrBBQ@N=BnA(vF(5~5| zFND{8x;TZp&}9A7{tlvJX_m(sDkh&(!=gQd*LRM5Z6W-ki!_@NZ0l>iZ`{jVzowRD zu5;Xh0;3P;klpfDZQFPDcLrD?1a~0ZLjBE1XE;^*XDFeS@)BvaLgBf4GQb;uK7hW+ zV=$E?a=yO#kmjJt(yyDBGpp2D-YzxG*L%$uC&(IdNa#EkAzZm@twI~{&M6#=HSxEhEK;d4A+ z71a+G>i52=&3O?wRGoFHRN+Gk$6~cpTj{JB&TQ}UJS#G1=i2QN--B+@FwC99E|?3g z8V_-o7b`(0Pf91KrtXPRM2Vap*1=eEvhK8D$sWImE)y#l^<8VfPYfW@ZQb6?$QbmN z-bJh|@a%uT%qte1^#Xk3*!GZj7c-YqC9S*BXI}Cz-Jz8p^m#?Qn1Z>Lr`j;HSgam& z7MdYYG5{;Ri#w^tIZn!+^J(VmLO!(YOrQKg#rOTl*_o%80O9WxE&p*-4Ym~v?4xHP z%aV!ZgqWd_dQpiXhlEdGsgQBJgOBJ7P;)LL;uhjy+0~X49N5+VVQrxjm_s5HFuZKxGo$6S)3;{a8?&b2xcM;`Pn#EuMOj z9uEUxRu1OmiNG)V6(628=;+77dwYs)cd9%A$O}Zr|rJ_vX=MOU&9XBbA}-&7Os4Fio(OX504n zxNCHoX72eA@jYMgT317`Ps;d{bH;XFH)(4(RJ-1y&2SguC!a@n2UpS`)(u)?P0BOM(7Af?VZ#t-X;%UNNig zzRiI*Ws6M?&{Z5&P$f7B)gAVaC@W4I43zC3u@~h;rL0&fy_~#*6#wxSx>sZBM}4V+ zAgaD7#qEJOWRo(rwK&-Mb?LbR^)~cprFG-7WpHLr7{xK`jjn@8k+gPhOWgSk zN#tE5s%@C5^26vmedFzrS(C3uTcdtW3N(W*RnNy1BD_ZPhkxm|fN$OWvlnbWw*|SP z_O|t1r1!3W%nE*K#P5tqZP6(`#tL>ZUir8|$vylC;FO>qZaMy`6qr;N4T_@|t>BCU zodqBPsxoqRs&^ptpEh4w-gf<)m`3vxB*F`+ukb+J=2N@Z9L-J34u=-Pb+SE|R0 zN+(!7xUvrwY6w@OiONfky1v61K`o_A&One#_d%&g{maYN`vg+0JGWp(J16j5BP$oQ zbZ+zpkt{r^7TCW1F8*Y(%%_pGj4D5>EHjNi2r_u>IEXRpl-<=_PkP0`R=s&|?$7=|hmV+E?Ygj#-%q zg=4s9kaiEN7#CUCnTgUDvImlyCfD}DkAQ6QGzl}&i(~Z&MzVuE6^KgEzs^X}9ocI) zCMy{3eym)TmRZ((mEHamH0q+g57Vdulv<_-?l<_tqJT(}vnm+>p_GE=0Z@9T@VrGt zA6>>q8IqtH{Ev;!vgfL=3wOfb7`Js%#GU5w)pi5!n}o^cS$7}twRer*`Vh%$wUSFcg7{L`IJkkJAvf|$xR7}idaMKXjFIpLZ)H1BI=%y+7UZ#hU1G&+Nzc@( z+78c&%iZ_LhZuv^;v;7q5H1^yMEUir8HxV5+jMbIRYY_^_Q1@+p z6%kNU)qsgr@M-eGYByG$W!PDkD%CF?o?kK;CaVU#1Qv9q+u8Ii5ABEhy)b~cN@Unp zr=_H-l`=nYcYUS>yq^3?>hS@huVzpaoVc6 zyZ%N4rGP)DW&W*vujVH!o4=JmyWG$HJ6`-=Y%er5E3rXCeI#$o1yHv@%YDnHibA^~ ziJB?a8Y=cjm5!M^%q7(3MsC=-`@Dc|#7i#s+lpHT{zQVrRde|%I~;m5>U4SSc*N#L z1DEG92kD)diOk$s9QE@14#-32QhV>F8lK`T@h`uI4HyCr)a3FN7rmT4d=t#$`QFg3 zhFMx6sd8U<1Vl7Y9F0b6_m1_~1j9 z?Lo_Vuig)m8!3uf>KRt1A$Q(w(|y{rJFpx8uq>eGwgg-mMFBicvSK1k=u=V<7Ea@h5TDjmYJX^ zgp1M;wpEl@d=T{dU!~Dc4WyYr?-nWNn3)eAx*YecYi9pCLuBx# z_mx+gO!v>bDlZ{|73|U#ob$R#N_y-QPvv;K;YQj$%O;w}W2b1R6xiqTwCtKi*g`}G zv1SfKSzcKh$f!7BX5_FJwYX%a7pAy=gY7?QV5I5#T9~M-ET`*lVL;NPpQmgC+hSr= zb2qr0fSG>)2T_K9qkZ9O#UA7>B)WJ&5Ri%OXE&4CgXu|_fL(Bpz)6r@KJSq`+E+~Q zG*pT1Un|E=M$}6zy00r;ZPO9+-K`B7Ex9nC6W%N4f{&9XrNJ}}bvQi=Z%w%ez&&V{E_?6}Z`!p6-IcHlYEzPQ6Xd?-r=K-S(=%zTfnEK}IewHCQGg=v z)B2)sCz+K|FT#Yr_OOaR$NI;co#>4PE1C6rH@0&^vPI{^kchrK57vAi@&aVaiPDFF z;V$K_y{Iy6I7W8}T%a~YZcZ)>#z>ij->u{J9f%q@z4%QWzkb98-lW_%%U z;;%47R71d_N{sfcNM*QU{)yw)pY^b6Ux3V=|fR>Fk+N!00Ho9>&;99}>tQ~8uEXiSt{c($qs6n}rY zx;gxz9cx5X>Tmt@O95eOltjR|66^bf*}>W^tdf^N_eAkR)<~h;TKd6&!n&|tKYK4y z&d?u@GS>d}?g`+|51R$^ey}UWVV4)i+5^K5h?1jrC~$~?5i1Ib_s%Z--XisMPTk|N z_o@xRdg;Mqk)RMQJT!RF=G030W({2OnVKY-vcNx2SjWFq>xe0wh$&*e8!q;^4v21I z`On<&JHcY({9YjK<`{hUbn&X6!Yr65yKON=P^n8U^k?h`5GE=Wuzq@B3 zznHnU!ZUt9Gy7x}-y(6^* zxj7x7drg2Hxp^bL94nh%8hUcS8B>?{3ftrG|1Z2(1q3s#&eU%^B6C4{>`A&WD;<5F z9Mwb9Axn`H?-0G($^2&lEr;fApHn`jKOu*oDNmcdSd+3 zru^;>wNn8qA2Zz7)%YB}gtS;S>-o{hSAN)oz3G{EplHV>C_%otEy85pR*3j6yrTy$ zq8>loF?hZ~S1xUF)wn2Y?tskT#1DJa)rBa;$~oWITPgwRJ{#KlKe;jQL+ydsfdh*UY%)SJsYN<=xc_3!-!!0EJ&*fhBdG1;OX@UI}s^V$sEcn?zsH`MZtlT!*W z6?dfmqqBW;bl)v#h7oF+Q<}q{=Z;R9cPY$zsEOrstW$aCwGR&;wz@ryv>czA5{wOI zN2F}3Lo2DjcPZJE=LXi^4FpZUZdhn4LcQpJg(@o6jL$ZrR+&>y7eeA|by9iQ%zXYj z%G|fu^WOr!o{{dbuZsL$qOMYSuia#+a`94`IaE_u5~kL3kznjmLa?P?(A5#YpI&;^ zs#T>KK;^$1ukO|EFW9JPwy#XoGFmdax#>CbMDaO-Zj>_{+xca*#hK|=E;4no(TYSF z*46$>0Tb&BX5xld(Czn%-bblX(0P7KC`*D zWHGh!t@1=H?bJ!5?V$L}8`JnygeW!WfJTUNgB&pUMd{d!b=})m$)Ge`{Q06`_wZ^n zr!k8hYRZwAH@eJLcc%zU-AlXKFrMo5Egnp0SyIoqA5GM#-in|-Liy{>$3gfy9Zm1^ zkJm5wyVq+~#nt9AqmOPnEgh;_ub9W22s&Fs7YT&>H*Q$^!;1)8Bqv9?!vM^d_NWcB z0FneJZjC(p4VWbKs2uv)Lt&OmsY~qP9A8eK9WyPxLMo#N$aZ`OveP#IqeJlwGe z>9paE+)uwRxqsd-OrVDFP%3icV^m0ys|M8 zK1}6Ul$23317U1;TSr3b_Rf-5^Z|K_l| zqI({ISqxp)QGBVS%(|F(INT)(8*!Bsh2J+du|2{se7$4yvug?AJtdj(=%XLR%sK-X zK#JNh4`tdq81L%2{S?5PQ7Y&^K|oDmX5+3aciA{L2M?RpPc?97)5Y6aNno`SZh?VK zlZ>a_1twF$l*teKxM*b~s@WZZ(4DPP%DOKEg12*zu5QhcOSs!IRnrP|3cwe`W%}M8 z`21aFRnIz7rdR>N7M~yGzYX^!NGqsA)OUzhdC;!yxKr~6yFQX?$se8-gWin4I1kAQ z9aKJ){KPz=#xGS6ugntwM9;WSWfS4@$B5JF85zVINH6Qb zo8>dNL_u~DTk$f*FPL`zh2GXmC`26`Y>s~%b>&KSUdThZNk=m1}o7hl|&aw zr)Ns{m?bA*UNEkQRI-H%=tC1$7z~>;eEfEs)%UYopie10%=8`YN+@tJ5+q?CTOBUh zI7JucjpjzoKD_;k6cG>Hq+120uz~a=CX{k?vK-A%@IL%p+QKDTs%&`053qeG0$e9^ zn$;lRHS{}7DWQF%AZSXmJZEd>zr3M$~`~cfFh0|%}WTN z@LgTD64lnNZ72n$NuN3Z&zn{7a9;x<8hf;N!cmmwKbr`7lFV49$gk|qEE%vbNubM8 zv|UZ|;=Yt}COT+&c4}Rj#KCphH)-<=e8=&R6bd~ke6qI5rl2h0i~Ykl!~@1kdxXWT zJ);cBD!2~!deyY|(GgQXUhElCeX9X7Uh?b(vlZ7RaY!Xhh%xs=Vh==Q_-6!e5WTVQ zcrTmFgMLph9GVxUxX7wZJ&5G)|seWDNz!!UH(xrIa4)h5@`sO1nUvZ1`ttbDB z^Zf(3Hn!tcV6)C@F1ATJM0V8jU!boP;7up?@vJ1N0sHrdy$+?4ts8~HQz3$FuGIfA z(6{3nI0UdD--vDX#Yh@Wo@uUKLA+0>zfa^^oS05#{npG#MI_-oS*-=TG~W4OCgGcAH&yERwmM& zKW(EM2x-BuNUHW8ivCxXapKZ$MKG1jdUMg%;=ia0MYNG3lz;bWa4yKyOJX=3&hh9#fM$E~dq9J{2%e6yO1+@^tX1$N5_T0R>uDEi5)F6zl+~Y6rB4{uBYHuA0$VP zVNWTL2=V23BwwS1^X$o>3scPbp{uS3U3{WpC~wWg9HqPji9U%;P#eXj8ueBo#Zk@^ zR_~T4du;W=z(zuP4$UrLvmsP=kCfEh6HOoNPj}0d^T|?DCq2^NS!Ow)H!V_=483PbCD%^#+U`ZVC0?bIZ+} zQ`G~1(QlRUKLX@`h@bo#R87E=f0@%{4YMprml|~M@w*}td-=ifwDvWJ1mrAgnl+B6 zCrS|?`ZHz!2J58!{-Fon_O~BGe))e1_s;!q$GdCy`hSCk)u;ZATDbe)RR7fvyTO;q zoByX3Sij}7IV$b^TYK+qcQpXlZ}e;lf(O$wH+l1Ukz!=TzQ6rLN?E%0|Iwxe;QzOC zzMGz@2F2f8-Tt>~K40JBtbZ&2_bailb6GsW&5INbawj%V;0+^&X+j1QcxT3Dltg*4 zzOg{^8Bx^BFn2}c%*VAb5L=|pHlcvUw291(BR$Nf;*0O61!>piXN~|}5Q!%whlC=w@5L|h!vxj60mo{;m?eu^Pp;+LS zDQUstlEo$m^9crYi^gJZtr<>{;4sKeWY-SpzZ^_=fcmPeh;YjlF0!K|&V@~7A+xp9 zlome6J)r!=Z9g^rk}@DgeJL4KKVIQ4GtSZ=sl+R;o9x6glN5EWnPm#bwoLv9>DPIl zSS;%9zRZu@E6zsL3g>L@73DK4dlr`#ziMt^hA$S9?=qV_1~27~6e=#sG5TQv#qoiO zU?d%8LZsyDSD=SUKkEbJ?4i0y^;+T_ZMDLWivAbLl=<;KFkZY6m}yJ_Q;6wxylkBU z{s8M{PbTMm<-V}d$?)|51eqh|4xIOUxp|m<&7Taa1uq2fA26LPVOAz|wKsFuG z*4Z`A%z3P~-Ql{r9J@=W__4T@ev~urS<1kuTiL|rdG3bE>VW|bx?bXfp>s3n5fX$F znVn#jXhW5&lh5RxaHTyCZ@!dEUhEw&Nh}Lz*MhR63K)QizM^4g=aQ8F5DHnJ6#kSh z|L=s6^Qj8-TiN2)OmkgTEju0{16bPf3Gw8PT^xn}ld)Dtw2gdNFHy6%n0#^d5;PK) z1@k~u1XLPqSTx=}o_UN?r*G&=TfD2Kt-?+u5BJH8Eayc0td^boebr}Q-5J8f1QVt?%Im^ijGjHl3OYTuEZu)o-FyoG#gn^Ugov29kn0Hni3dS3e0P&D z>$Y@vz$^=MyZOCyR^ygnpuY=FI>=)R=bMTX!WIuJq&kIR-u9-JbDBn|Z@|~=d9Z-4 z;l?FNN}Y*jV?){xHv8!JCIWZqtYiI|W{*>v34-gz??*=DB0DY&IGv%=DL{0k<+ zjSaE^!t;>1@XPp{4()PIdADj@Xj!m>F(vH2w?mf2&BCE0q|CAiiaqc7${_1-hQIXg z$TV#K3!b-AJ*2RdaCzwy;MV*DZ@}v!Qq{ix6f8Pv;=$-}MFX>S`N(DD0J@}zpo=bq zzaP(&Q!}L@Ur59+{tZFQPPtq)otjUfWAlF z?UmE2x;Dn*pQSKd*AFlEc@cfuf>wJkQ_p?$+8jU;(~`{S?E+PJVINoz-p`NigO)*@ zsuO|Wiqe5t;Uvq4PNE=5n^=*7L5W@%63$#Wc4FqROPHUJJQh)Lo)8$2Q+Jz=$hSpd zslI14ed^QXWD4s2YQnB75(^IeCYdfpXqL-M^hF{WA8q|_Z2L^RnTmk2aw`~sr5_mq zLE^r-uI{Id_pMVW9qSMywVTz%$Pv=WNrWo*7%TIsAid(Ai9aVPpw^Hj8~CCfL)c>v zrA}O~GRw5nelWXIV?G%YQsx79a|Q(?UM;dBGfe@7WB$ zFm3H)WUv#&Pf)i4HY+jj_68r@QMW^dY+|>{}US}04M~D zH6>qzFSfKGOhNNIvEdHLR)?B8!I~0@cL+Yj&+92%4UAB?zO*ju7%9DE_>8$8>$*)$V+{#l%2GOQT2 zA;QtuvYO_O;#y1F`QNd1Jpc4e5+veT9HY+jQ+O>k+MpL0lpl7J!vYT9(2cDZ=5li) zbfrx_Y`k_{Pok?JtWK4Exiu4p0fh%TqP|Zw9R&j3TUKy58iS|sSL)By_U+OzYbH>R zIPA@U>#l_7$EW3Tvhv!42n6dEQMKB0t5U!ISuylbrl8@{r(qOND8iH*Ujz5Oj{RuI( zfrPkyt|#ywCGGl6j1kpa>@vhj)cAyDU#uaeb2D=QXpbE#J(6t`W+F<*%?5p_JrvQk z@S|1c-Cz7K^f>&=pqUcyf+9h_@Rx0xhGho zq-$#p4Vxd52tFli8j{KdSIuO36}ZUJO(zgf>zjXLbhUyzj$IY>I>KXevL`S3ArIBl zX|wg*R>eCN+w_=2I+No&>^;Yh=POL~wOf%g)xVo+?+EuUzdKSw`lgM2FrdK&jswx7#X*XWE#FeAVw}xU&sQH zj`(F_kNv~KZbbjqlR%tTT9{3)o>1DMZ*~d~43=bRBfQy4g}&di*X+fb)%Zcery`db zL3NCb333tB>s9=%&FI-U*$kIqlT&iO>=cG3FZ}q&WXHScv(5js@jiRw_nW?$;7t^& zBv}Mwp%u92(q~wDN6^U#{0n-hdKznh?~W^6BU`Q%nV>$dx&TbW=N5rW2-IHctMoDf z#S>*J0B!9S9NsWrs#PK1<_&=3KSG`X7-z_a>tQS9y~n3#|ONzF2K} z@*KYvg{Kly*$Qvo1K1(wo ztV|JG(>r1ynw>!IJ1`Uwd62cPCD583Y` z8^JY&zp}sQlx|I$24h~bktGQUS{r)7+^kodm49%O4%QQNC$dIH_RBAVbaH~8vyI%K z9ZKt2qmPcq1|<`fk~KRH-KRUJcdUZW=gX)dWw)JI$Fr8ey~ z2)1viG;?HU3h&NdudVYDoP$^wna=0gAqPaA!omUTzwF)lm@RvEcH-k|-le3MA_ZBH zF}ALw-{N`Z^y+1gKUJ&Wk-P{8yMNldmfd?wwOL)PSX9EYP7W@2wdsT10FDHrBll#Y zHT8?e2Ql|kPo&VVA$!mBtF246*?GF!QFrMi45qh6Znc-Q`tN zbUpbGn)9JUjh7O{y94yhSSc!yYnEGbX-}w|`!R2+k^1ZAL{FA_yj;@M3UM6diWR zW+91sTP<=sh)LK>Lt~~3NeG5hRIrt2#{1z!5|2lXzH<#3{U3^+6N-{Cc(yw=Xh9T+ zy4KZL94qUpS_MqnG_0>-ENxB5^lIV(peGg0+b7C8LIS2Io!@`SA`Pmok2MvDQhn3O z#G;3~ckksQ{pf)oN+K;n8=`H~05DVEx7PiCjOo!-(L;nYw9g)=+b^qF$+nxZ$gzth z-hlNRh(z8YN&wt23E%B8Z~@6+SO|LjV!f+{Grta30^%S?OcS*r9N$bO@4ZICtD=Jb z{Ixe+^&HYgta#t>7cGl>iph^CK z-9~~3&&QcfTPZ-0XQjshxVDw%fFT2y94hD(ylFq?ba-~HUck?-01}V(50`7E%u9`y zo<{C7fk)lhhfEuvmdr5D>O%m?XC&OnXSsKKB)R9DLIwyd+eX)krW%LqzV2aA%F8e> zx0jzzl1p({0&f2&Z^LH(HlhQ5_NF(j9MLK**5h@oWQ)>!byTBhR@0y5ys`_nA64RH z!P>+M;l1_*sE#I;as%n6OnRvsIqVD18p;oz2$rjVJQ?$`K2~czFJQK(;+XAta+)?G z6QC5Yj`o!UXC`b^Al}l)%8+Qix%ObtgfQ7Ci*y?jClL;*(6@#u_kBHqrRJc$=$m0Y z4%|*L-1T&8%#ho9%1{g89)82QbVnQ*tuoD-+G7k@!~aIPc#VJEZ2B6eFPrB~ja?sq zXS>!G@xxN&pSpfp4MLQMaO-_UgiBTfFxu?!1y$9ks%jcyaE6Mr4B zdDa%Z`GyOUPRsmmDV|wp0+)j%(ZKVh-B=3(5G*tJ8f?svP~F7PO;qQ|!O(#{%VF&8 z-}}Cw-*Znrpcci}{}QbB`A-Lf{+Fd`S-qJXwny|3=V?p!8Lp0u*zy>Cl>FA6yh z_lucCv3L4}9+a>KNu~|D3pljpG;4ov#9_CKJr*W3)8V&c+gC3`WbB7Jgxh2sSwCRY z6Gy5^&T10_p%HgsKQZF6yTyrgQnjDoTXL}SnnrxK^HweY*9_Z#hr!0AtQhc2yhwAy zR0~R^%uSA~b$H4k(ZI+s5Lf| zK<|geC!ZSIViIseTpM&LLiJx*==7#*WGW-FrA^7H;Vp038Lzsl+d_$>;mQhxsYBx0 z3uhNw{9&_cZU1M1E(ks@#!$StRkkR!)Akl1ulM zOU0IXx4=!1;jLt&u1hOKQ263qQC4p^B`ES>cdhdulK;qLMa6T$Ll;~uMNhHc)V|vN z_f*!3l7jyf+*S=(&Vy$PpW3t{>k>ZF?G(CfILh%szIS;Vy!x+ot{3^2z>#xBa&d=C zvnDmG8x`ZK8^Z`?S!UV~{%Y>8q(V#!h`cO<#VdJXy-nP+rxjmw!|?Ha%e5zfoy$1> z#`2;>DZ=L&?kr`hX~9{pQeSRPMJ{!c>XcM(MxTr0PsXa|Lh6$dRlyklI>(O~gRkw} z5geT7AwDwgN;U6p51V#RP~*g-(bn`<4P5HCRF3nTV%e0sibd#z>AL=_IGc9l$m*JW z5H2hmjA7$}J zQ^}jwI|X-651(E)S7&A^5tk^lkjY&iRJ|DBVmOHde6S4eSPfIVrSZ;5}RlNSG?-V zYOY;-7TZx&OTNLg9N@d6*`2B<3xfb4u3zCAz8#h+J9dY*u(-JYFZ)cPXfy>;_A;qi zbdq%oJs$)8k2F(i^DDKQ*}TKy10^cy-w5Jg)mFjBq3`~!Z8pn$Mz7xS7;07L(;8G! z1Xt_rvp60$On4V^zCy0EcWZ$6v8bzR4frxYH}He&hTW61wN1|IZBGUEmfTum#w5I% z)NRG%TbK5)aGbmO`%MkNF_KhlBFO2aPFrd6ynn=WK+oiZ>lQgqf^Qvpf5I{rxcypT zKZ$PFT`NE6z~4c!zHax0)T3|6i%E7mxyTDs7)6#!RJve}a44iL<|PUxS=#teG(vDX z3G$i=BlohiB^!d2vFG~pZ-i}`$i@x-nZf@vEz19s7I}I5|BDvohaoBy{)-kFoVBQR zwC=;}15|HB4hQqI^bBp)Z;CWZXN8LsVS6Hq>jc@gp+N?@zDqBx&#;2GHn@)0_d13VyVTi6&~(;dI~TYw;|J7G%!DoW(M+~ZOs znM)K751$^CQ?QF4md`cuT)I#aVEdWVw^Cc%EpF;7Q||wc$|)hJZg$FPuwb|%Mv-$D zlU_o3x~52_1Rh#<914S{njWt^s%;Dy=e%bU{z?}eBmW9ILwOb~*2tM_%tc_H>uwr$ z1!8J-q>;XwuAQjP#CA*&6bxVDJWl4Fu8-BI`H@>Wm#u^zl`BX}&{2v@IFm0&R_P3! zjG&p&{}ei=5O{j!nk(lm4%W%*Xjr;5$1&pfO**53>b}D0NF!h@?n)szerC%F@~+3v zm8b+0i{c?qc7@!2=aMkE zEMDf?`rzg+F^Adu*5##?0#jPHQ9Xs5qIh3S16pa$rhCbs4Zo`UG%avW9_5%SUF^z8 zZ|~1$I&KF5_L0LJ-2YrZ_QMt}M82K#uWEKR|3euQ23KG`d5>G$B5)4 zW;=Ccp=GRzw6VVow|<>u{iT2tK4!}BsYsjIbPMpDeayE!6IF>7WyQ!I_6*_`-Wn5b zVz_v~FT>OSBza44yF5ZyFfV(<2{g(5uFSL*p2oYAKg7$HzQ=#kXZ+=f-4X#siGgM{ z{X3yM3T!8_&n+&6{MXWPZN-x`vH53M#4hi+U0h>8c^8~vH1#{qpO@xlFu9dJ9L>&; zUz|+@?0KJ{j<%fql&Lu_-WWJP_{kcuqdAjM7LA)09G%s|#No^&dN#eNlwIowL7Khz zT)PTjQ050}tv;utS+$cvI!g+Yd@@X}gIOzNhkxlO9#F0 z22&=ozc4H+?ZhweN`|cL^10i%U|6m=kms|v=Q=O@1#g(;bPB&kh^X3Cy3gR8$(7M% zx3~`jQfXdu@|b|6Q+oafyJG`pCUR{UG8?b5%8oTJ1v4AF`m&j=>8Wz*$!ghL?zl#u zUs{yEb{VTRAXOI#9u}k*lwI=NO5jWEgAqV`gq`D2LI=$PCTH84`Y_t#Sb z`>pV>xYsH2hWXdXE=z7 zL20P+kMYOCiVwg!vRqW>Oq;W=8RhgVP?>jDhyOo~t`BKyjnCbL*#AY|d51NXZ~fjG zN7PYd6qOD-2LY%5_I(qvxd5fnSRAqie9E9h2 z(1EXBP3I?cN)z2GpEV%YHp&Ta=iDj^_R~I_%fH!+|KMgk@KzF|4d{%%Vz*b^Wr6Py zk!dwi&PRArg+FXzJnK9h!LyA$=i^T9H)Nj13lGM+wspuR9 zmE(oyUJ9i2`P2CtgwF&@`Ep4}#lO&n%I7N(2tXI=hT6Zf@IiKk6lrV^w204%M2ETE zZf?Uq#Ew+ePUO)^q@DiaCGHpN{mox<4IcN@#UL)i;hTbI<3yVw)oxO`GE9tb6a2&u zVlgQ;kVFj)lB^m>)mUa4-@$D4y?ZcKM(EI2Z&r5wA=n2Iu$`4*qa}S*;6rw=-LL==?a02Oo z0v+iS*QO(xsvmI*c=sMk&61jSx`S|`0`8(JQas*#WX|&v^=!Ge$TJ6P_EryDDYa0} zl=$rj@}4$vzi{7{#VnHMK9AWbsdyrzWRd%j_CP)NrgR9Lr8d{3$0D-?T4#qv9n0F5 zU-{W~`byc@a&E}MPgEiYgv3w^_3kQI;0Zr=Le^-}{nyY=Rgn!3^z^6CmF$V$vXh!_ z)hOTpNbM?2%JS9||1ENN0HY&lA3f*z%g2UG&|@}PfIaM=$w6QC{g0J{!Zaf^-}L4o z%j{g%56rr)s6ofUI!6Pqg9?CeU%Mw%Kc*dWWc@~aZzl8n#)bL9ed>u(gl6;v?xLbc zt$=W<1JA3k+}S}u9cHFCWREjdFDBar`&;_7W5t6kQpcYze2AG!RHyES`1o19%JDQI z1cTF z2;+AR21Ush`!=Iaj1hr32eEAbV9S;h5ym2Y`{=X#i4mpa2Rdcy+&bOE#aWqXsgw1D zVa#N6_4a)g=xhyH_v3T4t=m`ZVJ)+zaR>Y+QR|^;N>i~vx^z20 zt&Fz^JrxSC^4GcZkPAGOVoriBhc9lcl(@o}pQAq|CWxM!{8gj9Tux(B z4w(F0%RJ6k2?4Lw2lAG3{1edj7IaiNq(zc92*E1%=@bop73(aRI(6?$rMDZQZCn}B z^Xyi$z(u#or?e-43Uo-Rh!(N?k10Uc79?IibZSh9XZgWY zmhoS}_VCSGBOB&RCw{o?JpO#cyX~5c{0ZEr(w`C}(onbM0eguKRii&x<6!8!@N74G zZzE(>O6v5Lf%+>;M(#X$FLOpHFFuQTm21dPJ8e=d|5@eu|GyH@J1tf4 z?9MPKc5&i4-z}uHr2mnUtjp2-D3em|PUaz!EVi>6#Fpkin=Jn@&V4>Gpu#CsvYVqX z)}fbWuQA%hl`K>LQG}9y*%fZP^v0u&!)>B_8lruWgxwTIKKIqauM|ACrU_?S+;nf^osfugG+emGO2Z%N;MD2>Uu(cmptYAV6lh@N|E_y z37T241ev|uQd_Xapv(;14@$@@>$i|%%3FW_edsEO`5f?}tNB&n%(G&fS7=y*XTGnu-D5{A>h9DO1*T7t4`|##+QsJ&GCXaD%{DekcBpjHTirlqzvwy-FlSSun<~T#0E%{D$H=) z=8>9cC8NIBM077kq|AGV5dt-e2p=N7m1NVZxwQ|nyZvUW7e`&AN47jxB)M)ikyb;! z(5M^VvgMzffS9NLFPMNV{@YAIpZ-%OAe)cJQn9Mkgk&m-yyXRZda=tpvwu!FsqX<` z0kZ19P8-pL(wLWLBVbroHQ2j5g_UZIx9&rNaHG;WnlCZa7zxK;TIZ%+3R=c{;e`$m zho1A5VeKa3Iu64d{e!_N_3~ERn^cQ>*lZ4pZot6>rYsR#>M}E)}P?R?!_pv_PoWZXEpej z^ohJeOP64z_QqDzSLOM-$t zF#S-SS4}^9qLBE~kEm>I62%DAFfn;zzqh`hW1ol5IMUQ@^gi9A`K z)03d1_Qm04MHA3G)Xnzy%$`@UMZJ8g;Gn1TnD;t%le~s+@7R<)em<5ME-ArP+7})7 zR+BoJ$jT&&-%6IQ1I#SFx)>X&h7AXMcBUy`pC}=4L|ny8BzkVR1W8wsdUfAVcb2!f zjNHp6S)xB_2I2FJkeI8qi{JBkmN8Dm73vm0S7@~A`MAmr5REtp1=)=}v&f|`qz?o* z+ZUw}yi_H*k?zFUhEiGrs1H^}vB`xGm_?S2d5d)*Lo4{!uEWHxxU8Ae<`SQFMeH5|D?spa_#Sq)l?91biBbdGb`qfJ?zB#K#?il8B=?Z;dyy>eZqr z^x36inO>;9JIiEs^C#XDdUK%XI;SiSPrd}OUS7{8s%p~DCKyAI(kEeF zr5p=g^*bj zpJ@EhcWyxs)ihk&r6P~Mh-9|@)G+Z?x3ds}odHCU%#D+7j#A#bWe_UuA=$3gb0y1* zh5}h$ClIzw_Yf|#h4QSry@_K(wz@ZUxk6vX9Cm>isA)b{SG=0#wRvgv$fGE_+}6S; zDRl3EY8fc7>>evaoyJMz-p>gVr8s0oOkn%|Jf(tIQsaC03eZOW53RD1@Vu&w<1U%>nNmlFbtpNew5D_r`5V8u z9x5x-)HS!AOdX6jH9v&!vrt8mQZK_wKbwC>HH9D};wP)>()rGwolB4VBS%Y!fQ%C> z5S1GA>=|{i%aKwTDQ?n1_qP+j0y!CPwy~iPQYLTy;<2Te$RNaywm-2gcGBbq>p>$L zu~_eS@#Gt|y;w9k6B;x2@iQ^K)vk#$66!nAdwHq8pur+LGu(`EKaf=`qcUmqgnj=I zoc1GQgK=C}Y*2IkW9V^&e2j}gE-rJME*v64wCSQ)&JPEZr1UWJ6 zYE_DxA^v2DRz)>=nD{NcC6%XXBCT^sl`A2-Ptz)iiew73PZrNYH!{6ED=y)w%Rj3>LSW;LK;X==c zfjDDM*N@NUP9Mo7oywN*QPqt7dL>ViuHa27qBGHs3l7k?(DYSS7O2Q28~;py;wN1P z=jjQJFje20#zpm1^1LRm)We>BpH|c%HCI&b8Ye2?5&i9^&WnOesm{N;;Q~W6M(ZxAW=qwnT+*ke&BL znhu}`%5Tdrow0{JU9xc|@*9N)kBzON1F4OL!-?qKF}>FqOt%ZL#u}Su;|9GdAw!wH zE{f%85ZU22<;7uy*dz79S>YpYMib1XDSEr&+`i5@FZ01)uY@U#xYl=2ZgmM)7-_Q= z<_51tJE3>N46J6OMnVVPx)lF{L0&C)-ODuP>7g9@H$^~o=o>c{+XtouJ&I{`uWe&n z*ANc`KH3g#$J;oft;jod=SuH_E{n`Rfd0YJv-khd(SsTMZ+G+<{(Fv|@J9FF7V6r9 zwa=x7#&Wj3lBuJ(kj_312-&WHHP8i1?9x^9z5ga__T^CUYZ&RP+p5p_mys)|n!AiC zEl950PwxgCViPu=n3G7r&O;WICabQ&MO{uObd3ea#0_j|F`InVc21*-v;1&6GZQoR ztB2>tJs^#Dcgc2XBDX`Gvp89FV#$$^gb5V;4Cjtw{(aoW8^Wumo+*aiNj2?MKB4L^ z-q(LVaM^G0W9)O|4ZZ{Me`V?k8vh5To}X7uJ(Lww&pC@P-%LH6>Qz(E=)=D-^;{Et zlQZR%*Ed~W^rzVrV(4ZjI%=#E~PpEB8x@5=(v1_Wo)4V zI*zo|9N!|%Dqnu_H@UL^qRFS~2V6ph*3y*mBI&|P9IQt78FewOjv=$V*ohP3)m&0bLFTW`|A`5_`bMoqcO3^`i)p|UK z^+iy2VN2~yNx;V|Q=E3j*j6A1|gsSUETku^JkBRHZ zyi+1kE+P@%jXd5Twm6V9k-Xl|Q>U}ieL2hf(ojY3=3_J^J`cFV%juF-pSM6hiVvmN zy`R&pIax_Xjn49B5w5)KP+}6<32Ji12S}wu`h8=^T-KQM-ZD1vIVwqbvZf@b;}h*1 z&HsT<)gxw*_i~pxNLgIO7A7o{%=O20%ve|?;&EJdpB^qj+_*is zSis)91Nfh1bTPNbaS1URW5zN)6-}g+YK5*Kbt$sPM6I`yuWVlN}n+n$9JKB1wfB3atJe<@JO0$m|RB&2@~9S^VVt zMe$W^ybn4UQm$26z3>z^JHUs=-%32J&Q@DU^fp5#n3gbibc-x4q8@w4U|MV2J7z)K zE5p5>%<^c@iCLV|8HlV>IMqY4lx?uENk^|CZLD3=TUM#=K~DCUf8^qk@HzqK=A+=l z3aL4MU+9p6u*Iofi9Sw}fE8og@}s#8nuUjXHks7AAXj)M_2kfbr!h`aoEhgvRoKkK zk!%5jt0(70KLO9CM!@~yIU=xs8%oQ|zI@rG9fXJwBkA-`vqv6=riRKe609=%)=wHH zqhDuzIf+lx$aKF)j*>`ozdJuU9;S48`#e;o*h~9xd;6E;b-9u^^Foq?0;kF@C?yn2 zwxeg>4x2ld26;zq39mMb1NGG)FEI}KI8-F$L97OL$d&I9$}*Upk21xR{Xw3-e9(P$ zSFMCoJ~>Gn#8clka`(ganTW_y=Hs!#;Q`IP(zLBIGD1Iq%Fko5P4h0A>gJlL%us~o zG_QOXw!4{eV&Y2ZO?2?wS2BDDLCm*CXwt3cz3r?&%0qv=m8cz!srR zW);j%#f@;Qu0MR!Q=UB->~vw4+EN#7ajkQH`X;9AR0Q2O-CN|c&krt3r5@DFd$1!>^4y#>Jyq$Ni;`&msp2a#4e0B9K3k%rXE5B|^$d2YEz@b`tpwuiU6MCV$)2t$^q1>R9`sq|hH+DdrosbvnkN zK?hcET&|rI!GGrhrj(tIX$`E?sIWyYPiZeK(y4`n&B--K7jR4o<(oN!wMZDX9?*aa zlOknD@)6P}<4@}|lqMWCZ!WZ!28#|(A#6MH+`X$!?O+t{?s<=AkKbp!SB8GsTYVuQ z!rc)?=;Oln%;hg{h9Y2vaA((`&c44_c22D*I|CwVzg`1rEs0`fyTwsfMbGj($E0I= zx*KwR_qq+VVQ$t2$b+VA#r$r5Mmhssq2sdkuvv)nIJlX9F4TW8D8dO;PGenFVmq*| zK~JFtVNl2u-$SFNIrV!Z7{6fSMJBne>~mA8Lt`!z7GZI_p{I{7@3?fwno7{Dj`5dt zSc5_v;KY9_5V8&M&sQ31X?nOX!vFqe<)YfD}(YRCk8H?I`)G zzD(u)vX|#-Vj#?T1t38ea)B#Gk!Am&%N8k+J)tKpQH#=T?5u;FBX^8sk$Pq=QfGS; zfGkgEe!+-AQE3C|-nNK~u%;QG2@VkEndUChONkfkn!*de*{Wm(AZ2keyo2%5Dzx@C zZs`L*DO>At$J^&2R%yo#Ao-WGcDc zVyl&4$O~~N-^-J|A2qCI-mSBU~S~iW?748oYVfed=%#Ti8EUV=k(H&cJ6MiCDVjv7CZZ@LwgkejnIUA)dIa zLZZa)bF^~)gdnY8@FM!)vam(P$hmHs(x7)&)y9OpSxt@_-QGs(^rVGsIa#5Z{=EHZ zk!>*A>@d<6z`*zQzmwvyeQ$tuAMLG~oJP<>YUGZV;jimd2Sz)I2WeC)>l)aVw%0q2 z%&F*HU9^y2xS9ZNU&zM3(4LuaYk%W2~SrcjD-sKil(~by!+2 zs-RO4PJgW#fvb#K9_EHIHR-uwB8qr)@CC&1++&=b&Z&M=hAk&MPd_!X@cmppu@SIK zk&z>sexBY`$&n-Wj*J!uRz36|+=?NTt!dul5=Uz8)SQp8Lj*oi16gd+lOG+-^om!xjW!3dJ-o=glUI`t^v~Bbu#{Y z`RKm!99}&t2!F{=CV(JtQd5?9e$3=KHk^~%&KZaz3?S!^P9frbVh>pf0(PHr))nsV zZ=NX=Wk92KV2?D+)ii|!yzx7VOK{yPtkX{$PCC9BX$`qq2PCxbC~sG|12H3w#0XcMNKBkhww$Ux z{fK=TbX^#6N<4!QqO`B7;z7hwz`7(7!J+SI`k1n8jL*HQS{|i!rYCszz9cm9wRWx9 z-UV7qP#=Jk8X#HCjBEHsoDcgl>DaX0`VP0ZwNyjizBjfeC_8qtlcddCfn#Y)lcNHD zuU?;n6|;2nTNam@(_#gvlQL8c+_C|`Tl37;Z7ylQa)z>Q0mVguvw&VZFxj@TWn4H$JJ{n1<9U| zo6o;BVA5&4@L}ofKZGI$G7B|5XcPh~>tu=i8*=MtHD%hP@8jXLypLahRD2e=hwLIY zR1fj;Ee7v*0FrYXipeqH+WowgMdz@&$hmy?4$f=4?)3w0h~C?MEkh<4O2`XZ#7NGN zT|qb9)Yb;`6Mz}17ksvjHw`y%K5_t-{Js|)cXcSraGw#iaf0aE-qIGaV#*P}vTDka z^Jy1hfgulax<3|1 zkWKTSj)^ohx^Gt&#ueK(7K`WBkF;iHdP2j?u!T_Y>31 zQZJ2tud)t~kqT8nn9NHeavISf19EMQRLSSMh%NB?;--D2Fz=lwv%n6d2#Hw-TFi+4 zv_=+8Yg?>D(_K)rp2@}&>g3a9QFE%iOoEcXwNbEdGP(MQj-uUiacB9dqKhN7BT9Z- z8boaT#Y}j{105Mor&FpAfGvqyP<2UtMu(EOmduh9>1=$J9fe`vg&jfEau0Z%(#=AP zFK~;F#(_M{>m*;XtwRx_4BRA*Qt@XF(hp0EMb?7iJ`6|;rAfroYJBN3`tfsSVsOg~ z=4*W*LMJv;bLFdBO7;c&G7n1-^NBXv$i03V=rwiN+}BE_;`L#=MoLu3rM!!PSxUnq z=*m~G6#vc~e{*d>4xqAd2djzO*I`p#yrnk{z`d-7IKw5?BG9s%Z+4*&6{@y~XwU__ znha}D_o>b=swg)a@lH6TWPKP`;?psfYRi_l`USLUnp2;pq&{V~)IDjssLe`2cM=sm{ zn=Qwqf7O;VbMh~2IT-?E`T_L|(>sj07w)9bzSEANtlzDH^m`XTTZ1C$7pYsx#9`kd z`}TwVleh}mvDoKD*^;vxLXl8FD*CsRS&Qe221278b@zmS33K6+uFfpKhZ?fN+a!9* zkXh5*A!LdtJ=edREkfDhv^)r^Le}U;`9F^P&Ru#VIS7>Iz$1nfTW2`(3>CGHr%^%c z%X)v6Z#)^9=Jp=33qDF;s;Xqy*`l8GiC1j%PA~pW8+sd)l1emD+mN#vmq7Xu)EB2b zYJSJNj>|pgyLhDQoaNfz{Pbu|CU^m?Cu3ZwLwwGlOffq@U>}Nn^>Zw!UjN^E} zW{z|l>sl&^ur~=fbeamMtE?P4QAlGcF79%(l_E>I#X-H{%}NSEFW7y^X=*gQ;aWv* z7P~tk&xL18RNK#sf%GX?drf+#y~l=LOo9FCNiHGn(HE@UtXn6_eV_wsPh)Je7r+>Mr(;{o-1JCp3vk$p{XWJ%>@eIU1H=59m!72d7FepK4aowJ zPpViaQO#vf1Q}3|K7mJ@`KZFg)%;V)uRy93S^m_#(=OhQ^N7f&W2w`z6G&fj#I`y4 z+w`f+i)*Vk-*Ni@fBZqrE{1!ILz*aZ@KD8%;NK;EXy%?lu6NEinT5kpK5g$hSIqI) z?a}1-;tnVY^#*|n4@k_gyib5Kh2b9GDjGO(D6WuQbYnKXz3A&awa_IFC$;uowe$gB zwpq{fIEOSKuMbmL$?GHed3^&}Ag}NK%>$fj&31RbnAm`^jFhRed)>~cZxgLpc|#5j z)jw9T`xBkKRo>!pQW5bJNXMntx!{+&;b@!-aSe?1DWM3TVn=<8b2wpG@8Is5&+0V5 z{F-F|rXFeLKj-YF>1=S>&#R*gpmpV>>bW@-2*%nOL7hd8|Q?LfO!)e+K`Q?@oQ^e3AYlOvJD3O7n6YEVc{to@8rbHD+1}9rk^K2f9}( z+yRLgIja?L5_Y?6{JeOX&n<5HferxFR(6WKsUe|@-jN4u8&%jFn^5CV)2hdifd!YO zPOAD@hov}Ja&YTNQCq%+&n74$dj+BN{l#4QbCKH%6uE|oNn7(6sD!p9z?j+o&TN(m}dHDzX_ z1Zsoht3*CY_{wQ>P)uGLTl$sUXk05?ynI|mE)a@->_v8pEstW1|U*vHr1 zlqb&KDc83yv=!gdX0+i1Ztxupj`xNwbJ#JgNF%U%!*?3s=u_aVdr8(l+d(D41b4bA z5N3?(%lU*U!iYM+3mo$FS?3wItfn+~olAuU;T&9taf|zPf?`ll${u^r(e0ylr`5I7 zn@GL~FnV^GPB$Pr?f|ePXY+i0Y{GF|(7^UNm zcSzh(ZP*+Cf}9z^bc?rwWi$djhAr-ZDvH59W$7^NJ$8p>RQ#|*<{BgN|AOHF^C=$3(ZEUW5Xf&PdNwoV?g)Knd= z5%^W1Q-=7ugYvX|=TahJ1AuA(V9IF#s^No;fy`dm73*sN5db6V=jkN!NIi}GQ>DfC zCqT63DjUFQotG7$2CaXwC}@n1UwasNF}@bSmK!aP$w@=W`vicPn-<{L7>S*H+Mc(hT^U{WJfn=KXC42*eu7k1WkT`LeF2 zk8P=;@a2XP_vCLr{-!I98S&1Th0~Jlkg#(A!lycjk64WxvxFJl`tY))mj2gj$g zQEl?oCTiMc;I>*hcY*9OKkO6OdXTgcw59EIK#s#EADHomId=^7x^Cn5pYf@;>kQST z?*Q%|{U<{T%q{rH^2q1b-jm}e=OdO9ZRNpwSQq#(TLMq-wHo^6+tpd8`sW%b`VVWM zAoxFB15xU&nr7Ebj3qC2YCky-JISwxlMY$*Twst~GlZ7o+;9L6!u^#?Z& z*Z5{>9}OJw+mXwTS+v8Hcg1b>f_av1F$_>0ku9lWT@pt4gcK0Yf?u#-k36LS9-;hc zE*+JLVv1|@crEtwe}U3gR2 z=@)2f2zsfeUo?7FD&?x{XAW<4h2(^T-5!u=GH0=~Rq+0>4{qF@uLn z`6r_F1g2@>K;cZtQL!_xU;BW))dKuqNqUVI0WhspwV;=Rq=ZQ#^t}W!R=%0x9ALBN zH;OhQ?FN^3wSng@qIauuclro(vgdMF%Je0ZkcCUQdt_4wb=uJp;r?db|->M9t)7+GMd7#b(>ntpa#wNssQ{G-sfOYl3)jnBWiu*$#b zdt@U1i;=2hmQnky&uW4~eEyAK;wt|}mbA*hQ33cjCxgE6Z}zFhGsHK4iNU`D{h*^q zXy#?diScf&504xXf|-u6&O)CoH?{qKNp<9UtL@FknJ`);=@{m?LUxp_qxx?X4)dF0 zbmj1`$eV(y#`7{B9(!{ryZ}d_tmmgjmY?$^Xl3nxEzek ztm0>YjnXu}_Pbp-{0oZjzuFK6xRAJfKJ6mugtYNkv1UhT((uyF7jrUJC~@Amg3HxB zhI-o{I=6?u_8N}}tsrNm*O0TArz^$!m|+yA}wEUf=?>lw7tdM+>aPR%8kTh^4DBu0uSKJ=}dvvL$z`}M|E-^dv| z(KmqOni{SDCcHekkfHUJsJeQR<`@pTh@WMcx?J7<4VIC_{07T_LD9g+Vgt${lvZP2 zu@9CL9-5#ku27O|Rhpk~-Nv%}JOGPor5E5yu`kA??VQhkNe-5rA<{m&4s@L4PK(s} znSmaAx~x5eZo|8P<$WEdE{G&sqXh2IU0=Vy>^GIiMzkDqo^t<0f8=*~<$7b%?P$j* z>I5fm5dX`gFq6+47A;t)aR3;>6}0o`2?V0G>%I92-W$xN>$)t{127LCCC@{2=?~b5 zE$wX<8Qzik4Wr?!~ZJ7snarrJV!OIezOI!PkVysKlRBF^4%nOk5%d_{|9T+zt`QiVp^)&t=%b+pF zW?o=8!y=W9G|SWSH(dXkkI4fc=VjBU>bVA~;<2{hdDSzumSAo6p^~DM8ZRS<-KW3L zTC9OtOG?AqtmR;Q$RWp+B7)Q8*LN;Yv*g;8;3I}2nOndI@&osS_m%*`=#%_lH10X2 zlx>QRdE)z0y@h`UZ>n1Bvwtu7`@-3l0c1jc(BwE2#WJK+ zs@s46hVA<1E>nEB>MnaHXJ4-F_bKQ83>2;r zDI&N`Psf{VA2(dzbi|BA={jj;Sg3 z!X+s)J2Yk0d=OA4O32L;d|b+bm}w_W#ZC#QADP5tnXwbpPv;jR8b||3UQx^au7pIF z%u9y$5NDqO?!9H6ZvW@GVK7w9Jj$p=Wa}sZIvzk8-q3}=g zLz#u7(3#^pvH{+lYOa9-^J59PGe9gM zdM%b<3Xw8znMJ_p{rVGT_4OvhhM6VII!|4>R)q!;$Ni3i7ac{TZ4(V94^`{=wv);? zr++-AdAoNAm$22!u1tG(&t)BjA;Wme{&``p=J(He{))bNYtntY@HgH;pI-U#-On8Q z2fy(Sa2NQz1B>Sz_sNK9lgMnfia^86m9vXW(9KZ1;8bAB^?)mOa|cPZOz3Q9M$mbO zxF_O`j_DtVTO!Tj{H9boAXuL)*yfbXXE9Fi>0NuIyd$hQxc9F}o`$jyFaAfHQagB1 zMOiyWWB#KlLHl@mjAxc(-r+%?Gn^WiOR&briUW^u7+>@&ofAysnpB|hS_9_pp}oJj zTp0N|?#S9F=C&mu4YHcTF}v0sOI8=Bt1W*`h}$_%N;h6x`{=u`Dgo=-fJPGvB(&>t zs%6Fn4ZMS%u6Tjh9xul>0>dWBWOmHb>$k7`^(QxI9f~xhJm5Qgm(O~Anf6rE`u4kj zH~-;3C5RHj_=0E%Pxb(y1lTV>yp|AHHqm`l-?W1DM$bLYYLbMuC9l24b?ac?_0WRC z&q-{p9mY);xL@^9=DeA*g%9SI$>E8)WMzih{W5~snPu^c8?FNqYuDq(Dqrqv>6OeB zO-KW_WgV<)3wlV!c6Ch=X<(h4D(#l7YyTo4B0xjiLFPRv?eo+-W}~@%X3a^#ga<(P zD*%aM{K!7PP@$#u+HIK@+t{fVRZ76e0ph}y?}-0MrEiG;zn4R`+i}={;5&rL0>=%Eeh8RS3z`1Bzl}(SOUK+t$|V3I#jnmD8+Su=ullY zwNi47pw1zpfu|7klF0Up6qg?+k*$P;b25&%u+W3q!q*;BMfjn>UmA-$>>826oh z3D&Pi<~Zf@o*Nyllz8d8fl_Nt?gOMr4e5-8*M`db!Jqq)u9!@;`M+#Vr>@?-i_;3x zcGU6Enuk3hKas5KOcY=!#Tzl|OM_p#Z=;>5>wK-O{wRkUbBhWAbU$|XLPr}0m89?y z3$tH2>C@EKb*8ZHT@<#`dFZrLtUpy2aA7$*|6v@|rMMF9O5)<=`!tiycDnE}J@Y>U zn4U-B0H#NM&2mC;hZLj%7{lzE&|Yd2&VAk-(|oNnp<)z?lCB&c(32^=CXOMS(%-ow zS^+~)seCD#oYnLXRZ+e6^Ck|~Mv2G_?H(gX0S9h@o(m#Ei{3E>{}mcA3l4 z?f7$7+B564Xy^B;{HIybsM^Y8uRfXD(kfv0MorBdgAzRg@YuHWALmq+ykRp9Ku zRxf}>LjQ_9`nN}SKm%a~QA2ZPg~GHF7#_$;A*PycQtr&d2dk@VX2{!VLQZdlAUx1q z;z?E3_p@M^1ply(6XqY*J3DVT0}+Q*`!tl(>R}zZlqs5Brk8scJGu&YuStlb)zEXoIiM#RQ4+Nc^UC2{ZCTeQ{uia8KB5(0Lg zC=+p7EuYM0C&#V7aV_gA*GsVAX`9OoPu55*)Ubh>Q!C*CuydnF6siDbPkN5)*Mq8y z9uPp~MsUo%(l`2?{w!z&fnI&GkVOi!sR!rMe0P>?h>V1RfMqT7G?uZ^m~;T++UPT^ zgohEWwMAH;n{_ zOCx<=y19Mcy6Q$Jm+osgybYka4x1bTj`3A-q%uPJV^WP@^diBlBu=>-vklKzNlW*8 zFf+poxA~QM0cBH{9|B){*#4f%S{a;gWBe53a2pt$XMOmQ-;fqfG;s$*4?q+FmYGoO zS$_Ogm8>QHsV@)J``?MB*F4h;&9Q#MDC=cF@`a>&f?)xuVQMSDxcuDXSJHv$nohi}+ zmPjH@LOK-+shqX`*vVy4o6& z-L+rh*4-GrkhRnP=D#`ZbDi2gcHmSsjfeP8u95V|W*vGr3M}0bJ*9+KlwXO`%gckm znuw>q6nJy}Hx85El6Fj_W{0nIq?*@&G5i8-f8M0%+E)&cBKc$Z)t+>v(hA&Pt+ZxG z3@a=7m9{c%wbE7-R$_%2^>yVfjlF*`q5WN{6%P9AQY#RJb6Co1`0yO`<(@VLJkV0( zf8FlM$r@}*Xc+1m?Rcx$O#(2OB3*cHrh8XdVeLg`*R96e2tTe%Ytie&5uu$^TEEox zjGs4Tte^k++u!?r+iE^-U2F|d_a0L1=621kFW=j~%>Mo3O^G+({@S|1uhXeKncdf> z$Ig`ho@$fJE^0)(EWU$o9;;x41g`z++lv3DEC8%%)1ZKib(g3;MxZYVDg* z1tu!lP9BW1+iaOE2zGf+MyOcpkL2gLiX|D#=1D}%v^h3UBvl65)b*~9-QDtUf`33a zzZ9Kd1)pR1SWOoHxbE$p#;MYuThvNV<&BFQ8uv!HNj?I=&gHl996hIo2C43Y&LswE z8X19$RmcKtT!w7-zIau;6>S#RQ8;RR4~N~7iI2(KZe=vX^l4BoduW~(WFT`9^wO%} zqt1UQ4pzof=Ji|jqV>W?o%*>EZoq?Zmzs9YhcdM-@I4{lgI^-lmEc$6on(pjzr;6t zeaAQd@1?#d9&qGk+dk@K%#5~3oUq2NruNQpKhg5$$8M2bd!LT_1u$ zm-K2;kbyQ_sWSkz<|ny~i?3hO0c+>|6AH6Zc?@FxQ`gdREv+Pp*PE7v_ID4iWVxgo zRfFcc0bAS-r5JN5AEhrpJv`gG^c$oN-9tJ8@yP%puu{|11ONEf&%*q(2mV{iv@YhT@;#Gw7Z~kpr>=TJZ>HL#x5ID?^l1+^5|_K}$muo7TN`2x*!MZpiQYX$$59 zKSZmpPDE6mcaV=xF?0C+Y+0r!v&Ju6nz?9F z2|zacVfA`_uq{KUT{o)5qoAsxs{Ca;^Uv^khsxVf2k;_O5bkSq_(Bm`t3Z=}-7@Dj zmK#GBl!fka3YGN%__-Z9_ywnjQg#`MK-9siDY^f=6qhUv7!mc;=7L$_Eu{wY4@Mq^ zg_8mi452ArE;%C4(}y#mp+i4V{+4ANL-EXV@wub~=MtU=+-foHSaRqr4slzmXQ=4U z*SqX1Gq+Bg6tA4Wya?P7Cl2ZGxrUo*hlC z+-P@6+iX`*O6()SlD3`aFgb0%m{X@L3Kry(PC1Om2%g;nXPilGZ7~-Pz;sNtxGzJ> zPS?$(>WxCN6jcw)cAm?+N|DH8Ym>Iv^_ zG(xBKNI+kkC&K+r0lu?CVmcU?*Devft`MuMh5E#DHgy@$DRhA%Kd} z=Oa67#Zj8V>aroAsfr$+;cMwTnN@=((xJ=pP#>v-pwsv_W&Ks!(s8q!3XR?ziR#)B z0@?>>1t+#QA|cWyw?o+J=-m~$7N(@fZC7-=adz`u)8@I~mL!(WtZdbh2>`b0Xjdn{ zEcs4PMf<#=O#1V)cY5CI73r1|&?~E3b>4aysz@NSTBkp{T_(-)s<97RN#PpuNK1;a zhoAq*a~I205&klz_ES33z+RoGB0Zf%Kb!&wTzB91ah$M|E;XYg7Gi7)W8jcXn9v~q zrWX5I#?@N^`>dPFKHVY@%_q$Q3@cVAcgW=BEW0)0i3U80dc6~@-1f$aHQ41_UckTX zL*8}3(lU{CpZtK4H~oeaHL^leRNQtm&ttgjjFV}9AjB{!PD9=U(r^M4XaJlA@4Ybb zF`3Raq-f#YUo{vukbkGcjb(7e;$=NR(^u+?^zT!(bRRf<+3I((xgXmTwHtHZrrq(5 zG+r|7I}?IIX>5+R;9*@w^A`s-GYA4^>7-oNK=Q0H%uFANyd?cglxWs$ciVKV%3lC+ zYz;$8X5S%q>)ORL#j`RK&m=f~q|cq27gsA$p$7weX?f+WX!NRigs(P#43>2)d%O_q+n$YAZF1-3migP z*hwochW4Y5BC=@tK9V-RJD~~Y2sa)#TZ1M8^N?D6D@>GB{ff9PBiFu-oWHWYJRWg1e6r z)jz(Y7wXrdV3$(!3IezvxAIBSjp?mFW-VA)*^=erfUC^}UvzSN=D+S;7NXP8SZ}T~ z=B5&OX=f1_VqS`pjYl?okQeAMCLCGjFRY@cgS8Djl|aG`<{ml$q=I9H^n46sJN&AD znWdF|mEXvTXQ09EFT3@;ec6FGm^Y!R6&KSG1I!OC4qf|15gz46lgEkg4PZbkP@o|Z zf-lq2Ntva#CCB0xHDVz6sM)f(0~$6Xg7Mq&!Y8k9)XYA~t1Kme<_&LsFrD$D&Unqc zg(r&YfQo{q+1cSQKw$$@eJ>u6usqD<%l4;5Hs@zwGV3!o^K0TH(_?v*^C(nt*xeu$;0l^ZXFA+r-?=pZjxyWx!J4O z@t8CbYs>8995K2y8+K;y{8-SiBEDN^KZQy^e4&kXiFR+}h&aH3h%A|zM@Dv>l zVk2FO6sys!iQHNXJ2;kVD*0tFE(3h7$S7;~8Q}UGN%Fza?_uU=faO%{*Np~r0(e@z z>yC}BBYGlmcP3>DJk&gE-p|DOKyFoh5$ClTJ1P}7XiJ}DTtnZvW!a4}5>up9owpiJ zN#&0;{ah>`&q$VT7Nk7Yz143y-dttM$lOpWvB~YTtE-=_($LSn0#8=C=h44a#PLO}f`(K0^Wl|5I_sT^Of z%Q&t7sJkKP8kh0x`D3^|sMxg@E;G{xz@?3br@fAtZ2R$7p;9Yb_o}%$Cv-nIPN+jN zhkU9*xgt{NxI8=9uivdl#EjHjez9ar8n)buIc#%3;rFMUkNjKr7+_D?oF4hjgZW}X zB^+jbI(wwpbJ%9d*HCd2w16K+uJ(YzslfSbQu7!97;J1)*U(Vy7~C+A`IuF%mtcSof!WQQfzK z0lExL_78L!D#oU4ueCq5;8ZzXm!o2H(QuDpa~R)F9r(Iu#;^T^WDmT4ADAW6Js9WSIMpdBx1 z3RIX_BC<*umokK@T-3qNPT4-}OW*bnJnuQrd(L~_-}C$~I+P|l0Wsh8o1-bX{L-0> z!$NavOJ;sD>v`!2!fa37`fz?5Tf=zBIU*@IcqtIe!5YGA_Rd=rVMvA*#$G_vQR@5`H5ExCqB{D1d^6S`ScjOYI zfza)2$z(~bX0WzuxN<`@;n#^2iHD9KmY<^>Ta#P5H;h%L_AuY{%WFR*yy&)vm|_lr zA8ILHW6(7^u+`?Bpo`KV2AKj<`5=cHOLb&pk()Ew=WA{Htm&hLqFG#EM> zC|FnaxA-VEYf1qGbUfWH*4Z1Z?xj;lW>B(anH&GuH2HiC?0+NbT2B?ULEpGhY^-VB6i0Wft$_HdtI6{3Rheh-58v@6EMpw{5|+|n)K4#k z@$Psx((fssjRXLL5=qW{E;-$}3)}K`9t43Dr>M4(H4TRR#qCc!KFCI)?T03QNK8q> zxW=)SKB9<`_HVUG12a`E*B=jWciY=Z<&`@rv@I}T7e#3jLy~X+D*fH`2q0gf4;2mT z6KN+JuEG91#Y)>-0t?~Y-W5S87S-ULKb83<@e4-P7;vGqk`PX)k&^TH%?J05N|C!AL@S^Ht8O16-r&A(M7H9*OutzmVZ;g%Gzt0V_NDIPxL;_~5)0aK7s0?zcHQfbrz1&#!dk9%LOLw8bzC)F12S&4Od z0C!0xl77y!fG^q8tV^7;Z`bQDnG6 zdrHbcAW?@ejh?Psw0Y^UEn%a+@UF9j7hsN3yC~6D(igt=x?hZ|yC_l(X;>s10sv0i|CUoS+*Q#5%niYWjd^AGTyTc`wf Vhj(_{xmp|lePGzWTYHZE{1?T2|A_zq literal 0 HcmV?d00001 diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/index.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/index.rst index 3ebb36b25..bf439861a 100644 --- a/docs/source/recipes/Non-streaming-ASR/librispeech/index.rst +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/index.rst @@ -9,3 +9,4 @@ LibriSpeech pruned_transducer_stateless zipformer_mmi zipformer_ctc_blankskip + distillation diff --git a/egs/librispeech/ASR/distillation_with_hubert.sh b/egs/librispeech/ASR/distillation_with_hubert.sh index a38cf590c..6aaa0333b 100755 --- a/egs/librispeech/ASR/distillation_with_hubert.sh +++ b/egs/librispeech/ASR/distillation_with_hubert.sh @@ -150,7 +150,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then num_codebooks=8 mkdir -p $exp_dir/vq - codebook_dir=$exp_dir/vq/${teacher_model_id}_layer${embedding_layer}_cb${num_codebooks} + codebook_dir=$exp_dir/vq/${teacher_model_id} mkdir -p codebook_dir codebook_download_dir=$exp_dir/download_codebook if [ -d $codebook_download_dir ]; then @@ -180,9 +180,9 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then ./pruned_transducer_stateless6/extract_codebook_index.py \ --full-libri $full_libri \ --exp-dir $exp_dir \ - --embedding-layer 36 \ + --embedding-layer $embedding_layer \ --num-utts 1000 \ - --num-codebooks 8 \ + --num-codebooks $num_codebooks \ --max-duration 100 \ --teacher-model-id $teacher_model_id \ --use-extracted-codebook $use_extracted_codebook From 958dbb3a1d02ecced9ff62624625892afb2206c3 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 11 Jan 2023 20:29:36 +0800 Subject: [PATCH 107/120] add doc for int8 quantization with sherpa-ncnn (#832) * add doc for int8 quantization with sherpa-ncnn * typo fixes --- ...te-int-8-scale-table-for-conv-emformer.txt | 104 ++++++ docs/source/model-export/export-ncnn.rst | 307 +++++++++++++++++- 2 files changed, 397 insertions(+), 14 deletions(-) create mode 100644 docs/source/model-export/code/generate-int-8-scale-table-for-conv-emformer.txt diff --git a/docs/source/model-export/code/generate-int-8-scale-table-for-conv-emformer.txt b/docs/source/model-export/code/generate-int-8-scale-table-for-conv-emformer.txt new file mode 100644 index 000000000..347e7e51a --- /dev/null +++ b/docs/source/model-export/code/generate-int-8-scale-table-for-conv-emformer.txt @@ -0,0 +1,104 @@ +Don't Use GPU. has_gpu: 0, config.use_vulkan_compute: 1 +num encoder conv layers: 88 +num joiner conv layers: 3 +num files: 3 +Processing ../test_wavs/1089-134686-0001.wav +Processing ../test_wavs/1221-135766-0001.wav +Processing ../test_wavs/1221-135766-0002.wav +Processing ../test_wavs/1089-134686-0001.wav +Processing ../test_wavs/1221-135766-0001.wav +Processing ../test_wavs/1221-135766-0002.wav +----------encoder---------- +conv_87 : max = 15.942385 threshold = 15.938493 scale = 7.968131 +conv_88 : max = 35.442448 threshold = 15.549335 scale = 8.167552 +conv_89 : max = 23.228289 threshold = 8.001738 scale = 15.871552 +linear_90 : max = 3.976146 threshold = 1.101789 scale = 115.267128 +linear_91 : max = 6.962030 threshold = 5.162033 scale = 24.602713 +linear_92 : max = 12.323041 threshold = 3.853959 scale = 32.953129 +linear_94 : max = 6.905416 threshold = 4.648006 scale = 27.323545 +linear_93 : max = 6.905416 threshold = 5.474093 scale = 23.200188 +linear_95 : max = 1.888012 threshold = 1.403563 scale = 90.483986 +linear_96 : max = 6.856741 threshold = 5.398679 scale = 23.524273 +linear_97 : max = 9.635942 threshold = 2.613655 scale = 48.590950 +linear_98 : max = 6.460340 threshold = 5.670146 scale = 22.398010 +linear_99 : max = 9.532276 threshold = 2.585537 scale = 49.119396 +linear_101 : max = 6.585871 threshold = 5.719224 scale = 22.205809 +linear_100 : max = 6.585871 threshold = 5.751382 scale = 22.081648 +linear_102 : max = 1.593344 threshold = 1.450581 scale = 87.551147 +linear_103 : max = 6.592681 threshold = 5.705824 scale = 22.257959 +linear_104 : max = 8.752957 threshold = 1.980955 scale = 64.110489 +linear_105 : max = 6.696240 threshold = 5.877193 scale = 21.608953 +linear_106 : max = 9.059659 threshold = 2.643138 scale = 48.048950 +linear_108 : max = 6.975461 threshold = 4.589567 scale = 27.671457 +linear_107 : max = 6.975461 threshold = 6.190381 scale = 20.515701 +linear_109 : max = 3.710759 threshold = 2.305635 scale = 55.082436 +linear_110 : max = 7.531228 threshold = 5.731162 scale = 22.159557 +linear_111 : max = 10.528083 threshold = 2.259322 scale = 56.211544 +linear_112 : max = 8.148807 threshold = 5.500842 scale = 23.087374 +linear_113 : max = 8.592566 threshold = 1.948851 scale = 65.166611 +linear_115 : max = 8.437109 threshold = 5.608947 scale = 22.642395 +linear_114 : max = 8.437109 threshold = 6.193942 scale = 20.503904 +linear_116 : max = 3.966980 threshold = 3.200896 scale = 39.676392 +linear_117 : max = 9.451303 threshold = 6.061664 scale = 20.951344 +linear_118 : max = 12.077262 threshold = 3.965800 scale = 32.023804 +linear_119 : max = 9.671615 threshold = 4.847613 scale = 26.198460 +linear_120 : max = 8.625638 threshold = 3.131427 scale = 40.556595 +linear_122 : max = 10.274080 threshold = 4.888716 scale = 25.978189 +linear_121 : max = 10.274080 threshold = 5.420480 scale = 23.429659 +linear_123 : max = 4.826197 threshold = 3.599617 scale = 35.281532 +linear_124 : max = 11.396383 threshold = 7.325849 scale = 17.335875 +linear_125 : max = 9.337198 threshold = 3.941410 scale = 32.221970 +linear_126 : max = 9.699965 threshold = 4.842878 scale = 26.224073 +linear_127 : max = 8.775370 threshold = 3.884215 scale = 32.696438 +linear_129 : max = 9.872276 threshold = 4.837319 scale = 26.254213 +linear_128 : max = 9.872276 threshold = 7.180057 scale = 17.687883 +linear_130 : max = 4.150427 threshold = 3.454298 scale = 36.765789 +linear_131 : max = 11.112692 threshold = 7.924847 scale = 16.025545 +linear_132 : max = 11.852893 threshold = 3.116593 scale = 40.749626 +linear_133 : max = 11.517084 threshold = 5.024665 scale = 25.275314 +linear_134 : max = 10.683807 threshold = 3.878618 scale = 32.743618 +linear_136 : max = 12.421055 threshold = 6.322729 scale = 20.086264 +linear_135 : max = 12.421055 threshold = 5.309880 scale = 23.917679 +linear_137 : max = 4.827781 threshold = 3.744595 scale = 33.915554 +linear_138 : max = 14.422395 threshold = 7.742882 scale = 16.402161 +linear_139 : max = 8.527538 threshold = 3.866123 scale = 32.849449 +linear_140 : max = 12.128619 threshold = 4.657793 scale = 27.266134 +linear_141 : max = 9.839593 threshold = 3.845993 scale = 33.021378 +linear_143 : max = 12.442304 threshold = 7.099039 scale = 17.889746 +linear_142 : max = 12.442304 threshold = 5.325038 scale = 23.849592 +linear_144 : max = 5.929444 threshold = 5.618206 scale = 22.605080 +linear_145 : max = 13.382126 threshold = 9.321095 scale = 13.625010 +linear_146 : max = 9.894987 threshold = 3.867645 scale = 32.836517 +linear_147 : max = 10.915313 threshold = 4.906028 scale = 25.886522 +linear_148 : max = 9.614287 threshold = 3.908151 scale = 32.496181 +linear_150 : max = 11.724932 threshold = 4.485588 scale = 28.312899 +linear_149 : max = 11.724932 threshold = 5.161146 scale = 24.606939 +linear_151 : max = 7.164453 threshold = 5.847355 scale = 21.719223 +linear_152 : max = 13.086471 threshold = 5.984121 scale = 21.222834 +linear_153 : max = 11.099524 threshold = 3.991601 scale = 31.816805 +linear_154 : max = 10.054585 threshold = 4.489706 scale = 28.286930 +linear_155 : max = 12.389185 threshold = 3.100321 scale = 40.963501 +linear_157 : max = 9.982999 threshold = 5.154796 scale = 24.637253 +linear_156 : max = 9.982999 threshold = 8.537706 scale = 14.875190 +linear_158 : max = 8.420287 threshold = 6.502287 scale = 19.531588 +linear_159 : max = 25.014746 threshold = 9.423280 scale = 13.477261 +linear_160 : max = 45.633553 threshold = 5.715335 scale = 22.220921 +linear_161 : max = 20.371849 threshold = 5.117830 scale = 24.815203 +linear_162 : max = 12.492933 threshold = 3.126283 scale = 40.623318 +linear_164 : max = 20.697504 threshold = 4.825712 scale = 26.317358 +linear_163 : max = 20.697504 threshold = 5.078367 scale = 25.008038 +linear_165 : max = 9.023975 threshold = 6.836278 scale = 18.577358 +linear_166 : max = 34.860619 threshold = 7.259792 scale = 17.493614 +linear_167 : max = 30.380934 threshold = 5.496160 scale = 23.107042 +linear_168 : max = 20.691216 threshold = 4.733317 scale = 26.831076 +linear_169 : max = 9.723948 threshold = 3.952728 scale = 32.129707 +linear_171 : max = 21.034811 threshold = 5.366547 scale = 23.665123 +linear_170 : max = 21.034811 threshold = 5.356277 scale = 23.710501 +linear_172 : max = 10.556884 threshold = 5.729481 scale = 22.166058 +linear_173 : max = 20.033039 threshold = 10.207264 scale = 12.442120 +linear_174 : max = 11.597379 threshold = 2.658676 scale = 47.768131 +----------joiner---------- +linear_2 : max = 19.293503 threshold = 14.305265 scale = 8.877850 +linear_1 : max = 10.812222 threshold = 8.766452 scale = 14.487047 +linear_3 : max = 0.999999 threshold = 0.999755 scale = 127.031174 +ncnn int8 calibration table create success, best wish for your int8 inference has a low accuracy loss...\(^0^)/...233... diff --git a/docs/source/model-export/export-ncnn.rst b/docs/source/model-export/export-ncnn.rst index 11471d611..ed0264089 100644 --- a/docs/source/model-export/export-ncnn.rst +++ b/docs/source/model-export/export-ncnn.rst @@ -204,7 +204,7 @@ Next, we use the following code to export our model: .. literalinclude:: ./code/export-conv-emformer-transducer-for-ncnn-output.txt - The log shows the model has ``75490012`` number of parameters, i.e., ``~75 M``. + The log shows the model has ``75490012`` parameters, i.e., ``~75 M``. .. code-block:: @@ -213,7 +213,7 @@ Next, we use the following code to export our model: -rw-r--r-- 1 kuangfangjun root 289M Jan 11 12:05 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/pretrained-epoch-30-avg-10-averaged.pt You can see that the file size of the pre-trained model is ``289 MB``, which - is roughly ``4 x 75 M``. + is roughly ``75490012*4/1024/1024 = 287.97 MB``. After running ``conv_emformer_transducer_stateless2/export-for-ncnn.py``, we will get the following files: @@ -286,8 +286,8 @@ We compare the file sizes of the models below before and after converting via `` | joiner_jit_trace-pnnx.ncnn.bin | 1.5 MB | +----------------------------------+------------+ -You can see that the file size of the models after converting is about one half -of the models before converting: +You can see that the file sizes of the models after conversion are about one half +of the models before conversion: - encoder: 283 MB vs 142 MB - decoder: 1010 KB vs 503 KB @@ -338,6 +338,8 @@ The output is given below: Congratulations! You have successfully exported a model from PyTorch to `ncnn`_! +.. _conv-emformer-modify-the-exported-encoder-for-sherpa-ncnn: + 5. Modify the exported encoder for sherpa-ncnn ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -356,14 +358,15 @@ Let us have a look at the first few lines of ``encoder_jit_trace-pnnx.ncnn.param 1. ``7767517``, it is a magic number and should not be changed. 2. ``1060 1342``, the first number ``1060`` specifies the number of layers - in this file, while ``1342`` specifies the number intermediate outputs of - this file + in this file, while ``1342`` specifies the number of intermediate outputs + of this file 3. ``Input in0 0 1 in0``, ``Input`` is the layer type of this layer; ``in0`` is the layer name of this layer; ``0`` means this layer has no input; - ``1`` means this layer has one output. ``in0`` is the output name of + ``1`` means this layer has one output; ``in0`` is the output name of this layer. -We need to add 1 extra line and the result looks like below: +We need to add 1 extra line and also increment the number of layers. +The result looks like below: .. code-block:: bash @@ -376,13 +379,13 @@ We need to add 1 extra line and the result looks like below: 1. ``7767517``, it is still the same 2. ``1061 1342``, we have added an extra layer, so we need to update ``1060`` to ``1061``. - We don't need to change ``1342`` since the newly added layer has no inputs and outputs. + We don't need to change ``1342`` since the newly added layer has no inputs or outputs. 3. ``SherpaMetaData sherpa_meta_data1 0 0 0=1 1=12 2=32 3=31 4=8 5=32 6=8 7=512`` This line is newly added. Its explanation is given below: - ``SherpaMetaData`` is the type of this layer. Must be ``SherpaMetaData``. - ``sherpa_meta_data1`` is the name of this layer. Must be ``sherpa_meta_data1``. - - ``0 0`` means this layer has no inputs and output. Must be ``0 0`` + - ``0 0`` means this layer has no inputs or output. Must be ``0 0`` - ``0=1``, 0 is the key and 1 is the value. MUST be ``0=1`` - ``1=12``, 1 is the key and 12 is the value of the parameter ``--num-encoder-layers`` that you provided when running @@ -483,10 +486,286 @@ disable ``fp16`` when using ``pnnx``: .. note:: - We add ``fp16=0`` when exporting the encoder and joiner. ``ncnn`` does not + We add ``fp16=0`` when exporting the encoder and joiner. `ncnn`_ does not support quantizing the decoder model yet. We will update this documentation - once ``ncnn`` supports it. (Maybe in this year, 2023). + once `ncnn`_ supports it. (Maybe in this year, 2023). -TODO(fangjun): Finish it. +It will generate the following files -Have fun with `sherpa-ncnn`_! +.. code-block:: bash + + ls -lh icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/*_jit_trace-pnnx.ncnn.{param,bin} + + -rw-r--r-- 1 kuangfangjun root 503K Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 437 Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 283M Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 79K Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 3.0M Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 488 Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.param + +Let us compare again the file sizes: + ++----------------------------------------+------------+ +| File name | File size | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.pt | 283 MB | ++----------------------------------------+------------+ +| decoder_jit_trace-pnnx.pt | 1010 KB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.pt | 3.0 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin (fp16) | 142 MB | ++----------------------------------------+------------+ +| decoder_jit_trace-pnnx.ncnn.bin (fp16) | 503 KB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin (fp16) | 1.5 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin (fp32) | 283 MB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin (fp32) | 3.0 MB | ++----------------------------------------+------------+ + +You can see that the file sizes are doubled when we disable ``fp16``. + +.. note:: + + You can again use ``streaming-ncnn-decode.py`` to test the exported models. + +Next, follow :ref:`conv-emformer-modify-the-exported-encoder-for-sherpa-ncnn` +to modify ``encoder_jit_trace-pnnx.ncnn.param``. + +Change + +.. code-block:: bash + + 7767517 + 1060 1342 + Input in0 0 1 in0 + +to + +.. code-block:: bash + + 7767517 + 1061 1342 + SherpaMetaData sherpa_meta_data1 0 0 0=1 1=12 2=32 3=31 4=8 5=32 6=8 7=512 + Input in0 0 1 in0 + +.. caution:: + + Please follow :ref:`conv-emformer-modify-the-exported-encoder-for-sherpa-ncnn` + to change the values for ``SherpaMetaData`` if your model uses a different setting. + + +Next, let us compile `sherpa-ncnn`_ since we will quantize our models within +`sherpa-ncnn`_. + +.. code-block:: bash + + # We will download sherpa-ncnn to $HOME/open-source/ + # You can change it to anywhere you like. + cd $HOME + mkdir -p open-source + + cd open-source + git clone https://github.com/k2-fsa/sherpa-ncnn + cd sherpa-ncnn + mkdir build + cd build + cmake .. + make -j 4 + + ./bin/generate-int8-scale-table + + export PATH=$HOME/open-source/sherpa-ncnn/build/bin:$PATH + +The output of the above commands are: + +.. code-block:: bash + + (py38) kuangfangjun:build$ generate-int8-scale-table + Please provide 10 arg. Currently given: 1 + Usage: + generate-int8-scale-table encoder.param encoder.bin decoder.param decoder.bin joiner.param joiner.bin encoder-scale-table.txt joiner-scale-table.txt wave_filenames.txt + + Each line in wave_filenames.txt is a path to some 16k Hz mono wave file. + +We need to create a file ``wave_filenames.txt``, in which we need to put +some calibration wave files. For testing purpose, we put the ``test_wavs`` +from the pre-trained model repository ``_ + +.. code-block:: bash + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ + + cat < wave_filenames.txt + ../test_wavs/1089-134686-0001.wav + ../test_wavs/1221-135766-0001.wav + ../test_wavs/1221-135766-0002.wav + EOF + +Now we can calculate the scales needed for quantization with the calibration data: + +.. code-block:: bash + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ + + generate-int8-scale-table \ + ./encoder_jit_trace-pnnx.ncnn.param \ + ./encoder_jit_trace-pnnx.ncnn.bin \ + ./decoder_jit_trace-pnnx.ncnn.param \ + ./decoder_jit_trace-pnnx.ncnn.bin \ + ./joiner_jit_trace-pnnx.ncnn.param \ + ./joiner_jit_trace-pnnx.ncnn.bin \ + ./encoder-scale-table.txt \ + ./joiner-scale-table.txt \ + ./wave_filenames.txt + +The output logs are in the following: + +.. literalinclude:: ./code/generate-int-8-scale-table-for-conv-emformer.txt + +It generates the following two files: + +.. code-block:: bash + + $ ls -lh encoder-scale-table.txt joiner-scale-table.txt + -rw-r--r-- 1 kuangfangjun root 955K Jan 11 17:28 encoder-scale-table.txt + -rw-r--r-- 1 kuangfangjun root 18K Jan 11 17:28 joiner-scale-table.txt + +.. caution:: + + Definitely, you need more calibration data to compute the scale table. + +Finally, let us use the scale table to quantize our models into ``int8``. + +.. code-block:: bash + + ncnn2int8 + + usage: ncnn2int8 [inparam] [inbin] [outparam] [outbin] [calibration table] + +First, we quantize the encoder model: + +.. code-block:: bash + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ + + ncnn2int8 \ + ./encoder_jit_trace-pnnx.ncnn.param \ + ./encoder_jit_trace-pnnx.ncnn.bin \ + ./encoder_jit_trace-pnnx.ncnn.int8.param \ + ./encoder_jit_trace-pnnx.ncnn.int8.bin \ + ./encoder-scale-table.txt + +Next, we quantize the joiner model: + +.. code-block:: bash + + ncnn2int8 \ + ./joiner_jit_trace-pnnx.ncnn.param \ + ./joiner_jit_trace-pnnx.ncnn.bin \ + ./joiner_jit_trace-pnnx.ncnn.int8.param \ + ./joiner_jit_trace-pnnx.ncnn.int8.bin \ + ./joiner-scale-table.txt + +The above two commands generate the following 4 files: + +.. code-block:: bash + + -rw-r--r-- 1 kuangfangjun root 99M Jan 11 17:34 encoder_jit_trace-pnnx.ncnn.int8.bin + -rw-r--r-- 1 kuangfangjun root 78K Jan 11 17:34 encoder_jit_trace-pnnx.ncnn.int8.param + -rw-r--r-- 1 kuangfangjun root 774K Jan 11 17:35 joiner_jit_trace-pnnx.ncnn.int8.bin + -rw-r--r-- 1 kuangfangjun root 496 Jan 11 17:35 joiner_jit_trace-pnnx.ncnn.int8.param + +Congratulations! You have successfully quantized your model from ``float32`` to ``int8``. + +.. caution:: + + ``ncnn.int8.param`` and ``ncnn.int8.bin`` must be used in pairs. + + You can replace ``ncnn.param`` and ``ncnn.bin`` with ``ncnn.int8.param`` + and ``ncnn.int8.bin`` in `sherpa-ncnn`_ if you like. + + For instance, to use only the ``int8`` encoder in ``sherpa-ncnn``, you can + replace the following invocation: + + .. code-block:: + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ + + sherpa-ncnn \ + ../data/lang_bpe_500/tokens.txt \ + ./encoder_jit_trace-pnnx.ncnn.param \ + ./encoder_jit_trace-pnnx.ncnn.bin \ + ./decoder_jit_trace-pnnx.ncnn.param \ + ./decoder_jit_trace-pnnx.ncnn.bin \ + ./joiner_jit_trace-pnnx.ncnn.param \ + ./joiner_jit_trace-pnnx.ncnn.bin \ + ../test_wavs/1089-134686-0001.wav + + with + + .. code-block:: + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ + + sherpa-ncnn \ + ../data/lang_bpe_500/tokens.txt \ + ./encoder_jit_trace-pnnx.ncnn.int8.param \ + ./encoder_jit_trace-pnnx.ncnn.int8.bin \ + ./decoder_jit_trace-pnnx.ncnn.param \ + ./decoder_jit_trace-pnnx.ncnn.bin \ + ./joiner_jit_trace-pnnx.ncnn.param \ + ./joiner_jit_trace-pnnx.ncnn.bin \ + ../test_wavs/1089-134686-0001.wav + + +The following table compares again the file sizes: + + ++----------------------------------------+------------+ +| File name | File size | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.pt | 283 MB | ++----------------------------------------+------------+ +| decoder_jit_trace-pnnx.pt | 1010 KB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.pt | 3.0 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin (fp16) | 142 MB | ++----------------------------------------+------------+ +| decoder_jit_trace-pnnx.ncnn.bin (fp16) | 503 KB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin (fp16) | 1.5 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin (fp32) | 283 MB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin (fp32) | 3.0 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.int8.bin | 99 MB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.int8.bin | 774 KB | ++----------------------------------------+------------+ + +You can see that the file sizes of the model after ``int8`` quantization +are much smaller. + +.. hint:: + + Currently, only linear layers and convolutional layers are quantized + with ``int8``, so you don't see an exact ``4x`` reduction in file sizes. + +.. note:: + + You need to test the recognition accuracy after ``int8`` quantization. + +You can find the speed comparison at ``_. + + +That's it! Have fun with `sherpa-ncnn`_! From 5c8e9628cc39b9fb1e471d53df9aec06b2602b97 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 13 Jan 2023 15:21:29 +0800 Subject: [PATCH 108/120] update faq for libpython3.10.so not found (#838) --- docs/source/conf.py | 3 + docs/source/faqs.rst | 40 ++++++++++++ .../librispeech/distillation.rst | 65 ++++++++++--------- .../pruned_transducer_stateless.rst | 2 + 4 files changed, 79 insertions(+), 31 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 33429f74c..ef9fe1445 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -81,6 +81,9 @@ todo_include_todos = True rst_epilog = """ .. _sherpa-ncnn: https://github.com/k2-fsa/sherpa-ncnn +.. _icefall: https://github.com/k2-fsa/icefall .. _git-lfs: https://git-lfs.com/ .. _ncnn: https://github.com/tencent/ncnn +.. _LibriSpeech: https://www.openslr.org/12 +.. _musan: http://www.openslr.org/17/ """ diff --git a/docs/source/faqs.rst b/docs/source/faqs.rst index c70ded431..72b0302d7 100644 --- a/docs/source/faqs.rst +++ b/docs/source/faqs.rst @@ -65,3 +65,43 @@ The fix is: pip uninstall setuptools pip install setuptools==58.0.4 + +ImportError: libpython3.10.so.1.0: cannot open shared object file: No such file or directory +-------------------------------------------------------------------------------------------- + +If you are using ``conda`` and encounter the following issue: + +.. code-block:: + + Traceback (most recent call last): + File "/k2-dev/yangyifan/anaconda3/envs/icefall/lib/python3.10/site-packages/k2-1.23.3.dev20230112+cuda11.6.torch1.13.1-py3.10-linux-x86_64.egg/k2/__init__.py", line 24, in + from _k2 import DeterminizeWeightPushingType + ImportError: libpython3.10.so.1.0: cannot open shared object file: No such file or directory + + During handling of the above exception, another exception occurred: + + Traceback (most recent call last): + File "/k2-dev/yangyifan/icefall/egs/librispeech/ASR/./pruned_transducer_stateless7_ctc_bs/decode.py", line 104, in + import k2 + File "/k2-dev/yangyifan/anaconda3/envs/icefall/lib/python3.10/site-packages/k2-1.23.3.dev20230112+cuda11.6.torch1.13.1-py3.10-linux-x86_64.egg/k2/__init__.py", line 30, in + raise ImportError( + ImportError: libpython3.10.so.1.0: cannot open shared object file: No such file or directory + Note: If you're using anaconda and importing k2 on MacOS, + you can probably fix this by setting the environment variable: + export DYLD_LIBRARY_PATH=$CONDA_PREFIX/lib/python3.10/site-packages:$DYLD_LIBRARY_PATH + +Please first try to find where ``libpython3.10.so.1.0`` locates. + +For instance, + +.. code-block:: bash + + cd $CONDA_PREFIX/lib + find . -name "libpython*" + +If you are able to find it inside ``$CODNA_PREFIX/lib``, please set the +following environment variable: + +.. code-block:: bash + + export LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/distillation.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/distillation.rst index aa379c3f8..ea9f350cd 100644 --- a/docs/source/recipes/Non-streaming-ASR/librispeech/distillation.rst +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/distillation.rst @@ -1,16 +1,16 @@ Distillation with HuBERT ======================== -This totorial shows you how to perform knowledge distillation in ``icefall`` -with the `LibriSpeech `_ dataset. The distillation method -used here is called "Multi Vector Quantization Knowledge Distillation" (MVQ-KD). +This tutorial shows you how to perform knowledge distillation in `icefall`_ +with the `LibriSpeech`_ dataset. The distillation method +used here is called "Multi Vector Quantization Knowledge Distillation" (MVQ-KD). Please have a look at our paper `Predicting Multi-Codebook Vector Quantization Indexes for Knowledge Distillation `_ for more details about MVQ-KD. .. note:: This tutorial is based on recipe - `pruned_transducer_stateless4 `_. + `pruned_transducer_stateless4 `_. Currently, we only implement MVQ-KD in this recipe. However, MVQ-KD is theoretically applicable to all recipes with only minor changes needed. Feel free to try out MVQ-KD in different recipes. If you encounter any problems, please open an issue here `icefall `_. @@ -18,7 +18,7 @@ for more details about MVQ-KD. .. note:: We assume you have read the page :ref:`install icefall` and have setup - the environment for ``icefall``. + the environment for `icefall`_. .. HINT:: @@ -27,13 +27,13 @@ for more details about MVQ-KD. Data preparation ---------------- -We first prepare necessary training data for ``LibriSpeech``. -This is the same as in `Pruned_transducer_statelessX <./pruned_transducer_stateless.rst>`_. +We first prepare necessary training data for `LibriSpeech`_. +This is the same as in :ref:`non_streaming_librispeech_pruned_transducer_stateless`. .. hint:: The data preparation is the same as other recipes on LibriSpeech dataset, - if you have finished this step, you can skip to ``Codebook index preparation`` directly. + if you have finished this step, you can skip to :ref:`codebook_index_preparation` directly. .. code-block:: bash @@ -61,8 +61,8 @@ For example, .. HINT:: - If you have pre-downloaded the `LibriSpeech `_ - dataset and the `musan `_ dataset, say, + If you have pre-downloaded the `LibriSpeech`_ + dataset and the `musan`_ dataset, say, they are saved in ``/tmp/LibriSpeech`` and ``/tmp/musan``, you can modify the ``dl_dir`` variable in ``./prepare.sh`` to point to ``/tmp`` so that ``./prepare.sh`` won't re-download them. @@ -84,24 +84,27 @@ We provide the following YouTube video showing how to run ``./prepare.sh``. .. youtube:: ofEIoJL-mGM +.. _codebook_index_preparation: + Codebook index preparation -------------------------- Here, we prepare necessary data for MVQ-KD. This requires the generation of codebook indexes (please read our `paper `_. -if you are interested in details). In this tutorial, we use the pre-computed -codebook indexes for convenience. The only thing you need to do is to -run ``./distillation_with_hubert.sh``. +if you are interested in details). In this tutorial, we use the pre-computed +codebook indexes for convenience. The only thing you need to do is to +run `./distillation_with_hubert.sh `_. .. note:: - There are 5 stages in total, the first and second stage will be automatically skipped - when choosing to downloaded codebook indexes prepared by `icefall`_. - Of course, you can extract and compute the codebook indexes by yourself. This - will require you downloading a HuBERT-XL model and it can take a while for - the extraction of codebook indexes. - -As usual, you can control the stages you want to run by specifying the following + There are 5 stages in total, the first and second stage will be automatically skipped + when choosing to downloaded codebook indexes prepared by `icefall`_. + Of course, you can extract and compute the codebook indexes by yourself. This + will require you downloading a HuBERT-XL model and it can take a while for + the extraction of codebook indexes. + + +As usual, you can control the stages you want to run by specifying the following two options: - ``--stage`` @@ -115,7 +118,7 @@ For example, $ ./distillation_with_hubert.sh --stage 0 --stop-stage 0 # run only stage 0 $ ./distillation_with_hubert.sh --stage 2 --stop-stage 4 # run from stage 2 to stage 5 -Here are a few options in ``./distillation_with_hubert.sh`` +Here are a few options in `./distillation_with_hubert.sh `_ you need to know before you proceed. - ``--full_libri`` If True, use full 960h data. Otherwise only ``train-clean-100`` will be used @@ -126,14 +129,14 @@ Since we are using the pre-computed codebook indexes, we set ``use_extracted_codebook=True``. If you want to do full `LibriSpeech`_ experiments, please set ``full_libri=True``. -The following command downloads the pre-computed codebook indexes -and prepares MVQ-augmented training manifests. +The following command downloads the pre-computed codebook indexes +and prepares MVQ-augmented training manifests. .. code-block:: bash $ ./distillation_with_hubert.sh --stage 2 --stop-stage 2 # run only stage 2 -Please see the +Please see the following screenshot for the output of an example execution. .. figure:: ./images/distillation_codebook.png @@ -146,12 +149,12 @@ following screenshot for the output of an example execution. .. hint:: The codebook indexes we prepared for you in this tutorial - are extracted from the 36-th layer of a fine-tuned HuBERT-XL model + are extracted from the 36-th layer of a fine-tuned HuBERT-XL model with 8 codebooks. If you want to try other configurations, please - set ``use_extracted_codebook=False`` and set ``embedding_layer`` and + set ``use_extracted_codebook=False`` and set ``embedding_layer`` and ``num_codebooks`` by yourself. -Now, you should see the following files under the direcory ``./data/vq_fbank_layer36_cb8``. +Now, you should see the following files under the directory ``./data/vq_fbank_layer36_cb8``. .. figure:: ./images/distillation_directory.png :width: 800 @@ -165,7 +168,7 @@ Whola! You are ready to perform knowledge distillation training now! Training -------- -To perform training, please run stage 3 by executing the following command. +To perform training, please run stage 3 by executing the following command. .. code-block:: bash @@ -176,7 +179,7 @@ Here is the code snippet for training: .. code-block:: bash WORLD_SIZE=$(echo ${CUDA_VISIBLE_DEVICES} | awk '{n=split($1, _, ","); print n}') - + ./pruned_transducer_stateless6/train.py \ --manifest-dir ./data/vq_fbank_layer36_cb8 \ --master-port 12359 \ @@ -191,6 +194,7 @@ Here is the code snippet for training: There are a few training arguments in the following training commands that should be paid attention to. + - ``--enable-distillation`` If True, knowledge distillation training is enabled. - ``--codebook-loss-scale`` The scale of the knowledge distillation loss. - ``--manifest-dir`` The path to the MVQ-augmented manifest. @@ -204,7 +208,7 @@ the following command. .. code-block:: bash - export CUDA_VISIBLE_DEVICES=0 + export CUDA_VISIBLE_DEVICES=0 ./pruned_transducer_stateless6/train.py \ --decoding-method "modified_beam_search" \ --epoch 30 \ @@ -217,4 +221,3 @@ You should get similar results as `here `_. - diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst index 86d43c8fe..42fd3df77 100644 --- a/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst @@ -1,3 +1,5 @@ +.. _non_streaming_librispeech_pruned_transducer_stateless: + Pruned transducer statelessX ============================ From 2a463a420d5080a93ac8933554e13f788a8a59e1 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Mon, 16 Jan 2023 20:15:35 +0800 Subject: [PATCH 109/120] Filter uneven-sized batch (#843) * add filter_uneven_sized_batch fucntion * set --filter-uneven-sized-batch=True as default --- .../ASR/pruned_transducer_stateless7/train.py | 33 ++++++++++++++++- icefall/utils.py | 36 +++++++++++++++++++ 2 files changed, 68 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 31a3a0505..a806244ff 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -82,7 +82,13 @@ from icefall.checkpoint import ( from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + filter_uneven_sized_batch, + setup_logger, + str2bool, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -368,6 +374,21 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--filter-uneven-sized-batch", + type=str2bool, + default=True, + help="""Whether to filter uneven-sized minibatch. + For the uneven-sized batch, the total duration after padding would possibly + cause OOM. Hence, for each batch, which is sorted descendingly by length, + we simply drop the last few shortest samples, so that the retained total frames + (after padding) would not exceed `allowed_max_frames`: + `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, + where `max_frames = max_duration * 1000 // frame_shift_ms`. + We set allowed_excess_duration_ratio=0.1. + """, + ) + add_model_arguments(parser) return parser @@ -420,6 +441,9 @@ def get_params() -> AttributeDict: """ params = AttributeDict( { + "frame_shift_ms": 10.0, + # only used when params.filter_uneven_sized_batch is True + "allowed_excess_duration_ratio": 0.1, "best_train_loss": float("inf"), "best_valid_loss": float("inf"), "best_train_epoch": -1, @@ -642,6 +666,13 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ + if params.filter_uneven_sized_batch: + max_frames = params.max_duration * 1000 // params.frame_shift_ms + allowed_max_frames = int( + max_frames * (1.0 + params.allowed_excess_duration_ratio) + ) + batch = filter_uneven_sized_batch(batch, allowed_max_frames) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) diff --git a/icefall/utils.py b/icefall/utils.py index 99e51a2a9..ba0b7fe43 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -1395,3 +1395,39 @@ def is_module_available(*modules: str) -> bool: import importlib return all(importlib.util.find_spec(m) is not None for m in modules) + + +def filter_uneven_sized_batch(batch: dict, allowed_max_frames: int): + """For the uneven-sized batch, the total duration after padding would possibly + cause OOM. Hence, for each batch, which is sorted descendingly by length, + we simply drop the last few shortest samples, so that the retained total frames + (after padding) would not exceed the given allow_max_frames. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + allowed_max_frames: + The allowed max number of frames in batch. + """ + features = batch["inputs"] + supervisions = batch["supervisions"] + + N, T, _ = features.size() + assert T == supervisions["num_frames"].max(), (T, supervisions["num_frames"].max()) + keep_num_utt = allowed_max_frames // T + + if keep_num_utt >= N: + return batch + + # Note: we assume the samples in batch is sorted descendingly by length + logging.info( + f"Filtering uneven-sized batch, original batch size is {N}, " + f"retained batch size is {keep_num_utt}." + ) + batch["inputs"] = features[:keep_num_utt] + for k, v in supervisions.items(): + assert len(v) == N, (len(v), N) + batch["supervisions"][k] = v[:keep_num_utt] + + return batch From 0af3e7beda1cb47cba8b51ce71f691e86cae2091 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 16 Jan 2023 20:26:36 +0800 Subject: [PATCH 110/120] fix export for stateless4 (#844) --- egs/librispeech/ASR/pruned_transducer_stateless4/export.py | 2 ++ egs/librispeech/ASR/pruned_transducer_stateless4/lstmp.py | 1 + .../ASR/pruned_transducer_stateless4/scaling_converter.py | 1 + 3 files changed, 4 insertions(+) create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless4/lstmp.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless4/scaling_converter.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py index 401b3ef3a..8f33f5b05 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py @@ -50,6 +50,7 @@ from pathlib import Path import sentencepiece as spm import torch +from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( @@ -261,6 +262,7 @@ def main(): model.eval() if params.jit: + convert_scaled_to_non_scaled(model, inplace=True) # We won't use the forward() method of the model in C++, so just ignore # it here. # Otherwise, one of its arguments is a ragged tensor and is not diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/lstmp.py b/egs/librispeech/ASR/pruned_transducer_stateless4/lstmp.py new file mode 120000 index 000000000..9aa06f82f --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/lstmp.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/lstmp.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless4/scaling_converter.py new file mode 120000 index 000000000..3b667058d --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/scaling_converter.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/scaling_converter.py \ No newline at end of file From f5ff7a18ebf90c82dd73434b276328fcbe287c13 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Tue, 17 Jan 2023 11:28:59 +0800 Subject: [PATCH 111/120] Fix the unclear description for streaming model (#849) --- docs/source/recipes/Streaming-ASR/introduction.rst | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/source/recipes/Streaming-ASR/introduction.rst b/docs/source/recipes/Streaming-ASR/introduction.rst index d81156659..e1382e77d 100644 --- a/docs/source/recipes/Streaming-ASR/introduction.rst +++ b/docs/source/recipes/Streaming-ASR/introduction.rst @@ -30,8 +30,9 @@ In icefall, we implement the streaming conformer the way just like what `WeNet < See :doc:`Pruned transducer statelessX ` for more details. .. HINT:: - If you want to adapt a non-streaming conformer model to be streaming, please refer - to `this pull request `_. + If you want to modify a non-streaming conformer recipe to support both streaming and non-streaming, please refer + to `this pull request `_. After adding the code needed by streaming training, + you have to re-train it with the extra arguments metioned in the docs above to get a streaming model. Streaming Emformer From 6b1ab71dc9c715fe08f5ba7dadc6d7c083be904c Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Fri, 27 Jan 2023 21:24:12 +0800 Subject: [PATCH 112/120] hardcode --filter-uneven-sized-batch (#854) --- .../ASR/pruned_transducer_stateless7/train.py | 32 ++++++------------- 1 file changed, 10 insertions(+), 22 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index a806244ff..6022406eb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -374,21 +374,6 @@ def get_parser(): help="Whether to use half precision training.", ) - parser.add_argument( - "--filter-uneven-sized-batch", - type=str2bool, - default=True, - help="""Whether to filter uneven-sized minibatch. - For the uneven-sized batch, the total duration after padding would possibly - cause OOM. Hence, for each batch, which is sorted descendingly by length, - we simply drop the last few shortest samples, so that the retained total frames - (after padding) would not exceed `allowed_max_frames`: - `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, - where `max_frames = max_duration * 1000 // frame_shift_ms`. - We set allowed_excess_duration_ratio=0.1. - """, - ) - add_model_arguments(parser) return parser @@ -442,7 +427,6 @@ def get_params() -> AttributeDict: params = AttributeDict( { "frame_shift_ms": 10.0, - # only used when params.filter_uneven_sized_batch is True "allowed_excess_duration_ratio": 0.1, "best_train_loss": float("inf"), "best_valid_loss": float("inf"), @@ -666,12 +650,16 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - if params.filter_uneven_sized_batch: - max_frames = params.max_duration * 1000 // params.frame_shift_ms - allowed_max_frames = int( - max_frames * (1.0 + params.allowed_excess_duration_ratio) - ) - batch = filter_uneven_sized_batch(batch, allowed_max_frames) + # For the uneven-sized batch, the total duration after padding would possibly + # cause OOM. Hence, for each batch, which is sorted descendingly by length, + # we simply drop the last few shortest samples, so that the retained total frames + # (after padding) would not exceed `allowed_max_frames`: + # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, + # where `max_frames = max_duration * 1000 // frame_shift_ms`. + # We set allowed_excess_duration_ratio=0.1. + max_frames = params.max_duration * 1000 // params.frame_shift_ms + allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) + batch = filter_uneven_sized_batch(batch, allowed_max_frames) device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] From 1ce2bc1ee08a9b31b00c12aeb0912f41ec399d3f Mon Sep 17 00:00:00 2001 From: Teo Wen Shen <36886809+teowenshen@users.noreply.github.com> Date: Sat, 28 Jan 2023 14:47:21 +0900 Subject: [PATCH 113/120] edit comments (#852) --- .../pruned_transducer_stateless7/zipformer.py | 16 +++++++-------- .../zipformer.py | 20 +++++++++---------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index b1717ec64..5cde57812 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -197,13 +197,13 @@ class Zipformer(EncoderInterface): """ In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of randomized feature masks, one per encoder. - On e.g. 15% of frames, these masks will zero out all enocder dims larger than + On e.g. 15% of frames, these masks will zero out all encoder dims larger than some supplied number, e.g. >256, so in effect on those frames we are using - a smaller encoer dim. + a smaller encoder dim. We generate the random masks at this level because we want the 2 masks to 'agree' all the way up the encoder stack. This will mean that the 1st mask will have - mask values repeated self.zipformer_subsampling_factor times. + mask values repeated self.zipformer_downsampling_factors times. Args: x: the embeddings (needed for the shape and dtype and device), of shape @@ -1009,10 +1009,10 @@ class RelPositionMultiheadAttention(nn.Module): # the initial_scale is supposed to take over the "scaling" factor of # head_dim ** -0.5, dividing it between the query and key. in_proj_dim = ( - 2 * attention_dim - + attention_dim // 2 # query, key - + pos_dim * num_heads # value - ) # positional encoding query + 2 * attention_dim # query, key + + attention_dim // 2 # value + + pos_dim * num_heads # positional encoding query + ) self.in_proj = ScaledLinear( embed_dim, in_proj_dim, bias=True, initial_scale=self.head_dim**-0.25 @@ -1509,7 +1509,7 @@ class FeedforwardModule(nn.Module): class ConvolutionModule(nn.Module): """ConvolutionModule in Zipformer model. - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py Args: channels (int): The number of channels of conv layers. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py index 88beb38c1..e13629384 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py @@ -421,13 +421,13 @@ class Zipformer(EncoderInterface): """ In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of randomized feature masks, one per encoder. - On e.g. 15% of frames, these masks will zero out all enocder dims larger than + On e.g. 15% of frames, these masks will zero out all encoder dims larger than some supplied number, e.g. >256, so in effect on those frames we are using - a smaller encoer dim. + a smaller encoder dim. We generate the random masks at this level because we want the 2 masks to 'agree' all the way up the encoder stack. This will mean that the 1st mask will have - mask values repeated self.zipformer_subsampling_factor times. + mask values repeated self.zipformer_downsampling_factors times. Args: x: the embeddings (needed for the shape and dtype and device), of shape @@ -1687,8 +1687,8 @@ class RelPositionalEncoding(torch.nn.Module): if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return - # Suppose `i` means to the position of query vecotr and `j` means the - # position of key vector. We use position relative positions when keys + # Suppose `i` means to the position of query vector and `j` means the + # position of key vector. We use positive relative positions when keys # are to the left (i>j) and negative relative positions otherwise (i Date: Sat, 28 Jan 2023 14:43:47 +0800 Subject: [PATCH 114/120] fix expired links (#856) --- egs/aishell/ASR/README.md | 2 +- egs/librispeech/ASR/README.md | 2 +- egs/timit/ASR/README.md | 2 +- egs/yesno/ASR/README.md | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/aishell/ASR/README.md b/egs/aishell/ASR/README.md index 75fc6326e..f4a59e552 100644 --- a/egs/aishell/ASR/README.md +++ b/egs/aishell/ASR/README.md @@ -1,7 +1,7 @@ # Introduction -Please refer to +Please refer to for how to run models in this recipe. diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md index 94cb445a8..9ffd78d5b 100644 --- a/egs/librispeech/ASR/README.md +++ b/egs/librispeech/ASR/README.md @@ -1,6 +1,6 @@ # Introduction -Please refer to for how to run models in this recipe. +Please refer to for how to run models in this recipe. [./RESULTS.md](./RESULTS.md) contains the latest results. diff --git a/egs/timit/ASR/README.md b/egs/timit/ASR/README.md index f10bfccfd..d493fc479 100644 --- a/egs/timit/ASR/README.md +++ b/egs/timit/ASR/README.md @@ -1,3 +1,3 @@ -Please refer to +Please refer to for how to run models in this recipe. diff --git a/egs/yesno/ASR/README.md b/egs/yesno/ASR/README.md index 7257bad9a..38b491fc6 100644 --- a/egs/yesno/ASR/README.md +++ b/egs/yesno/ASR/README.md @@ -10,5 +10,5 @@ get the following WER: ``` Please refer to - + for detailed instructions. From e277e31e37279da78ece356efc664e310ef18e5d Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Sun, 29 Jan 2023 15:35:36 +0800 Subject: [PATCH 115/120] update huggingface link of zipformer_ctc_blankskip.rst (#858) * update huggingface link * update link --------- Co-authored-by: yifanyang --- .../Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst index 56a420605..4929df950 100644 --- a/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst @@ -447,7 +447,8 @@ Download pretrained models If you don't want to train from scratch, you can download the pretrained models by visiting the following links: - - ``_ + - trained on LibriSpeech 100h: ``_ + - trained on LibriSpeech 960h: ``_ See ``_ for the details of the above pretrained models From e9019511eb1792b6fa2c166dbe4f6ab02e7e537f Mon Sep 17 00:00:00 2001 From: BuaaAlban Date: Tue, 31 Jan 2023 15:19:50 +0800 Subject: [PATCH 116/120] Fix bug in streaming_conformer_ctc egs (#862) * Update train.py Fix transducer lstm egs bug as mentioned in issue 579 * Update train.py fix dataloader bug --- .../ASR/streaming_conformer_ctc/train.py | 21 ++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/train.py b/egs/librispeech/ASR/streaming_conformer_ctc/train.py index 553b7d092..d265de45b 100755 --- a/egs/librispeech/ASR/streaming_conformer_ctc/train.py +++ b/egs/librispeech/ASR/streaming_conformer_ctc/train.py @@ -50,7 +50,7 @@ from icefall.utils import ( setup_logger, str2bool, ) - +from lhotse.cut import Cut def get_parser(): parser = argparse.ArgumentParser( @@ -645,8 +645,23 @@ def run(rank, world_size, args): optimizer.load_state_dict(checkpoints["optimizer"]) librispeech = LibriSpeechAsrDataModule(args) - train_dl = librispeech.train_dataloaders() - valid_dl = librispeech.valid_dataloaders() + + if params.full_libri: + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + return 1.0 <= c.duration <= 20.0 + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + train_dl = librispeech.train_dataloaders(train_cuts) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) scan_pessimistic_batches_for_oom( model=model, From d8234e199c65a5971827ddaaa4deb72bd173f0ae Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Tue, 31 Jan 2023 15:57:03 +0800 Subject: [PATCH 117/120] Add export to ONNX for Zipformer+CTC using blank skip (#861) * Add export to ONNX for Zipformer+CTC using blank skip --------- Co-authored-by: yifanyang --- .../export.py | 6 +- .../export_onnx.py | 665 ++++++++++++++++++ .../frame_reducer.py | 76 +- .../onnx_pretrained.py | 461 ++++++++++++ 4 files changed, 1188 insertions(+), 20 deletions(-) create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py mode change 100755 => 100644 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py index 96d316604..05df8cfff 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py @@ -72,14 +72,14 @@ Check ./pretrained.py for its usage. Note: If you don't want to train a model from scratch, we have provided one for you. You can get it at -https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 +https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29 with the following commands: sudo apt-get install git-lfs git lfs install - git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 - # You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp + git clone https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29 + # You will find the pre-trained model in icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29/exp """ import argparse diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py new file mode 100644 index 000000000..50efa6e60 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py @@ -0,0 +1,665 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang, +# Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +(1) Export to ONNX format + +./pruned_transducer_stateless7_ctc_bs/export_onnx.py \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 13 + +It will generate the following files in the given `exp_dir`. +Check `onnx_check.py` for how to use them. + + - encoder.onnx + - decoder.onnx + - joiner.onnx + - joiner_encoder_proj.onnx + - joiner_decoder_proj.onnx + - lconv.onnx + - frame_reducer.onnx + +Please see ./onnx_pretrained.py for usage of the generated files + +Check +https://github.com/k2-fsa/sherpa-onnx +for how to use the exported models outside of icefall. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29 + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29 + # You will find the pre-trained model in icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29/exp +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +import torch.nn as nn +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_ctc_bs/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--onnx", + type=str2bool, + default=True, + help="""If True, --jit is ignored and it exports the model + to onnx format. It will generate the following files: + + - encoder.onnx + - decoder.onnx + - joiner.onnx + - joiner_encoder_proj.onnx + - joiner_decoder_proj.onnx + - lconv.onnx + - frame_reducer.onnx + + Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def export_encoder_model_onnx( + encoder_model: nn.Module, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T, C) + - encoder_out_lens, a tensor of shape (N,) + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(15, 2000, 80, dtype=torch.float32) + x_lens = torch.tensor([2000] * 15, dtype=torch.int64) + + # encoder_model = torch.jit.script(encoder_model) + # It throws the following error for the above statement + # + # RuntimeError: Exporting the operator __is_ to ONNX opset version + # 11 is not supported. Please feel free to request support or + # submit a pull request on PyTorch GitHub. + # + # I cannot find which statement causes the above error. + # torch.onnx.export() will use torch.jit.trace() internally, which + # works well for the current reworked model + torch.onnx.export( + encoder_model, + (x, x_lens), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["encoder_out", "encoder_out_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + }, + ) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_onnx( + decoder_model: nn.Module, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, 1, C) + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + y = torch.zeros(15, decoder_model.context_size, dtype=torch.int64) + need_pad = False # Always False, so we can use torch.jit.trace() here + # Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script() + # in this case + torch.onnx.export( + decoder_model, + (y, need_pad), + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y", "need_pad"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - projected_encoder_out: a tensor of shape (N, joiner_dim) + - projected_decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + + The exported encoder_proj model has one input: + + - encoder_out: a tensor of shape (N, encoder_out_dim) + + and produces one output: + + - projected_encoder_out: a tensor of shape (N, joiner_dim) + + The exported decoder_proj model has one input: + + - decoder_out: a tensor of shape (N, decoder_out_dim) + + and produces one output: + + - projected_decoder_out: a tensor of shape (N, joiner_dim) + """ + encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") + decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") + + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + joiner_dim = joiner_model.decoder_proj.weight.shape[0] + + projected_encoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32) + + project_input = False + # Note: It uses torch.jit.trace() internally + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out, project_input), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + "project_input", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + logging.info(f"Saved to {joiner_filename}") + + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + torch.onnx.export( + joiner_model.encoder_proj, + encoder_out, + encoder_proj_filename, + verbose=False, + opset_version=opset_version, + input_names=["encoder_out"], + output_names=["projected_encoder_out"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "projected_encoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {encoder_proj_filename}") + + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + torch.onnx.export( + joiner_model.decoder_proj, + decoder_out, + decoder_proj_filename, + verbose=False, + opset_version=opset_version, + input_names=["decoder_out"], + output_names=["projected_decoder_out"], + dynamic_axes={ + "decoder_out": {0: "N"}, + "projected_decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_proj_filename}") + + +def export_lconv_onnx( + lconv: nn.Module, + lconv_filename: str, + opset_version: int = 11, +) -> None: + """Export the lconv to ONNX format. + + The exported lconv has two inputs: + + - lconv_input: a tensor of shape (N, T, C) + - src_key_padding_mask: a tensor of shape (N, T) + + and has one output: + + - lconv_out: a tensor of shape (N, T, C) + + Args: + lconv: + The lconv to be exported. + lconv_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + lconv_input = torch.zeros(15, 498, 384, dtype=torch.float32) + src_key_padding_mask = torch.zeros(15, 498, dtype=torch.bool) + + torch.onnx.export( + lconv, + (lconv_input, src_key_padding_mask), + lconv_filename, + verbose=False, + opset_version=opset_version, + input_names=["lconv_input", "src_key_padding_mask"], + output_names=["lconv_out"], + dynamic_axes={ + "lconv_input": {0: "N", 1: "T"}, + "src_key_padding_mask": {0: "N", 1: "T"}, + "lconv_out": {0: "N", 1: "T"}, + }, + ) + logging.info(f"Saved to {lconv_filename}") + + +def export_frame_reducer_onnx( + frame_reducer: nn.Module, + frame_reducer_filename: str, + opset_version: int = 11, +) -> None: + """Export the frame_reducer to ONNX format. + + The exported frame_reducer has four inputs: + + - x: a tensor of shape (N, T, C) + - x_lens: a tensor of shape (N, T) + - ctc_output: a tensor of shape (N, T, vocab_size) + - blank_id: an int, always 0 + + and has two outputs: + + - x_fr: a tensor of shape (N, T, C) + - x_lens_fr: a tensor of shape (N, T) + + Args: + frame_reducer: + The frame_reducer to be exported. + frame_reducer_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(15, 498, 384, dtype=torch.float32) + x_lens = torch.tensor([498] * 15, dtype=torch.int64) + ctc_output = torch.randn(15, 498, 500, dtype=torch.float32) + + torch.onnx.export( + frame_reducer, + (x, x_lens, ctc_output), + frame_reducer_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens", "ctc_output"], + output_names=["out", "out_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "ctc_output": {0: "N", 1: "T"}, + "out": {0: "N", 1: "T"}, + "out_lens": {0: "N"}, + }, + ) + logging.info(f"Saved to {frame_reducer_filename}") + + +def export_ctc_output_onnx( + ctc_output: nn.Module, + ctc_output_filename: str, + opset_version: int = 11, +) -> None: + """Export the frame_reducer to ONNX format. + + The exported frame_reducer has one inputs: + + - encoder_out: a tensor of shape (N, T, C) + + and has one output: + + - ctc_output: a tensor of shape (N, T, vocab_size) + + Args: + ctc_output: + The ctc_output to be exported. + ctc_output_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + encoder_out = torch.zeros(15, 498, 384, dtype=torch.float32) + + torch.onnx.export( + ctc_output, + (encoder_out), + ctc_output_filename, + verbose=False, + opset_version=opset_version, + input_names=["encoder_out"], + output_names=["ctc_output"], + dynamic_axes={ + "encoder_out": {0: "N", 1: "T"}, + "ctc_output": {0: "N", 1: "T"}, + }, + ) + logging.info(f"Saved to {ctc_output_filename}") + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + opset_version = 13 + logging.info("Exporting to onnx format") + encoder_filename = params.exp_dir / "encoder.onnx" + export_encoder_model_onnx( + model.encoder, + encoder_filename, + opset_version=opset_version, + ) + + decoder_filename = params.exp_dir / "decoder.onnx" + export_decoder_model_onnx( + model.decoder, + decoder_filename, + opset_version=opset_version, + ) + + joiner_filename = params.exp_dir / "joiner.onnx" + export_joiner_model_onnx( + model.joiner, + joiner_filename, + opset_version=opset_version, + ) + + lconv_filename = params.exp_dir / "lconv.onnx" + export_lconv_onnx( + model.lconv, + lconv_filename, + opset_version=opset_version, + ) + + frame_reducer_filename = params.exp_dir / "frame_reducer.onnx" + export_frame_reducer_onnx( + model.frame_reducer, + frame_reducer_filename, + opset_version=opset_version, + ) + + ctc_output_filename = params.exp_dir / "ctc_output.onnx" + export_ctc_output_onnx( + model.ctc_output, + ctc_output_filename, + opset_version=opset_version, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py old mode 100755 new mode 100644 index 9fe88929d..4a19edf66 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py @@ -22,7 +22,8 @@ from typing import List, Optional, Tuple, Union import torch import torch.nn as nn -from torch.nn.utils.rnn import pad_sequence +import torch.nn.functional as F + from icefall.utils import make_pad_mask @@ -43,7 +44,6 @@ class FrameReducer(nn.Module): x: torch.Tensor, x_lens: torch.Tensor, ctc_output: torch.Tensor, - blank_id: int = 0, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -54,26 +54,68 @@ class FrameReducer(nn.Module): `x` before padding. ctc_output: The CTC output with shape [N, T, vocab_size]. - blank_id: - The ID of the blank symbol. Returns: - x_fr: + out: The frame reduced encoder output with shape [N, T', C]. - x_lens_fr: + out_lens: A tensor of shape (batch_size,) containing the number of frames in - `x_fr` before padding. + `out` before padding. """ + N, T, C = x.size() + padding_mask = make_pad_mask(x_lens) - non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask) + non_blank_mask = (ctc_output[:, :, 0] < math.log(0.9)) * (~padding_mask) - frames_list: List[torch.Tensor] = [] - lens_list: List[int] = [] - for i in range(x.shape[0]): - frames = x[i][non_blank_mask[i]] - frames_list.append(frames) - lens_list.append(frames.shape[0]) - x_fr = pad_sequence(frames_list, batch_first=True) - x_lens_fr = torch.tensor(lens_list).to(device=x.device) + out_lens = non_blank_mask.sum(dim=1) + max_len = out_lens.max() + pad_lens_list = torch.full_like(out_lens, max_len.item()) - out_lens + max_pad_len = pad_lens_list.max() - return x_fr, x_lens_fr + out = F.pad(x, (0, 0, 0, max_pad_len)) + + valid_pad_mask = ~make_pad_mask(pad_lens_list) + total_valid_mask = torch.concat([non_blank_mask, valid_pad_mask], dim=1) + + out = out[total_valid_mask].reshape(N, -1, C) + + return out.to(device=x.device), out_lens.to(device=x.device) + + +if __name__ == "__main__": + import time + from torch.nn.utils.rnn import pad_sequence + + test_times = 10000 + frame_reducer = FrameReducer() + + # non zero case + x = torch.ones(15, 498, 384, dtype=torch.float32) + x_lens = torch.tensor([498] * 15, dtype=torch.int64) + ctc_output = torch.log(torch.randn(15, 498, 500, dtype=torch.float32)) + x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output) + + avg_time = 0 + for i in range(test_times): + delta_time = time.time() + x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output) + delta_time = time.time() - delta_time + avg_time += delta_time + print(x_fr.shape) + print(x_lens_fr) + print(avg_time / test_times) + + # all zero case + x = torch.zeros(15, 498, 384, dtype=torch.float32) + x_lens = torch.tensor([498] * 15, dtype=torch.int64) + ctc_output = torch.zeros(15, 498, 500, dtype=torch.float32) + + avg_time = 0 + for i in range(test_times): + delta_time = time.time() + x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output) + delta_time = time.time() - delta_time + avg_time += delta_time + print(x_fr.shape) + print(x_lens_fr) + print(avg_time / test_times) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py new file mode 100644 index 000000000..8ff02fbcb --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py @@ -0,0 +1,461 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX models and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless7_ctc_bs/export_onnx.py \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 13 + +Usage of this script: + +./pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py \ + --encoder-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/encoder.onnx \ + --decoder-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/decoder.onnx \ + --joiner-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/joiner.onnx \ + --joiner-encoder-proj-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/joiner_encoder_proj.onnx \ + --joiner-decoder-proj-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/joiner_decoder_proj.onnx \ + --lconv-filename ./pruned_transducer_stateless7_ctc_bs/exp/lconv.onnx \ + --frame-reducer-filename ./pruned_transducer_stateless7_ctc_bs/exp/frame_reducer.onnx \ + --ctc-output-filename ./pruned_transducer_stateless7_ctc_bs/exp/ctc_output.onnx \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import numpy as np +import onnxruntime as ort +import sentencepiece as spm +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence + +from icefall.utils import make_pad_mask + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--joiner-encoder-proj-model-filename", + type=str, + required=True, + help="Path to the joiner encoder_proj onnx model. ", + ) + + parser.add_argument( + "--joiner-decoder-proj-model-filename", + type=str, + required=True, + help="Path to the joiner decoder_proj onnx model. ", + ) + + parser.add_argument( + "--lconv-filename", + type=str, + required=True, + help="Path to the lconv onnx model. ", + ) + + parser.add_argument( + "--frame-reducer-filename", + type=str, + required=True, + help="Path to the frame reducer onnx model. ", + ) + + parser.add_argument( + "--ctc-output-filename", + type=str, + required=True, + help="Path to the ctc_output onnx model. ", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="Context size of the decoder model", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + decoder: ort.InferenceSession, + joiner: ort.InferenceSession, + joiner_encoder_proj: ort.InferenceSession, + joiner_decoder_proj: ort.InferenceSession, + encoder_out: np.ndarray, + encoder_out_lens: np.ndarray, + context_size: int, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + decoder: + The decoder model. + joiner: + The joiner model. + joiner_encoder_proj: + The joiner encoder projection model. + joiner_decoder_proj: + The joiner decoder projection model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + context_size: + The context size of the decoder model. + Returns: + Return the decoded results for each utterance. + """ + encoder_out = torch.from_numpy(encoder_out) + encoder_out_lens = torch.from_numpy(encoder_out_lens) + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + projected_encoder_out = joiner_encoder_proj.run( + [joiner_encoder_proj.get_outputs()[0].name], + {joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()}, + )[0] + + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input_nodes = decoder.get_inputs() + decoder_output_nodes = decoder.get_outputs() + + joiner_input_nodes = joiner.get_inputs() + joiner_output_nodes = joiner.get_outputs() + + decoder_input = torch.tensor( + hyps, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = decoder.run( + [decoder_output_nodes[0].name], + { + decoder_input_nodes[0].name: decoder_input.numpy(), + }, + )[0].squeeze(1) + projected_decoder_out = joiner_decoder_proj.run( + [joiner_decoder_proj.get_outputs()[0].name], + {joiner_decoder_proj.get_inputs()[0].name: decoder_out}, + )[0] + + projected_decoder_out = torch.from_numpy(projected_decoder_out) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = projected_encoder_out[start:end] + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + projected_decoder_out = projected_decoder_out[:batch_size] + + logits = joiner.run( + [joiner_output_nodes[0].name], + { + joiner_input_nodes[0].name: np.expand_dims( + np.expand_dims(current_encoder_out, axis=1), axis=1 + ), + joiner_input_nodes[1] + .name: projected_decoder_out.unsqueeze(1) + .unsqueeze(1) + .numpy(), + }, + )[0] + logits = torch.from_numpy(logits).squeeze(1).squeeze(1) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + dtype=torch.int64, + ) + decoder_out = decoder.run( + [decoder_output_nodes[0].name], + { + decoder_input_nodes[0].name: decoder_input.numpy(), + }, + )[0].squeeze(1) + projected_decoder_out = joiner_decoder_proj.run( + [joiner_decoder_proj.get_outputs()[0].name], + {joiner_decoder_proj.get_inputs()[0].name: decoder_out}, + )[0] + projected_decoder_out = torch.from_numpy(projected_decoder_out) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + encoder = ort.InferenceSession( + args.encoder_model_filename, + sess_options=session_opts, + ) + + decoder = ort.InferenceSession( + args.decoder_model_filename, + sess_options=session_opts, + ) + + joiner = ort.InferenceSession( + args.joiner_model_filename, + sess_options=session_opts, + ) + + joiner_encoder_proj = ort.InferenceSession( + args.joiner_encoder_proj_model_filename, + sess_options=session_opts, + ) + + joiner_decoder_proj = ort.InferenceSession( + args.joiner_decoder_proj_model_filename, + sess_options=session_opts, + ) + + lconv = ort.InferenceSession( + args.lconv_filename, + sess_options=session_opts, + ) + + frame_reducer = ort.InferenceSession( + args.frame_reducer_filename, + sess_options=session_opts, + ) + + ctc_output = ort.InferenceSession( + args.ctc_output_filename, + sess_options=session_opts, + ) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) + + encoder_input_nodes = encoder.get_inputs() + encoder_out_nodes = encoder.get_outputs() + encoder_out, encoder_out_lens = encoder.run( + [encoder_out_nodes[0].name, encoder_out_nodes[1].name], + { + encoder_input_nodes[0].name: features.numpy(), + encoder_input_nodes[1].name: feature_lengths.numpy(), + }, + ) + + ctc_output_input_nodes = ctc_output.get_inputs() + ctc_output_out_nodes = ctc_output.get_outputs() + ctc_out = ctc_output.run( + [ctc_output_out_nodes[0].name], + { + ctc_output_input_nodes[0].name: encoder_out, + }, + )[0] + + lconv_input_nodes = lconv.get_inputs() + lconv_out_nodes = lconv.get_outputs() + encoder_out = lconv.run( + [lconv_out_nodes[0].name], + { + lconv_input_nodes[0].name: encoder_out, + lconv_input_nodes[1] + .name: make_pad_mask(torch.from_numpy(encoder_out_lens)) + .numpy(), + }, + )[0] + + frame_reducer_input_nodes = frame_reducer.get_inputs() + frame_reducer_out_nodes = frame_reducer.get_outputs() + encoder_out_fr, encoder_out_lens_fr = frame_reducer.run( + [frame_reducer_out_nodes[0].name, frame_reducer_out_nodes[1].name], + { + frame_reducer_input_nodes[0].name: encoder_out, + frame_reducer_input_nodes[1].name: encoder_out_lens, + frame_reducer_input_nodes[2].name: ctc_out, + }, + ) + + hyps = greedy_search( + decoder=decoder, + joiner=joiner, + joiner_encoder_proj=joiner_encoder_proj, + joiner_decoder_proj=joiner_decoder_proj, + encoder_out=encoder_out_fr, + encoder_out_lens=encoder_out_lens_fr, + context_size=args.context_size, + ) + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = sp.decode(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() From e36ea89112bb3d81602cb4df51bd68e6d06dc150 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Wed, 1 Feb 2023 21:04:56 +0800 Subject: [PATCH 118/120] update result.md for pruned_transducer_stateless7_ctc_bs (#865) --- egs/librispeech/ASR/RESULTS.md | 105 ++++++++++++++++++++++++++++++--- 1 file changed, 98 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index b30cf7c1f..a3e44f09c 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -93,13 +93,13 @@ results at: Number of model parameters: 69136519, i.e., 69.14 M -| | test-clean | test-other | comment | -|--------------------------|------------|-------------|---------------------| -| 1best | 2.54 | 5.65 | --epoch 30 --avg 10 | -| nbest | 2.54 | 5.66 | --epoch 30 --avg 10 | -| nbest-rescoring-LG | 2.49 | 5.42 | --epoch 30 --avg 10 | -| nbest-rescoring-3-gram | 2.52 | 5.62 | --epoch 30 --avg 10 | -| nbest-rescoring-4-gram | 2.5 | 5.51 | --epoch 30 --avg 10 | +| | test-clean | test-other | comment | +| ---------------------- | ---------- | ---------- | ------------------- | +| 1best | 2.54 | 5.65 | --epoch 30 --avg 10 | +| nbest | 2.54 | 5.66 | --epoch 30 --avg 10 | +| nbest-rescoring-LG | 2.49 | 5.42 | --epoch 30 --avg 10 | +| nbest-rescoring-3-gram | 2.52 | 5.62 | --epoch 30 --avg 10 | +| nbest-rescoring-4-gram | 2.5 | 5.51 | --epoch 30 --avg 10 | The training commands are: ```bash @@ -134,6 +134,97 @@ for m in nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescoring-4-gram; done ``` +### pruned_transducer_stateless7_ctc_bs (zipformer with transducer loss and ctc loss using blank skip) + +See https://github.com/k2-fsa/icefall/pull/730 for more details. + +[pruned_transducer_stateless7_ctc_bs](./pruned_transducer_stateless7_ctc_bs) + +The tensorboard log can be found at + + +You can find a pretrained model, training logs, decoding logs, and decoding +results at: + + +Number of model parameters: 76804822, i.e., 76.80 M + +Test on 8-card V100 cluster, with 4-card busy and 4-card idle. + +#### greedy_search + +| model | test-clean | test-other | decoding time(s) | comment | +| ------------------------------------------------------------ | ---------- | ---------- | ---------------- | ------------------- | +| [pruned_transducer_stateless7_ctc_bs](./pruned_transducer_stateless7_ctc_bs) | 2.28 | 5.53 | 48.939 | --epoch 30 --avg 13 | +| [pruned_transducer_stateless7_ctc](./pruned_transducer_stateless7_ctc) | 2.24 | 5.18 | 91.900 | --epoch 30 --avg 8 | + +- [pruned_transducer_stateless7_ctc_bs](./pruned_transducer_stateless7_ctc_bs) applies blank skip both on training and decoding, and [pruned_transducer_stateless7_ctc](./pruned_transducer_stateless7_ctc) doesn`t apply blank skip. +- Applying blank skip both on training and decoding is **1.88 times** faster than the model that doesn't apply blank skip without obvious performance loss. + +#### modified_beam_search + +| model | test-clean | test-other | decoding time(s) | comment | +| ------------------------------------------------------------ | ---------- | ---------- | ---------------- | ------------------- | +| [pruned_transducer_stateless7_ctc_bs](./pruned_transducer_stateless7_ctc_bs) | 2.26 | 5.44 | 80.446 | --epoch 30 --avg 13 | +| [pruned_transducer_stateless7_ctc](./pruned_transducer_stateless7_ctc) | 2.20 | 5.12 | 283.676 | --epoch 30 --avg 8 | + +- [pruned_transducer_stateless7_ctc_bs](./pruned_transducer_stateless7_ctc_bs) applies blank skip both on training and decoding, and [pruned_transducer_stateless7_ctc](./pruned_transducer_stateless7_ctc) doesn`t apply blank skip. +- Applying blank skip both on training and decoding is **3.53 times** faster than the model that doesn't apply blank skip without obvious performance loss. + +The training commands for the model using blank skip ([pruned_transducer_stateless7_ctc_bs](./pruned_transducer_stateless7_ctc_bs)) are: + +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7_ctc_bs/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --full-libri 1 \ + --use-fp16 1 \ + --max-duration 750 \ + --exp-dir pruned_transducer_stateless7_ctc_bs/exp \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --ctc-loss-scale 0.2 \ + --master-port 12535 +``` + +The decoding commands for the transducer branch of the model using blank skip ([pruned_transducer_stateless7_ctc_bs](./pruned_transducer_stateless7_ctc_bs)) are: + +```bash +for m in greedy_search modified_beam_search fast_beam_search; do + for epoch in 30; do + for avg in 15; do + ./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ + --epoch $epoch \ + --avg $avg \ + --use-averaged-model 1 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --max-duration 600 \ + --decoding-method $m + done + done +done +``` + +The decoding commands for the transducer branch of the model without blank skip ([pruned_transducer_stateless7_ctc](./pruned_transducer_stateless7_ctc)) are: + +```bash +for m in greedy_search modified_beam_search fast_beam_search; do + for epoch in 30; do + for avg in 15; do + ./pruned_transducer_stateless7_ctc/decode.py \ + --epoch $epoch \ + --avg $avg \ + --use-averaged-model 1 \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --max-duration 600 \ + --decoding-method $m + done + done +done +``` ### pruned_transducer_stateless7_ctc (zipformer with transducer loss and ctc loss) From 1e6d6f816001dbcf1275204385740a81bcc2ff14 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Fri, 3 Feb 2023 11:54:57 +0800 Subject: [PATCH 119/120] shuffle full Librispeech for zipformer recipes (#869) * shuffle libri --- egs/librispeech/ASR/pruned_transducer_stateless7/train.py | 6 +++--- .../ASR/pruned_transducer_stateless7_ctc/train.py | 6 +++--- .../ASR/pruned_transducer_stateless7_ctc_bs/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless7_streaming/train.py | 6 +++--- egs/librispeech/ASR/pruned_transducer_stateless8/train.py | 6 +++--- egs/librispeech/ASR/streaming_conformer_ctc/train.py | 3 ++- 6 files changed, 18 insertions(+), 17 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 6022406eb..792a243e5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -1043,10 +1043,10 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py index 5a05e1836..718381baa 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py @@ -1072,10 +1072,10 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py index 522ecc974..b282ab9db 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py @@ -55,9 +55,9 @@ import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from decoder import Decoder +from frame_reducer import FrameReducer from joiner import Joiner from lconv import LConv -from frame_reducer import FrameReducer from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed @@ -1063,10 +1063,10 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py index 2bdc882a5..c7a2a136d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py @@ -1049,10 +1049,10 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py index abe249c7b..b0abad5ae 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py @@ -1154,10 +1154,10 @@ def run(rank, world_size, args): librispeech = LibriSpeech(manifest_dir=args.manifest_dir) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() train_cuts = filter_short_and_long_utterances(train_cuts, sp) diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/train.py b/egs/librispeech/ASR/streaming_conformer_ctc/train.py index d265de45b..bb55ed6bb 100755 --- a/egs/librispeech/ASR/streaming_conformer_ctc/train.py +++ b/egs/librispeech/ASR/streaming_conformer_ctc/train.py @@ -30,6 +30,7 @@ import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from conformer import Conformer +from lhotse.cut import Cut from lhotse.utils import fix_random_seed from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP @@ -50,7 +51,7 @@ from icefall.utils import ( setup_logger, str2bool, ) -from lhotse.cut import Cut + def get_parser(): parser = argparse.ArgumentParser( From bffce413f07d938e35d69d5eb2f360c7ff842502 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Fri, 3 Feb 2023 12:32:06 +0800 Subject: [PATCH 120/120] Fix filename ctc_guild_decode_bs.py -> ctc_guide_decode_bs.py (#870) * fix filename ctc_guild_decode_bs.py -> ctc_guide_decode_bs.py --------- Co-authored-by: yifanyang --- .../librispeech/zipformer_ctc_blankskip.rst | 14 +++++++------- egs/librispeech/ASR/RESULTS.md | 2 +- ...c_guild_decode_bs.py => ctc_guide_decode_bs.py} | 14 +++++++------- .../pruned_transducer_stateless7_ctc_bs/lconv.py | 2 +- 4 files changed, 16 insertions(+), 16 deletions(-) rename egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/{ctc_guild_decode_bs.py => ctc_guide_decode_bs.py} (98%) diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst index 4929df950..aa73bfe33 100644 --- a/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst @@ -299,11 +299,11 @@ to run the training part first. - (1) ``epoch-1.pt``, ``epoch-2.pt``, ..., which are saved at the end of each epoch. You can pass ``--epoch`` to - ``pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py`` to use them. + ``pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py`` to use them. - (2) ``checkpoints-436000.pt``, ``epoch-438000.pt``, ..., which are saved every ``--save-every-n`` batches. You can pass ``--iter`` to - ``pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py`` to use them. + ``pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py`` to use them. We suggest that you try both types of checkpoints and choose the one that produces the lowest WERs. @@ -311,7 +311,7 @@ to run the training part first. .. code-block:: bash $ cd egs/librispeech/ASR - $ ./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py --help + $ ./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py --help shows the options for decoding. @@ -320,7 +320,7 @@ The following shows the example using ``epoch-*.pt``: .. code-block:: bash for m in greedy_search fast_beam_search modified_beam_search; do - ./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ + ./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \ --epoch 30 \ --avg 13 \ --exp-dir pruned_transducer_stateless7_ctc_bs/exp \ @@ -333,7 +333,7 @@ To test CTC branch, you can use the following command: .. code-block:: bash for m in ctc-decoding 1best; do - ./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ + ./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \ --epoch 30 \ --avg 13 \ --exp-dir pruned_transducer_stateless7_ctc_bs/exp \ @@ -367,7 +367,7 @@ It will generate a file ``./pruned_transducer_stateless7_ctc_bs/exp/pretrained.p .. hint:: - To use the generated ``pretrained.pt`` for ``pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py``, + To use the generated ``pretrained.pt`` for ``pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py``, you can run: .. code-block:: bash @@ -376,7 +376,7 @@ It will generate a file ``./pruned_transducer_stateless7_ctc_bs/exp/pretrained.p ln -s pretrained epoch-9999.pt And then pass ``--epoch 9999 --avg 1 --use-averaged-model 0`` to - ``./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py``. + ``./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py``. To use the exported model with ``./pruned_transducer_stateless7_ctc_bs/pretrained.py``, you can run: diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index a3e44f09c..1a894498e 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -194,7 +194,7 @@ The decoding commands for the transducer branch of the model using blank skip ([ for m in greedy_search modified_beam_search fast_beam_search; do for epoch in 30; do for avg in 15; do - ./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ + ./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \ --epoch $epoch \ --avg $avg \ --use-averaged-model 1 \ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py similarity index 98% rename from egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py rename to egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py index 9c2166aaf..01ba7b711 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py @@ -21,7 +21,7 @@ """ Usage: (1) greedy search -./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ +./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \ --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ @@ -29,7 +29,7 @@ Usage: --decoding-method greedy_search (2) beam search (not recommended) -./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ +./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \ --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ @@ -38,7 +38,7 @@ Usage: --beam-size 4 (3) modified beam search -./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ +./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \ --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ @@ -47,7 +47,7 @@ Usage: --beam-size 4 (4) fast beam search (one best) -./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ +./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \ --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ @@ -58,7 +58,7 @@ Usage: --max-states 64 (5) fast beam search (nbest) -./pruned_transducer_stateless7_ctc/ctc_guild_decode_bs.py \ +./pruned_transducer_stateless7_ctc/ctc_guide_decode_bs.py \ --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless7_ctc/exp \ @@ -71,7 +71,7 @@ Usage: --nbest-scale 0.5 (6) fast beam search (nbest oracle WER) -./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ +./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \ --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ @@ -84,7 +84,7 @@ Usage: --nbest-scale 0.5 (7) fast beam search (with LG) -./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ +./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \ --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py index bfd49d533..a902358ae 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py @@ -62,7 +62,7 @@ class LConv(nn.Module): kernel_size=kernel_size, stride=1, padding=(kernel_size - 1) // 2, - groups=channels, + groups=2 * channels, bias=bias, )