mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
update pretrain to support streaming model
This commit is contained in:
parent
8d37175ffb
commit
af80a463d3
@ -77,7 +77,9 @@ from beam_search import (
|
|||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -177,6 +179,29 @@ def get_parser():
|
|||||||
--method is greedy_search.
|
--method is greedy_search.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--simulate-streaming",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Whether to simulate streaming in decoding, this is a good way to
|
||||||
|
test a streaming model.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decode-chunk-size",
|
||||||
|
type=int,
|
||||||
|
default=16,
|
||||||
|
help="The chunk size for decoding (in frames after subsampling)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--left-context",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
help="left context can be seen during decoding (in frames after subsampling)",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@ -222,6 +247,11 @@ def main():
|
|||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
|
if params.simulate_streaming:
|
||||||
|
assert (
|
||||||
|
params.causal_convolution
|
||||||
|
), "Decoding in streaming requires causal convolution"
|
||||||
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
@ -268,9 +298,18 @@ def main():
|
|||||||
|
|
||||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(
|
if params.simulate_streaming:
|
||||||
x=features, x_lens=feature_lengths
|
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||||
)
|
x=features,
|
||||||
|
x_lens=feature_lengths,
|
||||||
|
chunk_size=params.decode_chunk_size,
|
||||||
|
left_context=params.left_context,
|
||||||
|
simulate_streaming=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
|
x=features, x_lens=feature_lengths
|
||||||
|
)
|
||||||
|
|
||||||
num_waves = encoder_out.size(0)
|
num_waves = encoder_out.size(0)
|
||||||
hyps = []
|
hyps = []
|
||||||
|
@ -77,7 +77,9 @@ from beam_search import (
|
|||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -178,6 +180,30 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--simulate-streaming",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Whether to simulate streaming in decoding, this is a good way to
|
||||||
|
test a streaming model.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decode-chunk-size",
|
||||||
|
type=int,
|
||||||
|
default=16,
|
||||||
|
help="The chunk size for decoding (in frames after subsampling)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--left-context",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
help="left context can be seen during decoding (in frames after subsampling)",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -222,6 +248,11 @@ def main():
|
|||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
|
if params.simulate_streaming:
|
||||||
|
assert (
|
||||||
|
params.causal_convolution
|
||||||
|
), "Decoding in streaming requires causal convolution"
|
||||||
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
@ -268,9 +299,18 @@ def main():
|
|||||||
|
|
||||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(
|
if params.simulate_streaming:
|
||||||
x=features, x_lens=feature_lengths
|
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||||
)
|
x=features,
|
||||||
|
x_lens=feature_lengths,
|
||||||
|
chunk_size=params.decode_chunk_size,
|
||||||
|
left_context=params.left_context,
|
||||||
|
simulate_streaming=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
|
x=features, x_lens=feature_lengths
|
||||||
|
)
|
||||||
|
|
||||||
num_waves = encoder_out.size(0)
|
num_waves = encoder_out.size(0)
|
||||||
hyps = []
|
hyps = []
|
||||||
|
@ -77,7 +77,9 @@ from beam_search import (
|
|||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -178,6 +180,30 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--simulate-streaming",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Whether to simulate streaming in decoding, this is a good way to
|
||||||
|
test a streaming model.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decode-chunk-size",
|
||||||
|
type=int,
|
||||||
|
default=16,
|
||||||
|
help="The chunk size for decoding (in frames after subsampling)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--left-context",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
help="left context can be seen during decoding (in frames after subsampling)",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -222,6 +248,11 @@ def main():
|
|||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
|
if params.simulate_streaming:
|
||||||
|
assert (
|
||||||
|
params.causal_convolution
|
||||||
|
), "Decoding in streaming requires causal convolution"
|
||||||
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
@ -268,9 +299,18 @@ def main():
|
|||||||
|
|
||||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(
|
if params.simulate_streaming:
|
||||||
x=features, x_lens=feature_lengths
|
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||||
)
|
x=features,
|
||||||
|
x_lens=feature_lengths,
|
||||||
|
chunk_size=params.decode_chunk_size,
|
||||||
|
left_context=params.left_context,
|
||||||
|
simulate_streaming=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
|
x=features, x_lens=feature_lengths
|
||||||
|
)
|
||||||
|
|
||||||
num_waves = encoder_out.size(0)
|
num_waves = encoder_out.size(0)
|
||||||
hyps = []
|
hyps = []
|
||||||
|
Loading…
x
Reference in New Issue
Block a user