From a00c0c5279df82eb1530647f22edc1294244c4e8 Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Thu, 7 Mar 2024 14:44:38 +0800 Subject: [PATCH] add speechio results --- egs/aishell/ASR/RESULTS.md | 2 +- egs/aishell2/ASR/RESULTS.md | 2 +- .../ASR/local/compute_fbank_aishell2.py | 23 +- egs/aishell2/ASR/prepare.sh | 4 +- egs/aishell4/ASR/README.md | 2 +- .../ASR/local/compute_fbank_aishell4.py | 25 +- egs/aishell4/ASR/prepare.sh | 4 +- .../ASR/local/compute_fbank_alimeeting.py | 23 +- egs/alimeeting/ASR/prepare.sh | 4 +- egs/multi_zh-hans/ASR/README.md | 2 +- .../local/compute_fbank_kespeech_dev_test.py | 18 +- .../local/compute_fbank_kespeech_splits.py | 9 +- .../ASR/local/compute_fbank_magicdata.py | 25 +- .../ASR/local/compute_fbank_primewords.py | 23 +- .../ASR/local/compute_fbank_stcmds.py | 21 +- .../ASR/local/compute_fbank_thchs30.py | 21 +- egs/multi_zh-hans/ASR/prepare.sh | 42 +- egs/multi_zh-hans/ASR/whisper/decode.py | 6 +- egs/multi_zh-hans/ASR/whisper/train.py | 14 +- egs/speechio/ASR/README.md | 15 + egs/speechio/ASR/RESULTS.md | 92 ++ .../ASR/local/compute_fbank_speechio.py | 25 +- .../ASR/local/display_manifest_statistics.py | 9 +- .../ASR/local/whisper_zipformer_fusion.py | 217 +++ egs/speechio/ASR/prepare.sh | 2 +- egs/speechio/ASR/whisper/asr_datamodule.py | 4 +- egs/speechio/ASR/whisper/decode.py | 6 +- egs/speechio/ASR/whisper/multi_dataset.py | 10 +- egs/speechio/ASR/zipformer/decode.py | 17 +- egs/speechio/ASR/zipformer/train.py | 1386 +---------------- .../compute_fbank_wenetspeech_dev_test.py | 18 +- .../local/compute_fbank_wenetspeech_splits.py | 22 +- egs/wenetspeech/ASR/prepare.sh | 16 +- .../onnx_check.py | 4 +- egs/wenetspeech/ASR/whisper/decode.py | 4 +- egs/wenetspeech/ASR/whisper/train.py | 1 + 36 files changed, 600 insertions(+), 1518 deletions(-) create mode 100644 egs/speechio/ASR/README.md create mode 100644 egs/speechio/ASR/RESULTS.md create mode 100644 egs/speechio/ASR/local/whisper_zipformer_fusion.py mode change 100644 => 120000 egs/speechio/ASR/zipformer/train.py diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md index 46d712fb2..355d1516d 100644 --- a/egs/aishell/ASR/RESULTS.md +++ b/egs/aishell/ASR/RESULTS.md @@ -75,7 +75,7 @@ It's reworked Zipformer with Pruned RNNT loss, trained with Byte-level BPE, `voc | fast beam search | 4.43 | 4.17 | --epoch 40 --avg 10 | ```bash -./prepare.sh +./prepare.sh export CUDA_VISIBLE_DEVICES="0,1" diff --git a/egs/aishell2/ASR/RESULTS.md b/egs/aishell2/ASR/RESULTS.md index 32ad74b50..0b7ae9299 100644 --- a/egs/aishell2/ASR/RESULTS.md +++ b/egs/aishell2/ASR/RESULTS.md @@ -1,6 +1,6 @@ ## Results -### Aishell2 char-based training results +### Aishell2 char-based training results #### Pruned transducer stateless 5 diff --git a/egs/aishell2/ASR/local/compute_fbank_aishell2.py b/egs/aishell2/ASR/local/compute_fbank_aishell2.py index dc2c4d153..557f22b0c 100755 --- a/egs/aishell2/ASR/local/compute_fbank_aishell2.py +++ b/egs/aishell2/ASR/local/compute_fbank_aishell2.py @@ -29,7 +29,14 @@ import os from pathlib import Path import torch -from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + LilcomChunkyWriter, + WhisperFbank, + WhisperFbankConfig, +) from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor, str2bool @@ -42,7 +49,9 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False): +def compute_fbank_aishell2( + num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False +): src_dir = Path("data/manifests") output_dir = Path("data/fbank") num_jobs = min(8, os.cpu_count()) @@ -69,7 +78,9 @@ def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False, dataset_parts, ) if whisper_fbank: - extractor = WhisperFbank(WhisperFbankConfig(num_filters=num_mel_bins, device='cuda')) + extractor = WhisperFbank( + WhisperFbankConfig(num_filters=num_mel_bins, device="cuda") + ) else: extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) @@ -84,7 +95,7 @@ def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False, supervisions=m["supervisions"], ) if "train" in partition and perturb_speed: - logging.info(f"Doing speed perturb") + logging.info("Doing speed perturb") cut_set = ( cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) @@ -129,5 +140,7 @@ if __name__ == "__main__": args = get_args() compute_fbank_aishell2( - num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed, whisper_fbank=args.whisper_fbank + num_mel_bins=args.num_mel_bins, + perturb_speed=args.perturb_speed, + whisper_fbank=args.whisper_fbank, ) diff --git a/egs/aishell2/ASR/prepare.sh b/egs/aishell2/ASR/prepare.sh index 418762be1..c959bd4d1 100755 --- a/egs/aishell2/ASR/prepare.sh +++ b/egs/aishell2/ASR/prepare.sh @@ -6,8 +6,8 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python set -eou pipefail nj=30 -stage=1 -stop_stage=1 +stage=0 +stop_stage=7 perturb_speed=true diff --git a/egs/aishell4/ASR/README.md b/egs/aishell4/ASR/README.md index 67fa17790..b96161762 100644 --- a/egs/aishell4/ASR/README.md +++ b/egs/aishell4/ASR/README.md @@ -3,7 +3,7 @@ This recipe contains some various ASR models trained with Aishell4 (including S, M and L three subsets). -The AISHELL-4 is a sizable real-recorded Mandarin speech dataset collected by 8-channel circular microphone array for speech processing in conference scenarios. The dataset consists of 211 recorded meeting sessions, each containing 4 to 8 speakers, with a total length of 120 hours. This dataset aims to bridge the advanced research on multi-speaker processing and the practical application scenario in three aspects. With real recorded meetings, AISHELL-4 provides realistic acoustics and rich natural speech characteristics in conversation such as short pause, speech overlap, quick speaker turn, noise, etc. Meanwhile, the accurate transcription and speaker voice activity are provided for each meeting in AISHELL-4. This allows the researchers to explore different aspects in meeting processing, ranging from individual tasks such as speech front-end processing, speech recognition and speaker diarization, to multi-modality modeling and joint optimization of relevant tasks. +The AISHELL-4 is a sizable real-recorded Mandarin speech dataset collected by 8-channel circular microphone array for speech processing in conference scenarios. The dataset consists of 211 recorded meeting sessions, each containing 4 to 8 speakers, with a total length of 120 hours. This dataset aims to bridge the advanced research on multi-speaker processing and the practical application scenario in three aspects. With real recorded meetings, AISHELL-4 provides realistic acoustics and rich natural speech characteristics in conversation such as short pause, speech overlap, quick speaker turn, noise, etc. Meanwhile, the accurate transcription and speaker voice activity are provided for each meeting in AISHELL-4. This allows the researchers to explore different aspects in meeting processing, ranging from individual tasks such as speech front-end processing, speech recognition and speaker diarization, to multi-modality modeling and joint optimization of relevant tasks. (From [Open Speech and Language Resources](https://www.openslr.org/111/)) diff --git a/egs/aishell4/ASR/local/compute_fbank_aishell4.py b/egs/aishell4/ASR/local/compute_fbank_aishell4.py index 2ecae7b3d..b5f8468ac 100755 --- a/egs/aishell4/ASR/local/compute_fbank_aishell4.py +++ b/egs/aishell4/ASR/local/compute_fbank_aishell4.py @@ -29,7 +29,14 @@ import os from pathlib import Path import torch -from lhotse import LilcomChunkyWriter, CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + LilcomChunkyWriter, + WhisperFbank, + WhisperFbankConfig, +) from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor, str2bool @@ -42,7 +49,9 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False): +def compute_fbank_aishell4( + num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False +): src_dir = Path("data/manifests/aishell4") output_dir = Path("data/fbank") num_jobs = min(8, os.cpu_count()) @@ -71,7 +80,9 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False, ) if whisper_fbank: - extractor = WhisperFbank(WhisperFbankConfig(num_filters=num_mel_bins, device='cuda')) + extractor = WhisperFbank( + WhisperFbankConfig(num_filters=num_mel_bins, device="cuda") + ) else: extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) @@ -87,7 +98,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False, supervisions=m["supervisions"], ) if "train" in partition and perturb_speed: - logging.info(f"Doing speed perturb") + logging.info("Doing speed perturb") cut_set = ( cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) @@ -100,7 +111,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False, executor=ex, storage_type=LilcomChunkyWriter, ) - + logging.info("About splitting cuts into smaller chunks") cut_set = cut_set.trim_to_supervisions( keep_overlapping=False, @@ -140,5 +151,7 @@ if __name__ == "__main__": args = get_args() compute_fbank_aishell4( - num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed, whisper_fbank=args.whisper_fbank + num_mel_bins=args.num_mel_bins, + perturb_speed=args.perturb_speed, + whisper_fbank=args.whisper_fbank, ) diff --git a/egs/aishell4/ASR/prepare.sh b/egs/aishell4/ASR/prepare.sh index 254ef08a4..38a36d97a 100755 --- a/egs/aishell4/ASR/prepare.sh +++ b/egs/aishell4/ASR/prepare.sh @@ -5,8 +5,8 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python set -eou pipefail -stage=20 -stop_stage=20 +stage=-1 +stop_stage=7 perturb_speed=true diff --git a/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py b/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py index b5cbadc1e..09c873a34 100755 --- a/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py +++ b/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py @@ -29,7 +29,14 @@ import os from pathlib import Path import torch -from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + LilcomChunkyWriter, + WhisperFbank, + WhisperFbankConfig, +) from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor, str2bool @@ -42,7 +49,9 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False): +def compute_fbank_alimeeting( + num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False +): src_dir = Path("data/manifests/alimeeting") output_dir = Path("data/fbank") num_jobs = min(8, os.cpu_count()) @@ -71,7 +80,9 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False ) if whisper_fbank: - extractor = WhisperFbank(WhisperFbankConfig(num_filters=num_mel_bins, device='cuda')) + extractor = WhisperFbank( + WhisperFbankConfig(num_filters=num_mel_bins, device="cuda") + ) else: extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) @@ -86,7 +97,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False supervisions=m["supervisions"], ) if "train" in partition and perturb_speed: - logging.info(f"Doing speed perturb") + logging.info("Doing speed perturb") cut_set = ( cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) @@ -140,5 +151,7 @@ if __name__ == "__main__": args = get_args() compute_fbank_alimeeting( - num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed, whisper_fbank=args.whisper_fbank + num_mel_bins=args.num_mel_bins, + perturb_speed=args.perturb_speed, + whisper_fbank=args.whisper_fbank, ) diff --git a/egs/alimeeting/ASR/prepare.sh b/egs/alimeeting/ASR/prepare.sh index 8580d9e2a..301ab0111 100755 --- a/egs/alimeeting/ASR/prepare.sh +++ b/egs/alimeeting/ASR/prepare.sh @@ -5,8 +5,8 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python set -eou pipefail -stage=20 -stop_stage=20 +stage=-1 +stop_stage=7 perturb_speed=true # We assume dl_dir (download dir) contains the following diff --git a/egs/multi_zh-hans/ASR/README.md b/egs/multi_zh-hans/ASR/README.md index 537816a5d..1e60c733c 100644 --- a/egs/multi_zh-hans/ASR/README.md +++ b/egs/multi_zh-hans/ASR/README.md @@ -36,4 +36,4 @@ This recipe includes scripts for training Zipformer model using multiple Chinese 3. AliMeeting 4. MagicData 5. KeSpeech-ASR -6. WeNetSpeech \ No newline at end of file +6. WeNetSpeech diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_dev_test.py b/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_dev_test.py index 252a3c8d4..6f75dbfa4 100755 --- a/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_dev_test.py +++ b/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_dev_test.py @@ -17,14 +17,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import logging from pathlib import Path -import argparse import torch -from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter +from lhotse import ( + CutSet, + KaldifeatFbank, + KaldifeatFbankConfig, + LilcomChunkyWriter, + WhisperFbank, + WhisperFbankConfig, +) from icefall.utils import str2bool + # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. # Do this outside of main() in case it needs to take effect @@ -32,6 +40,7 @@ from icefall.utils import str2bool torch.set_num_threads(1) torch.set_num_interop_threads(1) + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -52,6 +61,7 @@ def get_parser(): ) return parser + def compute_fbank_kespeech_dev_test(args): in_out_dir = Path("data/fbank/kespeech") # number of workers in dataloader @@ -70,7 +80,9 @@ def compute_fbank_kespeech_dev_test(args): if torch.cuda.is_available(): device = torch.device("cuda", 0) if args.whisper_fbank: - extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device=device)) + extractor = WhisperFbank( + WhisperFbankConfig(num_filters=args.num_mel_bins, device=device) + ) else: extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_splits.py b/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_splits.py index ec02b45af..c398411f6 100755 --- a/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_splits.py +++ b/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_splits.py @@ -25,16 +25,17 @@ from pathlib import Path import torch from lhotse import ( CutSet, - WhisperFbank, - WhisperFbankConfig, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter, + WhisperFbank, + WhisperFbankConfig, set_audio_duration_mismatch_tolerance, set_caching_enabled, ) from icefall.utils import str2bool + # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. # Do this outside of main() in case it needs to take effect @@ -129,7 +130,9 @@ def compute_fbank_kespeech_splits(args): if torch.cuda.is_available(): device = torch.device("cuda", 0) if args.whisper_fbank: - extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda')) + extractor = WhisperFbank( + WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda") + ) else: extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) logging.info(f"device: {device}") diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_magicdata.py b/egs/multi_zh-hans/ASR/local/compute_fbank_magicdata.py index 4ec009d26..192bffa9f 100755 --- a/egs/multi_zh-hans/ASR/local/compute_fbank_magicdata.py +++ b/egs/multi_zh-hans/ASR/local/compute_fbank_magicdata.py @@ -30,7 +30,14 @@ import os from pathlib import Path import torch -from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + LilcomChunkyWriter, + WhisperFbank, + WhisperFbankConfig, +) from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor, str2bool @@ -42,6 +49,7 @@ from icefall.utils import get_executor, str2bool torch.set_num_threads(1) torch.set_num_interop_threads(1) + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -62,7 +70,10 @@ def get_parser(): ) return parser -def compute_fbank_magicdata(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False): + +def compute_fbank_magicdata( + num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False +): src_dir = Path("data/manifests/magicdata") output_dir = Path("data/fbank") num_jobs = min(8, os.cpu_count()) @@ -84,9 +95,11 @@ def compute_fbank_magicdata(num_mel_bins: int = 80, speed_perturb: bool = False, list(manifests.keys()), dataset_parts, ) - + if args.whisper_fbank: - extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")) + extractor = WhisperFbank( + WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda") + ) else: extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) @@ -145,5 +158,7 @@ if __name__ == "__main__": args = get_args() compute_fbank_magicdata( - num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank + num_mel_bins=args.num_mel_bins, + speed_perturb=args.speed_perturb, + whisper_fbank=args.whisper_fbank, ) diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_primewords.py b/egs/multi_zh-hans/ASR/local/compute_fbank_primewords.py index e47cd430e..019b10d24 100755 --- a/egs/multi_zh-hans/ASR/local/compute_fbank_primewords.py +++ b/egs/multi_zh-hans/ASR/local/compute_fbank_primewords.py @@ -30,7 +30,14 @@ import os from pathlib import Path import torch -from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + LilcomChunkyWriter, + WhisperFbank, + WhisperFbankConfig, +) from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor, str2bool @@ -43,7 +50,9 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_primewords(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False): +def compute_fbank_primewords( + num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False +): src_dir = Path("data/manifests/primewords") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) @@ -65,9 +74,11 @@ def compute_fbank_primewords(num_mel_bins: int = 80, speed_perturb: bool = False list(manifests.keys()), dataset_parts, ) - + if whisper_fbank: - extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda')) + extractor = WhisperFbank( + WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda") + ) else: extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) @@ -128,5 +139,7 @@ if __name__ == "__main__": args = get_args() compute_fbank_primewords( - num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank + num_mel_bins=args.num_mel_bins, + speed_perturb=args.speed_perturb, + whisper_fbank=args.whisper_fbank, ) diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_stcmds.py b/egs/multi_zh-hans/ASR/local/compute_fbank_stcmds.py index 1ff4e1c28..f29ae5a46 100755 --- a/egs/multi_zh-hans/ASR/local/compute_fbank_stcmds.py +++ b/egs/multi_zh-hans/ASR/local/compute_fbank_stcmds.py @@ -30,7 +30,14 @@ import os from pathlib import Path import torch -from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + LilcomChunkyWriter, + WhisperFbank, + WhisperFbankConfig, +) from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor, str2bool @@ -43,7 +50,9 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_stcmds(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False): +def compute_fbank_stcmds( + num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False +): src_dir = Path("data/manifests/stcmds") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) @@ -67,7 +76,9 @@ def compute_fbank_stcmds(num_mel_bins: int = 80, speed_perturb: bool = False, wh ) if whisper_fbank: - extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda')) + extractor = WhisperFbank( + WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda") + ) else: extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) @@ -126,5 +137,7 @@ if __name__ == "__main__": args = get_args() compute_fbank_stcmds( - num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank + num_mel_bins=args.num_mel_bins, + speed_perturb=args.speed_perturb, + whisper_fbank=args.whisper_fbank, ) diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_thchs30.py b/egs/multi_zh-hans/ASR/local/compute_fbank_thchs30.py index 1362180bb..4ad78e0ba 100755 --- a/egs/multi_zh-hans/ASR/local/compute_fbank_thchs30.py +++ b/egs/multi_zh-hans/ASR/local/compute_fbank_thchs30.py @@ -30,7 +30,14 @@ import os from pathlib import Path import torch -from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + LilcomChunkyWriter, + WhisperFbank, + WhisperFbankConfig, +) from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor, str2bool @@ -43,7 +50,9 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_thchs30(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False): +def compute_fbank_thchs30( + num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False +): src_dir = Path("data/manifests/thchs30") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) @@ -71,7 +80,9 @@ def compute_fbank_thchs30(num_mel_bins: int = 80, speed_perturb: bool = False, w ) if whisper_fbank: - extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda')) + extractor = WhisperFbank( + WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda") + ) else: extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) @@ -132,5 +143,7 @@ if __name__ == "__main__": args = get_args() compute_fbank_thchs30( - num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank + num_mel_bins=args.num_mel_bins, + speed_perturb=args.speed_perturb, + whisper_fbank=args.whisper_fbank, ) diff --git a/egs/multi_zh-hans/ASR/prepare.sh b/egs/multi_zh-hans/ASR/prepare.sh index 96ae1cf60..fa515ed50 100755 --- a/egs/multi_zh-hans/ASR/prepare.sh +++ b/egs/multi_zh-hans/ASR/prepare.sh @@ -5,8 +5,8 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python set -eou pipefail -stage=121 -stop_stage=121 +stage=-1 +stop_stage=100 num_splits=100 dl_dir=$PWD/download @@ -95,10 +95,10 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_cuts_dev.jsonl.gz) . ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_cuts_test.jsonl.gz) . cd ../.. - else + else log "Abort! Please run ../../aishell2/ASR/prepare.sh --stage 3 --stop-stage 3" exit 1 - fi + fi fi log "Dataset: AISHELL-4" @@ -115,10 +115,10 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_S.jsonl.gz) . ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_test.jsonl.gz) . cd ../.. - else + else log "Abort! Please run ../../aishell4/ASR/prepare.sh --stage 3 --stop-stage 3" exit 1 - fi + fi fi log "Dataset: ST-CMDS" @@ -261,7 +261,7 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then if [ ! -f data/manifests/.kespeech.done ]; then mkdir -p data/manifests - lhotse prepare kespeech -j 16 $dl_dir/KeSpeech data/manifests/kespeech + lhotse prepare kespeech -j 16 $dl_dir/KeSpeech data/manifests/kespeech touch data/manifests/.kespeech.done fi @@ -272,8 +272,8 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then if [ ! -f data/fbank/.kespeech_preprocess_complete ]; then python3 ./local/preprocess_kespeech.py touch data/fbank/.kespeech_preprocess_complete - fi - + fi + if [ ! -f data/fbank/.kespeech.train_phase1.split.${num_splits}.done ]; then log "Spliting KeSpeech train_phase1" lhotse split ${num_splits} \ @@ -281,7 +281,7 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then data/fbank/kespeech/train_phase1_split_${num_splits} touch data/fbank/.kespeech.train_phase1.split.${num_splits}.done fi - + if [ ! -f data/fbank/.kespeech.train_phase2.split.${num_splits}.done ]; then log "Spliting KeSpeech train_phase2" lhotse split ${num_splits} \ @@ -289,7 +289,7 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then data/fbank/kespeech/train_phase2_split_${num_splits} touch data/fbank/.kespeech.train_phase2.split.${num_splits}.done fi - + log "Compute KeSpeech fbank for train_phase1" ./local/compute_fbank_kespeech_splits.py --speed-perturb true --num-splits ${num_splits} --training-subset train_phase1 @@ -314,7 +314,7 @@ if [ $stage -le 120 ] && [ $stop_stage -ge 120 ]; then if [ ! -f data/manifests/.kespeech.done ]; then mkdir -p data/manifests - lhotse prepare kespeech -j 8 $dl_dir/KeSpeech data/manifests/kespeech + lhotse prepare kespeech -j 8 $dl_dir/KeSpeech data/manifests/kespeech touch data/manifests/.kespeech.done fi @@ -325,8 +325,8 @@ if [ $stage -le 120 ] && [ $stop_stage -ge 120 ]; then if [ ! -f data/fbank/.kespeech_preprocess_complete ]; then python3 ./local/preprocess_kespeech.py --speed-perturb true touch data/fbank/.kespeech_preprocess_complete - fi - + fi + if [ ! -f data/fbank/.kespeech.train_phase1.split.${num_splits}.done ]; then log "Spliting KeSpeech train_phase1" lhotse split ${num_splits} \ @@ -334,7 +334,7 @@ if [ $stage -le 120 ] && [ $stop_stage -ge 120 ]; then data/fbank/kespeech/train_phase1_split_${num_splits} touch data/fbank/.kespeech.train_phase1.split.${num_splits}.done fi - + if [ ! -f data/fbank/.kespeech.train_phase2.split.${num_splits}.done ]; then log "Spliting KeSpeech train_phase2" lhotse split ${num_splits} \ @@ -342,7 +342,7 @@ if [ $stage -le 120 ] && [ $stop_stage -ge 120 ]; then data/fbank/kespeech/train_phase2_split_${num_splits} touch data/fbank/.kespeech.train_phase2.split.${num_splits}.done fi - + log "Compute KeSpeech fbank for train_phase1" ./local/compute_fbank_kespeech_splits.py --num-splits ${num_splits} --training-subset train_phase1 --num-mel-bins ${whisper_mel_bins} --whisper-fbank true @@ -351,7 +351,7 @@ if [ $stage -le 120 ] && [ $stop_stage -ge 120 ]; then log "Compute KeSpeech fbank for test/dev" # ./local/compute_fbank_kespeech_dev_test.py --num-mel-bins ${whisper_mel_bins} --whisper-fbank true - + if [ ! -f data/fbank/kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz ]; then pieces=$(find data/fbank/kespeech/train_phase1_split_${num_splits} -name "kespeech-asr_cuts_train_phase1.*.jsonl.gz") lhotse combine $pieces data/fbank/kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz @@ -422,7 +422,7 @@ if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then for vocab_size in ${vocab_sizes[@]}; do lang_dir=data/lang_bpe_${vocab_size} - + mkdir -p $lang_dir if [ ! -f $lang_dir/bpe.model ]; then ./local/train_bpe_model.py \ @@ -442,7 +442,7 @@ if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then --lexicon $lang_dir/lexicon.txt \ --bpe-model $lang_dir/bpe.model fi - + if [ ! -f $lang_dir/L.fst ]; then log "Converting L.pt to L.fst" ./shared/convert-k2-to-openfst.py \ @@ -463,7 +463,7 @@ fi if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then log "Stage 14: Prepare G (note that we use ngram lm of wenetspeech only for G preparation)" - + if [ -d ../../wenetspeech/ASR/data/lang_char/ ]; then cd data ln -s ../../../../wenetspeech/ASR/data/lm . @@ -482,5 +482,3 @@ if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then python ./local/compile_lg.py --lang-dir $lang_dir done fi - - diff --git a/egs/multi_zh-hans/ASR/whisper/decode.py b/egs/multi_zh-hans/ASR/whisper/decode.py index 1452c86a3..aabb80eaf 100644 --- a/egs/multi_zh-hans/ASR/whisper/decode.py +++ b/egs/multi_zh-hans/ASR/whisper/decode.py @@ -52,14 +52,14 @@ import k2 import torch import torch.nn as nn import whisper - from asr_datamodule import AsrDataModule +from lhotse.cut import Cut +from multi_dataset import MultiDataset from tn.chinese.normalizer import Normalizer from whisper.normalizers import BasicTextNormalizer from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward from zhconv import convert -from lhotse.cut import Cut -from multi_dataset import MultiDataset + from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint from icefall.env import get_env_info from icefall.utils import ( diff --git a/egs/multi_zh-hans/ASR/whisper/train.py b/egs/multi_zh-hans/ASR/whisper/train.py index 40d0dc893..c6d81b319 100644 --- a/egs/multi_zh-hans/ASR/whisper/train.py +++ b/egs/multi_zh-hans/ASR/whisper/train.py @@ -34,10 +34,10 @@ torchrun --nproc-per-node 8 ./whisper/train.py \ --model-name medium """ -import os import argparse import copy import logging +import os import random import warnings from pathlib import Path @@ -52,13 +52,13 @@ import torch.multiprocessing as mp import torch.nn as nn import whisper from asr_datamodule import AsrDataModule -from multi_dataset import MultiDataset from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict from label_smoothing import LabelSmoothingLoss from lhotse import CutSet, load_manifest from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed +from multi_dataset import MultiDataset from optim import Eden, ScaledAdam from torch import Tensor from torch.cuda.amp import GradScaler @@ -626,7 +626,9 @@ def train_one_epoch( f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt", tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}", ) - os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}") + os.system( + f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}" + ) try: with torch.cuda.amp.autocast(enabled=params.use_fp16): @@ -761,9 +763,7 @@ def run(rank, world_size, args): del model.alignment_heads if params.pretrained_model_path: - checkpoint = torch.load( - params.pretrained_model_path, map_location="cpu" - ) + checkpoint = torch.load(params.pretrained_model_path, map_location="cpu") if "model" not in checkpoint: model.load_state_dict(checkpoint, strict=True) else: @@ -866,7 +866,7 @@ def run(rank, world_size, args): valid_cuts = multi_dataset.dev_cuts() valid_dl = data_module.valid_dataloaders(valid_cuts) - + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") diff --git a/egs/speechio/ASR/README.md b/egs/speechio/ASR/README.md new file mode 100644 index 000000000..2675efd9b --- /dev/null +++ b/egs/speechio/ASR/README.md @@ -0,0 +1,15 @@ + +# Introduction + +This recipe includes some different pretrained ASR models' decoding results with [SpeechIO](https://github.com/SpeechColab/Leaderboard) test sets. + +[./RESULTS.md](./RESULTS.md) contains the latest results. + +# Pretrained Models + +The following table lists the pretrained models. + +| | Huggingface | Comment | +|---------------------------------------|--------------------|-----------------------------| +| `zipformer` | zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24 | Using [multi-hans-zh recipe](../../multi_zh-hans/ASR/zipformer/) training | | +| `whisper` | yuekai/icefall_asr_wenetspeech_whisper | Using [wenetspeech recipe](../../wenetspeech/ASR/whisper/) training | diff --git a/egs/speechio/ASR/RESULTS.md b/egs/speechio/ASR/RESULTS.md new file mode 100644 index 000000000..07649e383 --- /dev/null +++ b/egs/speechio/ASR/RESULTS.md @@ -0,0 +1,92 @@ +## Results + +### SpeechIO Test Set Decoding Results + +##### Decoding results using pretrained [multi-hans-zh zipformer](https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24), [whipser-large-v2](https://github.com/openai/whisper/blob/main/whisper/__init__.py#L27), [whisper-large-v2-wenetspeech-ft](https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper). + +| | zipformer_transducer | zipformer_transducer_blank_penalty_2 | whisper_large_v2 | whisper_large_v2_wenetspeech | whisper_large_v2_wenetspeech_zipformer_fusion | +|------------------------|----------------------|--------------------------------------|------------------|------------------------------|-----------------------------------------------| +| SPEECHIO_ASR_ZH00000 | 10.04 | 8.04 | 11.4 | 9.88 | 7.78 | +| SPEECHIO_ASR_ZH00001 | 1.67 | 1.51 | 2.49 | 1.57 | 1.38 | +| SPEECHIO_ASR_ZH00002 | 5.89 | 5.27 | 7.89 | 5.65 | 4.99 | +| SPEECHIO_ASR_ZH00003 | 2.66 | 2.79 | 5.94 | 2.27 | 2.33 | +| SPEECHIO_ASR_ZH00004 | 3.6 | 3.34 | 4.57 | 3.62 | 3.26 | +| SPEECHIO_ASR_ZH00005 | 7.54 | 5.81 | 8.39 | 7.26 | 5.43 | +| SPEECHIO_ASR_ZH00006 | 15.59 | 13.34 | 19.07 | 13.64 | 11.96 | +| SPEECHIO_ASR_ZH00007 | 15.9 | 15.05 | 16.7 | 14.06 | 13.73 | +| SPEECHIO_ASR_ZH00008 | 11.07 | 9.68 | 14.69 | 10.34 | 8.87 | +| SPEECHIO_ASR_ZH00009 | 7.38 | 6.23 | 8.32 | 6.74 | 5.96 | +| SPEECHIO_ASR_ZH00010 | 9.19 | 7.33 | 11.2 | 8.85 | 6.97 | +| SPEECHIO_ASR_ZH00011 | 4.16 | 3.84 | 54.56 | 4.09 | 3.72 | +| SPEECHIO_ASR_ZH00012 | 7.61 | 6.58 | 10.53 | 8.35 | 6.27 | +| SPEECHIO_ASR_ZH00013 | 8.72 | 7.66 | 9.32 | 7.26 | 6.7 | +| SPEECHIO_ASR_ZH00014 | 9.69 | 8.71 | 9.03 | 7.03 | 6.59 | +| SPEECHIO_ASR_ZH00015 | 11.94 | 11.37 | 16.58 | 12.02 | 11.11 | +| SPEECHIO_ASR_ZH00016 | 9.79 | 8.79 | 14.1 | 10.19 | 8.15 | +| SPEECHIO_ASR_ZH00017 | 8 | 6.72 | 9.04 | 8.9 | 6.44 | +| SPEECHIO_ASR_ZH00018 | 5.42 | 5.02 | 6.06 | 4.86 | 4.4 | +| SPEECHIO_ASR_ZH00019 | 11.26 | 9.06 | 14.8 | 9.83 | 8.22 | +| SPEECHIO_ASR_ZH00020 | 4.37 | 4.23 | 5.97 | 4.23 | 4.13 | +| SPEECHIO_ASR_ZH00021 | 7.81 | 6.34 | 8.53 | 7.08 | 5.88 | +| SPEECHIO_ASR_ZH00022 | 9.11 | 8.54 | 9.7 | 8.97 | 8.02 | +| SPEECHIO_ASR_ZH00023 | 9.98 | 8.98 | 6.31 | 9.44 | 8.57 | +| SPEECHIO_ASR_ZH00024 | 16.15 | 12.95 | 20.54 | 15.92 | 12.28 | +| SPEECHIO_ASR_ZH00025 | 10.38 | 9.82 | 11.4 | 10.26 | 9.27 | +| SPEECHIO_ASR_ZH00026 | 5.69 | 5.63 | 9.09 | 5.95 | 5.51 | +| Average WER (001-026) | 8.48 | 7.48 | 12.11 | 8.01 | 6.93 | + + + + +Command for decoding using fine-tuned whisper: +```bash +git lfs install +git clone https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper +ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-4-avg3.pt whisper/exp_large_v2_wenetspeech/epoch-999.pt + +python3 ./whisper/decode.py \ + --exp-dir whisper/exp_large_v2_wenetspeech \ + --model-name large-v2 \ + --epoch 999 --avg 1 \ + --start-index 0 --end-index 26 \ + --remove-whisper-encoder-input-length-restriction True \ + --manifest-dir data/fbank \ + --beam-size 1 --max-duration 50 +``` +Command for decoding using pretrained zipformer: +```bash +git lfs install +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24 +cd icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24 +git lfs pull --include "exp/pretrained.pt" +git lfs pull --include "data/lang_bpe_2000/*" +ln -s ../icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24/exp/pretrained.pt zipformer/exp_pretrain/epoch-999.pt +ln -s ../icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24/data/lang_bpe_2000/ ./data +wget https://huggingface.co/pkufool/icefall-asr-zipformer-wenetspeech-20230615/resolve/main/data/lang_char/words.txt +mv words.txt ./data/lang_bpe_2000/ + +./zipformer/decode.py \ + --epoch 999 \ + --avg 1 \ + --blank-penalty 2.0 \ + --use-averaged-model false \ + --exp-dir ./zipformer/exp_pretrain \ + --max-duration 600 \ + --start-index 0 --end-index 26 \ + --manifest-dir data/fbank_kaldi \ + --decoding-method greedy_search +``` +Command for fusion the above decoding results from whisper and zipformer: +```bash +python local/whisper_zipformer_fusion.py \ + --whisper-log-dir ./whisper/exp_large_v2_wenetspeech \ + --zipformer-log-dir ./zipformer/exp_pretrain/greedy_search \ + --output-log-dir ./results_fusion + +``` + +See why the fusion helps [here](./local/whisper_zipformer_fusion.py). + +SpeechIO fbank features, decoding scripts, logs, and decoding results +are available at + diff --git a/egs/speechio/ASR/local/compute_fbank_speechio.py b/egs/speechio/ASR/local/compute_fbank_speechio.py index d6956781b..5b3489a9f 100644 --- a/egs/speechio/ASR/local/compute_fbank_speechio.py +++ b/egs/speechio/ASR/local/compute_fbank_speechio.py @@ -30,7 +30,14 @@ import os from pathlib import Path import torch -from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + LilcomChunkyWriter, + WhisperFbank, + WhisperFbankConfig, +) from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor, str2bool @@ -44,7 +51,13 @@ torch.set_num_interop_threads(1) SPEECHIO_TESTSET_INDEX = 26 # Currently, from 0 - 26 test sets are open source. -def compute_fbank_speechio(num_mel_bins: int = 80, speed_perturb: bool = False, fbank_dir: str = "data/fbank", whisper_fbank: bool = False): + +def compute_fbank_speechio( + num_mel_bins: int = 80, + speed_perturb: bool = False, + fbank_dir: str = "data/fbank", + whisper_fbank: bool = False, +): src_dir = Path("data/manifests") output_dir = Path(fbank_dir) num_jobs = min(8, os.cpu_count()) @@ -72,7 +85,9 @@ def compute_fbank_speechio(num_mel_bins: int = 80, speed_perturb: bool = False, ) if whisper_fbank: - extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda')) + extractor = WhisperFbank( + WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda") + ) else: extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) @@ -127,5 +142,7 @@ if __name__ == "__main__": args = get_args() compute_fbank_speechio( - num_mel_bins=args.num_mel_bins, fbank_dir=args.fbank_dir, whisper_fbank=args.whisper_fbank + num_mel_bins=args.num_mel_bins, + fbank_dir=args.fbank_dir, + whisper_fbank=args.whisper_fbank, ) diff --git a/egs/speechio/ASR/local/display_manifest_statistics.py b/egs/speechio/ASR/local/display_manifest_statistics.py index b2f52d137..0c803bfcd 100644 --- a/egs/speechio/ASR/local/display_manifest_statistics.py +++ b/egs/speechio/ASR/local/display_manifest_statistics.py @@ -35,15 +35,18 @@ def main(): idx = f"{i}".zfill(2) dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}") - prefix="speechio" - suffix="jsonl.gz" + prefix = "speechio" + suffix = "jsonl.gz" for partition in dataset_parts: path = f"./data/fbank/{prefix}_cuts_{partition}.{suffix}" cuts = load_manifest_lazy(path) - print(f"===================Duration statistics of {partition}===================") + print( + f"===================Duration statistics of {partition}===================" + ) cuts.describe() + if __name__ == "__main__": main() diff --git a/egs/speechio/ASR/local/whisper_zipformer_fusion.py b/egs/speechio/ASR/local/whisper_zipformer_fusion.py new file mode 100644 index 000000000..04c5e75f0 --- /dev/null +++ b/egs/speechio/ASR/local/whisper_zipformer_fusion.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python3 +# +# Copyright 2024 Author: Yuekai Zhang +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This file uses whisper and zipformer decoding results to generate fusion decoding results. +Since whisper model is more likely to make deletion errors and zipformer model is more likely to make substitution and insertion errors, +we trust whisper model when it makes substitution and insertion errors and trust zipformer model when it makes deletion errors. + +Usage: + python whisper_zipformer_fusion.py --whisper-log-dir ./whisper_decoding_log_dir --zipformer-log-dir ./zipformer_decoding_log_dir --output-log-dir ./results_fusion +""" + +import argparse +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import kaldialign + +from icefall.utils import store_transcripts, write_error_stats + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--whisper-log-dir", + type=str, + default="./recogs_whisper", + help="The directory to store the whisper logs: e.g. recogs-SPEECHIO_ASR_ZH00014-beam-search-epoch--1-avg-1.txt", + ) + parser.add_argument( + "--zipformer-log-dir", + type=str, + default="./recogs_zipformer", + help="The directory to store the zipformer logs", + ) + parser.add_argument( + "--output-log-dir", + type=str, + default="./results_fusion", + help="The directory to store the fusion logs", + ) + return parser + + +def save_results( + res_dir: Path, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + + suffix = "epoch-999-avg-1" + + for key, results in results_dict.items(): + recog_path = res_dir / f"recogs-{test_set_name}-{key}-{suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + print(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = res_dir / f"errs-{test_set_name}-{key}-{suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + print("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = res_dir / f"wer-summary-{test_set_name}-{key}-{suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + print(s) + + +def extract_hyp_ref_wavname(filename): + """ + 0Phqz8RWYuE_0007-5: ref=['R', 'Y', 'Y', 'B', '它最大的优势就是进光量或者说是对光线利用率的提升'] + 0Phqz8RWYuE_0007-5: hyp=而YB它最大的优势是近光量或者说是对光线利用率的提升 + """ + hyps, refs, wav_name = [], [], [] + with open(filename, "r") as f: + for line in f: + if "ref" in line: + ref = line.split("ref=")[1].strip() + ref = ref[2:-2] + list_elements = ref.split("', '") + ref = "".join(list_elements) + refs.append(ref) + elif "hyp" in line: + hyp = line.split("hyp=")[1].strip() + hyps.append(hyp) + wav_name.append(line.split(":")[0]) + return hyps, refs, wav_name + + +def get_pair_filenames( + whisper_log_dir, + zipformer_log_dir, + whisper_suffix="beam-search-epoch-999-avg-1", + zipformer_suffix="greedy_search_blank_penalty_2.0-epoch-999-avg-1-context-2-max-sym-per-frame-1-blank-penalty-2.0", +): + results = [] + start_index, end_index = 0, 26 + dataset_parts = [] + for i in range(start_index, end_index + 1): + idx = f"{i}".zfill(2) + dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}") + for partition in dataset_parts: + whisper_filename = f"{whisper_log_dir}/recogs-{partition}-{whisper_suffix}.txt" + zipformer_filename = ( + f"{zipformer_log_dir}/recogs-{partition}-{zipformer_suffix}.txt" + ) + results.append((whisper_filename, zipformer_filename)) + return results + + +def fusion_hyps_trust_substituion_insertion( + hyps_whisper, hyps_zipformer, refs, ERR="*" +): + """ + alignment example: + [('我', '你'), ('在', '*'), ('任', '任'), ('的', '的'), ('时', '时'), ('候', '候'), ('*', '呢')] + left is whisper, right is zipformer + for whisper substitution, use left + for whisper insertion, use left + for whisper deletion, use right + """ + hyps_fusion = [] + for hyp_w, hyp_z, ref in zip(hyps_whisper, hyps_zipformer, refs): + ali = kaldialign.align(hyp_w, hyp_z, ERR) + hyp_f = "" + for a in ali: + if a[0] == ERR: + hyp_f += a[1] + else: + hyp_f += a[0] + hyps_fusion.append(hyp_f) + return hyps_fusion + + +def fusion_hyps_trust_substituion(hyps_whisper, hyps_zipformer, refs, ERR="*"): + """ + alignment example: + [('我', '你'), ('在', '*'), ('任', '任'), ('的', '的'), ('时', '时'), ('候', '候'), ('*', '呢')] + left is whisper, right is zipformer + for whisper substitution, use left + for whisper insertion, use right + for whisper deletion, use right + """ + hyps_fusion = [] + for hyp_w, hyp_z, ref in zip(hyps_whisper, hyps_zipformer, refs): + ali = kaldialign.align(hyp_w, hyp_z, ERR) + hyp_f = "" + for a in ali: + if a[0] == ERR: + hyp_f += a[1] + elif a[1] == ERR: + pass + else: + hyp_f += a[0] + hyps_fusion.append(hyp_f) + return hyps_fusion + + +def main(): + parser = get_parser() + args = parser.parse_args() + # mkdir output_log_dir + Path(args.output_log_dir).mkdir(parents=True, exist_ok=True) + pair_logs = get_pair_filenames(args.whisper_log_dir, args.zipformer_log_dir) + for pair in pair_logs: + hyps_whisper, refs, wav_name = extract_hyp_ref_wavname(pair[0]) + hyps_zipformer, _, _ = extract_hyp_ref_wavname(pair[1]) + + hyps_fusion = fusion_hyps_trust_substituion_insertion( + hyps_whisper, hyps_zipformer, refs + ) + + partition_name = pair[0].split("/")[-1].split("-")[1] + save_results( + Path(args.output_log_dir), + partition_name, + {"fusion": list(zip(wav_name, refs, hyps_fusion))}, + ) + + print(f"Processed {partition_name}") + + +if __name__ == "__main__": + main() diff --git a/egs/speechio/ASR/prepare.sh b/egs/speechio/ASR/prepare.sh index 5b29440e5..048a66d8f 100644 --- a/egs/speechio/ASR/prepare.sh +++ b/egs/speechio/ASR/prepare.sh @@ -12,7 +12,7 @@ stop_stage=3 # - $dl_dir/SPEECHIO_ASR_ZH00000 # This directory contains the following files downloaded from # https://github.com/SpeechColab/Leaderboard -# +# # - metadata.tsv # - wav # - wav.scp diff --git a/egs/speechio/ASR/whisper/asr_datamodule.py b/egs/speechio/ASR/whisper/asr_datamodule.py index a32ea83e8..7382fd3f5 100644 --- a/egs/speechio/ASR/whisper/asr_datamodule.py +++ b/egs/speechio/ASR/whisper/asr_datamodule.py @@ -34,9 +34,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures SimpleCutSampler, SpecAugment, ) -from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples - AudioSamples, -) +from lhotse.dataset.input_strategies import AudioSamples # noqa F401 For AudioSamples from lhotse.utils import fix_random_seed from torch.utils.data import DataLoader diff --git a/egs/speechio/ASR/whisper/decode.py b/egs/speechio/ASR/whisper/decode.py index 5a0d478eb..001367791 100644 --- a/egs/speechio/ASR/whisper/decode.py +++ b/egs/speechio/ASR/whisper/decode.py @@ -53,14 +53,14 @@ import k2 import torch import torch.nn as nn import whisper - from asr_datamodule import AsrDataModule +from lhotse.cut import Cut +from multi_dataset import MultiDataset from tn.chinese.normalizer import Normalizer from whisper.normalizers import BasicTextNormalizer from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward from zhconv import convert -from lhotse.cut import Cut -from multi_dataset import MultiDataset + from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint from icefall.env import get_env_info from icefall.utils import ( diff --git a/egs/speechio/ASR/whisper/multi_dataset.py b/egs/speechio/ASR/whisper/multi_dataset.py index f427c271f..f55d45394 100644 --- a/egs/speechio/ASR/whisper/multi_dataset.py +++ b/egs/speechio/ASR/whisper/multi_dataset.py @@ -45,17 +45,15 @@ class MultiDataset: idx = f"{i}".zfill(2) dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}") - prefix="speechio" - suffix="jsonl.gz" + prefix = "speechio" + suffix = "jsonl.gz" results_dict = {} for partition in dataset_parts: path = f"{prefix}_cuts_{partition}.{suffix}" logging.info(f"Loading {path} set in lazy mode") - test_cuts = load_manifest_lazy( - self.fbank_dir / path - ) + test_cuts = load_manifest_lazy(self.fbank_dir / path) results_dict[partition] = test_cuts - return results_dict \ No newline at end of file + return results_dict diff --git a/egs/speechio/ASR/zipformer/decode.py b/egs/speechio/ASR/zipformer/decode.py index 91c43d044..ffdd7b500 100644 --- a/egs/speechio/ASR/zipformer/decode.py +++ b/egs/speechio/ASR/zipformer/decode.py @@ -303,6 +303,17 @@ def get_parser(): fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) add_model_arguments(parser) return parser @@ -431,6 +442,7 @@ def decode_one_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, + blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -455,6 +467,7 @@ def decode_one_batch( model=model, encoder_out=encoder_out_i, max_sym_per_frame=params.max_sym_per_frame, + blank_penalty=params.blank_penalty, ) elif params.decoding_method == "beam_search": hyp = beam_search( @@ -468,8 +481,9 @@ def decode_one_batch( ) hyps.append(sp.decode(hyp).split()) + key = f"blank_penalty_{params.blank_penalty}" if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} + return {"greedy_search_" + key: hyps} elif "fast_beam_search" in params.decoding_method: key = f"beam_{params.beam}_" key += f"max_contexts_{params.max_contexts}_" @@ -657,6 +671,7 @@ def main(): params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + params.suffix += f"-blank-penalty-{params.blank_penalty}" if params.use_averaged_model: params.suffix += "-use-averaged-model" diff --git a/egs/speechio/ASR/zipformer/train.py b/egs/speechio/ASR/zipformer/train.py deleted file mode 100644 index c1bbd2ee8..000000000 --- a/egs/speechio/ASR/zipformer/train.py +++ /dev/null @@ -1,1385 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo, -# Zengwei Yao, -# Daniel Povey) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -# For non-streaming model training: -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --max-duration 1000 - -# For streaming model training: -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --causal 1 \ - --max-duration 1000 - -It supports training with: - - transducer loss (default), with `--use-transducer True --use-ctc False` - - ctc loss (not recommended), with `--use-transducer False --use-ctc True` - - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import AsrDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import AsrModel -from multi_dataset import MultiDataset -from optim import Eden, ScaledAdam -from scaling import ScheduledFloat -from subsampling import Conv2dSubsampling -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer2 - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - MetricsTracker, - get_parameter_groups_with_lrs, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def get_adjusted_batch_count(params: AttributeDict) -> float: - # returns the number of batches we would have used so far if we had used the reference - # duration. This is for purposes of set_batch_count(). - return ( - params.batch_idx_train - * (params.max_duration * params.world_size) - / params.ref_duration - ) - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for name, module in model.named_modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - if hasattr(module, "name"): - module.name = name - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,2,3,4,3,2", - help="Number of zipformer encoder layers per stack, comma separated.", - ) - - parser.add_argument( - "--downsampling-factor", - type=str, - default="1,2,4,8,4,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--feedforward-dim", - type=str, - default="512,768,1024,1536,1024,768", - help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", - ) - - parser.add_argument( - "--num-heads", - type=str, - default="4,4,4,8,4,4", - help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", - ) - - parser.add_argument( - "--encoder-dim", - type=str, - default="192,256,384,512,384,256", - help="Embedding dimension in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--query-head-dim", - type=str, - default="32", - help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--value-head-dim", - type=str, - default="12", - help="Value dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-head-dim", - type=str, - default="4", - help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-dim", - type=int, - default="48", - help="Positional-encoding embedding dimension", - ) - - parser.add_argument( - "--encoder-unmasked-dim", - type=str, - default="192,192,256,256,256,192", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", - ) - - parser.add_argument( - "--cnn-module-kernel", - type=str, - default="31,31,15,15,15,31", - help="Sizes of convolutional kernels in convolution modules in each encoder stack: " - "a single int or comma-separated list.", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--causal", - type=str2bool, - default=False, - help="If True, use causal version of model.", - ) - - parser.add_argument( - "--chunk-size", - type=str, - default="16,32,64,-1", - help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " - " Must be just -1 if --causal=False", - ) - - parser.add_argument( - "--left-context-frames", - type=str, - default="64,128,256,-1", - help="Maximum left-contexts for causal training, measured in frames which will " - "be converted to a number of chunks. If splitting into chunks, " - "chunk left-context frames will be chosen randomly from this list; else not relevant.", - ) - - parser.add_argument( - "--use-transducer", - type=str2bool, - default=True, - help="If True, use Transducer head.", - ) - - parser.add_argument( - "--use-ctc", - type=str2bool, - default=False, - help="If True, use CTC head.", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_2000/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--base-lr", type=float, default=0.045, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=7500, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=3.5, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--ref-duration", - type=float, - default=600, - help="Reference batch duration for purposes of adjusting batch counts for setting various " - "schedules inside the model", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" "part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--ctc-loss-scale", - type=float, - default=0.2, - help="Scale for CTC loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=4000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 1. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def _to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - -def get_encoder_embed(params: AttributeDict) -> nn.Module: - # encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, (T - 7) // 2, encoder_dims). - # That is, it does two things simultaneously: - # (1) subsampling: T -> (T - 7) // 2 - # (2) embedding: num_features -> encoder_dims - # In the normal configuration, we will downsample once more at the end - # by a factor of 2, and most of the encoder stacks will run at a lower - # sampling rate. - encoder_embed = Conv2dSubsampling( - in_channels=params.feature_dim, - out_channels=_to_int_tuple(params.encoder_dim)[0], - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - ) - return encoder_embed - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Zipformer2( - output_downsampling_factor=2, - downsampling_factor=_to_int_tuple(params.downsampling_factor), - num_encoder_layers=_to_int_tuple(params.num_encoder_layers), - encoder_dim=_to_int_tuple(params.encoder_dim), - encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), - query_head_dim=_to_int_tuple(params.query_head_dim), - pos_head_dim=_to_int_tuple(params.pos_head_dim), - value_head_dim=_to_int_tuple(params.value_head_dim), - pos_dim=params.pos_dim, - num_heads=_to_int_tuple(params.num_heads), - feedforward_dim=_to_int_tuple(params.feedforward_dim), - cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - warmup_batches=4000.0, - causal=params.causal, - chunk_size=_to_int_tuple(params.chunk_size), - left_context_frames=_to_int_tuple(params.left_context_frames), - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=max(_to_int_tuple(params.encoder_dim)), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_model(params: AttributeDict) -> nn.Module: - assert params.use_transducer or params.use_ctc, ( - f"At least one of them should be True, " - f"but got params.use_transducer={params.use_transducer}, " - f"params.use_ctc={params.use_ctc}" - ) - - encoder_embed = get_encoder_embed(params) - encoder = get_encoder_model(params) - - if params.use_transducer: - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - else: - decoder = None - joiner = None - - model = AsrModel( - encoder_embed=encoder_embed, - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=max(_to_int_tuple(params.encoder_dim)), - decoder_dim=params.decoder_dim, - vocab_size=params.vocab_size, - use_transducer=params.use_transducer, - use_ctc=params.use_ctc, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - - loss = 0.0 - - if params.use_transducer: - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - if params.use_ctc: - loss += params.ctc_loss_scale * ctc_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - if params.use_transducer: - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - if params.use_ctc: - info["ctc_loss"] = ctc_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - saved_bad_model = False - - def save_bad_model(suffix: str = ""): - save_checkpoint_impl( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=0, - ) - - for batch_idx, batch in enumerate(train_dl): - if batch_idx % 10 == 0: - set_batch_count(model, get_adjusted_batch_count(params)) - - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - save_bad_model() - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) - - if batch_idx % params.log_interval == 0: - cur_lr = max(scheduler.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - if not params.use_transducer: - params.ctc_loss_scale = 1.0 - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - optimizer = ScaledAdam( - get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), - lr=params.base_lr, # should have no effect - clipping_scale=2.0, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - data_module = AsrDataModule(args) - multi_dataset = MultiDataset(args.manifest_dir) - - train_cuts = multi_dataset.train_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 20.0: - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 - tokens = sp.encode(c.supervisions[0].text, out_type=str) - - if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - - return True - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = data_module.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = multi_dataset.dev_cuts() - valid_dl = data_module.valid_dataloaders(valid_cuts) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - AsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/speechio/ASR/zipformer/train.py b/egs/speechio/ASR/zipformer/train.py new file mode 120000 index 000000000..ad7216cf7 --- /dev/null +++ b/egs/speechio/ASR/zipformer/train.py @@ -0,0 +1 @@ +../../../multi_zh-hans/ASR/zipformer/train.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py index d73644d71..ac4e92ec5 100755 --- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py +++ b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py @@ -16,11 +16,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import logging from pathlib import Path -import argparse + import torch -from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter +from lhotse import ( + CutSet, + KaldifeatFbank, + KaldifeatFbankConfig, + LilcomChunkyWriter, + WhisperFbank, + WhisperFbankConfig, +) # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. @@ -32,6 +40,7 @@ torch.multiprocessing.set_sharing_strategy("file_system") from icefall.utils import str2bool + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -52,6 +61,7 @@ def get_parser(): ) return parser + def compute_fbank_wenetspeech_dev_test(args): in_out_dir = Path("data/fbank") # number of workers in dataloader @@ -66,7 +76,9 @@ def compute_fbank_wenetspeech_dev_test(args): if torch.cuda.is_available(): device = torch.device("cuda", 0) if args.whisper_fbank: - extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda')) + extractor = WhisperFbank( + WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda") + ) else: extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py index ad7eb55cc..804a302bd 100755 --- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py +++ b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py @@ -22,20 +22,19 @@ from datetime import datetime from pathlib import Path import torch -from lhotse import ( +from lhotse import ( # KaldifeatWhisperFbank,; KaldifeatWhisperFbankConfig, CutSet, - WhisperFbank, - WhisperFbankConfig, - # KaldifeatWhisperFbank, - # KaldifeatWhisperFbankConfig, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter, + WhisperFbank, + WhisperFbankConfig, set_audio_duration_mismatch_tolerance, set_caching_enabled, ) -from icefall.utils import str2bool, get_executor +from icefall.utils import get_executor, str2bool + # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. # Do this outside of main() in case it needs to take effect @@ -148,11 +147,11 @@ def compute_fbank_wenetspeech_splits(args): set_audio_duration_mismatch_tolerance(0.01) # 10ms tolerance set_caching_enabled(False) - #with get_executor() as ex: # Initialize the executor only once. + # with get_executor() as ex: # Initialize the executor only once. for i in range(start, stop): idx = f"{i}".zfill(num_digits) logging.info(f"Processing {i+1}/{num_splits}") - + cuts_path = output_dir / f"cuts_{subset}.{idx}.jsonl.gz" if cuts_path.is_file(): logging.info(f"{cuts_path} exists - skipping") @@ -177,13 +176,6 @@ def compute_fbank_wenetspeech_splits(args): storage_type=LilcomChunkyWriter, overwrite=True, ) - # cut_set = cut_set.compute_and_store_features( - # extractor=extractor, - # storage_path=f"{output_dir}/feats_{subset}_{idx}", - # num_jobs=args.num_workers, - # executor=ex, - # storage_type=LilcomChunkyWriter, - # ) logging.info(f"Saving to {cuts_path}") cut_set.to_file(cuts_path) diff --git a/egs/wenetspeech/ASR/prepare.sh b/egs/wenetspeech/ASR/prepare.sh index fc2f9c16b..a8ad36e52 100755 --- a/egs/wenetspeech/ASR/prepare.sh +++ b/egs/wenetspeech/ASR/prepare.sh @@ -6,8 +6,8 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python set -eou pipefail nj=15 -stage=131 -stop_stage=131 +stage=0 +stop_stage=100 # Split L subset to this number of pieces # This is to avoid OOM during feature extraction. @@ -309,7 +309,7 @@ if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then mkdir -p $text_out_dir log "Genearating training text data" - + if [ ! -f $text_out_dir/lm_data.pt ]; then ./local/prepare_char_lm_training_data.py \ --lang-char data/lang_char \ @@ -318,14 +318,14 @@ if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then fi log "Generating DEV text data" - # prepare validation text data + # prepare validation text data if [ ! -f $text_out_dir/valid_text_words_segmentation ]; then valid_text=${text_out_dir}/ gunzip -c data/manifests/wenetspeech_supervisions_DEV.jsonl.gz \ | jq '.text' | sed 's/"//g' \ | ./local/text2token.py -t "char" > $text_out_dir/valid_text - + python3 ./local/text2segments.py \ --num-process $nj \ --input-file $text_out_dir/valid_text \ @@ -337,7 +337,7 @@ if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then --lm-data $text_out_dir/valid_text_words_segmentation \ --lm-archive $text_out_dir/lm_data_valid.pt - # prepare TEST text data + # prepare TEST text data if [ ! -f $text_out_dir/TEST_text_words_segmentation ]; then log "Prepare text for test set." for test_set in TEST_MEETING TEST_NET; do @@ -350,7 +350,7 @@ if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then --input-file $text_out_dir/${test_set}_text \ --output-file $text_out_dir/${test_set}_text_words_segmentation done - + cat $text_out_dir/TEST_*_text_words_segmentation > $text_out_dir/test_text_words_segmentation fi @@ -401,4 +401,4 @@ if [ $stage -le 21 ] && [ $stop_stage -ge 21 ]; then --lm-data-valid data/lm_char/sorted_lm_data-valid.pt \ --vocab-size 5537 \ --master-port 12340 -fi \ No newline at end of file +fi diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_check.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_check.py index ee8252a90..8c192913e 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_check.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_check.py @@ -74,10 +74,10 @@ It will generate the following 3 files inside $repo/exp: import argparse import logging -from icefall import is_module_available +import torch from onnx_pretrained import OnnxModel -import torch +from icefall import is_module_available def get_parser(): diff --git a/egs/wenetspeech/ASR/whisper/decode.py b/egs/wenetspeech/ASR/whisper/decode.py index 1b174ea6d..103f8d725 100755 --- a/egs/wenetspeech/ASR/whisper/decode.py +++ b/egs/wenetspeech/ASR/whisper/decode.py @@ -52,13 +52,13 @@ import k2 import torch import torch.nn as nn import whisper - from asr_datamodule import WenetSpeechAsrDataModule +from lhotse.cut import Cut from tn.chinese.normalizer import Normalizer from whisper.normalizers import BasicTextNormalizer from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward from zhconv import convert -from lhotse.cut import Cut + from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint from icefall.env import get_env_info from icefall.utils import ( diff --git a/egs/wenetspeech/ASR/whisper/train.py b/egs/wenetspeech/ASR/whisper/train.py index 2316a3c3b..b1e0253e6 100644 --- a/egs/wenetspeech/ASR/whisper/train.py +++ b/egs/wenetspeech/ASR/whisper/train.py @@ -834,6 +834,7 @@ def run(rank, world_size, args): # ) return False return True + train_cuts = wenetspeech.train_cuts() train_cuts = train_cuts.filter(remove_short_and_long_utt) train_dl = wenetspeech.train_dataloaders(train_cuts)