mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
update decoding files
This commit is contained in:
parent
b39ac0207e
commit
6018f222df
@ -704,10 +704,12 @@ def main():
|
||||
tal_csasr = TAL_CSASRAsrDataModule(args)
|
||||
|
||||
dev_cuts = tal_csasr.valid_cuts()
|
||||
dev_cuts = dev_cuts.subset(first=300)
|
||||
dev_cuts = dev_cuts.map(text_normalize_for_cut)
|
||||
dev_dl = tal_csasr.valid_dataloaders(dev_cuts)
|
||||
|
||||
test_cuts = tal_csasr.test_cuts()
|
||||
test_cuts = test_cuts.subset(first=300)
|
||||
test_cuts = test_cuts.map(text_normalize_for_cut)
|
||||
test_dl = tal_csasr.test_dataloaders(test_cuts)
|
||||
|
||||
|
@ -55,6 +55,7 @@ Usage:
|
||||
--max-contexts 4 \
|
||||
--max-states 8
|
||||
"""
|
||||
import re
|
||||
import argparse
|
||||
import logging
|
||||
import warnings
|
||||
@ -66,11 +67,13 @@ import numpy as np
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from asr_datamodule import TAL_CSASRAsrDataModule
|
||||
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
||||
from kaldifeat import Fbank, FbankOptions
|
||||
from lhotse import CutSet
|
||||
from lhotse.cut import Cut
|
||||
from lstm import LOG_EPSILON, stack_states, unstack_states
|
||||
from local.text_normalize import text_normalize
|
||||
from stream import Stream
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
@ -82,6 +85,8 @@ from icefall.checkpoint import (
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.decode import one_best_decoding
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
get_texts,
|
||||
@ -143,10 +148,10 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
default="data/lang_char",
|
||||
help="Path to the dir containing bpe.model and tokens.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -617,12 +622,25 @@ def create_streaming_feature_extractor() -> Fbank:
|
||||
opts.mel_opts.num_bins = 80
|
||||
return Fbank(opts)
|
||||
|
||||
def filter_zh_en(text: str):
|
||||
import re
|
||||
pattern = re.compile(r"([\u4e00-\u9fff])")
|
||||
|
||||
chars = pattern.split(text.upper())
|
||||
chars_new = []
|
||||
for char in chars:
|
||||
if char != "":
|
||||
tokens = char.strip().split(" ")
|
||||
chars_new.extend(tokens)
|
||||
return chars_new
|
||||
|
||||
def decode_dataset(
|
||||
cuts: CutSet,
|
||||
model: nn.Module,
|
||||
params: AttributeDict,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
lexicon: Lexicon,
|
||||
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
):
|
||||
"""Decode dataset.
|
||||
@ -691,11 +709,12 @@ def decode_dataset(
|
||||
)
|
||||
|
||||
for i in sorted(finished_streams, reverse=True):
|
||||
hyp = streams[i].decoding_result()
|
||||
decode_results.append(
|
||||
(
|
||||
streams[i].id,
|
||||
streams[i].ground_truth.split(),
|
||||
sp.decode(streams[i].decoding_result()).split(),
|
||||
filter_zh_en(streams[i].ground_truth),
|
||||
sp.decode([lexicon.token_table[idx] for idx in hyp]),
|
||||
)
|
||||
)
|
||||
del streams[i]
|
||||
@ -712,11 +731,12 @@ def decode_dataset(
|
||||
)
|
||||
|
||||
for i in sorted(finished_streams, reverse=True):
|
||||
hyp = streams[i].decoding_result()
|
||||
decode_results.append(
|
||||
(
|
||||
streams[i].id,
|
||||
streams[i].ground_truth.split(),
|
||||
sp.decode(streams[i].decoding_result()).split(),
|
||||
filter_zh_en(streams[i].ground_truth),
|
||||
[sp.decode(lexicon.token_table[idx]) for idx in hyp],
|
||||
)
|
||||
)
|
||||
del streams[i]
|
||||
@ -781,7 +801,7 @@ def save_results(
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
TAL_CSASRAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
@ -822,13 +842,17 @@ def main():
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
bpe_model = params.lang_dir + "/bpe.model"
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
sp.load(bpe_model)
|
||||
|
||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||
lexicon=lexicon,
|
||||
device=device,
|
||||
)
|
||||
params.blank_id = lexicon.token_table["<blk>"]
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
params.device = device
|
||||
|
||||
@ -924,13 +948,23 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
def text_normalize_for_cut(c: Cut):
|
||||
# Text normalize for each sample
|
||||
text = c.supervisions[0].text
|
||||
text = text.strip("\n").strip("\t")
|
||||
c.supervisions[0].text = text_normalize(text)
|
||||
return c
|
||||
|
||||
tal_csasr = TAL_CSASRAsrDataModule(args)
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
test_other_cuts = librispeech.test_other_cuts()
|
||||
dev_cuts = tal_csasr.valid_cuts()
|
||||
dev_cuts = dev_cuts.map(text_normalize_for_cut)
|
||||
|
||||
test_cuts = tal_csasr.test_cuts()
|
||||
test_cuts = test_cuts.map(text_normalize_for_cut)
|
||||
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
test_cuts = [test_clean_cuts, test_other_cuts]
|
||||
test_sets = ["dev", "test"]
|
||||
test_cuts = [dev_cuts, test_cuts]
|
||||
|
||||
for test_set, test_cut in zip(test_sets, test_cuts):
|
||||
results_dict = decode_dataset(
|
||||
@ -938,6 +972,8 @@ def main():
|
||||
model=model,
|
||||
params=params,
|
||||
sp=sp,
|
||||
lexicon=lexicon,
|
||||
graph_compiler=graph_compiler,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user