updates for multi dataset decoding

This commit is contained in:
JinZr 2023-07-24 15:13:06 +08:00
parent 49b0a6d952
commit a91c90636b
4 changed files with 117 additions and 52 deletions

View File

@ -105,7 +105,7 @@ class LibriSpeechAsrDataModule:
group.add_argument( group.add_argument(
"--max-duration", "--max-duration",
type=int, type=int,
default=200.0, default=300.0,
help="Maximum pooled recordings duration (seconds) in a " help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.", "single batch. You can reduce it if it causes CUDA OOM.",
) )

View File

@ -117,6 +117,7 @@ from beam_search import (
modified_beam_search, modified_beam_search,
) )
from train import add_model_arguments, get_params, get_model from train import add_model_arguments, get_params, get_model
from multi_dataset import MultiDataset
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
@ -191,7 +192,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--bpe-model", "--bpe-model",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_bpe_2000/bpe.model",
help="Path to the BPE model", help="Path to the BPE model",
) )
@ -273,8 +274,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
@ -425,10 +425,7 @@ def decode_one_batch(
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
elif ( elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -534,6 +531,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
texts = [list(str(text).replace(" ", "")) for text in texts]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
@ -549,8 +547,8 @@ def decode_dataset(
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split() hyp_text = "".join(hyp_words)
this_batch.append((cut_id, ref_words, hyp_words)) this_batch.append((cut_id, ref_text, hyp_text))
results[name].extend(this_batch) results[name].extend(this_batch)
@ -559,9 +557,7 @@ def decode_dataset(
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info( logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -594,8 +590,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = (
params.res_dir params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tWER", file=f) print("settings\tWER", file=f)
@ -656,9 +651,7 @@ def main():
if "LG" in params.decoding_method: if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += ( params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -690,9 +683,9 @@ def main():
if not params.use_averaged_model: if not params.use_averaged_model:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg
)[: params.avg] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for"
@ -719,9 +712,9 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(average_checkpoints(filenames, device=device))
else: else:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg + 1
)[: params.avg + 1] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for"
@ -780,9 +773,7 @@ def main():
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:
word_table = None word_table = None
decoding_graph = k2.trivial_graph( decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
params.vocab_size - 1, device=device
)
else: else:
decoding_graph = None decoding_graph = None
word_table = None word_table = None
@ -793,17 +784,25 @@ def main():
# we need cut ids to display recognition results. # we need cut ids to display recognition results.
args.return_cuts = True args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
multi_dataset = MultiDataset(args.manifest_dir)
test_clean_cuts = librispeech.test_clean_cuts() # test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts() # test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) # test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts) # test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"] test_sets_cuts = multi_dataset.test_cuts()
test_dl = [test_clean_dl, test_other_dl]
test_sets = test_sets_cuts.keys()
test_dl = [
librispeech.test_dataloaders(test_sets_cuts[cuts_name])
for cuts_name in test_sets
]
for test_set, test_dl in zip(test_sets, test_dl): for test_set, test_dl in zip(test_sets, test_dl):
logging.info(f"Start decoding test set: {test_set}")
results_dict = decode_dataset( results_dict = decode_dataset(
dl=test_dl, dl=test_dl,
params=params, params=params,

View File

@ -19,6 +19,7 @@ import glob
import logging import logging
import re import re
from pathlib import Path from pathlib import Path
from typing import List, Dict
import lhotse import lhotse
from lhotse import CutSet, load_manifest_lazy from lhotse import CutSet, load_manifest_lazy
@ -215,3 +216,69 @@ class MultiDataset:
# kespeech_dev_phase2_cuts, # kespeech_dev_phase2_cuts,
# wenetspeech_dev_cuts, # wenetspeech_dev_cuts,
# ] # ]
def test_cuts(self) -> Dict[str, CutSet]:
logging.info("About to get multidataset test cuts")
# Aidatatang_200zh
logging.info("Loading Aidatatang_200zh TEST set in lazy mode")
aidatatang_test_cuts = load_manifest_lazy(
self.fbank_dir / "aidatatang_cuts_test.jsonl.gz"
)
# AISHELL
logging.info("Loading Aishell TEST set in lazy mode")
aishell_test_cuts = load_manifest_lazy(
self.fbank_dir / "aishell_cuts_test.jsonl.gz"
)
# AISHELL-2
logging.info("Loading Aishell-2 TEST set in lazy mode")
aishell2_test_cuts = load_manifest_lazy(
self.fbank_dir / "aishell2_cuts_test.jsonl.gz"
)
# AISHELL-4
logging.info("Loading Aishell-4 TEST set in lazy mode")
aishell4_test_cuts = load_manifest_lazy(
self.fbank_dir / "aishell4_cuts_test.jsonl.gz"
)
# Ali-Meeting
logging.info("Loading Ali-Meeting TEST set in lazy mode")
alimeeting_test_cuts = load_manifest_lazy(
self.fbank_dir / "alimeeting-far_cuts_test.jsonl.gz"
)
# MagicData
logging.info("Loading MagicData TEST set in lazy mode")
magicdata_test_cuts = load_manifest_lazy(
self.fbank_dir / "magicdata_cuts_test.jsonl.gz"
)
# KeSpeech
logging.info("Loading KeSpeech TEST set in lazy mode")
kespeech_test_cuts = load_manifest_lazy(
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_test.jsonl.gz"
)
# WeNetSpeech
logging.info("Loading WeNetSpeech TEST set in lazy mode")
wenetspeech_test_meeting_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_TEST_MEETING.jsonl.gz"
)
wenetspeech_test_net_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_TEST_NET.jsonl.gz"
)
return {
"aidatatang": aidatatang_test_cuts,
# "alimeeting": alimeeting_test_cuts,
"aishell": aishell_test_cuts,
"aishell-2": aishell2_test_cuts,
"aishell-4": aishell4_test_cuts,
"magicdata": magicdata_test_cuts,
"kespeech": kespeech_test_cuts,
"wenetspeech-meeting": wenetspeech_test_meeting_cuts,
"wenetspeech-net": wenetspeech_test_net_cuts,
}

View File

@ -605,11 +605,11 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
def get_model(params: AttributeDict) -> nn.Module: def get_model(params: AttributeDict) -> nn.Module:
assert ( assert params.use_transducer or params.use_ctc, (
params.use_transducer or params.use_ctc f"At least one of them should be True, "
), (f"At least one of them should be True, "
f"but got params.use_transducer={params.use_transducer}, " f"but got params.use_transducer={params.use_transducer}, "
f"params.use_ctc={params.use_ctc}") f"params.use_ctc={params.use_ctc}"
)
encoder_embed = get_encoder_embed(params) encoder_embed = get_encoder_embed(params)
encoder = get_encoder_model(params) encoder = get_encoder_model(params)
@ -809,17 +809,16 @@ def compute_loss(
# take down the scale on the simple loss from 1.0 at the start # take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step. # to params.simple_loss scale by warm_step.
simple_loss_scale = ( simple_loss_scale = (
s if batch_idx_train >= warm_step s
if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
) )
pruned_loss_scale = ( pruned_loss_scale = (
1.0 if batch_idx_train >= warm_step 1.0
if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step) else 0.1 + 0.9 * (batch_idx_train / warm_step)
) )
loss += ( loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
)
if params.use_ctc: if params.use_ctc:
loss += params.ctc_loss_scale * ctc_loss loss += params.ctc_loss_scale * ctc_loss
@ -1192,7 +1191,7 @@ def run(rank, world_size, args):
# You should use ../local/display_manifest_statistics.py to get # You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select # an utterance duration distribution for your dataset to select
# the threshold # the threshold
if c.duration < 1.0 or c.duration > 600.0: if c.duration < 1.0 or c.duration > 20.0:
# logging.warning( # logging.warning(
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# ) # )
@ -1238,14 +1237,14 @@ def run(rank, world_size, args):
valid_cuts = multi_dataset.dev_cuts() valid_cuts = multi_dataset.dev_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts) valid_dl = librispeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics: # if not params.print_diagnostics:
scan_pessimistic_batches_for_oom( # scan_pessimistic_batches_for_oom(
model=model, # model=model,
train_dl=train_dl, # train_dl=train_dl,
optimizer=optimizer, # optimizer=optimizer,
sp=sp, # sp=sp,
params=params, # params=params,
) # )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints: