update decoding files

This commit is contained in:
marcoyang 2023-02-13 16:21:01 +08:00
parent b39ac0207e
commit 6018f222df
2 changed files with 57 additions and 19 deletions

View File

@ -704,10 +704,12 @@ def main():
tal_csasr = TAL_CSASRAsrDataModule(args)
dev_cuts = tal_csasr.valid_cuts()
dev_cuts = dev_cuts.subset(first=300)
dev_cuts = dev_cuts.map(text_normalize_for_cut)
dev_dl = tal_csasr.valid_dataloaders(dev_cuts)
test_cuts = tal_csasr.test_cuts()
test_cuts = test_cuts.subset(first=300)
test_cuts = test_cuts.map(text_normalize_for_cut)
test_dl = tal_csasr.test_dataloaders(test_cuts)

View File

@ -55,6 +55,7 @@ Usage:
--max-contexts 4 \
--max-states 8
"""
import re
import argparse
import logging
import warnings
@ -66,11 +67,13 @@ import numpy as np
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from asr_datamodule import TAL_CSASRAsrDataModule
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet
from lhotse.cut import Cut
from lstm import LOG_EPSILON, stack_states, unstack_states
from local.text_normalize import text_normalize
from stream import Stream
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
@ -82,6 +85,8 @@ from icefall.checkpoint import (
load_checkpoint,
)
from icefall.decode import one_best_decoding
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
get_texts,
@ -143,10 +148,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--lang-dir",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_char",
help="Path to the dir containing bpe.model and tokens.txt",
)
parser.add_argument(
@ -617,12 +622,25 @@ def create_streaming_feature_extractor() -> Fbank:
opts.mel_opts.num_bins = 80
return Fbank(opts)
def filter_zh_en(text: str):
import re
pattern = re.compile(r"([\u4e00-\u9fff])")
chars = pattern.split(text.upper())
chars_new = []
for char in chars:
if char != "":
tokens = char.strip().split(" ")
chars_new.extend(tokens)
return chars_new
def decode_dataset(
cuts: CutSet,
model: nn.Module,
params: AttributeDict,
sp: spm.SentencePieceProcessor,
lexicon: Lexicon,
graph_compiler: CharCtcTrainingGraphCompiler,
decoding_graph: Optional[k2.Fsa] = None,
):
"""Decode dataset.
@ -691,11 +709,12 @@ def decode_dataset(
)
for i in sorted(finished_streams, reverse=True):
hyp = streams[i].decoding_result()
decode_results.append(
(
streams[i].id,
streams[i].ground_truth.split(),
sp.decode(streams[i].decoding_result()).split(),
filter_zh_en(streams[i].ground_truth),
sp.decode([lexicon.token_table[idx] for idx in hyp]),
)
)
del streams[i]
@ -712,11 +731,12 @@ def decode_dataset(
)
for i in sorted(finished_streams, reverse=True):
hyp = streams[i].decoding_result()
decode_results.append(
(
streams[i].id,
streams[i].ground_truth.split(),
sp.decode(streams[i].decoding_result()).split(),
filter_zh_en(streams[i].ground_truth),
[sp.decode(lexicon.token_table[idx]) for idx in hyp],
)
)
del streams[i]
@ -781,7 +801,7 @@ def save_results(
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
TAL_CSASRAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
@ -822,13 +842,17 @@ def main():
logging.info(f"Device: {device}")
bpe_model = params.lang_dir + "/bpe.model"
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
sp.load(bpe_model)
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
lexicon = Lexicon(params.lang_dir)
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
params.device = device
@ -924,13 +948,23 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args)
def text_normalize_for_cut(c: Cut):
# Text normalize for each sample
text = c.supervisions[0].text
text = text.strip("\n").strip("\t")
c.supervisions[0].text = text_normalize(text)
return c
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
tal_csasr = TAL_CSASRAsrDataModule(args)
test_sets = ["test-clean", "test-other"]
test_cuts = [test_clean_cuts, test_other_cuts]
dev_cuts = tal_csasr.valid_cuts()
dev_cuts = dev_cuts.map(text_normalize_for_cut)
test_cuts = tal_csasr.test_cuts()
test_cuts = test_cuts.map(text_normalize_for_cut)
test_sets = ["dev", "test"]
test_cuts = [dev_cuts, test_cuts]
for test_set, test_cut in zip(test_sets, test_cuts):
results_dict = decode_dataset(
@ -938,6 +972,8 @@ def main():
model=model,
params=params,
sp=sp,
lexicon=lexicon,
graph_compiler=graph_compiler,
decoding_graph=decoding_graph,
)