modify streaming decoding files
This commit is contained in:
parent
822cc78a9c
commit
5c669b7716
@ -71,9 +71,9 @@ class DecodeStream(object):
|
|||||||
# encoder.streaming_forward
|
# encoder.streaming_forward
|
||||||
self.done_frames: int = 0
|
self.done_frames: int = 0
|
||||||
|
|
||||||
self.pad_length = (
|
# add 2 here since we will drop the first and last frames after
|
||||||
params.right_context + 2
|
# the convolutional subsampling module
|
||||||
) * params.subsampling_factor + 3
|
self.pad_length = 2 * params.subsampling_factor + 3
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
self.hyp = [params.blank_id] * params.context_size
|
self.hyp = [params.blank_id] * params.context_size
|
||||||
|
|||||||
@ -17,15 +17,13 @@
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
./pruned_transducer_stateless2/streaming_decode.py \
|
./lstm_transducer_stateless/streaming_decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--left-context 32 \
|
--exp-dir ./lstm_transducer_stateless/exp \
|
||||||
--decode-chunk-size 8 \
|
--decoding_method greedy_search \
|
||||||
--right-context 0 \
|
--decode-chunk-size 1 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--num-decode-streams 1000
|
||||||
--decoding_method greedy_search \
|
|
||||||
--num-decode-streams 1000
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@ -44,7 +42,7 @@ from decode_stream import DecodeStream
|
|||||||
from kaldifeat import Fbank, FbankOptions
|
from kaldifeat import Fbank, FbankOptions
|
||||||
from lhotse import CutSet
|
from lhotse import CutSet
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
@ -155,24 +153,10 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--decode-chunk-size",
|
"--decode-chunk-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=16,
|
default=1,
|
||||||
help="The chunk size for decoding (in frames after subsampling)",
|
help="The chunk size for decoding (in frames after subsampling)",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--left-context",
|
|
||||||
type=int,
|
|
||||||
default=64,
|
|
||||||
help="left context can be seen during decoding (in frames after subsampling)",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--right-context",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="right context can be seen during decoding (in frames after subsampling)",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-decode-streams",
|
"--num-decode-streams",
|
||||||
type=int,
|
type=int,
|
||||||
@ -180,8 +164,6 @@ def get_parser():
|
|||||||
help="The number of streams that can be decoded parallel.",
|
help="The number of streams that can be decoded parallel.",
|
||||||
)
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -334,7 +316,7 @@ def decode_one_chunk(
|
|||||||
# we plus 2 here because we will cut off one frame on each size of
|
# we plus 2 here because we will cut off one frame on each size of
|
||||||
# encoder_embed output as they see invalid paddings. so we need extra 2
|
# encoder_embed output as they see invalid paddings. so we need extra 2
|
||||||
# frames.
|
# frames.
|
||||||
tail_length = 7 + (2 + params.right_context) * params.subsampling_factor
|
tail_length = 7 + 2 * params.subsampling_factor
|
||||||
if features.size(1) < tail_length:
|
if features.size(1) < tail_length:
|
||||||
feature_lens += tail_length - features.size(1)
|
feature_lens += tail_length - features.size(1)
|
||||||
features = torch.cat(
|
features = torch.cat(
|
||||||
@ -357,13 +339,10 @@ def decode_one_chunk(
|
|||||||
]
|
]
|
||||||
processed_lens = torch.tensor(processed_lens, device=device)
|
processed_lens = torch.tensor(processed_lens, device=device)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens, states = model.encoder.streaming_forward(
|
encoder_out, encoder_out_lens, states = model.encoder.infer(
|
||||||
x=features,
|
x=features,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
states=states,
|
states=states,
|
||||||
left_context=params.left_context,
|
|
||||||
right_context=params.right_context,
|
|
||||||
processed_lens=processed_lens,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||||
@ -442,8 +421,8 @@ def decode_dataset(
|
|||||||
decode_results = []
|
decode_results = []
|
||||||
# Contain decode streams currently running.
|
# Contain decode streams currently running.
|
||||||
decode_streams = []
|
decode_streams = []
|
||||||
initial_states = model.encoder.get_init_state(
|
initial_states = model.encoder.get_init_states(
|
||||||
params.left_context, device=device
|
device=device
|
||||||
)
|
)
|
||||||
for num, cut in enumerate(cuts):
|
for num, cut in enumerate(cuts):
|
||||||
# each utterance has a DecodeStream.
|
# each utterance has a DecodeStream.
|
||||||
@ -584,8 +563,6 @@ def main():
|
|||||||
|
|
||||||
# for streaming
|
# for streaming
|
||||||
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
|
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
|
||||||
params.suffix += f"-left-context-{params.left_context}"
|
|
||||||
params.suffix += f"-right-context-{params.right_context}"
|
|
||||||
|
|
||||||
# for fast_beam_search
|
# for fast_beam_search
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
@ -609,8 +586,6 @@ def main():
|
|||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
# Decoding in streaming requires causal convolution
|
|
||||||
params.causal_convolution = True
|
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user