mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
added scripts for streaming related stuff
This commit is contained in:
parent
60691efddf
commit
a421792863
1
egs/mdcc/ASR/zipformer/decode_stream.py
Symbolic link
1
egs/mdcc/ASR/zipformer/decode_stream.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/decode_stream.py
|
@ -31,7 +31,7 @@ from typing import List, Tuple
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import AishellAsrDataModule
|
||||
from asr_datamodule import MdccAsrDataModule
|
||||
from lhotse.cut import Cut
|
||||
from onnx_pretrained import OnnxModel, greedy_search
|
||||
|
||||
@ -212,7 +212,7 @@ def save_results(
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
AishellAsrDataModule.add_arguments(parser)
|
||||
MdccAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
assert (
|
||||
@ -241,7 +241,7 @@ def main():
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
|
||||
aishell = AishellAsrDataModule(args)
|
||||
mdcc = MdccAsrDataModule(args)
|
||||
|
||||
def remove_short_utt(c: Cut):
|
||||
T = ((c.num_frames - 7) // 2 + 1) // 2
|
||||
@ -251,16 +251,16 @@ def main():
|
||||
)
|
||||
return T > 0
|
||||
|
||||
dev_cuts = aishell.valid_cuts()
|
||||
dev_cuts = dev_cuts.filter(remove_short_utt)
|
||||
dev_dl = aishell.valid_dataloaders(dev_cuts)
|
||||
valid_cuts = mdcc.valid_cuts()
|
||||
valid_cuts = valid_cuts.filter(remove_short_utt)
|
||||
valid_dl = mdcc.valid_dataloaders(valid_cuts)
|
||||
|
||||
test_cuts = aishell.test_net_cuts()
|
||||
test_cuts = mdcc.test_net_cuts()
|
||||
test_cuts = test_cuts.filter(remove_short_utt)
|
||||
test_dl = aishell.test_dataloaders(test_cuts)
|
||||
test_dl = mdcc.test_dataloaders(test_cuts)
|
||||
|
||||
test_sets = ["dev", "test"]
|
||||
test_dl = [dev_dl, test_dl]
|
||||
test_sets = ["valid", "test"]
|
||||
test_dl = [valid_dl, test_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
start_time = time.time()
|
||||
|
1
egs/mdcc/ASR/zipformer/streaming_beam_search.py
Symbolic link
1
egs/mdcc/ASR/zipformer/streaming_beam_search.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/streaming_beam_search.py
|
@ -39,7 +39,7 @@ from typing import Dict, List, Optional, Tuple
|
||||
import k2
|
||||
import numpy as np
|
||||
import torch
|
||||
from asr_datamodule import AishellAsrDataModule
|
||||
from asr_datamodule import MdccAsrDataModule
|
||||
from decode_stream import DecodeStream
|
||||
from kaldifeat import Fbank, FbankOptions
|
||||
from lhotse import CutSet
|
||||
@ -177,7 +177,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
default=1,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
@ -386,7 +386,11 @@ def streaming_forward(
|
||||
Returns encoder outputs, output lengths, and updated states.
|
||||
"""
|
||||
cached_embed_left_pad = states[-2]
|
||||
(x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward(
|
||||
(
|
||||
x,
|
||||
x_lens,
|
||||
new_cached_embed_left_pad,
|
||||
) = model.encoder_embed.streaming_forward(
|
||||
x=features,
|
||||
x_lens=feature_lens,
|
||||
cached_left_pad=cached_embed_left_pad,
|
||||
@ -714,7 +718,7 @@ def save_results(
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
AishellAsrDataModule.add_arguments(parser)
|
||||
MdccAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
@ -852,13 +856,13 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
aishell = AishellAsrDataModule(args)
|
||||
mdcc = MdccAsrDataModule(args)
|
||||
|
||||
dev_cuts = aishell.valid_cuts()
|
||||
test_cuts = aishell.test_cuts()
|
||||
valid_cuts = mdcc.valid_cuts()
|
||||
test_cuts = mdcc.test_cuts()
|
||||
|
||||
test_sets = ["dev", "test"]
|
||||
test_cuts = [dev_cuts, test_cuts]
|
||||
test_sets = ["valid", "test"]
|
||||
test_cuts = [valid_cuts, test_cuts]
|
||||
|
||||
for test_set, test_cut in zip(test_sets, test_cuts):
|
||||
results_dict = decode_dataset(
|
||||
|
Loading…
x
Reference in New Issue
Block a user