mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Remove simulate streaming from stateless8 (#985)
This commit is contained in:
parent
d337398d29
commit
c90f57afdb
@ -301,29 +301,6 @@ def get_parser():
|
|||||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--simulate-streaming",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="""Whether to simulate streaming in decoding, this is a good way to
|
|
||||||
test a streaming model.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--decode-chunk-size",
|
|
||||||
type=int,
|
|
||||||
default=16,
|
|
||||||
help="The chunk size for decoding (in frames after subsampling)",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--left-context",
|
|
||||||
type=int,
|
|
||||||
default=64,
|
|
||||||
help="left context can be seen during decoding (in frames after subsampling)",
|
|
||||||
)
|
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -378,21 +355,6 @@ def decode_one_batch(
|
|||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
if params.simulate_streaming:
|
|
||||||
feature_lens += params.left_context
|
|
||||||
feature = torch.nn.functional.pad(
|
|
||||||
feature,
|
|
||||||
pad=(0, 0, 0, params.left_context),
|
|
||||||
value=LOG_EPS,
|
|
||||||
)
|
|
||||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
|
||||||
x=feature,
|
|
||||||
x_lens=feature_lens,
|
|
||||||
chunk_size=params.decode_chunk_size,
|
|
||||||
left_context=params.left_context,
|
|
||||||
simulate_streaming=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
||||||
|
|
||||||
hyps = []
|
hyps = []
|
||||||
@ -651,10 +613,6 @@ def main():
|
|||||||
else:
|
else:
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
if params.simulate_streaming:
|
|
||||||
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
|
|
||||||
params.suffix += f"-left-context-{params.left_context}"
|
|
||||||
|
|
||||||
if "fast_beam_search" in params.decoding_method:
|
if "fast_beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-beam-{params.beam}"
|
params.suffix += f"-beam-{params.beam}"
|
||||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
@ -690,11 +648,6 @@ def main():
|
|||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
if params.simulate_streaming:
|
|
||||||
assert (
|
|
||||||
params.causal_convolution
|
|
||||||
), "Decoding in streaming requires causal convolution"
|
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user