From 2cb0092b09a5851b6cbdf62dfea097c652af5f1f Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Sat, 2 Sep 2023 01:33:26 +0800 Subject: [PATCH] this commit finalize the recipe (hopefully) --- egs/multi_zh-hans/ASR/README.md | 23 +++++ egs/multi_zh-hans/ASR/RESULTS.md | 30 +++++++ egs/multi_zh-hans/ASR/prepare.sh | 6 +- .../ASR/zipformer/asr_datamodule.py | 89 +------------------ egs/multi_zh-hans/ASR/zipformer/ctc_decode.py | 8 +- egs/multi_zh-hans/ASR/zipformer/decode.py | 14 +-- .../ASR/zipformer/onnx_decode.py | 15 ++-- .../ASR/zipformer/streaming_decode.py | 59 +++++------- egs/multi_zh-hans/ASR/zipformer/train.py | 34 +++---- 9 files changed, 107 insertions(+), 171 deletions(-) create mode 100644 egs/multi_zh-hans/ASR/README.md create mode 100644 egs/multi_zh-hans/ASR/RESULTS.md diff --git a/egs/multi_zh-hans/ASR/README.md b/egs/multi_zh-hans/ASR/README.md new file mode 100644 index 000000000..f74c6d382 --- /dev/null +++ b/egs/multi_zh-hans/ASR/README.md @@ -0,0 +1,23 @@ + +# Introduction + +This recipe includes scripts for training Zipformer model using multiple Chinese datasets. + +# Included Training Sets +1. THCHS-30 +2. AiShell-{1,2,4} +3. ST-CMDS +4. Primewords +5. MagicData +6. Aidatatang_200zh +7. AliMeeting +8. WeNetSpeech +9. KeSpeech-ASR + +# Included Test Sets +1. Aishell-{1,2,4} +2. Aidatatang_200zh +3. AliMeeting +4. MagicData +5. KeSpeech-ASR +6. WeNetSpeech \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/RESULTS.md b/egs/multi_zh-hans/ASR/RESULTS.md new file mode 100644 index 000000000..868b4da74 --- /dev/null +++ b/egs/multi_zh-hans/ASR/RESULTS.md @@ -0,0 +1,30 @@ +## Results + +### WenetSpeech char-based training results (Non-streaming and streaming) on zipformer model + +This is the [pull request](https://github.com/k2-fsa/icefall/pull/1130) in icefall. + +#### Non-streaming + +Best results (num of params : ~68M): + +The training command: + +``` +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 23 \ + --use-fp16 1 \ + --max-duration 500 \ + --num-workers 8 +``` + +Character Error Rates (CERs) listed below are produced by the checkpoint of the 20th epoch using greedy search and BPE model ( # tokens is 2000, byte fallback enabled). + +| Datasets | aidatatang _200zh | aidatatang _200zh | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech | +|--------------------------------|------------------------------|-------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------| +| Zipformer CER (%) | dev | test | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net | +| | 3.2 | 3.67 | 23.15 | 24.78 | 2.91 | 3.04 | 3.59 | 4.03 | 15.68 | 3.68 | 3.12 | 6.69 | 3.19 | 8.01 | 9.32 | 7.05 | 8.78 | + + +The pre-trained model is available here : https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2 diff --git a/egs/multi_zh-hans/ASR/prepare.sh b/egs/multi_zh-hans/ASR/prepare.sh index 767836422..ccc1e5ea4 100755 --- a/egs/multi_zh-hans/ASR/prepare.sh +++ b/egs/multi_zh-hans/ASR/prepare.sh @@ -229,10 +229,12 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then cd data/fbank ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_DEV.jsonl.gz) . ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_L.jsonl.gz) . - ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_M.jsonl.gz) . - ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_S.jsonl.gz) . ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_MEETING.jsonl.gz) . ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_NET.jsonl.gz) . + + ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/L_split_1000) . + ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/*.lca) . + ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/) ./wenetspeech cd ../.. else log "Abort! Please run ../../wenetspeech/ASR/prepare.sh" diff --git a/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py b/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py index e7318b0dc..3518eee3f 100644 --- a/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py +++ b/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py @@ -52,7 +52,7 @@ class _SeedWorkers: fix_random_seed(self.seed + worker_id) -class LibriSpeechAsrDataModule: +class AsrDataModule: """ DataModule for k2 ASR experiments. It assumes there is always one train and valid dataloader, @@ -82,20 +82,6 @@ class LibriSpeechAsrDataModule: "effective batch sizes, sampling strategies, applied data " "augmentations, etc.", ) - group.add_argument( - "--full-libri", - type=str2bool, - default=True, - help="""Used only when --mini-libri is False.When enabled, - use 960h LibriSpeech. Otherwise, use 100h subset.""", - ) - group.add_argument( - "--mini-libri", - type=str2bool, - default=False, - help="True for mini librispeech", - ) - group.add_argument( "--manifest-dir", type=Path, @@ -400,76 +386,3 @@ class LibriSpeechAsrDataModule: num_workers=self.args.num_workers, ) return test_dl - - @lru_cache() - def train_clean_5_cuts(self) -> CutSet: - logging.info("mini_librispeech: About to get train-clean-5 cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz" - ) - - @lru_cache() - def train_clean_100_cuts(self) -> CutSet: - logging.info("About to get train-clean-100 cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz" - ) - - @lru_cache() - def train_clean_360_cuts(self) -> CutSet: - logging.info("About to get train-clean-360 cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz" - ) - - @lru_cache() - def train_other_500_cuts(self) -> CutSet: - logging.info("About to get train-other-500 cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz" - ) - - @lru_cache() - def train_all_shuf_cuts(self) -> CutSet: - logging.info( - "About to get the shuffled train-clean-100, \ - train-clean-360 and train-other-500 cuts" - ) - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" - ) - - @lru_cache() - def dev_clean_2_cuts(self) -> CutSet: - logging.info("mini_librispeech: About to get dev-clean-2 cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz" - ) - - @lru_cache() - def dev_clean_cuts(self) -> CutSet: - logging.info("About to get dev-clean cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz" - ) - - @lru_cache() - def dev_other_cuts(self) -> CutSet: - logging.info("About to get dev-other cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz" - ) - - @lru_cache() - def test_clean_cuts(self) -> CutSet: - logging.info("About to get test-clean cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" - ) - - @lru_cache() - def test_other_cuts(self) -> CutSet: - logging.info("About to get test-other cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" - ) diff --git a/egs/multi_zh-hans/ASR/zipformer/ctc_decode.py b/egs/multi_zh-hans/ASR/zipformer/ctc_decode.py index 4db50b981..2fa73475b 100755 --- a/egs/multi_zh-hans/ASR/zipformer/ctc_decode.py +++ b/egs/multi_zh-hans/ASR/zipformer/ctc_decode.py @@ -87,8 +87,8 @@ import k2 import sentencepiece as spm import torch import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule -from train import add_model_arguments, get_params, get_model +from asr_datamodule import AsrDataModule +from train import add_model_arguments, get_model, get_params from icefall.checkpoint import ( average_checkpoints, @@ -598,7 +598,7 @@ def save_results( @torch.no_grad() def main(): parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) + AsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) args.lang_dir = Path(args.lang_dir) @@ -811,7 +811,7 @@ def main(): # we need cut ids to display recognition results. args.return_cuts = True - librispeech = LibriSpeechAsrDataModule(args) + librispeech = AsrDataModule(args) test_clean_cuts = librispeech.test_clean_cuts() test_other_cuts = librispeech.test_other_cuts() diff --git a/egs/multi_zh-hans/ASR/zipformer/decode.py b/egs/multi_zh-hans/ASR/zipformer/decode.py index 5cd133df8..94a4c4a3c 100755 --- a/egs/multi_zh-hans/ASR/zipformer/decode.py +++ b/egs/multi_zh-hans/ASR/zipformer/decode.py @@ -105,7 +105,7 @@ import k2 import sentencepiece as spm import torch import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule +from asr_datamodule import AsrDataModule from beam_search import ( beam_search, fast_beam_search_nbest, @@ -609,7 +609,7 @@ def save_results( @torch.no_grad() def main(): parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) + AsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -784,15 +784,9 @@ def main(): # we need cut ids to display recognition results. args.return_cuts = True - librispeech = LibriSpeechAsrDataModule(args) + data_module = AsrDataModule(args) multi_dataset = MultiDataset(args.manifest_dir) - # test_clean_cuts = librispeech.test_clean_cuts() - # test_other_cuts = librispeech.test_other_cuts() - - # test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) - # test_other_dl = librispeech.test_dataloaders(test_other_cuts) - def remove_short_utt(c: Cut): T = ((c.num_frames - 7) // 2 + 1) // 2 if T <= 0: @@ -805,7 +799,7 @@ def main(): test_sets = test_sets_cuts.keys() test_dl = [ - librispeech.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt)) + data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt)) for cuts_name in test_sets ] diff --git a/egs/multi_zh-hans/ASR/zipformer/onnx_decode.py b/egs/multi_zh-hans/ASR/zipformer/onnx_decode.py index 2aca36ca9..bea6bc5c4 100755 --- a/egs/multi_zh-hans/ASR/zipformer/onnx_decode.py +++ b/egs/multi_zh-hans/ASR/zipformer/onnx_decode.py @@ -76,12 +76,11 @@ from typing import List, Tuple import torch import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule - -from onnx_pretrained import greedy_search, OnnxModel +from asr_datamodule import AsrDataModule +from k2 import SymbolTable +from onnx_pretrained import OnnxModel, greedy_search from icefall.utils import setup_logger, store_transcripts, write_error_stats -from k2 import SymbolTable def get_parser(): @@ -263,7 +262,7 @@ def save_results( @torch.no_grad() def main(): parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) + AsrDataModule.add_arguments(parser) args = parser.parse_args() assert ( @@ -290,7 +289,7 @@ def main(): # we need cut ids to display recognition results. args.return_cuts = True - librispeech = LibriSpeechAsrDataModule(args) + librispeech = AsrDataModule(args) test_clean_cuts = librispeech.test_clean_cuts() test_other_cuts = librispeech.test_other_cuts() @@ -303,7 +302,9 @@ def main(): for test_set, test_dl in zip(test_sets, test_dl): start_time = time.time() - results, total_duration = decode_dataset(dl=test_dl, model=model, token_table=token_table) + results, total_duration = decode_dataset( + dl=test_dl, model=model, token_table=token_table + ) end_time = time.time() elapsed_seconds = end_time - start_time rtf = elapsed_seconds / total_duration diff --git a/egs/multi_zh-hans/ASR/zipformer/streaming_decode.py b/egs/multi_zh-hans/ASR/zipformer/streaming_decode.py index 44ff392a3..1dcd74cb2 100755 --- a/egs/multi_zh-hans/ASR/zipformer/streaming_decode.py +++ b/egs/multi_zh-hans/ASR/zipformer/streaming_decode.py @@ -40,7 +40,7 @@ import k2 import numpy as np import sentencepiece as spm import torch -from asr_datamodule import LibriSpeechAsrDataModule +from asr_datamodule import AsrDataModule from decode_stream import DecodeStream from kaldifeat import Fbank, FbankOptions from lhotse import CutSet @@ -51,7 +51,7 @@ from streaming_beam_search import ( ) from torch import Tensor, nn from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_model +from train import add_model_arguments, get_model, get_params from icefall.checkpoint import ( average_checkpoints, @@ -282,9 +282,7 @@ def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: ) batch_states.append(cached_embed_left_pad) - processed_lens = torch.cat( - [state_list[i][-1] for i in range(batch_size)], dim=0 - ) + processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) batch_states.append(processed_lens) return batch_states @@ -322,9 +320,7 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: for layer in range(tot_num_layers): layer_offset = layer * 6 # cached_key: (left_context_len, batch_size, key_dim) - cached_key_list = batch_states[layer_offset].chunk( - chunks=batch_size, dim=1 - ) + cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( chunks=batch_size, dim=1 @@ -355,9 +351,7 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: cached_conv2_list[i], ] - cached_embed_left_pad_list = batch_states[-2].chunk( - chunks=batch_size, dim=0 - ) + cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) for i in range(batch_size): state_list[i].append(cached_embed_left_pad_list[i]) @@ -404,9 +398,7 @@ def streaming_forward( new_processed_lens = processed_lens + x_lens # (batch, left_context_size + chunk_size) - src_key_padding_mask = torch.cat( - [processed_mask, src_key_padding_mask], dim=1 - ) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) encoder_states = states[:-2] @@ -494,9 +486,7 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = torch.tensor(processed_lens, device=device) processed_lens = processed_lens + encoder_out_lens @@ -517,9 +507,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = unstack_states(new_states) @@ -577,9 +565,7 @@ def decode_dataset( decode_streams = [] for num, cut in enumerate(cuts): # each utterance has a DecodeStream. - initial_states = get_init_states( - model=model, batch_size=1, device=device - ) + initial_states = get_init_states(model=model, batch_size=1, device=device) decode_stream = DecodeStream( params=params, cut_id=cut.id, @@ -649,9 +635,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -684,8 +668,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -703,7 +686,7 @@ def save_results( @torch.no_grad() def main(): parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) + AsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -718,9 +701,7 @@ def main(): params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" assert params.causal, params.causal - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." + assert "," not in params.chunk_size, "chunk_size should be one value in decoding." assert ( "," not in params.left_context_frames ), "left_context_frames should be one value in decoding." @@ -760,9 +741,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -789,9 +770,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -846,7 +827,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - librispeech = LibriSpeechAsrDataModule(args) + librispeech = AsrDataModule(args) test_clean_cuts = librispeech.test_clean_cuts() test_other_cuts = librispeech.test_other_cuts() diff --git a/egs/multi_zh-hans/ASR/zipformer/train.py b/egs/multi_zh-hans/ASR/zipformer/train.py index 6332f7e37..4f2d728be 100755 --- a/egs/multi_zh-hans/ASR/zipformer/train.py +++ b/egs/multi_zh-hans/ASR/zipformer/train.py @@ -30,7 +30,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --start-epoch 1 \ --use-fp16 1 \ --exp-dir zipformer/exp \ - --full-libri 1 \ --max-duration 1000 # For streaming model training: @@ -41,7 +40,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --use-fp16 1 \ --exp-dir zipformer/exp \ --causal 1 \ - --full-libri 1 \ --max-duration 1000 It supports training with: @@ -65,7 +63,7 @@ import sentencepiece as spm import torch import torch.multiprocessing as mp import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule +from asr_datamodule import AsrDataModule from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut @@ -1173,14 +1171,10 @@ def run(rank, world_size, args): if params.inf_check: register_inf_check_hooks(model) - librispeech = LibriSpeechAsrDataModule(args) + data_module = AsrDataModule(args) multi_dataset = MultiDataset(args.manifest_dir) train_cuts = multi_dataset.train_cuts() - # train_cuts = librispeech.train_clean_100_cuts() - # if params.full_libri: - # train_cuts += librispeech.train_clean_360_cuts() - # train_cuts += librispeech.train_other_500_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds @@ -1228,23 +1222,21 @@ def run(rank, world_size, args): else: sampler_state_dict = None - train_dl = librispeech.train_dataloaders( + train_dl = data_module.train_dataloaders( train_cuts, sampler_state_dict=sampler_state_dict ) - # valid_cuts = librispeech.dev_clean_cuts() - # valid_cuts += librispeech.dev_other_cuts() valid_cuts = multi_dataset.dev_cuts() - valid_dl = librispeech.valid_dataloaders(valid_cuts) + valid_dl = data_module.valid_dataloaders(valid_cuts) - # if not params.print_diagnostics: - # scan_pessimistic_batches_for_oom( - # model=model, - # train_dl=train_dl, - # optimizer=optimizer, - # sp=sp, - # params=params, - # ) + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: @@ -1374,7 +1366,7 @@ def scan_pessimistic_batches_for_oom( def main(): parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) + AsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir)