mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Fix comments
This commit is contained in:
parent
353863a55c
commit
3c4c615e1f
@ -55,7 +55,7 @@ def greedy_search(
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
# decoder_out is of shape (N, decoder_out_dim)
|
||||
# decoder_out is of shape (N, 1, decoder_out_dim)
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
|
||||
for t in range(T):
|
||||
@ -93,7 +93,7 @@ def modified_beam_search(
|
||||
model: nn.Module,
|
||||
encoder_out: torch.Tensor,
|
||||
streams: List[DecodeStream],
|
||||
beam: int = 4,
|
||||
num_active_paths: int = 4,
|
||||
) -> None:
|
||||
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
||||
|
||||
@ -105,7 +105,7 @@ def modified_beam_search(
|
||||
the encoder model.
|
||||
streams:
|
||||
A list of stream objects.
|
||||
beam:
|
||||
num_active_paths:
|
||||
Number of active paths during the beam search.
|
||||
"""
|
||||
assert encoder_out.ndim == 3, encoder_out.shape
|
||||
@ -171,7 +171,9 @@ def modified_beam_search(
|
||||
)
|
||||
|
||||
for i in range(batch_size):
|
||||
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
||||
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(
|
||||
num_active_paths
|
||||
)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
|
@ -17,13 +17,13 @@
|
||||
|
||||
"""
|
||||
Usage:
|
||||
./pruned_transducer_stateless2/streaming_decode.py \
|
||||
./pruned_transducer_stateless/streaming_decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--decode-chunk-size 8 \
|
||||
--left-context 32 \
|
||||
--right-context 0 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--exp-dir ./pruned_transducer_stateless/exp \
|
||||
--decoding_method greedy_search \
|
||||
--num-decode-streams 1000
|
||||
"""
|
||||
@ -125,7 +125,7 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
"--num-active-paths",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An interger indicating how many candidates we will keep for each
|
||||
@ -288,7 +288,7 @@ def decode_one_chunk(
|
||||
model=model,
|
||||
streams=decode_streams,
|
||||
encoder_out=encoder_out,
|
||||
beam=params.beam_size,
|
||||
num_active_paths=params.num_active_paths,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
@ -417,7 +417,7 @@ def decode_dataset(
|
||||
f"max_states_{params.max_states}"
|
||||
)
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
key = f"beam_size_{params.beam_size}"
|
||||
key = f"num_active_paths_{params.num_active_paths}"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
|
@ -55,7 +55,7 @@ def greedy_search(
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
# decoder_out is of shape (N, decoder_out_dim)
|
||||
# decoder_out is of shape (N, 1, decoder_out_dim)
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
|
||||
@ -96,7 +96,7 @@ def modified_beam_search(
|
||||
model: nn.Module,
|
||||
encoder_out: torch.Tensor,
|
||||
streams: List[DecodeStream],
|
||||
beam: int = 4,
|
||||
num_active_paths: int = 4,
|
||||
) -> None:
|
||||
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
||||
|
||||
@ -108,7 +108,7 @@ def modified_beam_search(
|
||||
the encoder model.
|
||||
streams:
|
||||
A list of stream objects.
|
||||
beam:
|
||||
num_active_paths:
|
||||
Number of active paths during the beam search.
|
||||
"""
|
||||
assert encoder_out.ndim == 3, encoder_out.shape
|
||||
@ -177,7 +177,9 @@ def modified_beam_search(
|
||||
)
|
||||
|
||||
for i in range(batch_size):
|
||||
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
||||
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(
|
||||
num_active_paths
|
||||
)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
|
@ -125,7 +125,7 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
"--num_active_paths",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An interger indicating how many candidates we will keep for each
|
||||
@ -290,7 +290,7 @@ def decode_one_chunk(
|
||||
model=model,
|
||||
streams=decode_streams,
|
||||
encoder_out=encoder_out,
|
||||
beam=params.beam_size,
|
||||
num_active_paths=params.num_active_paths,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
@ -420,7 +420,7 @@ def decode_dataset(
|
||||
f"max_states_{params.max_states}"
|
||||
)
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
key = f"beam_size_{params.beam_size}"
|
||||
key = f"num_active_paths_{params.num_active_paths}"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
|
@ -17,13 +17,13 @@
|
||||
|
||||
"""
|
||||
Usage:
|
||||
./pruned_transducer_stateless2/streaming_decode.py \
|
||||
./pruned_transducer_stateless3/streaming_decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--left-context 32 \
|
||||
--decode-chunk-size 8 \
|
||||
--right-context 0 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||
--decoding_method greedy_search \
|
||||
--num-decode-streams 1000
|
||||
"""
|
||||
@ -126,7 +126,7 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
"--num_active_paths",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An interger indicating how many candidates we will keep for each
|
||||
@ -291,7 +291,7 @@ def decode_one_chunk(
|
||||
model=model,
|
||||
streams=decode_streams,
|
||||
encoder_out=encoder_out,
|
||||
beam=params.beam_size,
|
||||
num_active_paths=params.num_active_paths,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
@ -421,7 +421,7 @@ def decode_dataset(
|
||||
f"max_states_{params.max_states}"
|
||||
)
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
key = f"beam_size_{params.beam_size}"
|
||||
key = f"num_active_paths_{params.num_active_paths}"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
|
@ -17,13 +17,13 @@
|
||||
|
||||
"""
|
||||
Usage:
|
||||
./pruned_transducer_stateless2/streaming_decode.py \
|
||||
./pruned_transducer_stateless4/streaming_decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--left-context 32 \
|
||||
--decode-chunk-size 8 \
|
||||
--right-context 0 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--exp-dir ./pruned_transducer_stateless4/exp \
|
||||
--decoding_method greedy_search \
|
||||
--num-decode-streams 200
|
||||
"""
|
||||
@ -138,7 +138,7 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
"--num_active_paths",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An interger indicating how many candidates we will keep for each
|
||||
@ -303,7 +303,7 @@ def decode_one_chunk(
|
||||
model=model,
|
||||
streams=decode_streams,
|
||||
encoder_out=encoder_out,
|
||||
beam=params.beam_size,
|
||||
num_active_paths=params.num_active_paths,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
@ -433,7 +433,7 @@ def decode_dataset(
|
||||
f"max_states_{params.max_states}"
|
||||
)
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
key = f"beam_size_{params.beam_size}"
|
||||
key = f"num_active_paths_{params.num_active_paths}"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
|
Loading…
x
Reference in New Issue
Block a user