mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 17:14:20 +00:00
updates for multi dataset decoding
This commit is contained in:
parent
49b0a6d952
commit
a91c90636b
@ -105,7 +105,7 @@ class LibriSpeechAsrDataModule:
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=200.0,
|
||||
default=300.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
|
@ -117,6 +117,7 @@ from beam_search import (
|
||||
modified_beam_search,
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_model
|
||||
from multi_dataset import MultiDataset
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
@ -191,7 +192,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
default="data/lang_bpe_2000/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
@ -273,8 +274,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
@ -425,10 +425,7 @@ def decode_one_batch(
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif (
|
||||
params.decoding_method == "greedy_search"
|
||||
and params.max_sym_per_frame == 1
|
||||
):
|
||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
@ -534,6 +531,7 @@ def decode_dataset(
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
texts = [list(str(text).replace(" ", "")) for text in texts]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
@ -549,8 +547,8 @@ def decode_dataset(
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
hyp_text = "".join(hyp_words)
|
||||
this_batch.append((cut_id, ref_text, hyp_text))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
@ -559,9 +557,7 @@ def decode_dataset(
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
@ -594,8 +590,7 @@ def save_results(
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir
|
||||
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
@ -656,9 +651,7 @@ def main():
|
||||
if "LG" in params.decoding_method:
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += (
|
||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
)
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
@ -690,9 +683,9 @@ def main():
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg]
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
@ -719,9 +712,9 @@ def main():
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg + 1]
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
@ -780,9 +773,7 @@ def main():
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
word_table = None
|
||||
decoding_graph = k2.trivial_graph(
|
||||
params.vocab_size - 1, device=device
|
||||
)
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
else:
|
||||
decoding_graph = None
|
||||
word_table = None
|
||||
@ -793,17 +784,25 @@ def main():
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
multi_dataset = MultiDataset(args.manifest_dir)
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
test_other_cuts = librispeech.test_other_cuts()
|
||||
# test_clean_cuts = librispeech.test_clean_cuts()
|
||||
# test_other_cuts = librispeech.test_other_cuts()
|
||||
|
||||
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||
# test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||
# test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
test_dl = [test_clean_dl, test_other_dl]
|
||||
test_sets_cuts = multi_dataset.test_cuts()
|
||||
|
||||
test_sets = test_sets_cuts.keys()
|
||||
test_dl = [
|
||||
librispeech.test_dataloaders(test_sets_cuts[cuts_name])
|
||||
for cuts_name in test_sets
|
||||
]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
logging.info(f"Start decoding test set: {test_set}")
|
||||
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
|
@ -19,6 +19,7 @@ import glob
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List, Dict
|
||||
|
||||
import lhotse
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
@ -215,3 +216,69 @@ class MultiDataset:
|
||||
# kespeech_dev_phase2_cuts,
|
||||
# wenetspeech_dev_cuts,
|
||||
# ]
|
||||
|
||||
def test_cuts(self) -> Dict[str, CutSet]:
|
||||
logging.info("About to get multidataset test cuts")
|
||||
|
||||
# Aidatatang_200zh
|
||||
logging.info("Loading Aidatatang_200zh TEST set in lazy mode")
|
||||
aidatatang_test_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aidatatang_cuts_test.jsonl.gz"
|
||||
)
|
||||
|
||||
# AISHELL
|
||||
logging.info("Loading Aishell TEST set in lazy mode")
|
||||
aishell_test_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell_cuts_test.jsonl.gz"
|
||||
)
|
||||
|
||||
# AISHELL-2
|
||||
logging.info("Loading Aishell-2 TEST set in lazy mode")
|
||||
aishell2_test_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell2_cuts_test.jsonl.gz"
|
||||
)
|
||||
|
||||
# AISHELL-4
|
||||
logging.info("Loading Aishell-4 TEST set in lazy mode")
|
||||
aishell4_test_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell4_cuts_test.jsonl.gz"
|
||||
)
|
||||
|
||||
# Ali-Meeting
|
||||
logging.info("Loading Ali-Meeting TEST set in lazy mode")
|
||||
alimeeting_test_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "alimeeting-far_cuts_test.jsonl.gz"
|
||||
)
|
||||
|
||||
# MagicData
|
||||
logging.info("Loading MagicData TEST set in lazy mode")
|
||||
magicdata_test_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "magicdata_cuts_test.jsonl.gz"
|
||||
)
|
||||
|
||||
# KeSpeech
|
||||
logging.info("Loading KeSpeech TEST set in lazy mode")
|
||||
kespeech_test_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_test.jsonl.gz"
|
||||
)
|
||||
|
||||
# WeNetSpeech
|
||||
logging.info("Loading WeNetSpeech TEST set in lazy mode")
|
||||
wenetspeech_test_meeting_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "wenetspeech" / "cuts_TEST_MEETING.jsonl.gz"
|
||||
)
|
||||
wenetspeech_test_net_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "wenetspeech" / "cuts_TEST_NET.jsonl.gz"
|
||||
)
|
||||
|
||||
return {
|
||||
"aidatatang": aidatatang_test_cuts,
|
||||
# "alimeeting": alimeeting_test_cuts,
|
||||
"aishell": aishell_test_cuts,
|
||||
"aishell-2": aishell2_test_cuts,
|
||||
"aishell-4": aishell4_test_cuts,
|
||||
"magicdata": magicdata_test_cuts,
|
||||
"kespeech": kespeech_test_cuts,
|
||||
"wenetspeech-meeting": wenetspeech_test_meeting_cuts,
|
||||
"wenetspeech-net": wenetspeech_test_net_cuts,
|
||||
}
|
||||
|
@ -605,11 +605,11 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
|
||||
|
||||
def get_model(params: AttributeDict) -> nn.Module:
|
||||
assert (
|
||||
params.use_transducer or params.use_ctc
|
||||
), (f"At least one of them should be True, "
|
||||
assert params.use_transducer or params.use_ctc, (
|
||||
f"At least one of them should be True, "
|
||||
f"but got params.use_transducer={params.use_transducer}, "
|
||||
f"params.use_ctc={params.use_ctc}")
|
||||
f"params.use_ctc={params.use_ctc}"
|
||||
)
|
||||
|
||||
encoder_embed = get_encoder_embed(params)
|
||||
encoder = get_encoder_model(params)
|
||||
@ -809,17 +809,16 @@ def compute_loss(
|
||||
# take down the scale on the simple loss from 1.0 at the start
|
||||
# to params.simple_loss scale by warm_step.
|
||||
simple_loss_scale = (
|
||||
s if batch_idx_train >= warm_step
|
||||
s
|
||||
if batch_idx_train >= warm_step
|
||||
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
|
||||
)
|
||||
pruned_loss_scale = (
|
||||
1.0 if batch_idx_train >= warm_step
|
||||
1.0
|
||||
if batch_idx_train >= warm_step
|
||||
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
||||
)
|
||||
loss += (
|
||||
simple_loss_scale * simple_loss
|
||||
+ pruned_loss_scale * pruned_loss
|
||||
)
|
||||
loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||
|
||||
if params.use_ctc:
|
||||
loss += params.ctc_loss_scale * ctc_loss
|
||||
@ -1192,7 +1191,7 @@ def run(rank, world_size, args):
|
||||
# You should use ../local/display_manifest_statistics.py to get
|
||||
# an utterance duration distribution for your dataset to select
|
||||
# the threshold
|
||||
if c.duration < 1.0 or c.duration > 600.0:
|
||||
if c.duration < 1.0 or c.duration > 20.0:
|
||||
# logging.warning(
|
||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
# )
|
||||
@ -1238,14 +1237,14 @@ def run(rank, world_size, args):
|
||||
valid_cuts = multi_dataset.dev_cuts()
|
||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||
|
||||
if not params.print_diagnostics:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
train_dl=train_dl,
|
||||
optimizer=optimizer,
|
||||
sp=sp,
|
||||
params=params,
|
||||
)
|
||||
# if not params.print_diagnostics:
|
||||
# scan_pessimistic_batches_for_oom(
|
||||
# model=model,
|
||||
# train_dl=train_dl,
|
||||
# optimizer=optimizer,
|
||||
# sp=sp,
|
||||
# params=params,
|
||||
# )
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
|
Loading…
x
Reference in New Issue
Block a user