Fix comments

This commit is contained in:
pkufool 2022-07-23 22:42:18 +08:00
parent 353863a55c
commit 3c4c615e1f
6 changed files with 30 additions and 26 deletions

View File

@ -55,7 +55,7 @@ def greedy_search(
device=device, device=device,
dtype=torch.int64, dtype=torch.int64,
) )
# decoder_out is of shape (N, decoder_out_dim) # decoder_out is of shape (N, 1, decoder_out_dim)
decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.decoder(decoder_input, need_pad=False)
for t in range(T): for t in range(T):
@ -93,7 +93,7 @@ def modified_beam_search(
model: nn.Module, model: nn.Module,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
streams: List[DecodeStream], streams: List[DecodeStream],
beam: int = 4, num_active_paths: int = 4,
) -> None: ) -> None:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
@ -105,7 +105,7 @@ def modified_beam_search(
the encoder model. the encoder model.
streams: streams:
A list of stream objects. A list of stream objects.
beam: num_active_paths:
Number of active paths during the beam search. Number of active paths during the beam search.
""" """
assert encoder_out.ndim == 3, encoder_out.shape assert encoder_out.ndim == 3, encoder_out.shape
@ -171,7 +171,9 @@ def modified_beam_search(
) )
for i in range(batch_size): for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) topk_log_probs, topk_indexes = ragged_log_probs[i].topk(
num_active_paths
)
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")

View File

@ -17,13 +17,13 @@
""" """
Usage: Usage:
./pruned_transducer_stateless2/streaming_decode.py \ ./pruned_transducer_stateless/streaming_decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--decode-chunk-size 8 \ --decode-chunk-size 8 \
--left-context 32 \ --left-context 32 \
--right-context 0 \ --right-context 0 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless/exp \
--decoding_method greedy_search \ --decoding_method greedy_search \
--num-decode-streams 1000 --num-decode-streams 1000
""" """
@ -125,7 +125,7 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--beam-size", "--num-active-paths",
type=int, type=int,
default=4, default=4,
help="""An interger indicating how many candidates we will keep for each help="""An interger indicating how many candidates we will keep for each
@ -288,7 +288,7 @@ def decode_one_chunk(
model=model, model=model,
streams=decode_streams, streams=decode_streams,
encoder_out=encoder_out, encoder_out=encoder_out,
beam=params.beam_size, num_active_paths=params.num_active_paths,
) )
else: else:
raise ValueError( raise ValueError(
@ -417,7 +417,7 @@ def decode_dataset(
f"max_states_{params.max_states}" f"max_states_{params.max_states}"
) )
elif params.decoding_method == "modified_beam_search": elif params.decoding_method == "modified_beam_search":
key = f"beam_size_{params.beam_size}" key = f"num_active_paths_{params.num_active_paths}"
else: else:
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"

View File

@ -55,7 +55,7 @@ def greedy_search(
device=device, device=device,
dtype=torch.int64, dtype=torch.int64,
) )
# decoder_out is of shape (N, decoder_out_dim) # decoder_out is of shape (N, 1, decoder_out_dim)
decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out) decoder_out = model.joiner.decoder_proj(decoder_out)
@ -96,7 +96,7 @@ def modified_beam_search(
model: nn.Module, model: nn.Module,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
streams: List[DecodeStream], streams: List[DecodeStream],
beam: int = 4, num_active_paths: int = 4,
) -> None: ) -> None:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
@ -108,7 +108,7 @@ def modified_beam_search(
the encoder model. the encoder model.
streams: streams:
A list of stream objects. A list of stream objects.
beam: num_active_paths:
Number of active paths during the beam search. Number of active paths during the beam search.
""" """
assert encoder_out.ndim == 3, encoder_out.shape assert encoder_out.ndim == 3, encoder_out.shape
@ -177,7 +177,9 @@ def modified_beam_search(
) )
for i in range(batch_size): for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) topk_log_probs, topk_indexes = ragged_log_probs[i].topk(
num_active_paths
)
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")

View File

@ -125,7 +125,7 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--beam-size", "--num_active_paths",
type=int, type=int,
default=4, default=4,
help="""An interger indicating how many candidates we will keep for each help="""An interger indicating how many candidates we will keep for each
@ -290,7 +290,7 @@ def decode_one_chunk(
model=model, model=model,
streams=decode_streams, streams=decode_streams,
encoder_out=encoder_out, encoder_out=encoder_out,
beam=params.beam_size, num_active_paths=params.num_active_paths,
) )
else: else:
raise ValueError( raise ValueError(
@ -420,7 +420,7 @@ def decode_dataset(
f"max_states_{params.max_states}" f"max_states_{params.max_states}"
) )
elif params.decoding_method == "modified_beam_search": elif params.decoding_method == "modified_beam_search":
key = f"beam_size_{params.beam_size}" key = f"num_active_paths_{params.num_active_paths}"
else: else:
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"

View File

@ -17,13 +17,13 @@
""" """
Usage: Usage:
./pruned_transducer_stateless2/streaming_decode.py \ ./pruned_transducer_stateless3/streaming_decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--left-context 32 \ --left-context 32 \
--decode-chunk-size 8 \ --decode-chunk-size 8 \
--right-context 0 \ --right-context 0 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless3/exp \
--decoding_method greedy_search \ --decoding_method greedy_search \
--num-decode-streams 1000 --num-decode-streams 1000
""" """
@ -126,7 +126,7 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--beam-size", "--num_active_paths",
type=int, type=int,
default=4, default=4,
help="""An interger indicating how many candidates we will keep for each help="""An interger indicating how many candidates we will keep for each
@ -291,7 +291,7 @@ def decode_one_chunk(
model=model, model=model,
streams=decode_streams, streams=decode_streams,
encoder_out=encoder_out, encoder_out=encoder_out,
beam=params.beam_size, num_active_paths=params.num_active_paths,
) )
else: else:
raise ValueError( raise ValueError(
@ -421,7 +421,7 @@ def decode_dataset(
f"max_states_{params.max_states}" f"max_states_{params.max_states}"
) )
elif params.decoding_method == "modified_beam_search": elif params.decoding_method == "modified_beam_search":
key = f"beam_size_{params.beam_size}" key = f"num_active_paths_{params.num_active_paths}"
else: else:
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"

View File

@ -17,13 +17,13 @@
""" """
Usage: Usage:
./pruned_transducer_stateless2/streaming_decode.py \ ./pruned_transducer_stateless4/streaming_decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--left-context 32 \ --left-context 32 \
--decode-chunk-size 8 \ --decode-chunk-size 8 \
--right-context 0 \ --right-context 0 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless4/exp \
--decoding_method greedy_search \ --decoding_method greedy_search \
--num-decode-streams 200 --num-decode-streams 200
""" """
@ -138,7 +138,7 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--beam-size", "--num_active_paths",
type=int, type=int,
default=4, default=4,
help="""An interger indicating how many candidates we will keep for each help="""An interger indicating how many candidates we will keep for each
@ -303,7 +303,7 @@ def decode_one_chunk(
model=model, model=model,
streams=decode_streams, streams=decode_streams,
encoder_out=encoder_out, encoder_out=encoder_out,
beam=params.beam_size, num_active_paths=params.num_active_paths,
) )
else: else:
raise ValueError( raise ValueError(
@ -433,7 +433,7 @@ def decode_dataset(
f"max_states_{params.max_states}" f"max_states_{params.max_states}"
) )
elif params.decoding_method == "modified_beam_search": elif params.decoding_method == "modified_beam_search":
key = f"beam_size_{params.beam_size}" key = f"num_active_paths_{params.num_active_paths}"
else: else:
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"