diff --git a/egs/bengaliai_speech/ASR/zipformer/beam_search.py b/egs/bengaliai_speech/ASR/zipformer/beam_search.py index 27645364d..8e2c0a65c 120000 --- a/egs/bengaliai_speech/ASR/zipformer/beam_search.py +++ b/egs/bengaliai_speech/ASR/zipformer/beam_search.py @@ -1 +1 @@ -/k2-dev/yangyifan/icefall-bengaliai/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file +../../../librispeech/ASR/zipformer/beam_search.py \ No newline at end of file diff --git a/egs/bengaliai_speech/ASR/zipformer/decode_stream.py b/egs/bengaliai_speech/ASR/zipformer/decode_stream.py index 9229f8c21..b8d8ddfc4 120000 --- a/egs/bengaliai_speech/ASR/zipformer/decode_stream.py +++ b/egs/bengaliai_speech/ASR/zipformer/decode_stream.py @@ -1 +1 @@ -/k2-dev/yangyifan/icefall-bengaliai/egs/librispeech/ASR/zipformer/decode_stream.py \ No newline at end of file +../../../librispeech/ASR/zipformer/decode_stream.py \ No newline at end of file diff --git a/egs/bengaliai_speech/ASR/zipformer/decoder.py b/egs/bengaliai_speech/ASR/zipformer/decoder.py index acc59c1d6..5a8018680 120000 --- a/egs/bengaliai_speech/ASR/zipformer/decoder.py +++ b/egs/bengaliai_speech/ASR/zipformer/decoder.py @@ -1 +1 @@ -/k2-dev/yangyifan/icefall-bengaliai/egs/librispeech/ASR/zipformer/decoder.py \ No newline at end of file +../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/bengaliai_speech/ASR/zipformer/encoder_interface.py b/egs/bengaliai_speech/ASR/zipformer/encoder_interface.py index eca9b8c94..c2eaca671 120000 --- a/egs/bengaliai_speech/ASR/zipformer/encoder_interface.py +++ b/egs/bengaliai_speech/ASR/zipformer/encoder_interface.py @@ -1 +1 @@ -/k2-dev/yangyifan/icefall-bengaliai/egs/librispeech/ASR/transducer_stateless/encoder_interface.py \ No newline at end of file +../../../librispeech/ASR/zipformer/encoder_interface.py \ No newline at end of file diff --git a/egs/bengaliai_speech/ASR/zipformer/jit_pretrained.py b/egs/bengaliai_speech/ASR/zipformer/jit_pretrained.py index 4c3d8c1e0..25108391f 120000 --- a/egs/bengaliai_speech/ASR/zipformer/jit_pretrained.py +++ b/egs/bengaliai_speech/ASR/zipformer/jit_pretrained.py @@ -1 +1 @@ -/k2-dev/yangyifan/icefall-bengaliai/egs/librispeech/ASR/zipformer/jit_pretrained.py \ No newline at end of file +../../../librispeech/ASR/zipformer/jit_pretrained.py \ No newline at end of file diff --git a/egs/bengaliai_speech/ASR/zipformer/jit_pretrained_ctc.py b/egs/bengaliai_speech/ASR/zipformer/jit_pretrained_ctc.py index 1bd6f1ded..9a8da5844 120000 --- a/egs/bengaliai_speech/ASR/zipformer/jit_pretrained_ctc.py +++ b/egs/bengaliai_speech/ASR/zipformer/jit_pretrained_ctc.py @@ -1 +1 @@ -/k2-dev/yangyifan/icefall-bengaliai/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py \ No newline at end of file +../../../librispeech/ASR/zipformer/jit_pretrained_ctc.py \ No newline at end of file diff --git a/egs/bengaliai_speech/ASR/zipformer/jit_pretrained_streaming.py b/egs/bengaliai_speech/ASR/zipformer/jit_pretrained_streaming.py index 987c78292..1962351e9 120000 --- a/egs/bengaliai_speech/ASR/zipformer/jit_pretrained_streaming.py +++ b/egs/bengaliai_speech/ASR/zipformer/jit_pretrained_streaming.py @@ -1 +1 @@ -/k2-dev/yangyifan/icefall-bengaliai/egs/librispeech/ASR/zipformer/jit_pretrained_streaming.py \ No newline at end of file +../../../librispeech/ASR/zipformer/jit_pretrained_streaming.py \ No newline at end of file diff --git a/egs/bengaliai_speech/ASR/zipformer/joiner.py b/egs/bengaliai_speech/ASR/zipformer/joiner.py index 3ba6f03be..5b8a36332 120000 --- a/egs/bengaliai_speech/ASR/zipformer/joiner.py +++ b/egs/bengaliai_speech/ASR/zipformer/joiner.py @@ -1 +1 @@ -/k2-dev/yangyifan/icefall-bengaliai/egs/librispeech/ASR/zipformer/joiner.py \ No newline at end of file +../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/bengaliai_speech/ASR/zipformer/onnx_check.py b/egs/bengaliai_speech/ASR/zipformer/onnx_check.py index 7f833b320..f3dd42004 120000 --- a/egs/bengaliai_speech/ASR/zipformer/onnx_check.py +++ b/egs/bengaliai_speech/ASR/zipformer/onnx_check.py @@ -1 +1 @@ -/k2-dev/yangyifan/icefall-bengaliai/egs/librispeech/ASR/zipformer/onnx_check.py \ No newline at end of file +../../../librispeech/ASR/zipformer/onnx_check.py \ No newline at end of file diff --git a/egs/bengaliai_speech/ASR/zipformer/onnx_decode.py b/egs/bengaliai_speech/ASR/zipformer/onnx_decode.py index a6c5fbb98..0573b88c5 120000 --- a/egs/bengaliai_speech/ASR/zipformer/onnx_decode.py +++ b/egs/bengaliai_speech/ASR/zipformer/onnx_decode.py @@ -1 +1 @@ -/k2-dev/yangyifan/icefall-bengaliai/egs/librispeech/ASR/zipformer/onnx_decode.py \ No newline at end of file +../../../librispeech/ASR/zipformer/onnx_decode.py \ No newline at end of file diff --git a/egs/bengaliai_speech/ASR/zipformer/onnx_pretrained-streaming.py b/egs/bengaliai_speech/ASR/zipformer/onnx_pretrained-streaming.py index 6e63d941f..cfea104c2 120000 --- a/egs/bengaliai_speech/ASR/zipformer/onnx_pretrained-streaming.py +++ b/egs/bengaliai_speech/ASR/zipformer/onnx_pretrained-streaming.py @@ -1 +1 @@ -/k2-dev/yangyifan/icefall-bengaliai/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py \ No newline at end of file +../../../librispeech/ASR/zipformer/onnx_pretrained-streaming.py \ No newline at end of file diff --git a/egs/bengaliai_speech/ASR/zipformer/onnx_pretrained.py b/egs/bengaliai_speech/ASR/zipformer/onnx_pretrained.py index eb2e2d7f2..8f32f4ee7 120000 --- a/egs/bengaliai_speech/ASR/zipformer/onnx_pretrained.py +++ b/egs/bengaliai_speech/ASR/zipformer/onnx_pretrained.py @@ -1 +1 @@ -/k2-dev/yangyifan/icefall-bengaliai/egs/librispeech/ASR/zipformer/onnx_pretrained.py \ No newline at end of file +../../../librispeech/ASR/zipformer/onnx_pretrained.py \ No newline at end of file diff --git a/egs/bengaliai_speech/ASR/zipformer/optim.py b/egs/bengaliai_speech/ASR/zipformer/optim.py index 334108434..5eaa3cffd 120000 --- a/egs/bengaliai_speech/ASR/zipformer/optim.py +++ b/egs/bengaliai_speech/ASR/zipformer/optim.py @@ -1 +1 @@ -/k2-dev/yangyifan/icefall-bengaliai/egs/librispeech/ASR/zipformer/optim.py \ No newline at end of file +../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/bengaliai_speech/ASR/zipformer/pretrained.py b/egs/bengaliai_speech/ASR/zipformer/pretrained.py index 839f64ff0..0bd71dde4 120000 --- a/egs/bengaliai_speech/ASR/zipformer/pretrained.py +++ b/egs/bengaliai_speech/ASR/zipformer/pretrained.py @@ -1 +1 @@ -/k2-dev/yangyifan/icefall-bengaliai/egs/librispeech/ASR/zipformer/pretrained.py \ No newline at end of file +../../../librispeech/ASR/zipformer/pretrained.py \ No newline at end of file diff --git a/egs/bengaliai_speech/ASR/zipformer/pretrained_ctc.py b/egs/bengaliai_speech/ASR/zipformer/pretrained_ctc.py index 0b53ba71a..c2f6f6fc3 120000 --- a/egs/bengaliai_speech/ASR/zipformer/pretrained_ctc.py +++ b/egs/bengaliai_speech/ASR/zipformer/pretrained_ctc.py @@ -1 +1 @@ -/k2-dev/yangyifan/icefall-bengaliai/egs/librispeech/ASR/zipformer/pretrained_ctc.py \ No newline at end of file +../../../librispeech/ASR/zipformer/pretrained_ctc.py \ No newline at end of file diff --git a/egs/bengaliai_speech/ASR/zipformer/profile.py b/egs/bengaliai_speech/ASR/zipformer/profile.py index 05fcd0d73..c93adbd14 120000 --- a/egs/bengaliai_speech/ASR/zipformer/profile.py +++ b/egs/bengaliai_speech/ASR/zipformer/profile.py @@ -1 +1 @@ -/k2-dev/yangyifan/icefall-bengaliai/egs/librispeech/ASR/zipformer/profile.py \ No newline at end of file +../../../librispeech/ASR/zipformer/profile.py \ No newline at end of file diff --git a/egs/bengaliai_speech/ASR/zipformer/scaling.py b/egs/bengaliai_speech/ASR/zipformer/scaling.py index a094239d4..6f398f431 120000 --- a/egs/bengaliai_speech/ASR/zipformer/scaling.py +++ b/egs/bengaliai_speech/ASR/zipformer/scaling.py @@ -1 +1 @@ -/k2-dev/yangyifan/icefall-bengaliai/egs/librispeech/ASR/zipformer/scaling.py \ No newline at end of file +../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/bengaliai_speech/ASR/zipformer/scaling_converter.py b/egs/bengaliai_speech/ASR/zipformer/scaling_converter.py index c3357f17a..b0ecee05e 120000 --- a/egs/bengaliai_speech/ASR/zipformer/scaling_converter.py +++ b/egs/bengaliai_speech/ASR/zipformer/scaling_converter.py @@ -1 +1 @@ -/k2-dev/yangyifan/icefall-bengaliai/egs/librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file +../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/bengaliai_speech/ASR/zipformer/subsampling.py b/egs/bengaliai_speech/ASR/zipformer/subsampling.py index 1588d5ce0..01ae9002c 120000 --- a/egs/bengaliai_speech/ASR/zipformer/subsampling.py +++ b/egs/bengaliai_speech/ASR/zipformer/subsampling.py @@ -1 +1 @@ -/k2-dev/yangyifan/icefall-bengaliai/egs/librispeech/ASR/zipformer/subsampling.py \ No newline at end of file +../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/bengaliai_speech/ASR/zipformer/train.py b/egs/bengaliai_speech/ASR/zipformer/train.py index 4ecfb5592..e736bb707 100755 --- a/egs/bengaliai_speech/ASR/zipformer/train.py +++ b/egs/bengaliai_speech/ASR/zipformer/train.py @@ -26,22 +26,20 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" # For non-streaming model training: ./zipformer/train.py \ --world-size 4 \ - --num-epochs 30 \ + --num-epochs 120 \ --start-epoch 1 \ --use-fp16 1 \ --exp-dir zipformer/exp \ - --full-libri 1 \ --max-duration 1000 # For streaming model training: ./zipformer/train.py \ --world-size 4 \ - --num-epochs 30 \ + --num-epochs 120 \ --start-epoch 1 \ --use-fp16 1 \ --exp-dir zipformer/exp \ --causal 1 \ - --full-libri 1 \ --max-duration 1000 It supports training with: diff --git a/egs/bengaliai_speech/ASR/zipformer/zipformer.py b/egs/bengaliai_speech/ASR/zipformer/zipformer.py index 9cfa01a75..23011dda7 120000 --- a/egs/bengaliai_speech/ASR/zipformer/zipformer.py +++ b/egs/bengaliai_speech/ASR/zipformer/zipformer.py @@ -1 +1 @@ -/k2-dev/yangyifan/icefall-bengaliai/egs/librispeech/ASR/zipformer/zipformer.py \ No newline at end of file +../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer/pretrained_ctc.py b/egs/librispeech/ASR/zipformer/pretrained_ctc.py deleted file mode 100755 index be239e9c3..000000000 --- a/egs/librispeech/ASR/zipformer/pretrained_ctc.py +++ /dev/null @@ -1,445 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads a checkpoint and uses it to decode waves. -You can generate the checkpoint with the following command: - -- For non-streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --use-ctc 1 \ - --tokens data/lang_bpe_500/tokens.txt \ - --epoch 30 \ - --avg 9 - -- For streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --use-ctc 1 \ - --causal 1 \ - --tokens data/lang_bpe_500/tokens.txt \ - --epoch 30 \ - --avg 9 - -Usage of this script: - -(1) ctc-decoding -./zipformer/pretrained_ctc.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --tokens data/lang_bpe_500/tokens.txt \ - --method ctc-decoding \ - --sample-rate 16000 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) 1best -./zipformer/pretrained_ctc.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --HLG data/lang_bpe_500/HLG.pt \ - --words-file data/lang_bpe_500/words.txt \ - --method 1best \ - --sample-rate 16000 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) nbest-rescoring -./zipformer/pretrained_ctc.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --HLG data/lang_bpe_500/HLG.pt \ - --words-file data/lang_bpe_500/words.txt \ - --G data/lm/G_4_gram.pt \ - --method nbest-rescoring \ - --sample-rate 16000 \ - /path/to/foo.wav \ - /path/to/bar.wav - - -(4) whole-lattice-rescoring -./zipformer/pretrained_ctc.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --HLG data/lang_bpe_500/HLG.pt \ - --words-file data/lang_bpe_500/words.txt \ - --G data/lm/G_4_gram.pt \ - --method whole-lattice-rescoring \ - --sample-rate 16000 \ - /path/to/foo.wav \ - /path/to/bar.wav -""" - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import torch -import torchaudio -from ctc_decode import get_decoding_params -from export import num_tokens -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_model, get_params - -from icefall.decode import ( - get_lattice, - one_best_decoding, - rescore_with_n_best_list, - rescore_with_whole_lattice, -) -from icefall.utils import get_texts - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - - parser.add_argument( - "--words-file", - type=str, - help="""Path to words.txt. - Used only when method is not ctc-decoding. - """, - ) - - parser.add_argument( - "--HLG", - type=str, - help="""Path to HLG.pt. - Used only when method is not ctc-decoding. - """, - ) - - parser.add_argument( - "--tokens", - type=str, - help="""Path to tokens.txt. - Used only when method is ctc-decoding. - """, - ) - - parser.add_argument( - "--method", - type=str, - default="1best", - help="""Decoding method. - Possible values are: - (0) ctc-decoding - Use CTC decoding. It uses a token table, - i.e., lang_dir/tokens.txt, to convert - word pieces to words. It needs neither a lexicon - nor an n-gram LM. - (1) 1best - Use the best path as decoding output. Only - the transformer encoder output is used for decoding. - We call it HLG decoding. - (2) nbest-rescoring. Extract n paths from the decoding lattice, - rescore them with an LM, the path with - the highest score is the decoding result. - We call it HLG decoding + nbest n-gram LM rescoring. - (3) whole-lattice-rescoring - Use an LM to rescore the - decoding lattice and then use 1best to decode the - rescored lattice. - We call it HLG decoding + whole-lattice n-gram LM rescoring. - """, - ) - - parser.add_argument( - "--G", - type=str, - help="""An LM for rescoring. - Used only when method is - whole-lattice-rescoring or nbest-rescoring. - It's usually a 4-gram LM. - """, - ) - - parser.add_argument( - "--num-paths", - type=int, - default=100, - help=""" - Used only when method is attention-decoder. - It specifies the size of n-best list.""", - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=1.3, - help=""" - Used only when method is whole-lattice-rescoring and nbest-rescoring. - It specifies the scale for n-gram LM scores. - (Note: You need to tune it on a dataset.) - """, - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=1.0, - help=""" - Used only when method is nbest-rescoring. - It specifies the scale for lattice.scores when - extracting n-best lists. A smaller value results in - more unique number of paths with the risk of missing - the best path. - """, - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - add_model_arguments(parser) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float = 16000 -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" - ) - # We use only the first channel - ans.append(wave[0].contiguous()) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - # add decoding params - params.update(get_decoding_params()) - params.update(vars(args)) - - token_table = k2.SymbolTable.from_file(params.tokens) - params.vocab_size = num_tokens(token_table) - params.blank_id = token_table[""] - assert params.blank_id == 0 - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = get_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths) - ctc_output = model.ctc_output(encoder_out) # (N, T, C) - - batch_size = ctc_output.shape[0] - supervision_segments = torch.tensor( - [ - [i, 0, feature_lengths[i].item() // params.subsampling_factor] - for i in range(batch_size) - ], - dtype=torch.int32, - ) - - if params.method == "ctc-decoding": - logging.info("Use CTC decoding") - max_token_id = params.vocab_size - 1 - - H = k2.ctc_topo( - max_token=max_token_id, - modified=False, - device=device, - ) - - lattice = get_lattice( - nnet_output=ctc_output, - decoding_graph=H, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) - - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - token_ids = get_texts(best_path) - hyps = [[token_table[i] for i in ids] for ids in token_ids] - elif params.method in [ - "1best", - "nbest-rescoring", - "whole-lattice-rescoring", - ]: - logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) - HLG = HLG.to(device) - if not hasattr(HLG, "lm_scores"): - # For whole-lattice-rescoring and attention-decoder - HLG.lm_scores = HLG.scores.clone() - - if params.method in [ - "nbest-rescoring", - "whole-lattice-rescoring", - ]: - logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) - G = G.to(device) - if params.method == "whole-lattice-rescoring": - # Add epsilon self-loops to G as we will compose - # it with the whole lattice later - G = k2.add_epsilon_self_loops(G) - G = k2.arc_sort(G) - - # G.lm_scores is used to replace HLG.lm_scores during - # LM rescoring. - G.lm_scores = G.scores.clone() - - lattice = get_lattice( - nnet_output=ctc_output, - decoding_graph=HLG, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) - - if params.method == "1best": - logging.info("Use HLG decoding") - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - if params.method == "nbest-rescoring": - logging.info("Use HLG decoding + LM rescoring") - best_path_dict = rescore_with_n_best_list( - lattice=lattice, - G=G, - num_paths=params.num_paths, - lm_scale_list=[params.ngram_lm_scale], - nbest_scale=params.nbest_scale, - ) - best_path = next(iter(best_path_dict.values())) - elif params.method == "whole-lattice-rescoring": - logging.info("Use HLG decoding + LM rescoring") - best_path_dict = rescore_with_whole_lattice( - lattice=lattice, - G_with_epsilon_loops=G, - lm_scale_list=[params.ngram_lm_scale], - ) - best_path = next(iter(best_path_dict.values())) - - hyps = get_texts(best_path) - word_sym_table = k2.SymbolTable.from_file(params.words_file) - hyps = [[word_sym_table[i] for i in ids] for ids in hyps] - else: - raise ValueError(f"Unsupported decoding method: {params.method}") - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - words = words.replace("▁", " ").strip() - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/librispeech/ASR/zipformer/pretrained_ctc.py b/egs/librispeech/ASR/zipformer/pretrained_ctc.py new file mode 120000 index 000000000..0bd71dde4 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/pretrained_ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer/scaling_converter.py b/egs/librispeech/ASR/zipformer/scaling_converter.py deleted file mode 100644 index 76622fa12..000000000 --- a/egs/librispeech/ASR/zipformer/scaling_converter.py +++ /dev/null @@ -1,104 +0,0 @@ -# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file replaces various modules in a model. -Specifically, ActivationBalancer is replaced with an identity operator; -Whiten is also replaced with an identity operator; -BasicNorm is replaced by a module with `exp` removed. -""" - -import copy -from typing import List, Tuple - -import torch -import torch.nn as nn -from scaling import ( - Balancer, - Dropout3, - ScaleGrad, - SwooshL, - SwooshLOnnx, - SwooshR, - SwooshROnnx, - Whiten, -) -from zipformer import CompactRelPositionalEncoding - - -# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa -# get_submodule was added to nn.Module at v1.9.0 -def get_submodule(model, target): - if target == "": - return model - atoms: List[str] = target.split(".") - mod: torch.nn.Module = model - for item in atoms: - if not hasattr(mod, item): - raise AttributeError( - mod._get_name() + " has no " "attribute `" + item + "`" - ) - mod = getattr(mod, item) - if not isinstance(mod, torch.nn.Module): - raise AttributeError("`" + item + "` is not " "an nn.Module") - return mod - - -def convert_scaled_to_non_scaled( - model: nn.Module, - inplace: bool = False, - is_pnnx: bool = False, - is_onnx: bool = False, -): - """ - Args: - model: - The model to be converted. - inplace: - If True, the input model is modified inplace. - If False, the input model is copied and we modify the copied version. - is_pnnx: - True if we are going to export the model for PNNX. - is_onnx: - True if we are going to export the model for ONNX. - Return: - Return a model without scaled layers. - """ - if not inplace: - model = copy.deepcopy(model) - - d = {} - for name, m in model.named_modules(): - if isinstance(m, (Balancer, Dropout3, ScaleGrad, Whiten)): - d[name] = nn.Identity() - elif is_onnx and isinstance(m, SwooshR): - d[name] = SwooshROnnx() - elif is_onnx and isinstance(m, SwooshL): - d[name] = SwooshLOnnx() - elif is_onnx and isinstance(m, CompactRelPositionalEncoding): - # We want to recreate the positional encoding vector when - # the input changes, so we have to use torch.jit.script() - # to replace torch.jit.trace() - d[name] = torch.jit.script(m) - - for k, v in d.items(): - if "." in k: - parent, child = k.rsplit(".", maxsplit=1) - setattr(get_submodule(model, parent), child, v) - else: - setattr(model, k, v) - - return model diff --git a/egs/librispeech/ASR/zipformer/scaling_converter.py b/egs/librispeech/ASR/zipformer/scaling_converter.py new file mode 120000 index 000000000..6f398f431 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer/streaming_decode.py b/egs/librispeech/ASR/zipformer/streaming_decode.py deleted file mode 100755 index 44ff392a3..000000000 --- a/egs/librispeech/ASR/zipformer/streaming_decode.py +++ /dev/null @@ -1,876 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang, -# Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: -./zipformer/streaming_decode.py \ - --epoch 28 \ - --avg 15 \ - --causal 1 \ - --chunk-size 32 \ - --left-context-frames 256 \ - --exp-dir ./zipformer/exp \ - --decoding-method greedy_search \ - --num-decode-streams 2000 -""" - -import argparse -import logging -import math -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import numpy as np -import sentencepiece as spm -import torch -from asr_datamodule import LibriSpeechAsrDataModule -from decode_stream import DecodeStream -from kaldifeat import Fbank, FbankOptions -from lhotse import CutSet -from streaming_beam_search import ( - fast_beam_search_one_best, - greedy_search, - modified_beam_search, -) -from torch import Tensor, nn -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import ( - AttributeDict, - make_pad_mask, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Supported decoding methods are: - greedy_search - modified_beam_search - fast_beam_search - """, - ) - - parser.add_argument( - "--num_active_paths", - type=int, - default=4, - help="""An interger indicating how many candidates we will keep for each - frame. Used only when --decoding-method is modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=32, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--num-decode-streams", - type=int, - default=2000, - help="The number of streams that can be decoded parallel.", - ) - - add_model_arguments(parser) - - return parser - - -def get_init_states( - model: nn.Module, - batch_size: int = 1, - device: torch.device = torch.device("cpu"), -) -> List[torch.Tensor]: - """ - Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] - is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - states[-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - states[-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - """ - states = model.encoder.get_init_states(batch_size, device) - - embed_states = model.encoder_embed.get_init_states(batch_size, device) - states.append(embed_states) - - processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) - states.append(processed_lens) - - return states - - -def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: - """Stack list of zipformer states that correspond to separate utterances - into a single emformer state, so that it can be used as an input for - zipformer when those utterances are formed into a batch. - - Args: - state_list: - Each element in state_list corresponding to the internal state - of the zipformer model for a single utterance. For element-n, - state_list[n] is a list of cached tensors of all encoder layers. For layer-i, - state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, - cached_val2, cached_conv1, cached_conv2). - state_list[n][-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - state_list[n][-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - - Note: - It is the inverse of :func:`unstack_states`. - """ - batch_size = len(state_list) - assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0]) - tot_num_layers = (len(state_list[0]) - 2) // 6 - - batch_states = [] - for layer in range(tot_num_layers): - layer_offset = layer * 6 - # cached_key: (left_context_len, batch_size, key_dim) - cached_key = torch.cat( - [state_list[i][layer_offset] for i in range(batch_size)], dim=1 - ) - # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) - cached_nonlin_attn = torch.cat( - [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 - ) - # cached_val1: (left_context_len, batch_size, value_dim) - cached_val1 = torch.cat( - [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1 - ) - # cached_val2: (left_context_len, batch_size, value_dim) - cached_val2 = torch.cat( - [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1 - ) - # cached_conv1: (#batch, channels, left_pad) - cached_conv1 = torch.cat( - [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 - ) - # cached_conv2: (#batch, channels, left_pad) - cached_conv2 = torch.cat( - [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0 - ) - batch_states += [ - cached_key, - cached_nonlin_attn, - cached_val1, - cached_val2, - cached_conv1, - cached_conv2, - ] - - cached_embed_left_pad = torch.cat( - [state_list[i][-2] for i in range(batch_size)], dim=0 - ) - batch_states.append(cached_embed_left_pad) - - processed_lens = torch.cat( - [state_list[i][-1] for i in range(batch_size)], dim=0 - ) - batch_states.append(processed_lens) - - return batch_states - - -def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: - """Unstack the zipformer state corresponding to a batch of utterances - into a list of states, where the i-th entry is the state from the i-th - utterance in the batch. - - Note: - It is the inverse of :func:`stack_states`. - - Args: - batch_states: A list of cached tensors of all encoder layers. For layer-i, - states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, - cached_conv1, cached_conv2). - state_list[-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - states[-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - - Returns: - state_list: A list of list. Each element in state_list corresponding to the internal state - of the zipformer model for a single utterance. - """ - assert (len(batch_states) - 2) % 6 == 0, len(batch_states) - tot_num_layers = (len(batch_states) - 2) // 6 - - processed_lens = batch_states[-1] - batch_size = processed_lens.shape[0] - - state_list = [[] for _ in range(batch_size)] - - for layer in range(tot_num_layers): - layer_offset = layer * 6 - # cached_key: (left_context_len, batch_size, key_dim) - cached_key_list = batch_states[layer_offset].chunk( - chunks=batch_size, dim=1 - ) - # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) - cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( - chunks=batch_size, dim=1 - ) - # cached_val1: (left_context_len, batch_size, value_dim) - cached_val1_list = batch_states[layer_offset + 2].chunk( - chunks=batch_size, dim=1 - ) - # cached_val2: (left_context_len, batch_size, value_dim) - cached_val2_list = batch_states[layer_offset + 3].chunk( - chunks=batch_size, dim=1 - ) - # cached_conv1: (#batch, channels, left_pad) - cached_conv1_list = batch_states[layer_offset + 4].chunk( - chunks=batch_size, dim=0 - ) - # cached_conv2: (#batch, channels, left_pad) - cached_conv2_list = batch_states[layer_offset + 5].chunk( - chunks=batch_size, dim=0 - ) - for i in range(batch_size): - state_list[i] += [ - cached_key_list[i], - cached_nonlin_attn_list[i], - cached_val1_list[i], - cached_val2_list[i], - cached_conv1_list[i], - cached_conv2_list[i], - ] - - cached_embed_left_pad_list = batch_states[-2].chunk( - chunks=batch_size, dim=0 - ) - for i in range(batch_size): - state_list[i].append(cached_embed_left_pad_list[i]) - - processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) - for i in range(batch_size): - state_list[i].append(processed_lens_list[i]) - - return state_list - - -def streaming_forward( - features: Tensor, - feature_lens: Tensor, - model: nn.Module, - states: List[Tensor], - chunk_size: int, - left_context_len: int, -) -> Tuple[Tensor, Tensor, List[Tensor]]: - """ - Returns encoder outputs, output lengths, and updated states. - """ - cached_embed_left_pad = states[-2] - ( - x, - x_lens, - new_cached_embed_left_pad, - ) = model.encoder_embed.streaming_forward( - x=features, - x_lens=feature_lens, - cached_left_pad=cached_embed_left_pad, - ) - assert x.size(1) == chunk_size, (x.size(1), chunk_size) - - src_key_padding_mask = make_pad_mask(x_lens) - - # processed_mask is used to mask out initial states - processed_mask = torch.arange(left_context_len, device=x.device).expand( - x.size(0), left_context_len - ) - processed_lens = states[-1] # (batch,) - # (batch, left_context_size) - processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) - # Update processed lengths - new_processed_lens = processed_lens + x_lens - - # (batch, left_context_size + chunk_size) - src_key_padding_mask = torch.cat( - [processed_mask, src_key_padding_mask], dim=1 - ) - - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_states = states[:-2] - ( - encoder_out, - encoder_out_lens, - new_encoder_states, - ) = model.encoder.streaming_forward( - x=x, - x_lens=x_lens, - states=encoder_states, - src_key_padding_mask=src_key_padding_mask, - ) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - new_states = new_encoder_states + [ - new_cached_embed_left_pad, - new_processed_lens, - ] - return encoder_out, encoder_out_lens, new_states - - -def decode_one_chunk( - params: AttributeDict, - model: nn.Module, - decode_streams: List[DecodeStream], -) -> List[int]: - """Decode one chunk frames of features for each decode_streams and - return the indexes of finished streams in a List. - - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - decode_streams: - A List of DecodeStream, each belonging to a utterance. - Returns: - Return a List containing which DecodeStreams are finished. - """ - device = model.device - chunk_size = int(params.chunk_size) - left_context_len = int(params.left_context_frames) - - features = [] - feature_lens = [] - states = [] - processed_lens = [] # Used in fast-beam-search - - for stream in decode_streams: - feat, feat_len = stream.get_feature_frames(chunk_size * 2) - features.append(feat) - feature_lens.append(feat_len) - states.append(stream.states) - processed_lens.append(stream.done_frames) - - feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) - - # Make sure the length after encoder_embed is at least 1. - # The encoder_embed subsample features (T - 7) // 2 - # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling - tail_length = chunk_size * 2 + 7 + 2 * 3 - if features.size(1) < tail_length: - pad_length = tail_length - features.size(1) - feature_lens += pad_length - features = torch.nn.functional.pad( - features, - (0, 0, 0, pad_length), - mode="constant", - value=LOG_EPS, - ) - - states = stack_states(states) - - encoder_out, encoder_out_lens, new_states = streaming_forward( - features=features, - feature_lens=feature_lens, - model=model, - states=states, - chunk_size=chunk_size, - left_context_len=left_context_len, - ) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) - elif params.decoding_method == "fast_beam_search": - processed_lens = torch.tensor(processed_lens, device=device) - processed_lens = processed_lens + encoder_out_lens - fast_beam_search_one_best( - model=model, - encoder_out=encoder_out, - processed_lens=processed_lens, - streams=decode_streams, - beam=params.beam, - max_states=params.max_states, - max_contexts=params.max_contexts, - ) - elif params.decoding_method == "modified_beam_search": - modified_beam_search( - model=model, - streams=decode_streams, - encoder_out=encoder_out, - num_active_paths=params.num_active_paths, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - - states = unstack_states(new_states) - - finished_streams = [] - for i in range(len(decode_streams)): - decode_streams[i].states = states[i] - decode_streams[i].done_frames += encoder_out_lens[i] - if decode_streams[i].done: - finished_streams.append(i) - - return finished_streams - - -def decode_dataset( - cuts: CutSet, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: - """Decode dataset. - - Args: - cuts: - Lhotse Cutset containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - device = model.device - - opts = FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - - log_interval = 100 - - decode_results = [] - # Contain decode streams currently running. - decode_streams = [] - for num, cut in enumerate(cuts): - # each utterance has a DecodeStream. - initial_states = get_init_states( - model=model, batch_size=1, device=device - ) - decode_stream = DecodeStream( - params=params, - cut_id=cut.id, - initial_states=initial_states, - decoding_graph=decoding_graph, - device=device, - ) - - audio: np.ndarray = cut.load_audio() - # audio.shape: (1, num_samples) - assert len(audio.shape) == 2 - assert audio.shape[0] == 1, "Should be single channel" - assert audio.dtype == np.float32, audio.dtype - - # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" - - samples = torch.from_numpy(audio).squeeze(0) - - fbank = Fbank(opts) - feature = fbank(samples.to(device)) - decode_stream.set_features(feature, tail_pad_len=30) - decode_stream.ground_truth = cut.supervisions[0].text - - decode_streams.append(decode_stream) - - while len(decode_streams) >= params.num_decode_streams: - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - decode_streams[i].ground_truth.split(), - sp.decode(decode_streams[i].decoding_result()).split(), - ) - ) - del decode_streams[i] - - if num % log_interval == 0: - logging.info(f"Cuts processed until now is {num}.") - - # decode final chunks of last sequences - while len(decode_streams): - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - decode_streams[i].ground_truth.split(), - sp.decode(decode_streams[i].decoding_result()).split(), - ) - ) - del decode_streams[i] - - if params.decoding_method == "greedy_search": - key = "greedy_search" - elif params.decoding_method == "fast_beam_search": - key = ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ) - elif params.decoding_method == "modified_beam_search": - key = f"num_active_paths_{params.num_active_paths}" - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - return {key: decode_results} - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - params.res_dir = params.exp_dir / "streaming" / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - assert params.causal, params.causal - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - - # for fast_beam_search - if params.decoding_method == "fast_beam_search": - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - model.device = device - - decoding_graph = None - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - librispeech = LibriSpeechAsrDataModule(args) - - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() - - test_sets = ["test-clean", "test-other"] - test_cuts = [test_clean_cuts, test_other_cuts] - - for test_set, test_cut in zip(test_sets, test_cuts): - results_dict = decode_dataset( - cuts=test_cut, - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/librispeech/ASR/zipformer/streaming_decode.py b/egs/librispeech/ASR/zipformer/streaming_decode.py new file mode 120000 index 000000000..b1ed54557 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/streaming_decode.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/streaming_beam_search.py \ No newline at end of file