Fix decode.py

This commit is contained in:
Fangjun Kuang 2022-02-17 18:31:50 +08:00
parent 1930d72b17
commit 981bf74364

View File

@ -18,18 +18,18 @@
"""
Usage:
(1) greedy search
./transducer_stateless/decode.py \
./transducer_stateless_multi_datasets/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless/exp \
--exp-dir ./transducer_stateless_multi_datasets/exp \
--max-duration 100 \
--decoding-method greedy_search
(2) beam search
./transducer_stateless/decode.py \
./transducer_stateless_multi_datasets/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless/exp \
--exp-dir ./transducer_stateless_multi_datasets/exp \
--max-duration 100 \
--decoding-method beam_search \
--beam-size 4
@ -45,11 +45,12 @@ from typing import Dict, List, Tuple
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from asr_datamodule import AsrDataModule
from beam_search import beam_search, greedy_search, modified_beam_search
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from librispeech import LibriSpeech
from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
@ -86,7 +87,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="transducer_stateless/exp",
default="transducer_stateless_multi_datasets/exp",
help="The experiment dir",
)
@ -190,11 +191,17 @@ def get_transducer_model(params: AttributeDict):
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
decoder_giga = get_decoder_model(params)
joiner_giga = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
decoder_giga=decoder_giga,
joiner_giga=joiner_giga,
)
return model
@ -389,7 +396,7 @@ def save_results(
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
AsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
@ -441,7 +448,9 @@ def main():
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
model.to(device)
model.eval()
@ -450,13 +459,14 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args)
asr_datamodule = AsrDataModule(args)
librispeech = LibriSpeech(manifest_dir=args.manifest_dir)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_clean_dl = asr_datamodule.test_dataloaders(test_clean_cuts)
test_other_dl = asr_datamodule.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]