minor fixes for aidatatang_200zh zipformer recipe

This commit is contained in:
jinzr 2023-08-13 00:51:27 +08:00
parent 94eedab88e
commit 68c5619608

View File

@ -75,7 +75,7 @@ from typing import List, Tuple
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import WenetSpeechAsrDataModule from asr_datamodule import Aidatatang_200zhAsrDataModule
from lhotse.cut import Cut from lhotse.cut import Cut
from onnx_pretrained import OnnxModel, greedy_search from onnx_pretrained import OnnxModel, greedy_search
@ -256,7 +256,7 @@ def save_results(
@torch.no_grad() @torch.no_grad()
def main(): def main():
parser = get_parser() parser = get_parser()
WenetSpeechAsrDataModule.add_arguments(parser) Aidatatang_200zhAsrDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
assert ( assert (
@ -285,7 +285,7 @@ def main():
# we need cut ids to display recognition results. # we need cut ids to display recognition results.
args.return_cuts = True args.return_cuts = True
wenetspeech = WenetSpeechAsrDataModule(args) aidatatang_200zh = Aidatatang_200zhAsrDataModule(args)
def remove_short_utt(c: Cut): def remove_short_utt(c: Cut):
T = ((c.num_frames - 7) // 2 + 1) // 2 T = ((c.num_frames - 7) // 2 + 1) // 2
@ -295,20 +295,16 @@ def main():
) )
return T > 0 return T > 0
dev_cuts = wenetspeech.valid_cuts() dev_cuts = aidatatang_200zh.valid_cuts()
dev_cuts = dev_cuts.filter(remove_short_utt) dev_cuts = dev_cuts.filter(remove_short_utt)
dev_dl = wenetspeech.valid_dataloaders(dev_cuts) dev_dl = aidatatang_200zh.valid_dataloaders(dev_cuts)
test_net_cuts = wenetspeech.test_net_cuts() test_cuts = aidatatang_200zh.test_cuts()
test_net_cuts = test_net_cuts.filter(remove_short_utt) test_cuts = test_cuts.filter(remove_short_utt)
test_net_dl = wenetspeech.test_dataloaders(test_net_cuts) test_dl = aidatatang_200zh.test_dataloaders(test_cuts)
test_meeting_cuts = wenetspeech.test_meeting_cuts() test_sets = ["dev", "test"]
test_meeting_cuts = test_meeting_cuts.filter(remove_short_utt) test_dl = [dev_dl, test_dl]
test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
test_dl = [dev_dl, test_net_dl, test_meeting_dl]
for test_set, test_dl in zip(test_sets, test_dl): for test_set, test_dl in zip(test_sets, test_dl):
start_time = time.time() start_time = time.time()