diff --git a/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py b/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py index c47964b07..e7318b0dc 100644 --- a/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py +++ b/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py @@ -105,7 +105,7 @@ class LibriSpeechAsrDataModule: group.add_argument( "--max-duration", type=int, - default=200.0, + default=300.0, help="Maximum pooled recordings duration (seconds) in a " "single batch. You can reduce it if it causes CUDA OOM.", ) diff --git a/egs/multi_zh-hans/ASR/zipformer/decode.py b/egs/multi_zh-hans/ASR/zipformer/decode.py index 93680602e..df7124c0b 100755 --- a/egs/multi_zh-hans/ASR/zipformer/decode.py +++ b/egs/multi_zh-hans/ASR/zipformer/decode.py @@ -117,6 +117,7 @@ from beam_search import ( modified_beam_search, ) from train import add_model_arguments, get_params, get_model +from multi_dataset import MultiDataset from icefall.checkpoint import ( average_checkpoints, @@ -191,7 +192,7 @@ def get_parser(): parser.add_argument( "--bpe-model", type=str, - default="data/lang_bpe_500/bpe.model", + default="data/lang_bpe_2000/bpe.model", help="Path to the BPE model", ) @@ -273,8 +274,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -425,10 +425,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -534,6 +531,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + texts = [list(str(text).replace(" ", "")) for text in texts] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( @@ -549,8 +547,8 @@ def decode_dataset( this_batch = [] assert len(hyps) == len(texts) for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) + hyp_text = "".join(hyp_words) + this_batch.append((cut_id, ref_text, hyp_text)) results[name].extend(this_batch) @@ -559,9 +557,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -594,8 +590,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -656,9 +651,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -690,9 +683,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -719,9 +712,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -780,9 +773,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None @@ -793,17 +784,25 @@ def main(): # we need cut ids to display recognition results. args.return_cuts = True librispeech = LibriSpeechAsrDataModule(args) + multi_dataset = MultiDataset(args.manifest_dir) - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() + # test_clean_cuts = librispeech.test_clean_cuts() + # test_other_cuts = librispeech.test_other_cuts() - test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) - test_other_dl = librispeech.test_dataloaders(test_other_cuts) + # test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + # test_other_dl = librispeech.test_dataloaders(test_other_cuts) - test_sets = ["test-clean", "test-other"] - test_dl = [test_clean_dl, test_other_dl] + test_sets_cuts = multi_dataset.test_cuts() + + test_sets = test_sets_cuts.keys() + test_dl = [ + librispeech.test_dataloaders(test_sets_cuts[cuts_name]) + for cuts_name in test_sets + ] for test_set, test_dl in zip(test_sets, test_dl): + logging.info(f"Start decoding test set: {test_set}") + results_dict = decode_dataset( dl=test_dl, params=params, diff --git a/egs/multi_zh-hans/ASR/zipformer/multi_dataset.py b/egs/multi_zh-hans/ASR/zipformer/multi_dataset.py index f2cd80393..09e98ba0c 100644 --- a/egs/multi_zh-hans/ASR/zipformer/multi_dataset.py +++ b/egs/multi_zh-hans/ASR/zipformer/multi_dataset.py @@ -19,6 +19,7 @@ import glob import logging import re from pathlib import Path +from typing import List, Dict import lhotse from lhotse import CutSet, load_manifest_lazy @@ -215,3 +216,69 @@ class MultiDataset: # kespeech_dev_phase2_cuts, # wenetspeech_dev_cuts, # ] + + def test_cuts(self) -> Dict[str, CutSet]: + logging.info("About to get multidataset test cuts") + + # Aidatatang_200zh + logging.info("Loading Aidatatang_200zh TEST set in lazy mode") + aidatatang_test_cuts = load_manifest_lazy( + self.fbank_dir / "aidatatang_cuts_test.jsonl.gz" + ) + + # AISHELL + logging.info("Loading Aishell TEST set in lazy mode") + aishell_test_cuts = load_manifest_lazy( + self.fbank_dir / "aishell_cuts_test.jsonl.gz" + ) + + # AISHELL-2 + logging.info("Loading Aishell-2 TEST set in lazy mode") + aishell2_test_cuts = load_manifest_lazy( + self.fbank_dir / "aishell2_cuts_test.jsonl.gz" + ) + + # AISHELL-4 + logging.info("Loading Aishell-4 TEST set in lazy mode") + aishell4_test_cuts = load_manifest_lazy( + self.fbank_dir / "aishell4_cuts_test.jsonl.gz" + ) + + # Ali-Meeting + logging.info("Loading Ali-Meeting TEST set in lazy mode") + alimeeting_test_cuts = load_manifest_lazy( + self.fbank_dir / "alimeeting-far_cuts_test.jsonl.gz" + ) + + # MagicData + logging.info("Loading MagicData TEST set in lazy mode") + magicdata_test_cuts = load_manifest_lazy( + self.fbank_dir / "magicdata_cuts_test.jsonl.gz" + ) + + # KeSpeech + logging.info("Loading KeSpeech TEST set in lazy mode") + kespeech_test_cuts = load_manifest_lazy( + self.fbank_dir / "kespeech" / "kespeech-asr_cuts_test.jsonl.gz" + ) + + # WeNetSpeech + logging.info("Loading WeNetSpeech TEST set in lazy mode") + wenetspeech_test_meeting_cuts = load_manifest_lazy( + self.fbank_dir / "wenetspeech" / "cuts_TEST_MEETING.jsonl.gz" + ) + wenetspeech_test_net_cuts = load_manifest_lazy( + self.fbank_dir / "wenetspeech" / "cuts_TEST_NET.jsonl.gz" + ) + + return { + "aidatatang": aidatatang_test_cuts, + # "alimeeting": alimeeting_test_cuts, + "aishell": aishell_test_cuts, + "aishell-2": aishell2_test_cuts, + "aishell-4": aishell4_test_cuts, + "magicdata": magicdata_test_cuts, + "kespeech": kespeech_test_cuts, + "wenetspeech-meeting": wenetspeech_test_meeting_cuts, + "wenetspeech-net": wenetspeech_test_net_cuts, + } diff --git a/egs/multi_zh-hans/ASR/zipformer/train.py b/egs/multi_zh-hans/ASR/zipformer/train.py index 1b7b4cc83..6332f7e37 100755 --- a/egs/multi_zh-hans/ASR/zipformer/train.py +++ b/egs/multi_zh-hans/ASR/zipformer/train.py @@ -605,11 +605,11 @@ def get_joiner_model(params: AttributeDict) -> nn.Module: def get_model(params: AttributeDict) -> nn.Module: - assert ( - params.use_transducer or params.use_ctc - ), (f"At least one of them should be True, " + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " f"but got params.use_transducer={params.use_transducer}, " - f"params.use_ctc={params.use_ctc}") + f"params.use_ctc={params.use_ctc}" + ) encoder_embed = get_encoder_embed(params) encoder = get_encoder_model(params) @@ -809,17 +809,16 @@ def compute_loss( # take down the scale on the simple loss from 1.0 at the start # to params.simple_loss scale by warm_step. simple_loss_scale = ( - s if batch_idx_train >= warm_step + s + if batch_idx_train >= warm_step else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) ) pruned_loss_scale = ( - 1.0 if batch_idx_train >= warm_step + 1.0 + if batch_idx_train >= warm_step else 0.1 + 0.9 * (batch_idx_train / warm_step) ) - loss += ( - simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss - ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss if params.use_ctc: loss += params.ctc_loss_scale * ctc_loss @@ -1192,7 +1191,7 @@ def run(rank, world_size, args): # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - if c.duration < 1.0 or c.duration > 600.0: + if c.duration < 1.0 or c.duration > 20.0: # logging.warning( # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" # ) @@ -1238,14 +1237,14 @@ def run(rank, world_size, args): valid_cuts = multi_dataset.dev_cuts() valid_dl = librispeech.valid_dataloaders(valid_cuts) - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) + # if not params.print_diagnostics: + # scan_pessimistic_batches_for_oom( + # model=model, + # train_dl=train_dl, + # optimizer=optimizer, + # sp=sp, + # params=params, + # ) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: