modify streaming decoding files

This commit is contained in:
yaozengwei 2022-07-17 16:09:24 +08:00
parent 822cc78a9c
commit 5c669b7716
2 changed files with 16 additions and 41 deletions

View File

@ -71,9 +71,9 @@ class DecodeStream(object):
# encoder.streaming_forward
self.done_frames: int = 0
self.pad_length = (
params.right_context + 2
) * params.subsampling_factor + 3
# add 2 here since we will drop the first and last frames after
# the convolutional subsampling module
self.pad_length = 2 * params.subsampling_factor + 3
if params.decoding_method == "greedy_search":
self.hyp = [params.blank_id] * params.context_size

View File

@ -17,15 +17,13 @@
"""
Usage:
./pruned_transducer_stateless2/streaming_decode.py \
--epoch 28 \
--avg 15 \
--left-context 32 \
--decode-chunk-size 8 \
--right-context 0 \
--exp-dir ./pruned_transducer_stateless2/exp \
--decoding_method greedy_search \
--num-decode-streams 1000
./lstm_transducer_stateless/streaming_decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./lstm_transducer_stateless/exp \
--decoding_method greedy_search \
--decode-chunk-size 1 \
--num-decode-streams 1000
"""
import argparse
@ -44,7 +42,7 @@ from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from train import get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
@ -155,24 +153,10 @@ def get_parser():
parser.add_argument(
"--decode-chunk-size",
type=int,
default=16,
default=1,
help="The chunk size for decoding (in frames after subsampling)",
)
parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument(
"--right-context",
type=int,
default=0,
help="right context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument(
"--num-decode-streams",
type=int,
@ -180,8 +164,6 @@ def get_parser():
help="The number of streams that can be decoded parallel.",
)
add_model_arguments(parser)
return parser
@ -334,7 +316,7 @@ def decode_one_chunk(
# we plus 2 here because we will cut off one frame on each size of
# encoder_embed output as they see invalid paddings. so we need extra 2
# frames.
tail_length = 7 + (2 + params.right_context) * params.subsampling_factor
tail_length = 7 + 2 * params.subsampling_factor
if features.size(1) < tail_length:
feature_lens += tail_length - features.size(1)
features = torch.cat(
@ -357,13 +339,10 @@ def decode_one_chunk(
]
processed_lens = torch.tensor(processed_lens, device=device)
encoder_out, encoder_out_lens, states = model.encoder.streaming_forward(
encoder_out, encoder_out_lens, states = model.encoder.infer(
x=features,
x_lens=feature_lens,
states=states,
left_context=params.left_context,
right_context=params.right_context,
processed_lens=processed_lens,
)
encoder_out = model.joiner.encoder_proj(encoder_out)
@ -442,8 +421,8 @@ def decode_dataset(
decode_results = []
# Contain decode streams currently running.
decode_streams = []
initial_states = model.encoder.get_init_state(
params.left_context, device=device
initial_states = model.encoder.get_init_states(
device=device
)
for num, cut in enumerate(cuts):
# each utterance has a DecodeStream.
@ -584,8 +563,6 @@ def main():
# for streaming
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-left-context-{params.left_context}"
params.suffix += f"-right-context-{params.right_context}"
# for fast_beam_search
if params.decoding_method == "fast_beam_search":
@ -609,8 +586,6 @@ def main():
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
# Decoding in streaming requires causal convolution
params.causal_convolution = True
logging.info(params)