mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
modify streaming decoding files
This commit is contained in:
parent
822cc78a9c
commit
5c669b7716
@ -71,9 +71,9 @@ class DecodeStream(object):
|
||||
# encoder.streaming_forward
|
||||
self.done_frames: int = 0
|
||||
|
||||
self.pad_length = (
|
||||
params.right_context + 2
|
||||
) * params.subsampling_factor + 3
|
||||
# add 2 here since we will drop the first and last frames after
|
||||
# the convolutional subsampling module
|
||||
self.pad_length = 2 * params.subsampling_factor + 3
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
self.hyp = [params.blank_id] * params.context_size
|
||||
|
@ -17,15 +17,13 @@
|
||||
|
||||
"""
|
||||
Usage:
|
||||
./pruned_transducer_stateless2/streaming_decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--left-context 32 \
|
||||
--decode-chunk-size 8 \
|
||||
--right-context 0 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--decoding_method greedy_search \
|
||||
--num-decode-streams 1000
|
||||
./lstm_transducer_stateless/streaming_decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./lstm_transducer_stateless/exp \
|
||||
--decoding_method greedy_search \
|
||||
--decode-chunk-size 1 \
|
||||
--num-decode-streams 1000
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@ -44,7 +42,7 @@ from decode_stream import DecodeStream
|
||||
from kaldifeat import Fbank, FbankOptions
|
||||
from lhotse import CutSet
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
@ -155,24 +153,10 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--decode-chunk-size",
|
||||
type=int,
|
||||
default=16,
|
||||
default=1,
|
||||
help="The chunk size for decoding (in frames after subsampling)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--left-context",
|
||||
type=int,
|
||||
default=64,
|
||||
help="left context can be seen during decoding (in frames after subsampling)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--right-context",
|
||||
type=int,
|
||||
default=0,
|
||||
help="right context can be seen during decoding (in frames after subsampling)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-decode-streams",
|
||||
type=int,
|
||||
@ -180,8 +164,6 @@ def get_parser():
|
||||
help="The number of streams that can be decoded parallel.",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -334,7 +316,7 @@ def decode_one_chunk(
|
||||
# we plus 2 here because we will cut off one frame on each size of
|
||||
# encoder_embed output as they see invalid paddings. so we need extra 2
|
||||
# frames.
|
||||
tail_length = 7 + (2 + params.right_context) * params.subsampling_factor
|
||||
tail_length = 7 + 2 * params.subsampling_factor
|
||||
if features.size(1) < tail_length:
|
||||
feature_lens += tail_length - features.size(1)
|
||||
features = torch.cat(
|
||||
@ -357,13 +339,10 @@ def decode_one_chunk(
|
||||
]
|
||||
processed_lens = torch.tensor(processed_lens, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens, states = model.encoder.streaming_forward(
|
||||
encoder_out, encoder_out_lens, states = model.encoder.infer(
|
||||
x=features,
|
||||
x_lens=feature_lens,
|
||||
states=states,
|
||||
left_context=params.left_context,
|
||||
right_context=params.right_context,
|
||||
processed_lens=processed_lens,
|
||||
)
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
@ -442,8 +421,8 @@ def decode_dataset(
|
||||
decode_results = []
|
||||
# Contain decode streams currently running.
|
||||
decode_streams = []
|
||||
initial_states = model.encoder.get_init_state(
|
||||
params.left_context, device=device
|
||||
initial_states = model.encoder.get_init_states(
|
||||
device=device
|
||||
)
|
||||
for num, cut in enumerate(cuts):
|
||||
# each utterance has a DecodeStream.
|
||||
@ -584,8 +563,6 @@ def main():
|
||||
|
||||
# for streaming
|
||||
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
|
||||
params.suffix += f"-left-context-{params.left_context}"
|
||||
params.suffix += f"-right-context-{params.right_context}"
|
||||
|
||||
# for fast_beam_search
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
@ -609,8 +586,6 @@ def main():
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
# Decoding in streaming requires causal convolution
|
||||
params.causal_convolution = True
|
||||
|
||||
logging.info(params)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user