added scripts for streaming related stuff

This commit is contained in:
jinzr 2024-03-11 11:08:15 +08:00
parent 60691efddf
commit a421792863
4 changed files with 25 additions and 19 deletions

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/decode_stream.py

View File

@ -31,7 +31,7 @@ from typing import List, Tuple
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import AishellAsrDataModule from asr_datamodule import MdccAsrDataModule
from lhotse.cut import Cut from lhotse.cut import Cut
from onnx_pretrained import OnnxModel, greedy_search from onnx_pretrained import OnnxModel, greedy_search
@ -212,7 +212,7 @@ def save_results(
@torch.no_grad() @torch.no_grad()
def main(): def main():
parser = get_parser() parser = get_parser()
AishellAsrDataModule.add_arguments(parser) MdccAsrDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
assert ( assert (
@ -241,7 +241,7 @@ def main():
# we need cut ids to display recognition results. # we need cut ids to display recognition results.
args.return_cuts = True args.return_cuts = True
aishell = AishellAsrDataModule(args) mdcc = MdccAsrDataModule(args)
def remove_short_utt(c: Cut): def remove_short_utt(c: Cut):
T = ((c.num_frames - 7) // 2 + 1) // 2 T = ((c.num_frames - 7) // 2 + 1) // 2
@ -251,16 +251,16 @@ def main():
) )
return T > 0 return T > 0
dev_cuts = aishell.valid_cuts() valid_cuts = mdcc.valid_cuts()
dev_cuts = dev_cuts.filter(remove_short_utt) valid_cuts = valid_cuts.filter(remove_short_utt)
dev_dl = aishell.valid_dataloaders(dev_cuts) valid_dl = mdcc.valid_dataloaders(valid_cuts)
test_cuts = aishell.test_net_cuts() test_cuts = mdcc.test_net_cuts()
test_cuts = test_cuts.filter(remove_short_utt) test_cuts = test_cuts.filter(remove_short_utt)
test_dl = aishell.test_dataloaders(test_cuts) test_dl = mdcc.test_dataloaders(test_cuts)
test_sets = ["dev", "test"] test_sets = ["valid", "test"]
test_dl = [dev_dl, test_dl] test_dl = [valid_dl, test_dl]
for test_set, test_dl in zip(test_sets, test_dl): for test_set, test_dl in zip(test_sets, test_dl):
start_time = time.time() start_time = time.time()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/streaming_beam_search.py

View File

@ -39,7 +39,7 @@ from typing import Dict, List, Optional, Tuple
import k2 import k2
import numpy as np import numpy as np
import torch import torch
from asr_datamodule import AishellAsrDataModule from asr_datamodule import MdccAsrDataModule
from decode_stream import DecodeStream from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet from lhotse import CutSet
@ -177,7 +177,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--context-size", "--context-size",
type=int, type=int,
default=2, default=1,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
) )
@ -386,7 +386,11 @@ def streaming_forward(
Returns encoder outputs, output lengths, and updated states. Returns encoder outputs, output lengths, and updated states.
""" """
cached_embed_left_pad = states[-2] cached_embed_left_pad = states[-2]
(x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward( (
x,
x_lens,
new_cached_embed_left_pad,
) = model.encoder_embed.streaming_forward(
x=features, x=features,
x_lens=feature_lens, x_lens=feature_lens,
cached_left_pad=cached_embed_left_pad, cached_left_pad=cached_embed_left_pad,
@ -714,7 +718,7 @@ def save_results(
@torch.no_grad() @torch.no_grad()
def main(): def main():
parser = get_parser() parser = get_parser()
AishellAsrDataModule.add_arguments(parser) MdccAsrDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
@ -852,13 +856,13 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
aishell = AishellAsrDataModule(args) mdcc = MdccAsrDataModule(args)
dev_cuts = aishell.valid_cuts() valid_cuts = mdcc.valid_cuts()
test_cuts = aishell.test_cuts() test_cuts = mdcc.test_cuts()
test_sets = ["dev", "test"] test_sets = ["valid", "test"]
test_cuts = [dev_cuts, test_cuts] test_cuts = [valid_cuts, test_cuts]
for test_set, test_cut in zip(test_sets, test_cuts): for test_set, test_cut in zip(test_sets, test_cuts):
results_dict = decode_dataset( results_dict = decode_dataset(