mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Fix comments
This commit is contained in:
parent
353863a55c
commit
3c4c615e1f
@ -55,7 +55,7 @@ def greedy_search(
|
|||||||
device=device,
|
device=device,
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
)
|
)
|
||||||
# decoder_out is of shape (N, decoder_out_dim)
|
# decoder_out is of shape (N, 1, decoder_out_dim)
|
||||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
|
|
||||||
for t in range(T):
|
for t in range(T):
|
||||||
@ -93,7 +93,7 @@ def modified_beam_search(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
streams: List[DecodeStream],
|
streams: List[DecodeStream],
|
||||||
beam: int = 4,
|
num_active_paths: int = 4,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
||||||
|
|
||||||
@ -105,7 +105,7 @@ def modified_beam_search(
|
|||||||
the encoder model.
|
the encoder model.
|
||||||
streams:
|
streams:
|
||||||
A list of stream objects.
|
A list of stream objects.
|
||||||
beam:
|
num_active_paths:
|
||||||
Number of active paths during the beam search.
|
Number of active paths during the beam search.
|
||||||
"""
|
"""
|
||||||
assert encoder_out.ndim == 3, encoder_out.shape
|
assert encoder_out.ndim == 3, encoder_out.shape
|
||||||
@ -171,7 +171,9 @@ def modified_beam_search(
|
|||||||
)
|
)
|
||||||
|
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(
|
||||||
|
num_active_paths
|
||||||
|
)
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
|
@ -17,13 +17,13 @@
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
./pruned_transducer_stateless2/streaming_decode.py \
|
./pruned_transducer_stateless/streaming_decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--decode-chunk-size 8 \
|
--decode-chunk-size 8 \
|
||||||
--left-context 32 \
|
--left-context 32 \
|
||||||
--right-context 0 \
|
--right-context 0 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless/exp \
|
||||||
--decoding_method greedy_search \
|
--decoding_method greedy_search \
|
||||||
--num-decode-streams 1000
|
--num-decode-streams 1000
|
||||||
"""
|
"""
|
||||||
@ -125,7 +125,7 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--beam-size",
|
"--num-active-paths",
|
||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
help="""An interger indicating how many candidates we will keep for each
|
help="""An interger indicating how many candidates we will keep for each
|
||||||
@ -288,7 +288,7 @@ def decode_one_chunk(
|
|||||||
model=model,
|
model=model,
|
||||||
streams=decode_streams,
|
streams=decode_streams,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
beam=params.beam_size,
|
num_active_paths=params.num_active_paths,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -417,7 +417,7 @@ def decode_dataset(
|
|||||||
f"max_states_{params.max_states}"
|
f"max_states_{params.max_states}"
|
||||||
)
|
)
|
||||||
elif params.decoding_method == "modified_beam_search":
|
elif params.decoding_method == "modified_beam_search":
|
||||||
key = f"beam_size_{params.beam_size}"
|
key = f"num_active_paths_{params.num_active_paths}"
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported decoding method: {params.decoding_method}"
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
|
@ -55,7 +55,7 @@ def greedy_search(
|
|||||||
device=device,
|
device=device,
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
)
|
)
|
||||||
# decoder_out is of shape (N, decoder_out_dim)
|
# decoder_out is of shape (N, 1, decoder_out_dim)
|
||||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||||
|
|
||||||
@ -96,7 +96,7 @@ def modified_beam_search(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
streams: List[DecodeStream],
|
streams: List[DecodeStream],
|
||||||
beam: int = 4,
|
num_active_paths: int = 4,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
||||||
|
|
||||||
@ -108,7 +108,7 @@ def modified_beam_search(
|
|||||||
the encoder model.
|
the encoder model.
|
||||||
streams:
|
streams:
|
||||||
A list of stream objects.
|
A list of stream objects.
|
||||||
beam:
|
num_active_paths:
|
||||||
Number of active paths during the beam search.
|
Number of active paths during the beam search.
|
||||||
"""
|
"""
|
||||||
assert encoder_out.ndim == 3, encoder_out.shape
|
assert encoder_out.ndim == 3, encoder_out.shape
|
||||||
@ -177,7 +177,9 @@ def modified_beam_search(
|
|||||||
)
|
)
|
||||||
|
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(
|
||||||
|
num_active_paths
|
||||||
|
)
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
|
@ -125,7 +125,7 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--beam-size",
|
"--num_active_paths",
|
||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
help="""An interger indicating how many candidates we will keep for each
|
help="""An interger indicating how many candidates we will keep for each
|
||||||
@ -290,7 +290,7 @@ def decode_one_chunk(
|
|||||||
model=model,
|
model=model,
|
||||||
streams=decode_streams,
|
streams=decode_streams,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
beam=params.beam_size,
|
num_active_paths=params.num_active_paths,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -420,7 +420,7 @@ def decode_dataset(
|
|||||||
f"max_states_{params.max_states}"
|
f"max_states_{params.max_states}"
|
||||||
)
|
)
|
||||||
elif params.decoding_method == "modified_beam_search":
|
elif params.decoding_method == "modified_beam_search":
|
||||||
key = f"beam_size_{params.beam_size}"
|
key = f"num_active_paths_{params.num_active_paths}"
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported decoding method: {params.decoding_method}"
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
|
@ -17,13 +17,13 @@
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
./pruned_transducer_stateless2/streaming_decode.py \
|
./pruned_transducer_stateless3/streaming_decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--left-context 32 \
|
--left-context 32 \
|
||||||
--decode-chunk-size 8 \
|
--decode-chunk-size 8 \
|
||||||
--right-context 0 \
|
--right-context 0 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
--decoding_method greedy_search \
|
--decoding_method greedy_search \
|
||||||
--num-decode-streams 1000
|
--num-decode-streams 1000
|
||||||
"""
|
"""
|
||||||
@ -126,7 +126,7 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--beam-size",
|
"--num_active_paths",
|
||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
help="""An interger indicating how many candidates we will keep for each
|
help="""An interger indicating how many candidates we will keep for each
|
||||||
@ -291,7 +291,7 @@ def decode_one_chunk(
|
|||||||
model=model,
|
model=model,
|
||||||
streams=decode_streams,
|
streams=decode_streams,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
beam=params.beam_size,
|
num_active_paths=params.num_active_paths,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -421,7 +421,7 @@ def decode_dataset(
|
|||||||
f"max_states_{params.max_states}"
|
f"max_states_{params.max_states}"
|
||||||
)
|
)
|
||||||
elif params.decoding_method == "modified_beam_search":
|
elif params.decoding_method == "modified_beam_search":
|
||||||
key = f"beam_size_{params.beam_size}"
|
key = f"num_active_paths_{params.num_active_paths}"
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported decoding method: {params.decoding_method}"
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
|
@ -17,13 +17,13 @@
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
./pruned_transducer_stateless2/streaming_decode.py \
|
./pruned_transducer_stateless4/streaming_decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--left-context 32 \
|
--left-context 32 \
|
||||||
--decode-chunk-size 8 \
|
--decode-chunk-size 8 \
|
||||||
--right-context 0 \
|
--right-context 0 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless4/exp \
|
||||||
--decoding_method greedy_search \
|
--decoding_method greedy_search \
|
||||||
--num-decode-streams 200
|
--num-decode-streams 200
|
||||||
"""
|
"""
|
||||||
@ -138,7 +138,7 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--beam-size",
|
"--num_active_paths",
|
||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
help="""An interger indicating how many candidates we will keep for each
|
help="""An interger indicating how many candidates we will keep for each
|
||||||
@ -303,7 +303,7 @@ def decode_one_chunk(
|
|||||||
model=model,
|
model=model,
|
||||||
streams=decode_streams,
|
streams=decode_streams,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
beam=params.beam_size,
|
num_active_paths=params.num_active_paths,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -433,7 +433,7 @@ def decode_dataset(
|
|||||||
f"max_states_{params.max_states}"
|
f"max_states_{params.max_states}"
|
||||||
)
|
)
|
||||||
elif params.decoding_method == "modified_beam_search":
|
elif params.decoding_method == "modified_beam_search":
|
||||||
key = f"beam_size_{params.beam_size}"
|
key = f"num_active_paths_{params.num_active_paths}"
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported decoding method: {params.decoding_method}"
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user