Fix comments

This commit is contained in:
pkufool 2022-07-23 22:42:18 +08:00
parent 353863a55c
commit 3c4c615e1f
6 changed files with 30 additions and 26 deletions

View File

@ -55,7 +55,7 @@ def greedy_search(
device=device,
dtype=torch.int64,
)
# decoder_out is of shape (N, decoder_out_dim)
# decoder_out is of shape (N, 1, decoder_out_dim)
decoder_out = model.decoder(decoder_input, need_pad=False)
for t in range(T):
@ -93,7 +93,7 @@ def modified_beam_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[DecodeStream],
beam: int = 4,
num_active_paths: int = 4,
) -> None:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
@ -105,7 +105,7 @@ def modified_beam_search(
the encoder model.
streams:
A list of stream objects.
beam:
num_active_paths:
Number of active paths during the beam search.
"""
assert encoder_out.ndim == 3, encoder_out.shape
@ -171,7 +171,9 @@ def modified_beam_search(
)
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(
num_active_paths
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")

View File

@ -17,13 +17,13 @@
"""
Usage:
./pruned_transducer_stateless2/streaming_decode.py \
./pruned_transducer_stateless/streaming_decode.py \
--epoch 28 \
--avg 15 \
--decode-chunk-size 8 \
--left-context 32 \
--right-context 0 \
--exp-dir ./pruned_transducer_stateless2/exp \
--exp-dir ./pruned_transducer_stateless/exp \
--decoding_method greedy_search \
--num-decode-streams 1000
"""
@ -125,7 +125,7 @@ def get_parser():
)
parser.add_argument(
"--beam-size",
"--num-active-paths",
type=int,
default=4,
help="""An interger indicating how many candidates we will keep for each
@ -288,7 +288,7 @@ def decode_one_chunk(
model=model,
streams=decode_streams,
encoder_out=encoder_out,
beam=params.beam_size,
num_active_paths=params.num_active_paths,
)
else:
raise ValueError(
@ -417,7 +417,7 @@ def decode_dataset(
f"max_states_{params.max_states}"
)
elif params.decoding_method == "modified_beam_search":
key = f"beam_size_{params.beam_size}"
key = f"num_active_paths_{params.num_active_paths}"
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"

View File

@ -55,7 +55,7 @@ def greedy_search(
device=device,
dtype=torch.int64,
)
# decoder_out is of shape (N, decoder_out_dim)
# decoder_out is of shape (N, 1, decoder_out_dim)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
@ -96,7 +96,7 @@ def modified_beam_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[DecodeStream],
beam: int = 4,
num_active_paths: int = 4,
) -> None:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
@ -108,7 +108,7 @@ def modified_beam_search(
the encoder model.
streams:
A list of stream objects.
beam:
num_active_paths:
Number of active paths during the beam search.
"""
assert encoder_out.ndim == 3, encoder_out.shape
@ -177,7 +177,9 @@ def modified_beam_search(
)
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(
num_active_paths
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")

View File

@ -125,7 +125,7 @@ def get_parser():
)
parser.add_argument(
"--beam-size",
"--num_active_paths",
type=int,
default=4,
help="""An interger indicating how many candidates we will keep for each
@ -290,7 +290,7 @@ def decode_one_chunk(
model=model,
streams=decode_streams,
encoder_out=encoder_out,
beam=params.beam_size,
num_active_paths=params.num_active_paths,
)
else:
raise ValueError(
@ -420,7 +420,7 @@ def decode_dataset(
f"max_states_{params.max_states}"
)
elif params.decoding_method == "modified_beam_search":
key = f"beam_size_{params.beam_size}"
key = f"num_active_paths_{params.num_active_paths}"
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"

View File

@ -17,13 +17,13 @@
"""
Usage:
./pruned_transducer_stateless2/streaming_decode.py \
./pruned_transducer_stateless3/streaming_decode.py \
--epoch 28 \
--avg 15 \
--left-context 32 \
--decode-chunk-size 8 \
--right-context 0 \
--exp-dir ./pruned_transducer_stateless2/exp \
--exp-dir ./pruned_transducer_stateless3/exp \
--decoding_method greedy_search \
--num-decode-streams 1000
"""
@ -126,7 +126,7 @@ def get_parser():
)
parser.add_argument(
"--beam-size",
"--num_active_paths",
type=int,
default=4,
help="""An interger indicating how many candidates we will keep for each
@ -291,7 +291,7 @@ def decode_one_chunk(
model=model,
streams=decode_streams,
encoder_out=encoder_out,
beam=params.beam_size,
num_active_paths=params.num_active_paths,
)
else:
raise ValueError(
@ -421,7 +421,7 @@ def decode_dataset(
f"max_states_{params.max_states}"
)
elif params.decoding_method == "modified_beam_search":
key = f"beam_size_{params.beam_size}"
key = f"num_active_paths_{params.num_active_paths}"
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"

View File

@ -17,13 +17,13 @@
"""
Usage:
./pruned_transducer_stateless2/streaming_decode.py \
./pruned_transducer_stateless4/streaming_decode.py \
--epoch 28 \
--avg 15 \
--left-context 32 \
--decode-chunk-size 8 \
--right-context 0 \
--exp-dir ./pruned_transducer_stateless2/exp \
--exp-dir ./pruned_transducer_stateless4/exp \
--decoding_method greedy_search \
--num-decode-streams 200
"""
@ -138,7 +138,7 @@ def get_parser():
)
parser.add_argument(
"--beam-size",
"--num_active_paths",
type=int,
default=4,
help="""An interger indicating how many candidates we will keep for each
@ -303,7 +303,7 @@ def decode_one_chunk(
model=model,
streams=decode_streams,
encoder_out=encoder_out,
beam=params.beam_size,
num_active_paths=params.num_active_paths,
)
else:
raise ValueError(
@ -433,7 +433,7 @@ def decode_dataset(
f"max_states_{params.max_states}"
)
elif params.decoding_method == "modified_beam_search":
key = f"beam_size_{params.beam_size}"
key = f"num_active_paths_{params.num_active_paths}"
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"