minor updates

This commit is contained in:
JinZr 2023-08-09 16:27:36 +08:00
parent 97add79204
commit ba67942656
3 changed files with 24 additions and 24 deletions

View File

@ -102,7 +102,7 @@ class Aidatatang_200zhAsrDataModule:
group.add_argument( group.add_argument(
"--bucketing-sampler", "--bucketing-sampler",
type=str2bool, type=str2bool,
default=True, default=False,
help="When enabled, the batches will come from buckets of " help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).", "similar duration (saves padding frames).",
) )
@ -289,6 +289,7 @@ class Aidatatang_200zhAsrDataModule:
shuffle=self.args.shuffle, shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets, num_buckets=self.args.num_buckets,
drop_last=True, drop_last=True,
buffer_size=50000,
) )
else: else:
logging.info("Using SingleCutSampler.") logging.info("Using SingleCutSampler.")

View File

@ -88,7 +88,7 @@ from typing import Dict, List, Optional, Tuple
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import WenetSpeechAsrDataModule from asr_datamodule import Aidatatang_200zhAsrDataModule
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search_nbest, fast_beam_search_nbest,
@ -596,7 +596,7 @@ def save_results(
@torch.no_grad() @torch.no_grad()
def main(): def main():
parser = get_parser() parser = get_parser()
WenetSpeechAsrDataModule.add_arguments(parser) Aidatatang_200zhAsrDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
@ -770,7 +770,7 @@ def main():
# we need cut ids to display recognition results. # we need cut ids to display recognition results.
args.return_cuts = True args.return_cuts = True
wenetspeech = WenetSpeechAsrDataModule(args) aidatatang_200zh = Aidatatang_200zhAsrDataModule(args)
def remove_short_utt(c: Cut): def remove_short_utt(c: Cut):
T = ((c.num_frames - 7) // 2 + 1) // 2 T = ((c.num_frames - 7) // 2 + 1) // 2
@ -780,20 +780,16 @@ def main():
) )
return T > 0 return T > 0
dev_cuts = wenetspeech.valid_cuts() dev_cuts = aidatatang_200zh.valid_cuts()
dev_cuts = dev_cuts.filter(remove_short_utt) dev_cuts = dev_cuts.filter(remove_short_utt)
dev_dl = wenetspeech.valid_dataloaders(dev_cuts) dev_dl = aidatatang_200zh.valid_dataloaders(dev_cuts)
test_net_cuts = wenetspeech.test_net_cuts() test_cuts = aidatatang_200zh.test_cuts()
test_net_cuts = test_net_cuts.filter(remove_short_utt) test_cuts = test_cuts.filter(remove_short_utt)
test_net_dl = wenetspeech.test_dataloaders(test_net_cuts) test_dl = aidatatang_200zh.test_dataloaders(test_cuts)
test_meeting_cuts = wenetspeech.test_meeting_cuts() test_sets = ["valid_cuts", "test_cuts"]
test_meeting_cuts = test_meeting_cuts.filter(remove_short_utt) test_dls = [dev_dl, test_dl]
test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
test_dls = [dev_dl, test_net_dl, test_meeting_dl]
for test_set, test_dl in zip(test_sets, test_dls): for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset( results_dict = decode_dataset(

View File

@ -39,7 +39,7 @@ from typing import Dict, List, Optional, Tuple
import k2 import k2
import numpy as np import numpy as np
import torch import torch
from asr_datamodule import WenetSpeechAsrDataModule from asr_datamodule import Aidatatang_200zhAsrDataModule
from decode_stream import DecodeStream from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet from lhotse import CutSet
@ -386,7 +386,11 @@ def streaming_forward(
Returns encoder outputs, output lengths, and updated states. Returns encoder outputs, output lengths, and updated states.
""" """
cached_embed_left_pad = states[-2] cached_embed_left_pad = states[-2]
(x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward( (
x,
x_lens,
new_cached_embed_left_pad,
) = model.encoder_embed.streaming_forward(
x=features, x=features,
x_lens=feature_lens, x_lens=feature_lens,
cached_left_pad=cached_embed_left_pad, cached_left_pad=cached_embed_left_pad,
@ -713,7 +717,7 @@ def save_results(
@torch.no_grad() @torch.no_grad()
def main(): def main():
parser = get_parser() parser = get_parser()
WenetSpeechAsrDataModule.add_arguments(parser) Aidatatang_200zhAsrDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
@ -851,14 +855,13 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
wenetspeech = WenetSpeechAsrDataModule(args) aidatatang_200zh = Aidatatang_200zhAsrDataModule(args)
dev_cuts = wenetspeech.valid_cuts() dev_cuts = aidatatang_200zh.valid_cuts()
test_net_cuts = wenetspeech.test_net_cuts() test_cuts = aidatatang_200zh.test_cuts()
test_meeting_cuts = wenetspeech.test_meeting_cuts()
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] test_sets = ["valid_cuts", "test_cuts"]
test_cuts = [dev_cuts, test_net_cuts, test_meeting_cuts] test_cuts = [dev_cuts, test_cuts]
for test_set, test_cut in zip(test_sets, test_cuts): for test_set, test_cut in zip(test_sets, test_cuts):
results_dict = decode_dataset( results_dict = decode_dataset(