Decoding scripts

Signed-off-by: wd929 <dingwen929@gmail.com>
This commit is contained in:
wd929 2023-10-24 17:08:14 +08:00
parent 5aafaa35bd
commit eb11c2486f

View File

@ -2,6 +2,7 @@
#
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
# 2023 NVIDIA Corporation (Wen Ding)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -20,9 +21,9 @@
Usage:
(1) greedy search
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--epoch 13 \
--avg 10 \
--exp-dir ./zipformer/exp_maxdur500_lr0.0225/ \
--max-duration 600 \
--decoding-method greedy_search
@ -106,7 +107,7 @@ import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from asr_datamodule import ICMCAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
@ -729,7 +730,7 @@ def save_results(
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
f, f"{test_set_name}-{key}", results, enable_log=True, compute_CER=True
)
test_set_wers[key] = wer
@ -755,7 +756,7 @@ def save_results(
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
ICMCAsrDataModule.add_arguments(parser)
LmScorer.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
@ -1014,16 +1015,16 @@ def main():
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
icmc = ICMCAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_ihm_cuts = icmc.dev_ihm_cuts()
test_shm_cuts = icmc.dev_shm_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_ihm_dl = icmc.test_dataloaders(test_ihm_cuts)
test_shm_dl = icmc.test_dataloaders(test_shm_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
test_sets = ["test-ihm", "test-shm"]
test_dl = [test_ihm_dl, test_shm_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(