Fix decode.py

This commit is contained in:
Fangjun Kuang 2022-02-17 18:31:50 +08:00
parent 1930d72b17
commit 981bf74364

View File

@ -18,18 +18,18 @@
""" """
Usage: Usage:
(1) greedy search (1) greedy search
./transducer_stateless/decode.py \ ./transducer_stateless_multi_datasets/decode.py \
--epoch 14 \ --epoch 14 \
--avg 7 \ --avg 7 \
--exp-dir ./transducer_stateless/exp \ --exp-dir ./transducer_stateless_multi_datasets/exp \
--max-duration 100 \ --max-duration 100 \
--decoding-method greedy_search --decoding-method greedy_search
(2) beam search (2) beam search
./transducer_stateless/decode.py \ ./transducer_stateless_multi_datasets/decode.py \
--epoch 14 \ --epoch 14 \
--avg 7 \ --avg 7 \
--exp-dir ./transducer_stateless/exp \ --exp-dir ./transducer_stateless_multi_datasets/exp \
--max-duration 100 \ --max-duration 100 \
--decoding-method beam_search \ --decoding-method beam_search \
--beam-size 4 --beam-size 4
@ -45,11 +45,12 @@ from typing import Dict, List, Tuple
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import AsrDataModule
from beam_search import beam_search, greedy_search, modified_beam_search from beam_search import beam_search, greedy_search, modified_beam_search
from conformer import Conformer from conformer import Conformer
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
from librispeech import LibriSpeech
from model import Transducer from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
@ -86,7 +87,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="transducer_stateless/exp", default="transducer_stateless_multi_datasets/exp",
help="The experiment dir", help="The experiment dir",
) )
@ -190,11 +191,17 @@ def get_transducer_model(params: AttributeDict):
decoder = get_decoder_model(params) decoder = get_decoder_model(params)
joiner = get_joiner_model(params) joiner = get_joiner_model(params)
decoder_giga = get_decoder_model(params)
joiner_giga = get_joiner_model(params)
model = Transducer( model = Transducer(
encoder=encoder, encoder=encoder,
decoder=decoder, decoder=decoder,
joiner=joiner, joiner=joiner,
decoder_giga=decoder_giga,
joiner_giga=joiner_giga,
) )
return model return model
@ -389,7 +396,7 @@ def save_results(
@torch.no_grad() @torch.no_grad()
def main(): def main():
parser = get_parser() parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser) AsrDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
@ -441,7 +448,9 @@ def main():
filenames.append(f"{params.exp_dir}/epoch-{i}.pt") filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}") logging.info(f"averaging {filenames}")
model.to(device) model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
model.to(device) model.to(device)
model.eval() model.eval()
@ -450,13 +459,14 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args) asr_datamodule = AsrDataModule(args)
librispeech = LibriSpeech(manifest_dir=args.manifest_dir)
test_clean_cuts = librispeech.test_clean_cuts() test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts() test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) test_clean_dl = asr_datamodule.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts) test_other_dl = asr_datamodule.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"] test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl] test_dl = [test_clean_dl, test_other_dl]