Minor fixes for decoding.

This commit is contained in:
Fangjun Kuang 2022-04-28 10:39:08 +08:00
parent 52b3ed2920
commit b0e4e5cf31
3 changed files with 44 additions and 19 deletions

View File

@ -24,15 +24,24 @@ Usage:
--exp-dir ./transducer_lstm/exp \
--max-duration 100 \
--decoding-method greedy_search
(2) beam search
(2) beam search
./transducer_lstm/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_lstm/exp \
--max-duration 100 \
--decoding-method beam_search \
--beam-size 8
--beam-size 4
(3) modified beam search
./transducer_lstm/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_lstm/exp \
--max-duration 100 \
--decoding-method modified_beam_search \
--beam-size 4
"""
@ -71,14 +80,14 @@ def get_parser():
parser.add_argument(
"--epoch",
type=int,
default=77,
default=29,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=55,
default=13,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
@ -112,8 +121,9 @@ def get_parser():
parser.add_argument(
"--beam-size",
type=int,
default=5,
help="Used only when --decoding-method is beam_search",
default=4,
help="""Used only when --decoding-method is
beam_search or modified_beam_search""",
)
parser.add_argument(
@ -123,7 +133,6 @@ def get_parser():
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
@ -348,12 +357,19 @@ def main():
params = get_params()
params.update(vars(args))
assert params.decoding_method in ("greedy_search", "beam_search")
assert params.decoding_method in (
"greedy_search",
"beam_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.decoding_method == "beam_search":
if "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
@ -423,8 +439,5 @@ def main():
logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -26,7 +26,7 @@ class Joiner(nn.Module):
self.output_linear = nn.Linear(input_dim, output_dim)
def forward(
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor, *unused
) -> torch.Tensor:
"""
Args:
@ -51,5 +51,7 @@ class Joiner(nn.Module):
logit = F.relu(logit)
output = self.output_linear(logit)
if not self.training:
output = output.squeeze(2).squeeze(1)
return output

View File

@ -634,13 +634,23 @@ def run(rank, world_size, args):
train_cuts = train_cuts.filter(remove_short_and_long_utt)
num_left = len(train_cuts)
num_removed = num_in_total - num_left
removed_percent = num_removed / num_in_total * 100
try:
num_left = len(train_cuts)
num_removed = num_in_total - num_left
removed_percent = num_removed / num_in_total * 100
logging.info(f"Before removing short and long utterances: {num_in_total}")
logging.info(f"After removing short and long utterances: {num_left}")
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
logging.info(
f"Before removing short and long utterances: {num_in_total}"
)
logging.info(f"After removing short and long utterances: {num_left}")
logging.info(
f"Removed {num_removed} utterances ({removed_percent:.5f}%)"
)
except TypeError as e:
# You can ignore this error as previous versions of Lhotse work fine
# for the above code. In recent versions of Lhotse, it uses
# lazy filter, producing cutsets that don't have the __len__ method
logging.info(str(e))
train_dl = librispeech.train_dataloaders(train_cuts)