Minor fixes for decoding.

This commit is contained in:
Fangjun Kuang 2022-04-28 10:39:08 +08:00
parent 52b3ed2920
commit b0e4e5cf31
3 changed files with 44 additions and 19 deletions

View File

@ -24,15 +24,24 @@ Usage:
--exp-dir ./transducer_lstm/exp \ --exp-dir ./transducer_lstm/exp \
--max-duration 100 \ --max-duration 100 \
--decoding-method greedy_search --decoding-method greedy_search
(2) beam search
(2) beam search
./transducer_lstm/decode.py \ ./transducer_lstm/decode.py \
--epoch 14 \ --epoch 14 \
--avg 7 \ --avg 7 \
--exp-dir ./transducer_lstm/exp \ --exp-dir ./transducer_lstm/exp \
--max-duration 100 \ --max-duration 100 \
--decoding-method beam_search \ --decoding-method beam_search \
--beam-size 8 --beam-size 4
(3) modified beam search
./transducer_lstm/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_lstm/exp \
--max-duration 100 \
--decoding-method modified_beam_search \
--beam-size 4
""" """
@ -71,14 +80,14 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--epoch", "--epoch",
type=int, type=int,
default=77, default=29,
help="It specifies the checkpoint to use for decoding." help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.", "Note: Epoch counts from 0.",
) )
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=55, default=13,
help="Number of checkpoints to average. Automatically select " help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ", "'--epoch'. ",
@ -112,8 +121,9 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--beam-size", "--beam-size",
type=int, type=int,
default=5, default=4,
help="Used only when --decoding-method is beam_search", help="""Used only when --decoding-method is
beam_search or modified_beam_search""",
) )
parser.add_argument( parser.add_argument(
@ -123,7 +133,6 @@ def get_parser():
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram", "2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
@ -348,12 +357,19 @@ def main():
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
assert params.decoding_method in ("greedy_search", "beam_search") assert params.decoding_method in (
"greedy_search",
"beam_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.decoding_method == "beam_search": if "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}" params.suffix += f"-beam-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started") logging.info("Decoding started")
@ -423,8 +439,5 @@ def main():
logging.info("Done!") logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -26,7 +26,7 @@ class Joiner(nn.Module):
self.output_linear = nn.Linear(input_dim, output_dim) self.output_linear = nn.Linear(input_dim, output_dim)
def forward( def forward(
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor self, encoder_out: torch.Tensor, decoder_out: torch.Tensor, *unused
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
@ -51,5 +51,7 @@ class Joiner(nn.Module):
logit = F.relu(logit) logit = F.relu(logit)
output = self.output_linear(logit) output = self.output_linear(logit)
if not self.training:
output = output.squeeze(2).squeeze(1)
return output return output

View File

@ -634,13 +634,23 @@ def run(rank, world_size, args):
train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.filter(remove_short_and_long_utt)
try:
num_left = len(train_cuts) num_left = len(train_cuts)
num_removed = num_in_total - num_left num_removed = num_in_total - num_left
removed_percent = num_removed / num_in_total * 100 removed_percent = num_removed / num_in_total * 100
logging.info(f"Before removing short and long utterances: {num_in_total}") logging.info(
f"Before removing short and long utterances: {num_in_total}"
)
logging.info(f"After removing short and long utterances: {num_left}") logging.info(f"After removing short and long utterances: {num_left}")
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") logging.info(
f"Removed {num_removed} utterances ({removed_percent:.5f}%)"
)
except TypeError as e:
# You can ignore this error as previous versions of Lhotse work fine
# for the above code. In recent versions of Lhotse, it uses
# lazy filter, producing cutsets that don't have the __len__ method
logging.info(str(e))
train_dl = librispeech.train_dataloaders(train_cuts) train_dl = librispeech.train_dataloaders(train_cuts)