Decoding scripts

Signed-off-by: wd929 <dingwen929@gmail.com>
This commit is contained in:
wd929 2023-10-24 17:08:14 +08:00
parent 5aafaa35bd
commit eb11c2486f

View File

@ -2,6 +2,7 @@
# #
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, # Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao) # Zengwei Yao)
# 2023 NVIDIA Corporation (Wen Ding)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -20,9 +21,9 @@
Usage: Usage:
(1) greedy search (1) greedy search
./zipformer/decode.py \ ./zipformer/decode.py \
--epoch 28 \ --epoch 13 \
--avg 15 \ --avg 10 \
--exp-dir ./zipformer/exp \ --exp-dir ./zipformer/exp_maxdur500_lr0.0225/ \
--max-duration 600 \ --max-duration 600 \
--decoding-method greedy_search --decoding-method greedy_search
@ -106,7 +107,7 @@ import k2
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import ICMCAsrDataModule
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search_nbest, fast_beam_search_nbest,
@ -729,7 +730,7 @@ def save_results(
) )
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True f, f"{test_set_name}-{key}", results, enable_log=True, compute_CER=True
) )
test_set_wers[key] = wer test_set_wers[key] = wer
@ -755,7 +756,7 @@ def save_results(
@torch.no_grad() @torch.no_grad()
def main(): def main():
parser = get_parser() parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser) ICMCAsrDataModule.add_arguments(parser)
LmScorer.add_arguments(parser) LmScorer.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
@ -1014,16 +1015,16 @@ def main():
# we need cut ids to display recognition results. # we need cut ids to display recognition results.
args.return_cuts = True args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args) icmc = ICMCAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts() test_ihm_cuts = icmc.dev_ihm_cuts()
test_other_cuts = librispeech.test_other_cuts() test_shm_cuts = icmc.dev_shm_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) test_ihm_dl = icmc.test_dataloaders(test_ihm_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts) test_shm_dl = icmc.test_dataloaders(test_shm_cuts)
test_sets = ["test-clean", "test-other"] test_sets = ["test-ihm", "test-shm"]
test_dl = [test_clean_dl, test_other_dl] test_dl = [test_ihm_dl, test_shm_dl]
for test_set, test_dl in zip(test_sets, test_dl): for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset( results_dict = decode_dataset(