From 68c5619608961ded8dbcc4898b68580a336c2301 Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Sun, 13 Aug 2023 00:51:27 +0800 Subject: [PATCH] minor fixes for aidatatang_200zh zipformer recipe --- .../ASR/zipformer/onnx_decode.py | 24 ++++++++----------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/egs/aidatatang_200zh/ASR/zipformer/onnx_decode.py b/egs/aidatatang_200zh/ASR/zipformer/onnx_decode.py index 4bd88e546..6cd91686e 100755 --- a/egs/aidatatang_200zh/ASR/zipformer/onnx_decode.py +++ b/egs/aidatatang_200zh/ASR/zipformer/onnx_decode.py @@ -75,7 +75,7 @@ from typing import List, Tuple import k2 import torch import torch.nn as nn -from asr_datamodule import WenetSpeechAsrDataModule +from asr_datamodule import Aidatatang_200zhAsrDataModule from lhotse.cut import Cut from onnx_pretrained import OnnxModel, greedy_search @@ -256,7 +256,7 @@ def save_results( @torch.no_grad() def main(): parser = get_parser() - WenetSpeechAsrDataModule.add_arguments(parser) + Aidatatang_200zhAsrDataModule.add_arguments(parser) args = parser.parse_args() assert ( @@ -285,7 +285,7 @@ def main(): # we need cut ids to display recognition results. args.return_cuts = True - wenetspeech = WenetSpeechAsrDataModule(args) + aidatatang_200zh = Aidatatang_200zhAsrDataModule(args) def remove_short_utt(c: Cut): T = ((c.num_frames - 7) // 2 + 1) // 2 @@ -295,20 +295,16 @@ def main(): ) return T > 0 - dev_cuts = wenetspeech.valid_cuts() + dev_cuts = aidatatang_200zh.valid_cuts() dev_cuts = dev_cuts.filter(remove_short_utt) - dev_dl = wenetspeech.valid_dataloaders(dev_cuts) + dev_dl = aidatatang_200zh.valid_dataloaders(dev_cuts) - test_net_cuts = wenetspeech.test_net_cuts() - test_net_cuts = test_net_cuts.filter(remove_short_utt) - test_net_dl = wenetspeech.test_dataloaders(test_net_cuts) + test_cuts = aidatatang_200zh.test_cuts() + test_cuts = test_cuts.filter(remove_short_utt) + test_dl = aidatatang_200zh.test_dataloaders(test_cuts) - test_meeting_cuts = wenetspeech.test_meeting_cuts() - test_meeting_cuts = test_meeting_cuts.filter(remove_short_utt) - test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts) - - test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] - test_dl = [dev_dl, test_net_dl, test_meeting_dl] + test_sets = ["dev", "test"] + test_dl = [dev_dl, test_dl] for test_set, test_dl in zip(test_sets, test_dl): start_time = time.time()