mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 02:22:17 +00:00
update decoding files
This commit is contained in:
parent
b39ac0207e
commit
6018f222df
@ -704,10 +704,12 @@ def main():
|
|||||||
tal_csasr = TAL_CSASRAsrDataModule(args)
|
tal_csasr = TAL_CSASRAsrDataModule(args)
|
||||||
|
|
||||||
dev_cuts = tal_csasr.valid_cuts()
|
dev_cuts = tal_csasr.valid_cuts()
|
||||||
|
dev_cuts = dev_cuts.subset(first=300)
|
||||||
dev_cuts = dev_cuts.map(text_normalize_for_cut)
|
dev_cuts = dev_cuts.map(text_normalize_for_cut)
|
||||||
dev_dl = tal_csasr.valid_dataloaders(dev_cuts)
|
dev_dl = tal_csasr.valid_dataloaders(dev_cuts)
|
||||||
|
|
||||||
test_cuts = tal_csasr.test_cuts()
|
test_cuts = tal_csasr.test_cuts()
|
||||||
|
test_cuts = test_cuts.subset(first=300)
|
||||||
test_cuts = test_cuts.map(text_normalize_for_cut)
|
test_cuts = test_cuts.map(text_normalize_for_cut)
|
||||||
test_dl = tal_csasr.test_dataloaders(test_cuts)
|
test_dl = tal_csasr.test_dataloaders(test_cuts)
|
||||||
|
|
||||||
|
@ -55,6 +55,7 @@ Usage:
|
|||||||
--max-contexts 4 \
|
--max-contexts 4 \
|
||||||
--max-states 8
|
--max-states 8
|
||||||
"""
|
"""
|
||||||
|
import re
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
@ -66,11 +67,13 @@ import numpy as np
|
|||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import TAL_CSASRAsrDataModule
|
||||||
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
||||||
from kaldifeat import Fbank, FbankOptions
|
from kaldifeat import Fbank, FbankOptions
|
||||||
from lhotse import CutSet
|
from lhotse import CutSet
|
||||||
|
from lhotse.cut import Cut
|
||||||
from lstm import LOG_EPSILON, stack_states, unstack_states
|
from lstm import LOG_EPSILON, stack_states, unstack_states
|
||||||
|
from local.text_normalize import text_normalize
|
||||||
from stream import Stream
|
from stream import Stream
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
@ -82,6 +85,8 @@ from icefall.checkpoint import (
|
|||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.decode import one_best_decoding
|
from icefall.decode import one_best_decoding
|
||||||
|
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
get_texts,
|
get_texts,
|
||||||
@ -143,10 +148,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--lang-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_char",
|
||||||
help="Path to the BPE model",
|
help="Path to the dir containing bpe.model and tokens.txt",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -617,12 +622,25 @@ def create_streaming_feature_extractor() -> Fbank:
|
|||||||
opts.mel_opts.num_bins = 80
|
opts.mel_opts.num_bins = 80
|
||||||
return Fbank(opts)
|
return Fbank(opts)
|
||||||
|
|
||||||
|
def filter_zh_en(text: str):
|
||||||
|
import re
|
||||||
|
pattern = re.compile(r"([\u4e00-\u9fff])")
|
||||||
|
|
||||||
|
chars = pattern.split(text.upper())
|
||||||
|
chars_new = []
|
||||||
|
for char in chars:
|
||||||
|
if char != "":
|
||||||
|
tokens = char.strip().split(" ")
|
||||||
|
chars_new.extend(tokens)
|
||||||
|
return chars_new
|
||||||
|
|
||||||
def decode_dataset(
|
def decode_dataset(
|
||||||
cuts: CutSet,
|
cuts: CutSet,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
|
lexicon: Lexicon,
|
||||||
|
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
):
|
):
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
@ -691,11 +709,12 @@ def decode_dataset(
|
|||||||
)
|
)
|
||||||
|
|
||||||
for i in sorted(finished_streams, reverse=True):
|
for i in sorted(finished_streams, reverse=True):
|
||||||
|
hyp = streams[i].decoding_result()
|
||||||
decode_results.append(
|
decode_results.append(
|
||||||
(
|
(
|
||||||
streams[i].id,
|
streams[i].id,
|
||||||
streams[i].ground_truth.split(),
|
filter_zh_en(streams[i].ground_truth),
|
||||||
sp.decode(streams[i].decoding_result()).split(),
|
sp.decode([lexicon.token_table[idx] for idx in hyp]),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
del streams[i]
|
del streams[i]
|
||||||
@ -712,11 +731,12 @@ def decode_dataset(
|
|||||||
)
|
)
|
||||||
|
|
||||||
for i in sorted(finished_streams, reverse=True):
|
for i in sorted(finished_streams, reverse=True):
|
||||||
|
hyp = streams[i].decoding_result()
|
||||||
decode_results.append(
|
decode_results.append(
|
||||||
(
|
(
|
||||||
streams[i].id,
|
streams[i].id,
|
||||||
streams[i].ground_truth.split(),
|
filter_zh_en(streams[i].ground_truth),
|
||||||
sp.decode(streams[i].decoding_result()).split(),
|
[sp.decode(lexicon.token_table[idx]) for idx in hyp],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
del streams[i]
|
del streams[i]
|
||||||
@ -781,7 +801,7 @@ def save_results(
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
TAL_CSASRAsrDataModule.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
@ -822,13 +842,17 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
|
bpe_model = params.lang_dir + "/bpe.model"
|
||||||
sp = spm.SentencePieceProcessor()
|
sp = spm.SentencePieceProcessor()
|
||||||
sp.load(params.bpe_model)
|
sp.load(bpe_model)
|
||||||
|
|
||||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
lexicon = Lexicon(params.lang_dir)
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||||
params.unk_id = sp.piece_to_id("<unk>")
|
lexicon=lexicon,
|
||||||
params.vocab_size = sp.get_piece_size()
|
device=device,
|
||||||
|
)
|
||||||
|
params.blank_id = lexicon.token_table["<blk>"]
|
||||||
|
params.vocab_size = max(lexicon.tokens) + 1
|
||||||
|
|
||||||
params.device = device
|
params.device = device
|
||||||
|
|
||||||
@ -924,13 +948,23 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
def text_normalize_for_cut(c: Cut):
|
||||||
|
# Text normalize for each sample
|
||||||
|
text = c.supervisions[0].text
|
||||||
|
text = text.strip("\n").strip("\t")
|
||||||
|
c.supervisions[0].text = text_normalize(text)
|
||||||
|
return c
|
||||||
|
|
||||||
test_clean_cuts = librispeech.test_clean_cuts()
|
tal_csasr = TAL_CSASRAsrDataModule(args)
|
||||||
test_other_cuts = librispeech.test_other_cuts()
|
|
||||||
|
|
||||||
test_sets = ["test-clean", "test-other"]
|
dev_cuts = tal_csasr.valid_cuts()
|
||||||
test_cuts = [test_clean_cuts, test_other_cuts]
|
dev_cuts = dev_cuts.map(text_normalize_for_cut)
|
||||||
|
|
||||||
|
test_cuts = tal_csasr.test_cuts()
|
||||||
|
test_cuts = test_cuts.map(text_normalize_for_cut)
|
||||||
|
|
||||||
|
test_sets = ["dev", "test"]
|
||||||
|
test_cuts = [dev_cuts, test_cuts]
|
||||||
|
|
||||||
for test_set, test_cut in zip(test_sets, test_cuts):
|
for test_set, test_cut in zip(test_sets, test_cuts):
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
@ -938,6 +972,8 @@ def main():
|
|||||||
model=model,
|
model=model,
|
||||||
params=params,
|
params=params,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
|
lexicon=lexicon,
|
||||||
|
graph_compiler=graph_compiler,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user