diff --git a/egs/librispeech/ASR/transducer_lstm/decode.py b/egs/librispeech/ASR/transducer_lstm/decode.py index 03d1b840b..8db103672 100755 --- a/egs/librispeech/ASR/transducer_lstm/decode.py +++ b/egs/librispeech/ASR/transducer_lstm/decode.py @@ -24,15 +24,24 @@ Usage: --exp-dir ./transducer_lstm/exp \ --max-duration 100 \ --decoding-method greedy_search -(2) beam search +(2) beam search ./transducer_lstm/decode.py \ --epoch 14 \ --avg 7 \ --exp-dir ./transducer_lstm/exp \ --max-duration 100 \ --decoding-method beam_search \ - --beam-size 8 + --beam-size 4 + +(3) modified beam search +./transducer_lstm/decode.py \ + --epoch 14 \ + --avg 7 \ + --exp-dir ./transducer_lstm/exp \ + --max-duration 100 \ + --decoding-method modified_beam_search \ + --beam-size 4 """ @@ -71,14 +80,14 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=77, + default=29, help="It specifies the checkpoint to use for decoding." "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, - default=55, + default=13, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " "'--epoch'. ", @@ -112,8 +121,9 @@ def get_parser(): parser.add_argument( "--beam-size", type=int, - default=5, - help="Used only when --decoding-method is beam_search", + default=4, + help="""Used only when --decoding-method is + beam_search or modified_beam_search""", ) parser.add_argument( @@ -123,7 +133,6 @@ def get_parser(): help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) - parser.add_argument( "--max-sym-per-frame", type=int, @@ -348,12 +357,19 @@ def main(): params = get_params() params.update(vars(args)) - assert params.decoding_method in ("greedy_search", "beam_search") + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "modified_beam_search", + ) params.res_dir = params.exp_dir / params.decoding_method params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if params.decoding_method == "beam_search": + if "beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -423,8 +439,5 @@ def main(): logging.info("Done!") -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - if __name__ == "__main__": main() diff --git a/egs/librispeech/ASR/transducer_lstm/joiner.py b/egs/librispeech/ASR/transducer_lstm/joiner.py index 0422f8a6f..8c3710011 100644 --- a/egs/librispeech/ASR/transducer_lstm/joiner.py +++ b/egs/librispeech/ASR/transducer_lstm/joiner.py @@ -26,7 +26,7 @@ class Joiner(nn.Module): self.output_linear = nn.Linear(input_dim, output_dim) def forward( - self, encoder_out: torch.Tensor, decoder_out: torch.Tensor + self, encoder_out: torch.Tensor, decoder_out: torch.Tensor, *unused ) -> torch.Tensor: """ Args: @@ -51,5 +51,7 @@ class Joiner(nn.Module): logit = F.relu(logit) output = self.output_linear(logit) + if not self.training: + output = output.squeeze(2).squeeze(1) return output diff --git a/egs/librispeech/ASR/transducer_lstm/train.py b/egs/librispeech/ASR/transducer_lstm/train.py index 7f4dc32cf..98859a58f 100755 --- a/egs/librispeech/ASR/transducer_lstm/train.py +++ b/egs/librispeech/ASR/transducer_lstm/train.py @@ -634,13 +634,23 @@ def run(rank, world_size, args): train_cuts = train_cuts.filter(remove_short_and_long_utt) - num_left = len(train_cuts) - num_removed = num_in_total - num_left - removed_percent = num_removed / num_in_total * 100 + try: + num_left = len(train_cuts) + num_removed = num_in_total - num_left + removed_percent = num_removed / num_in_total * 100 - logging.info(f"Before removing short and long utterances: {num_in_total}") - logging.info(f"After removing short and long utterances: {num_left}") - logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") + logging.info( + f"Before removing short and long utterances: {num_in_total}" + ) + logging.info(f"After removing short and long utterances: {num_left}") + logging.info( + f"Removed {num_removed} utterances ({removed_percent:.5f}%)" + ) + except TypeError as e: + # You can ignore this error as previous versions of Lhotse work fine + # for the above code. In recent versions of Lhotse, it uses + # lazy filter, producing cutsets that don't have the __len__ method + logging.info(str(e)) train_dl = librispeech.train_dataloaders(train_cuts)