mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
added scripts for streaming related stuff
This commit is contained in:
parent
60691efddf
commit
a421792863
1
egs/mdcc/ASR/zipformer/decode_stream.py
Symbolic link
1
egs/mdcc/ASR/zipformer/decode_stream.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/zipformer/decode_stream.py
|
@ -31,7 +31,7 @@ from typing import List, Tuple
|
|||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import AishellAsrDataModule
|
from asr_datamodule import MdccAsrDataModule
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from onnx_pretrained import OnnxModel, greedy_search
|
from onnx_pretrained import OnnxModel, greedy_search
|
||||||
|
|
||||||
@ -212,7 +212,7 @@ def save_results(
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
AishellAsrDataModule.add_arguments(parser)
|
MdccAsrDataModule.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
@ -241,7 +241,7 @@ def main():
|
|||||||
# we need cut ids to display recognition results.
|
# we need cut ids to display recognition results.
|
||||||
args.return_cuts = True
|
args.return_cuts = True
|
||||||
|
|
||||||
aishell = AishellAsrDataModule(args)
|
mdcc = MdccAsrDataModule(args)
|
||||||
|
|
||||||
def remove_short_utt(c: Cut):
|
def remove_short_utt(c: Cut):
|
||||||
T = ((c.num_frames - 7) // 2 + 1) // 2
|
T = ((c.num_frames - 7) // 2 + 1) // 2
|
||||||
@ -251,16 +251,16 @@ def main():
|
|||||||
)
|
)
|
||||||
return T > 0
|
return T > 0
|
||||||
|
|
||||||
dev_cuts = aishell.valid_cuts()
|
valid_cuts = mdcc.valid_cuts()
|
||||||
dev_cuts = dev_cuts.filter(remove_short_utt)
|
valid_cuts = valid_cuts.filter(remove_short_utt)
|
||||||
dev_dl = aishell.valid_dataloaders(dev_cuts)
|
valid_dl = mdcc.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
test_cuts = aishell.test_net_cuts()
|
test_cuts = mdcc.test_net_cuts()
|
||||||
test_cuts = test_cuts.filter(remove_short_utt)
|
test_cuts = test_cuts.filter(remove_short_utt)
|
||||||
test_dl = aishell.test_dataloaders(test_cuts)
|
test_dl = mdcc.test_dataloaders(test_cuts)
|
||||||
|
|
||||||
test_sets = ["dev", "test"]
|
test_sets = ["valid", "test"]
|
||||||
test_dl = [dev_dl, test_dl]
|
test_dl = [valid_dl, test_dl]
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dl):
|
for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
1
egs/mdcc/ASR/zipformer/streaming_beam_search.py
Symbolic link
1
egs/mdcc/ASR/zipformer/streaming_beam_search.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/zipformer/streaming_beam_search.py
|
@ -39,7 +39,7 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
import k2
|
import k2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from asr_datamodule import AishellAsrDataModule
|
from asr_datamodule import MdccAsrDataModule
|
||||||
from decode_stream import DecodeStream
|
from decode_stream import DecodeStream
|
||||||
from kaldifeat import Fbank, FbankOptions
|
from kaldifeat import Fbank, FbankOptions
|
||||||
from lhotse import CutSet
|
from lhotse import CutSet
|
||||||
@ -177,7 +177,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=1,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -386,7 +386,11 @@ def streaming_forward(
|
|||||||
Returns encoder outputs, output lengths, and updated states.
|
Returns encoder outputs, output lengths, and updated states.
|
||||||
"""
|
"""
|
||||||
cached_embed_left_pad = states[-2]
|
cached_embed_left_pad = states[-2]
|
||||||
(x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward(
|
(
|
||||||
|
x,
|
||||||
|
x_lens,
|
||||||
|
new_cached_embed_left_pad,
|
||||||
|
) = model.encoder_embed.streaming_forward(
|
||||||
x=features,
|
x=features,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
cached_left_pad=cached_embed_left_pad,
|
cached_left_pad=cached_embed_left_pad,
|
||||||
@ -714,7 +718,7 @@ def save_results(
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
AishellAsrDataModule.add_arguments(parser)
|
MdccAsrDataModule.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
@ -852,13 +856,13 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
aishell = AishellAsrDataModule(args)
|
mdcc = MdccAsrDataModule(args)
|
||||||
|
|
||||||
dev_cuts = aishell.valid_cuts()
|
valid_cuts = mdcc.valid_cuts()
|
||||||
test_cuts = aishell.test_cuts()
|
test_cuts = mdcc.test_cuts()
|
||||||
|
|
||||||
test_sets = ["dev", "test"]
|
test_sets = ["valid", "test"]
|
||||||
test_cuts = [dev_cuts, test_cuts]
|
test_cuts = [valid_cuts, test_cuts]
|
||||||
|
|
||||||
for test_set, test_cut in zip(test_sets, test_cuts):
|
for test_set, test_cut in zip(test_sets, test_cuts):
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user