mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-12 18:44:20 +00:00
minor fixes for aidatatang_200zh zipformer recipe
This commit is contained in:
parent
94eedab88e
commit
68c5619608
@ -75,7 +75,7 @@ from typing import List, Tuple
|
|||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import WenetSpeechAsrDataModule
|
from asr_datamodule import Aidatatang_200zhAsrDataModule
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from onnx_pretrained import OnnxModel, greedy_search
|
from onnx_pretrained import OnnxModel, greedy_search
|
||||||
|
|
||||||
@ -256,7 +256,7 @@ def save_results(
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
WenetSpeechAsrDataModule.add_arguments(parser)
|
Aidatatang_200zhAsrDataModule.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
@ -285,7 +285,7 @@ def main():
|
|||||||
# we need cut ids to display recognition results.
|
# we need cut ids to display recognition results.
|
||||||
args.return_cuts = True
|
args.return_cuts = True
|
||||||
|
|
||||||
wenetspeech = WenetSpeechAsrDataModule(args)
|
aidatatang_200zh = Aidatatang_200zhAsrDataModule(args)
|
||||||
|
|
||||||
def remove_short_utt(c: Cut):
|
def remove_short_utt(c: Cut):
|
||||||
T = ((c.num_frames - 7) // 2 + 1) // 2
|
T = ((c.num_frames - 7) // 2 + 1) // 2
|
||||||
@ -295,20 +295,16 @@ def main():
|
|||||||
)
|
)
|
||||||
return T > 0
|
return T > 0
|
||||||
|
|
||||||
dev_cuts = wenetspeech.valid_cuts()
|
dev_cuts = aidatatang_200zh.valid_cuts()
|
||||||
dev_cuts = dev_cuts.filter(remove_short_utt)
|
dev_cuts = dev_cuts.filter(remove_short_utt)
|
||||||
dev_dl = wenetspeech.valid_dataloaders(dev_cuts)
|
dev_dl = aidatatang_200zh.valid_dataloaders(dev_cuts)
|
||||||
|
|
||||||
test_net_cuts = wenetspeech.test_net_cuts()
|
test_cuts = aidatatang_200zh.test_cuts()
|
||||||
test_net_cuts = test_net_cuts.filter(remove_short_utt)
|
test_cuts = test_cuts.filter(remove_short_utt)
|
||||||
test_net_dl = wenetspeech.test_dataloaders(test_net_cuts)
|
test_dl = aidatatang_200zh.test_dataloaders(test_cuts)
|
||||||
|
|
||||||
test_meeting_cuts = wenetspeech.test_meeting_cuts()
|
test_sets = ["dev", "test"]
|
||||||
test_meeting_cuts = test_meeting_cuts.filter(remove_short_utt)
|
test_dl = [dev_dl, test_dl]
|
||||||
test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
|
|
||||||
|
|
||||||
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
|
|
||||||
test_dl = [dev_dl, test_net_dl, test_meeting_dl]
|
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dl):
|
for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user