update decoding files

This commit is contained in:
marcoyang 2023-02-13 16:21:01 +08:00
parent b39ac0207e
commit 6018f222df
2 changed files with 57 additions and 19 deletions

View File

@ -704,10 +704,12 @@ def main():
tal_csasr = TAL_CSASRAsrDataModule(args) tal_csasr = TAL_CSASRAsrDataModule(args)
dev_cuts = tal_csasr.valid_cuts() dev_cuts = tal_csasr.valid_cuts()
dev_cuts = dev_cuts.subset(first=300)
dev_cuts = dev_cuts.map(text_normalize_for_cut) dev_cuts = dev_cuts.map(text_normalize_for_cut)
dev_dl = tal_csasr.valid_dataloaders(dev_cuts) dev_dl = tal_csasr.valid_dataloaders(dev_cuts)
test_cuts = tal_csasr.test_cuts() test_cuts = tal_csasr.test_cuts()
test_cuts = test_cuts.subset(first=300)
test_cuts = test_cuts.map(text_normalize_for_cut) test_cuts = test_cuts.map(text_normalize_for_cut)
test_dl = tal_csasr.test_dataloaders(test_cuts) test_dl = tal_csasr.test_dataloaders(test_cuts)

View File

@ -55,6 +55,7 @@ Usage:
--max-contexts 4 \ --max-contexts 4 \
--max-states 8 --max-states 8
""" """
import re
import argparse import argparse
import logging import logging
import warnings import warnings
@ -66,11 +67,13 @@ import numpy as np
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import TAL_CSASRAsrDataModule
from beam_search import Hypothesis, HypothesisList, get_hyps_shape from beam_search import Hypothesis, HypothesisList, get_hyps_shape
from kaldifeat import Fbank, FbankOptions from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet from lhotse import CutSet
from lhotse.cut import Cut
from lstm import LOG_EPSILON, stack_states, unstack_states from lstm import LOG_EPSILON, stack_states, unstack_states
from local.text_normalize import text_normalize
from stream import Stream from stream import Stream
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
@ -82,6 +85,8 @@ from icefall.checkpoint import (
load_checkpoint, load_checkpoint,
) )
from icefall.decode import one_best_decoding from icefall.decode import one_best_decoding
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
get_texts, get_texts,
@ -143,10 +148,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--lang-dir",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_char",
help="Path to the BPE model", help="Path to the dir containing bpe.model and tokens.txt",
) )
parser.add_argument( parser.add_argument(
@ -617,12 +622,25 @@ def create_streaming_feature_extractor() -> Fbank:
opts.mel_opts.num_bins = 80 opts.mel_opts.num_bins = 80
return Fbank(opts) return Fbank(opts)
def filter_zh_en(text: str):
import re
pattern = re.compile(r"([\u4e00-\u9fff])")
chars = pattern.split(text.upper())
chars_new = []
for char in chars:
if char != "":
tokens = char.strip().split(" ")
chars_new.extend(tokens)
return chars_new
def decode_dataset( def decode_dataset(
cuts: CutSet, cuts: CutSet,
model: nn.Module, model: nn.Module,
params: AttributeDict, params: AttributeDict,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
lexicon: Lexicon,
graph_compiler: CharCtcTrainingGraphCompiler,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
): ):
"""Decode dataset. """Decode dataset.
@ -691,11 +709,12 @@ def decode_dataset(
) )
for i in sorted(finished_streams, reverse=True): for i in sorted(finished_streams, reverse=True):
hyp = streams[i].decoding_result()
decode_results.append( decode_results.append(
( (
streams[i].id, streams[i].id,
streams[i].ground_truth.split(), filter_zh_en(streams[i].ground_truth),
sp.decode(streams[i].decoding_result()).split(), sp.decode([lexicon.token_table[idx] for idx in hyp]),
) )
) )
del streams[i] del streams[i]
@ -712,11 +731,12 @@ def decode_dataset(
) )
for i in sorted(finished_streams, reverse=True): for i in sorted(finished_streams, reverse=True):
hyp = streams[i].decoding_result()
decode_results.append( decode_results.append(
( (
streams[i].id, streams[i].id,
streams[i].ground_truth.split(), filter_zh_en(streams[i].ground_truth),
sp.decode(streams[i].decoding_result()).split(), [sp.decode(lexicon.token_table[idx]) for idx in hyp],
) )
) )
del streams[i] del streams[i]
@ -781,7 +801,7 @@ def save_results(
@torch.no_grad() @torch.no_grad()
def main(): def main():
parser = get_parser() parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser) TAL_CSASRAsrDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
@ -822,13 +842,17 @@ def main():
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
bpe_model = params.lang_dir + "/bpe.model"
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model) sp.load(bpe_model)
# <blk> and <unk> are defined in local/train_bpe_model.py lexicon = Lexicon(params.lang_dir)
params.blank_id = sp.piece_to_id("<blk>") graph_compiler = CharCtcTrainingGraphCompiler(
params.unk_id = sp.piece_to_id("<unk>") lexicon=lexicon,
params.vocab_size = sp.get_piece_size() device=device,
)
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
params.device = device params.device = device
@ -924,13 +948,23 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args) def text_normalize_for_cut(c: Cut):
# Text normalize for each sample
text = c.supervisions[0].text
text = text.strip("\n").strip("\t")
c.supervisions[0].text = text_normalize(text)
return c
tal_csasr = TAL_CSASRAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts() dev_cuts = tal_csasr.valid_cuts()
test_other_cuts = librispeech.test_other_cuts() dev_cuts = dev_cuts.map(text_normalize_for_cut)
test_cuts = tal_csasr.test_cuts()
test_cuts = test_cuts.map(text_normalize_for_cut)
test_sets = ["test-clean", "test-other"] test_sets = ["dev", "test"]
test_cuts = [test_clean_cuts, test_other_cuts] test_cuts = [dev_cuts, test_cuts]
for test_set, test_cut in zip(test_sets, test_cuts): for test_set, test_cut in zip(test_sets, test_cuts):
results_dict = decode_dataset( results_dict = decode_dataset(
@ -938,6 +972,8 @@ def main():
model=model, model=model,
params=params, params=params,
sp=sp, sp=sp,
lexicon=lexicon,
graph_compiler=graph_compiler,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
) )