diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py index c101d9397..4082c3e97 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py @@ -18,18 +18,18 @@ """ Usage: (1) greedy search -./transducer_stateless/decode.py \ +./transducer_stateless_multi_datasets/decode.py \ --epoch 14 \ --avg 7 \ - --exp-dir ./transducer_stateless/exp \ + --exp-dir ./transducer_stateless_multi_datasets/exp \ --max-duration 100 \ --decoding-method greedy_search (2) beam search -./transducer_stateless/decode.py \ +./transducer_stateless_multi_datasets/decode.py \ --epoch 14 \ --avg 7 \ - --exp-dir ./transducer_stateless/exp \ + --exp-dir ./transducer_stateless_multi_datasets/exp \ --max-duration 100 \ --decoding-method beam_search \ --beam-size 4 @@ -45,11 +45,12 @@ from typing import Dict, List, Tuple import sentencepiece as spm import torch import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule +from asr_datamodule import AsrDataModule from beam_search import beam_search, greedy_search, modified_beam_search from conformer import Conformer from decoder import Decoder from joiner import Joiner +from librispeech import LibriSpeech from model import Transducer from icefall.checkpoint import average_checkpoints, load_checkpoint @@ -86,7 +87,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/exp", + default="transducer_stateless_multi_datasets/exp", help="The experiment dir", ) @@ -190,11 +191,17 @@ def get_transducer_model(params: AttributeDict): decoder = get_decoder_model(params) joiner = get_joiner_model(params) + decoder_giga = get_decoder_model(params) + joiner_giga = get_joiner_model(params) + model = Transducer( encoder=encoder, decoder=decoder, joiner=joiner, + decoder_giga=decoder_giga, + joiner_giga=joiner_giga, ) + return model @@ -389,7 +396,7 @@ def save_results( @torch.no_grad() def main(): parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) + AsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -441,7 +448,9 @@ def main(): filenames.append(f"{params.exp_dir}/epoch-{i}.pt") logging.info(f"averaging {filenames}") model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) model.to(device) model.eval() @@ -450,13 +459,14 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - librispeech = LibriSpeechAsrDataModule(args) + asr_datamodule = AsrDataModule(args) + librispeech = LibriSpeech(manifest_dir=args.manifest_dir) test_clean_cuts = librispeech.test_clean_cuts() test_other_cuts = librispeech.test_other_cuts() - test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) - test_other_dl = librispeech.test_dataloaders(test_other_cuts) + test_clean_dl = asr_datamodule.test_dataloaders(test_clean_cuts) + test_other_dl = asr_datamodule.test_dataloaders(test_other_cuts) test_sets = ["test-clean", "test-other"] test_dl = [test_clean_dl, test_other_dl]