mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 14:44:18 +00:00
Minor fixes for decoding.
This commit is contained in:
parent
52b3ed2920
commit
b0e4e5cf31
@ -24,15 +24,24 @@ Usage:
|
||||
--exp-dir ./transducer_lstm/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method greedy_search
|
||||
(2) beam search
|
||||
|
||||
(2) beam search
|
||||
./transducer_lstm/decode.py \
|
||||
--epoch 14 \
|
||||
--avg 7 \
|
||||
--exp-dir ./transducer_lstm/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 8
|
||||
--beam-size 4
|
||||
|
||||
(3) modified beam search
|
||||
./transducer_lstm/decode.py \
|
||||
--epoch 14 \
|
||||
--avg 7 \
|
||||
--exp-dir ./transducer_lstm/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
"""
|
||||
|
||||
|
||||
@ -71,14 +80,14 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=77,
|
||||
default=29,
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=55,
|
||||
default=13,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch'. ",
|
||||
@ -112,8 +121,9 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Used only when --decoding-method is beam_search",
|
||||
default=4,
|
||||
help="""Used only when --decoding-method is
|
||||
beam_search or modified_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -123,7 +133,6 @@ def get_parser():
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
@ -348,12 +357,19 @@ def main():
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in ("greedy_search", "beam_search")
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"modified_beam_search",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
if params.decoding_method == "beam_search":
|
||||
if "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
@ -423,8 +439,5 @@ def main():
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -26,7 +26,7 @@ class Joiner(nn.Module):
|
||||
self.output_linear = nn.Linear(input_dim, output_dim)
|
||||
|
||||
def forward(
|
||||
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
|
||||
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor, *unused
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -51,5 +51,7 @@ class Joiner(nn.Module):
|
||||
logit = F.relu(logit)
|
||||
|
||||
output = self.output_linear(logit)
|
||||
if not self.training:
|
||||
output = output.squeeze(2).squeeze(1)
|
||||
|
||||
return output
|
||||
|
@ -634,13 +634,23 @@ def run(rank, world_size, args):
|
||||
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
|
||||
num_left = len(train_cuts)
|
||||
num_removed = num_in_total - num_left
|
||||
removed_percent = num_removed / num_in_total * 100
|
||||
try:
|
||||
num_left = len(train_cuts)
|
||||
num_removed = num_in_total - num_left
|
||||
removed_percent = num_removed / num_in_total * 100
|
||||
|
||||
logging.info(f"Before removing short and long utterances: {num_in_total}")
|
||||
logging.info(f"After removing short and long utterances: {num_left}")
|
||||
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
|
||||
logging.info(
|
||||
f"Before removing short and long utterances: {num_in_total}"
|
||||
)
|
||||
logging.info(f"After removing short and long utterances: {num_left}")
|
||||
logging.info(
|
||||
f"Removed {num_removed} utterances ({removed_percent:.5f}%)"
|
||||
)
|
||||
except TypeError as e:
|
||||
# You can ignore this error as previous versions of Lhotse work fine
|
||||
# for the above code. In recent versions of Lhotse, it uses
|
||||
# lazy filter, producing cutsets that don't have the __len__ method
|
||||
logging.info(str(e))
|
||||
|
||||
train_dl = librispeech.train_dataloaders(train_cuts)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user