minor fix in decode.py, about args

This commit is contained in:
yaozengwei 2023-04-13 17:20:25 +08:00
parent d27e61170b
commit 87d9491fba

View File

@ -301,27 +301,18 @@ def get_parser():
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--simulate-streaming",
type=str2bool,
default=False,
help="""Whether to simulate streaming in decoding, this is a good way to
test a streaming model.
""",
)
parser.add_argument(
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
help="The chunk size for decoding (in frames after subsampling), at 50Hz frame rate",
)
parser.add_argument(
"--left-context",
"--decode-left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)",
help="left context can be seen during decoding (in frames after subsampling), at 50Hz frame rate",
)
add_model_arguments(parser)
@ -378,28 +369,19 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
# this seems to cause insertions at the end of the utterance if used with zipformer.
#feature_lens += params.left_context
#feature = torch.nn.functional.pad(
# feature,
# pad=(0, 0, 0, params.left_context),
# value=LOG_EPS,
#)
if params.causal:
# this seems to cause insertions at the end of the utterance if used with zipformer.
pad_len = params.decode_chunk_size
feature_lens += pad_len
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, pad_len),
value=LOG_EPS,
)
if params.simulate_streaming:
# the chunk size and left context are now stored with the model.
# TODO: implement streaming_forward.
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature,
x_lens=feature_lens,
#chunk_size=params.decode_chunk_size,
#left_context=params.left_context,
#simulate_streaming=True,
)
else:
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
@ -669,11 +651,12 @@ def main():
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
# TODO: may still want to add something here? for now I am just
# moving the decoding directories around after decoding.
#if params.simulate_streaming:
#params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
#params.suffix += f"-left-context-{params.left_context}"
if params.causal:
# 'chunk_size' and 'left_context_frames' are used in function 'get_encoder_model' in train.py
params.chunk_size = str(params.decode_chunk_size)
params.left_context_frames = str(params.decode_left_context)
params.suffix += f"-decode-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-decode-left-context-{params.decode_left_context}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
@ -712,11 +695,6 @@ def main():
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
if params.simulate_streaming:
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
logging.info(params)
logging.info("About to create model")