Remove simulate streaming from stateless7 (#983)

* Remove simulate streaming from stateless7
This commit is contained in:
Yifan Yang 2023-04-03 14:55:45 +08:00 committed by GitHub
parent 180c7c2b7a
commit 46bf6df62f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -343,29 +343,6 @@ def get_parser():
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--simulate-streaming",
type=str2bool,
default=False,
help="""Whether to simulate streaming in decoding, this is a good way to
test a streaming model.
""",
)
parser.add_argument(
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
)
parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument(
"--use-shallow-fusion",
type=str2bool,
@ -474,22 +451,7 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
if params.simulate_streaming:
feature_lens += params.left_context
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, params.left_context),
value=LOG_EPS,
)
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature,
x_lens=feature_lens,
chunk_size=params.decode_chunk_size,
left_context=params.left_context,
simulate_streaming=True,
)
else:
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
hyps = []
@ -782,10 +744,6 @@ def main():
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.simulate_streaming:
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-left-context-{params.left_context}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
@ -834,11 +792,6 @@ def main():
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
if params.simulate_streaming:
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
logging.info(params)
logging.info("About to create model")