mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 16:44:20 +00:00
minor updates
This commit is contained in:
parent
97add79204
commit
ba67942656
@ -102,7 +102,7 @@ class Aidatatang_200zhAsrDataModule:
|
|||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--bucketing-sampler",
|
"--bucketing-sampler",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=False,
|
||||||
help="When enabled, the batches will come from buckets of "
|
help="When enabled, the batches will come from buckets of "
|
||||||
"similar duration (saves padding frames).",
|
"similar duration (saves padding frames).",
|
||||||
)
|
)
|
||||||
@ -289,6 +289,7 @@ class Aidatatang_200zhAsrDataModule:
|
|||||||
shuffle=self.args.shuffle,
|
shuffle=self.args.shuffle,
|
||||||
num_buckets=self.args.num_buckets,
|
num_buckets=self.args.num_buckets,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
|
buffer_size=50000,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Using SingleCutSampler.")
|
logging.info("Using SingleCutSampler.")
|
||||||
|
@ -88,7 +88,7 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import WenetSpeechAsrDataModule
|
from asr_datamodule import Aidatatang_200zhAsrDataModule
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
fast_beam_search_nbest,
|
fast_beam_search_nbest,
|
||||||
@ -596,7 +596,7 @@ def save_results(
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
WenetSpeechAsrDataModule.add_arguments(parser)
|
Aidatatang_200zhAsrDataModule.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
@ -770,7 +770,7 @@ def main():
|
|||||||
|
|
||||||
# we need cut ids to display recognition results.
|
# we need cut ids to display recognition results.
|
||||||
args.return_cuts = True
|
args.return_cuts = True
|
||||||
wenetspeech = WenetSpeechAsrDataModule(args)
|
aidatatang_200zh = Aidatatang_200zhAsrDataModule(args)
|
||||||
|
|
||||||
def remove_short_utt(c: Cut):
|
def remove_short_utt(c: Cut):
|
||||||
T = ((c.num_frames - 7) // 2 + 1) // 2
|
T = ((c.num_frames - 7) // 2 + 1) // 2
|
||||||
@ -780,20 +780,16 @@ def main():
|
|||||||
)
|
)
|
||||||
return T > 0
|
return T > 0
|
||||||
|
|
||||||
dev_cuts = wenetspeech.valid_cuts()
|
dev_cuts = aidatatang_200zh.valid_cuts()
|
||||||
dev_cuts = dev_cuts.filter(remove_short_utt)
|
dev_cuts = dev_cuts.filter(remove_short_utt)
|
||||||
dev_dl = wenetspeech.valid_dataloaders(dev_cuts)
|
dev_dl = aidatatang_200zh.valid_dataloaders(dev_cuts)
|
||||||
|
|
||||||
test_net_cuts = wenetspeech.test_net_cuts()
|
test_cuts = aidatatang_200zh.test_cuts()
|
||||||
test_net_cuts = test_net_cuts.filter(remove_short_utt)
|
test_cuts = test_cuts.filter(remove_short_utt)
|
||||||
test_net_dl = wenetspeech.test_dataloaders(test_net_cuts)
|
test_dl = aidatatang_200zh.test_dataloaders(test_cuts)
|
||||||
|
|
||||||
test_meeting_cuts = wenetspeech.test_meeting_cuts()
|
test_sets = ["valid_cuts", "test_cuts"]
|
||||||
test_meeting_cuts = test_meeting_cuts.filter(remove_short_utt)
|
test_dls = [dev_dl, test_dl]
|
||||||
test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
|
|
||||||
|
|
||||||
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
|
|
||||||
test_dls = [dev_dl, test_net_dl, test_meeting_dl]
|
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dls):
|
for test_set, test_dl in zip(test_sets, test_dls):
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
|
@ -39,7 +39,7 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
import k2
|
import k2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from asr_datamodule import WenetSpeechAsrDataModule
|
from asr_datamodule import Aidatatang_200zhAsrDataModule
|
||||||
from decode_stream import DecodeStream
|
from decode_stream import DecodeStream
|
||||||
from kaldifeat import Fbank, FbankOptions
|
from kaldifeat import Fbank, FbankOptions
|
||||||
from lhotse import CutSet
|
from lhotse import CutSet
|
||||||
@ -386,7 +386,11 @@ def streaming_forward(
|
|||||||
Returns encoder outputs, output lengths, and updated states.
|
Returns encoder outputs, output lengths, and updated states.
|
||||||
"""
|
"""
|
||||||
cached_embed_left_pad = states[-2]
|
cached_embed_left_pad = states[-2]
|
||||||
(x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward(
|
(
|
||||||
|
x,
|
||||||
|
x_lens,
|
||||||
|
new_cached_embed_left_pad,
|
||||||
|
) = model.encoder_embed.streaming_forward(
|
||||||
x=features,
|
x=features,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
cached_left_pad=cached_embed_left_pad,
|
cached_left_pad=cached_embed_left_pad,
|
||||||
@ -713,7 +717,7 @@ def save_results(
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
WenetSpeechAsrDataModule.add_arguments(parser)
|
Aidatatang_200zhAsrDataModule.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
@ -851,14 +855,13 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
wenetspeech = WenetSpeechAsrDataModule(args)
|
aidatatang_200zh = Aidatatang_200zhAsrDataModule(args)
|
||||||
|
|
||||||
dev_cuts = wenetspeech.valid_cuts()
|
dev_cuts = aidatatang_200zh.valid_cuts()
|
||||||
test_net_cuts = wenetspeech.test_net_cuts()
|
test_cuts = aidatatang_200zh.test_cuts()
|
||||||
test_meeting_cuts = wenetspeech.test_meeting_cuts()
|
|
||||||
|
|
||||||
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
|
test_sets = ["valid_cuts", "test_cuts"]
|
||||||
test_cuts = [dev_cuts, test_net_cuts, test_meeting_cuts]
|
test_cuts = [dev_cuts, test_cuts]
|
||||||
|
|
||||||
for test_set, test_cut in zip(test_sets, test_cuts):
|
for test_set, test_cut in zip(test_sets, test_cuts):
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user