updates for multi dataset decoding

This commit is contained in:
JinZr 2023-07-24 15:13:06 +08:00
parent 49b0a6d952
commit a91c90636b
4 changed files with 117 additions and 52 deletions

View File

@ -105,7 +105,7 @@ class LibriSpeechAsrDataModule:
group.add_argument(
"--max-duration",
type=int,
default=200.0,
default=300.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)

View File

@ -117,6 +117,7 @@ from beam_search import (
modified_beam_search,
)
from train import add_model_arguments, get_params, get_model
from multi_dataset import MultiDataset
from icefall.checkpoint import (
average_checkpoints,
@ -191,7 +192,7 @@ def get_parser():
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
default="data/lang_bpe_2000/bpe.model",
help="Path to the BPE model",
)
@ -273,8 +274,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@ -425,10 +425,7 @@ def decode_one_batch(
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
@ -534,6 +531,7 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
texts = [list(str(text).replace(" ", "")) for text in texts]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
@ -549,8 +547,8 @@ def decode_dataset(
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
hyp_text = "".join(hyp_words)
this_batch.append((cut_id, ref_text, hyp_text))
results[name].extend(this_batch)
@ -559,9 +557,7 @@ def decode_dataset(
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
@ -594,8 +590,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
@ -656,9 +651,7 @@ def main():
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -690,9 +683,9 @@ def main():
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
@ -719,9 +712,9 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
@ -780,9 +773,7 @@ def main():
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device
)
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
word_table = None
@ -793,17 +784,25 @@ def main():
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
multi_dataset = MultiDataset(args.manifest_dir)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
# test_clean_cuts = librispeech.test_clean_cuts()
# test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
# test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
# test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
test_sets_cuts = multi_dataset.test_cuts()
test_sets = test_sets_cuts.keys()
test_dl = [
librispeech.test_dataloaders(test_sets_cuts[cuts_name])
for cuts_name in test_sets
]
for test_set, test_dl in zip(test_sets, test_dl):
logging.info(f"Start decoding test set: {test_set}")
results_dict = decode_dataset(
dl=test_dl,
params=params,

View File

@ -19,6 +19,7 @@ import glob
import logging
import re
from pathlib import Path
from typing import List, Dict
import lhotse
from lhotse import CutSet, load_manifest_lazy
@ -215,3 +216,69 @@ class MultiDataset:
# kespeech_dev_phase2_cuts,
# wenetspeech_dev_cuts,
# ]
def test_cuts(self) -> Dict[str, CutSet]:
logging.info("About to get multidataset test cuts")
# Aidatatang_200zh
logging.info("Loading Aidatatang_200zh TEST set in lazy mode")
aidatatang_test_cuts = load_manifest_lazy(
self.fbank_dir / "aidatatang_cuts_test.jsonl.gz"
)
# AISHELL
logging.info("Loading Aishell TEST set in lazy mode")
aishell_test_cuts = load_manifest_lazy(
self.fbank_dir / "aishell_cuts_test.jsonl.gz"
)
# AISHELL-2
logging.info("Loading Aishell-2 TEST set in lazy mode")
aishell2_test_cuts = load_manifest_lazy(
self.fbank_dir / "aishell2_cuts_test.jsonl.gz"
)
# AISHELL-4
logging.info("Loading Aishell-4 TEST set in lazy mode")
aishell4_test_cuts = load_manifest_lazy(
self.fbank_dir / "aishell4_cuts_test.jsonl.gz"
)
# Ali-Meeting
logging.info("Loading Ali-Meeting TEST set in lazy mode")
alimeeting_test_cuts = load_manifest_lazy(
self.fbank_dir / "alimeeting-far_cuts_test.jsonl.gz"
)
# MagicData
logging.info("Loading MagicData TEST set in lazy mode")
magicdata_test_cuts = load_manifest_lazy(
self.fbank_dir / "magicdata_cuts_test.jsonl.gz"
)
# KeSpeech
logging.info("Loading KeSpeech TEST set in lazy mode")
kespeech_test_cuts = load_manifest_lazy(
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_test.jsonl.gz"
)
# WeNetSpeech
logging.info("Loading WeNetSpeech TEST set in lazy mode")
wenetspeech_test_meeting_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_TEST_MEETING.jsonl.gz"
)
wenetspeech_test_net_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_TEST_NET.jsonl.gz"
)
return {
"aidatatang": aidatatang_test_cuts,
# "alimeeting": alimeeting_test_cuts,
"aishell": aishell_test_cuts,
"aishell-2": aishell2_test_cuts,
"aishell-4": aishell4_test_cuts,
"magicdata": magicdata_test_cuts,
"kespeech": kespeech_test_cuts,
"wenetspeech-meeting": wenetspeech_test_meeting_cuts,
"wenetspeech-net": wenetspeech_test_net_cuts,
}

View File

@ -605,11 +605,11 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
def get_model(params: AttributeDict) -> nn.Module:
assert (
params.use_transducer or params.use_ctc
), (f"At least one of them should be True, "
assert params.use_transducer or params.use_ctc, (
f"At least one of them should be True, "
f"but got params.use_transducer={params.use_transducer}, "
f"params.use_ctc={params.use_ctc}")
f"params.use_ctc={params.use_ctc}"
)
encoder_embed = get_encoder_embed(params)
encoder = get_encoder_model(params)
@ -809,17 +809,16 @@ def compute_loss(
# take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step.
simple_loss_scale = (
s if batch_idx_train >= warm_step
s
if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
)
pruned_loss_scale = (
1.0 if batch_idx_train >= warm_step
1.0
if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step)
)
loss += (
simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
)
loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
if params.use_ctc:
loss += params.ctc_loss_scale * ctc_loss
@ -1192,7 +1191,7 @@ def run(rank, world_size, args):
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
if c.duration < 1.0 or c.duration > 600.0:
if c.duration < 1.0 or c.duration > 20.0:
# logging.warning(
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# )
@ -1238,14 +1237,14 @@ def run(rank, world_size, args):
valid_cuts = multi_dataset.dev_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
sp=sp,
params=params,
)
# if not params.print_diagnostics:
# scan_pessimistic_batches_for_oom(
# model=model,
# train_dl=train_dl,
# optimizer=optimizer,
# sp=sp,
# params=params,
# )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: