mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fix decode.py
This commit is contained in:
parent
1930d72b17
commit
981bf74364
@ -18,18 +18,18 @@
|
|||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
(1) greedy search
|
(1) greedy search
|
||||||
./transducer_stateless/decode.py \
|
./transducer_stateless_multi_datasets/decode.py \
|
||||||
--epoch 14 \
|
--epoch 14 \
|
||||||
--avg 7 \
|
--avg 7 \
|
||||||
--exp-dir ./transducer_stateless/exp \
|
--exp-dir ./transducer_stateless_multi_datasets/exp \
|
||||||
--max-duration 100 \
|
--max-duration 100 \
|
||||||
--decoding-method greedy_search
|
--decoding-method greedy_search
|
||||||
|
|
||||||
(2) beam search
|
(2) beam search
|
||||||
./transducer_stateless/decode.py \
|
./transducer_stateless_multi_datasets/decode.py \
|
||||||
--epoch 14 \
|
--epoch 14 \
|
||||||
--avg 7 \
|
--avg 7 \
|
||||||
--exp-dir ./transducer_stateless/exp \
|
--exp-dir ./transducer_stateless_multi_datasets/exp \
|
||||||
--max-duration 100 \
|
--max-duration 100 \
|
||||||
--decoding-method beam_search \
|
--decoding-method beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
@ -45,11 +45,12 @@ from typing import Dict, List, Tuple
|
|||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import AsrDataModule
|
||||||
from beam_search import beam_search, greedy_search, modified_beam_search
|
from beam_search import beam_search, greedy_search, modified_beam_search
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
|
from librispeech import LibriSpeech
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
@ -86,7 +87,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="transducer_stateless/exp",
|
default="transducer_stateless_multi_datasets/exp",
|
||||||
help="The experiment dir",
|
help="The experiment dir",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -190,11 +191,17 @@ def get_transducer_model(params: AttributeDict):
|
|||||||
decoder = get_decoder_model(params)
|
decoder = get_decoder_model(params)
|
||||||
joiner = get_joiner_model(params)
|
joiner = get_joiner_model(params)
|
||||||
|
|
||||||
|
decoder_giga = get_decoder_model(params)
|
||||||
|
joiner_giga = get_joiner_model(params)
|
||||||
|
|
||||||
model = Transducer(
|
model = Transducer(
|
||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
decoder=decoder,
|
decoder=decoder,
|
||||||
joiner=joiner,
|
joiner=joiner,
|
||||||
|
decoder_giga=decoder_giga,
|
||||||
|
joiner_giga=joiner_giga,
|
||||||
)
|
)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -389,7 +396,7 @@ def save_results(
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
AsrDataModule.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
@ -441,7 +448,9 @@ def main():
|
|||||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
logging.info(f"averaging {filenames}")
|
logging.info(f"averaging {filenames}")
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.load_state_dict(
|
||||||
|
average_checkpoints(filenames, device=device), strict=False
|
||||||
|
)
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
@ -450,13 +459,14 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
asr_datamodule = AsrDataModule(args)
|
||||||
|
librispeech = LibriSpeech(manifest_dir=args.manifest_dir)
|
||||||
|
|
||||||
test_clean_cuts = librispeech.test_clean_cuts()
|
test_clean_cuts = librispeech.test_clean_cuts()
|
||||||
test_other_cuts = librispeech.test_other_cuts()
|
test_other_cuts = librispeech.test_other_cuts()
|
||||||
|
|
||||||
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
test_clean_dl = asr_datamodule.test_dataloaders(test_clean_cuts)
|
||||||
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
test_other_dl = asr_datamodule.test_dataloaders(test_other_cuts)
|
||||||
|
|
||||||
test_sets = ["test-clean", "test-other"]
|
test_sets = ["test-clean", "test-other"]
|
||||||
test_dl = [test_clean_dl, test_other_dl]
|
test_dl = [test_clean_dl, test_other_dl]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user