update pretrain to support streaming model

This commit is contained in:
pkufool 2022-06-25 18:24:36 +08:00
parent 8d37175ffb
commit af80a463d3
3 changed files with 131 additions and 12 deletions

View File

@ -77,7 +77,9 @@ from beam_search import (
modified_beam_search, modified_beam_search,
) )
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import str2bool
def get_parser(): def get_parser():
@ -177,6 +179,29 @@ def get_parser():
--method is greedy_search. --method is greedy_search.
""", """,
) )
parser.add_argument(
"--simulate-streaming",
type=str2bool,
default=False,
help="""Whether to simulate streaming in decoding, this is a good way to
test a streaming model.
""",
)
parser.add_argument(
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
)
parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)",
)
add_model_arguments(parser)
return parser return parser
@ -222,6 +247,11 @@ def main():
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
if params.simulate_streaming:
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
logging.info(f"{params}") logging.info(f"{params}")
device = torch.device("cpu") device = torch.device("cpu")
@ -268,9 +298,18 @@ def main():
feature_lengths = torch.tensor(feature_lengths, device=device) feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder( if params.simulate_streaming:
x=features, x_lens=feature_lengths encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
) x=features,
x_lens=feature_lengths,
chunk_size=params.decode_chunk_size,
left_context=params.left_context,
simulate_streaming=True,
)
else:
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lengths
)
num_waves = encoder_out.size(0) num_waves = encoder_out.size(0)
hyps = [] hyps = []

View File

@ -77,7 +77,9 @@ from beam_search import (
modified_beam_search, modified_beam_search,
) )
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import str2bool
def get_parser(): def get_parser():
@ -178,6 +180,30 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--simulate-streaming",
type=str2bool,
default=False,
help="""Whether to simulate streaming in decoding, this is a good way to
test a streaming model.
""",
)
parser.add_argument(
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
)
parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)",
)
add_model_arguments(parser)
return parser return parser
@ -222,6 +248,11 @@ def main():
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
if params.simulate_streaming:
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
logging.info(f"{params}") logging.info(f"{params}")
device = torch.device("cpu") device = torch.device("cpu")
@ -268,9 +299,18 @@ def main():
feature_lengths = torch.tensor(feature_lengths, device=device) feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder( if params.simulate_streaming:
x=features, x_lens=feature_lengths encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
) x=features,
x_lens=feature_lengths,
chunk_size=params.decode_chunk_size,
left_context=params.left_context,
simulate_streaming=True,
)
else:
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lengths
)
num_waves = encoder_out.size(0) num_waves = encoder_out.size(0)
hyps = [] hyps = []

View File

@ -77,7 +77,9 @@ from beam_search import (
modified_beam_search, modified_beam_search,
) )
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import str2bool
def get_parser(): def get_parser():
@ -178,6 +180,30 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--simulate-streaming",
type=str2bool,
default=False,
help="""Whether to simulate streaming in decoding, this is a good way to
test a streaming model.
""",
)
parser.add_argument(
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
)
parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)",
)
add_model_arguments(parser)
return parser return parser
@ -222,6 +248,11 @@ def main():
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
if params.simulate_streaming:
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
logging.info(f"{params}") logging.info(f"{params}")
device = torch.device("cpu") device = torch.device("cpu")
@ -268,9 +299,18 @@ def main():
feature_lengths = torch.tensor(feature_lengths, device=device) feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder( if params.simulate_streaming:
x=features, x_lens=feature_lengths encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
) x=features,
x_lens=feature_lengths,
chunk_size=params.decode_chunk_size,
left_context=params.left_context,
simulate_streaming=True,
)
else:
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lengths
)
num_waves = encoder_out.size(0) num_waves = encoder_out.size(0)
hyps = [] hyps = []