diff --git a/egs/mdcc/ASR/zipformer/decode_stream.py b/egs/mdcc/ASR/zipformer/decode_stream.py new file mode 120000 index 000000000..b8d8ddfc4 --- /dev/null +++ b/egs/mdcc/ASR/zipformer/decode_stream.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decode_stream.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/onnx_decode.py b/egs/mdcc/ASR/zipformer/onnx_decode.py index 17c6eceb4..1ed4a9fa1 100755 --- a/egs/mdcc/ASR/zipformer/onnx_decode.py +++ b/egs/mdcc/ASR/zipformer/onnx_decode.py @@ -31,7 +31,7 @@ from typing import List, Tuple import k2 import torch import torch.nn as nn -from asr_datamodule import AishellAsrDataModule +from asr_datamodule import MdccAsrDataModule from lhotse.cut import Cut from onnx_pretrained import OnnxModel, greedy_search @@ -212,7 +212,7 @@ def save_results( @torch.no_grad() def main(): parser = get_parser() - AishellAsrDataModule.add_arguments(parser) + MdccAsrDataModule.add_arguments(parser) args = parser.parse_args() assert ( @@ -241,7 +241,7 @@ def main(): # we need cut ids to display recognition results. args.return_cuts = True - aishell = AishellAsrDataModule(args) + mdcc = MdccAsrDataModule(args) def remove_short_utt(c: Cut): T = ((c.num_frames - 7) // 2 + 1) // 2 @@ -251,16 +251,16 @@ def main(): ) return T > 0 - dev_cuts = aishell.valid_cuts() - dev_cuts = dev_cuts.filter(remove_short_utt) - dev_dl = aishell.valid_dataloaders(dev_cuts) + valid_cuts = mdcc.valid_cuts() + valid_cuts = valid_cuts.filter(remove_short_utt) + valid_dl = mdcc.valid_dataloaders(valid_cuts) - test_cuts = aishell.test_net_cuts() + test_cuts = mdcc.test_net_cuts() test_cuts = test_cuts.filter(remove_short_utt) - test_dl = aishell.test_dataloaders(test_cuts) + test_dl = mdcc.test_dataloaders(test_cuts) - test_sets = ["dev", "test"] - test_dl = [dev_dl, test_dl] + test_sets = ["valid", "test"] + test_dl = [valid_dl, test_dl] for test_set, test_dl in zip(test_sets, test_dl): start_time = time.time() diff --git a/egs/mdcc/ASR/zipformer/streaming_beam_search.py b/egs/mdcc/ASR/zipformer/streaming_beam_search.py new file mode 120000 index 000000000..b1ed54557 --- /dev/null +++ b/egs/mdcc/ASR/zipformer/streaming_beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/streaming_beam_search.py \ No newline at end of file diff --git a/egs/mdcc/ASR/zipformer/streaming_decode.py b/egs/mdcc/ASR/zipformer/streaming_decode.py index 6a7ef2750..fc457e823 100755 --- a/egs/mdcc/ASR/zipformer/streaming_decode.py +++ b/egs/mdcc/ASR/zipformer/streaming_decode.py @@ -39,7 +39,7 @@ from typing import Dict, List, Optional, Tuple import k2 import numpy as np import torch -from asr_datamodule import AishellAsrDataModule +from asr_datamodule import MdccAsrDataModule from decode_stream import DecodeStream from kaldifeat import Fbank, FbankOptions from lhotse import CutSet @@ -177,7 +177,7 @@ def get_parser(): parser.add_argument( "--context-size", type=int, - default=2, + default=1, help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) @@ -386,7 +386,11 @@ def streaming_forward( Returns encoder outputs, output lengths, and updated states. """ cached_embed_left_pad = states[-2] - (x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward( + ( + x, + x_lens, + new_cached_embed_left_pad, + ) = model.encoder_embed.streaming_forward( x=features, x_lens=feature_lens, cached_left_pad=cached_embed_left_pad, @@ -714,7 +718,7 @@ def save_results( @torch.no_grad() def main(): parser = get_parser() - AishellAsrDataModule.add_arguments(parser) + MdccAsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -852,13 +856,13 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - aishell = AishellAsrDataModule(args) + mdcc = MdccAsrDataModule(args) - dev_cuts = aishell.valid_cuts() - test_cuts = aishell.test_cuts() + valid_cuts = mdcc.valid_cuts() + test_cuts = mdcc.test_cuts() - test_sets = ["dev", "test"] - test_cuts = [dev_cuts, test_cuts] + test_sets = ["valid", "test"] + test_cuts = [valid_cuts, test_cuts] for test_set, test_cut in zip(test_sets, test_cuts): results_dict = decode_dataset(