From 6018f222df9d8dda03446e7b497d776d42b8ce6e Mon Sep 17 00:00:00 2001 From: marcoyang Date: Mon, 13 Feb 2023 16:21:01 +0800 Subject: [PATCH] update decoding files --- .../ASR/lstm_transducer_stateless3/decode.py | 2 + .../streaming_decode.py | 74 ++++++++++++++----- 2 files changed, 57 insertions(+), 19 deletions(-) diff --git a/egs/tal_csasr/ASR/lstm_transducer_stateless3/decode.py b/egs/tal_csasr/ASR/lstm_transducer_stateless3/decode.py index bdbe69b06..f40d22cd8 100755 --- a/egs/tal_csasr/ASR/lstm_transducer_stateless3/decode.py +++ b/egs/tal_csasr/ASR/lstm_transducer_stateless3/decode.py @@ -704,10 +704,12 @@ def main(): tal_csasr = TAL_CSASRAsrDataModule(args) dev_cuts = tal_csasr.valid_cuts() + dev_cuts = dev_cuts.subset(first=300) dev_cuts = dev_cuts.map(text_normalize_for_cut) dev_dl = tal_csasr.valid_dataloaders(dev_cuts) test_cuts = tal_csasr.test_cuts() + test_cuts = test_cuts.subset(first=300) test_cuts = test_cuts.map(text_normalize_for_cut) test_dl = tal_csasr.test_dataloaders(test_cuts) diff --git a/egs/tal_csasr/ASR/lstm_transducer_stateless3/streaming_decode.py b/egs/tal_csasr/ASR/lstm_transducer_stateless3/streaming_decode.py index 109746ed5..2bc5bfe8c 100644 --- a/egs/tal_csasr/ASR/lstm_transducer_stateless3/streaming_decode.py +++ b/egs/tal_csasr/ASR/lstm_transducer_stateless3/streaming_decode.py @@ -55,6 +55,7 @@ Usage: --max-contexts 4 \ --max-states 8 """ +import re import argparse import logging import warnings @@ -66,11 +67,13 @@ import numpy as np import sentencepiece as spm import torch import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule +from asr_datamodule import TAL_CSASRAsrDataModule from beam_search import Hypothesis, HypothesisList, get_hyps_shape from kaldifeat import Fbank, FbankOptions from lhotse import CutSet +from lhotse.cut import Cut from lstm import LOG_EPSILON, stack_states, unstack_states +from local.text_normalize import text_normalize from stream import Stream from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model @@ -82,6 +85,8 @@ from icefall.checkpoint import ( load_checkpoint, ) from icefall.decode import one_best_decoding +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler +from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, get_texts, @@ -143,10 +148,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--lang-dir", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_char", + help="Path to the dir containing bpe.model and tokens.txt", ) parser.add_argument( @@ -617,12 +622,25 @@ def create_streaming_feature_extractor() -> Fbank: opts.mel_opts.num_bins = 80 return Fbank(opts) +def filter_zh_en(text: str): + import re + pattern = re.compile(r"([\u4e00-\u9fff])") + + chars = pattern.split(text.upper()) + chars_new = [] + for char in chars: + if char != "": + tokens = char.strip().split(" ") + chars_new.extend(tokens) + return chars_new def decode_dataset( cuts: CutSet, model: nn.Module, params: AttributeDict, sp: spm.SentencePieceProcessor, + lexicon: Lexicon, + graph_compiler: CharCtcTrainingGraphCompiler, decoding_graph: Optional[k2.Fsa] = None, ): """Decode dataset. @@ -691,11 +709,12 @@ def decode_dataset( ) for i in sorted(finished_streams, reverse=True): + hyp = streams[i].decoding_result() decode_results.append( ( streams[i].id, - streams[i].ground_truth.split(), - sp.decode(streams[i].decoding_result()).split(), + filter_zh_en(streams[i].ground_truth), + sp.decode([lexicon.token_table[idx] for idx in hyp]), ) ) del streams[i] @@ -712,11 +731,12 @@ def decode_dataset( ) for i in sorted(finished_streams, reverse=True): + hyp = streams[i].decoding_result() decode_results.append( ( streams[i].id, - streams[i].ground_truth.split(), - sp.decode(streams[i].decoding_result()).split(), + filter_zh_en(streams[i].ground_truth), + [sp.decode(lexicon.token_table[idx]) for idx in hyp], ) ) del streams[i] @@ -781,7 +801,7 @@ def save_results( @torch.no_grad() def main(): parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) + TAL_CSASRAsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -822,13 +842,17 @@ def main(): logging.info(f"Device: {device}") + bpe_model = params.lang_dir + "/bpe.model" sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + sp.load(bpe_model) - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + lexicon = Lexicon(params.lang_dir) + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 params.device = device @@ -924,13 +948,23 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - librispeech = LibriSpeechAsrDataModule(args) + def text_normalize_for_cut(c: Cut): + # Text normalize for each sample + text = c.supervisions[0].text + text = text.strip("\n").strip("\t") + c.supervisions[0].text = text_normalize(text) + return c + + tal_csasr = TAL_CSASRAsrDataModule(args) - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() + dev_cuts = tal_csasr.valid_cuts() + dev_cuts = dev_cuts.map(text_normalize_for_cut) + + test_cuts = tal_csasr.test_cuts() + test_cuts = test_cuts.map(text_normalize_for_cut) - test_sets = ["test-clean", "test-other"] - test_cuts = [test_clean_cuts, test_other_cuts] + test_sets = ["dev", "test"] + test_cuts = [dev_cuts, test_cuts] for test_set, test_cut in zip(test_sets, test_cuts): results_dict = decode_dataset( @@ -938,6 +972,8 @@ def main(): model=model, params=params, sp=sp, + lexicon=lexicon, + graph_compiler=graph_compiler, decoding_graph=decoding_graph, )