diff --git a/egs/multi_en/ASR/prepare.sh b/egs/multi_en/ASR/prepare.sh index 2e429fb3a..1f8235153 100755 --- a/egs/multi_en/ASR/prepare.sh +++ b/egs/multi_en/ASR/prepare.sh @@ -29,7 +29,7 @@ vocab_sizes=( multidataset=( "gigaspeech", "commonvoice", - "peoples_speech", + "librilight", ) # All files generated by this script are saved in "data". @@ -164,18 +164,18 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then cd ../.. fi - # People's Speech - if [[ "${multidataset[@]}" =~ "peoples_speech" ]] && [ ! -f data/fbank/.peoples_speech.done ]; then - log "Dataset: People's Speech" + # LibriLight + if [[ "${multidataset[@]}" =~ "librilight" ]] && [ ! -f data/fbank/.librilight.done ]; then + log "Dataset: LibriLight" cd data/fbank - if [ -f ../../../../peoples_speech/ASR/data/fbank/.peoples_speech_train.done ]; then - ln -svf $(realpath ../../../../peoples_speech/ASR/data/fbank/peoples_speech_train_split) . + if [ -f ../../../../librilight/ASR/data/fbank/.librilight_train.done ]; then + ln -svf $(realpath ../../../../librilight/ASR/data/fbank/librilight_train_split) . else - log "Abort! Please run ../../peoples_speech/ASR/prepare.sh --stage 5 --stop-stage 6" + log "Abort! Please run ../../librilight/ASR/prepare.sh --stage 5 --stop-stage 6" exit 1 fi - touch .peoples_speech.done + touch .librilight.done cd ../.. fi fi diff --git a/egs/multi_en/ASR/zipformer/ctc_decode.py b/egs/multi_en/ASR/zipformer/ctc_decode.py index 4db50b981..1f0f9bfac 100755 --- a/egs/multi_en/ASR/zipformer/ctc_decode.py +++ b/egs/multi_en/ASR/zipformer/ctc_decode.py @@ -88,7 +88,7 @@ import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from train import add_model_arguments, get_params, get_model +from train import add_model_arguments, get_model, get_params from icefall.checkpoint import ( average_checkpoints, diff --git a/egs/multi_en/ASR/zipformer/decode.py b/egs/multi_en/ASR/zipformer/decode.py index 352d82f4b..832d541b4 100755 --- a/egs/multi_en/ASR/zipformer/decode.py +++ b/egs/multi_en/ASR/zipformer/decode.py @@ -116,7 +116,8 @@ from beam_search import ( greedy_search_batch, modified_beam_search, ) -from train import add_model_arguments, get_params, get_model +from multidataset import MultiDataset +from train import add_model_arguments, get_model, get_params from icefall.checkpoint import ( average_checkpoints, @@ -782,6 +783,7 @@ def main(): # we need cut ids to display recognition results. args.return_cuts = True librispeech = LibriSpeechAsrDataModule(args) + multidataset = MultiDataset(args.manifest_dir) test_clean_cuts = librispeech.test_clean_cuts() test_other_cuts = librispeech.test_other_cuts() @@ -789,8 +791,30 @@ def main(): test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) test_other_dl = librispeech.test_dataloaders(test_other_cuts) - test_sets = ["test-clean", "test-other"] - test_dl = [test_clean_dl, test_other_dl] + test_cuts = multidataset.test_cuts() + + gigaspeech_dev_dl = librispeech.test_dataloaders(test_cuts[0]) + gigaspeech_test_dl = librispeech.test_dataloaders(test_cuts[1]) + commonvoice_dev_dl = librispeech.test_dataloaders(test_cuts[2]) + commonvoice_test_dl = librispeech.test_dataloaders(test_cuts[3]) + + test_sets = [ + "librispeech-test-clean", + "librispeech-test-other", + "gigaspeech-dev", + "gigaspeech-test", + "commonvoice-dev", + "commonvoice-test", + ] + + test_dl = [ + test_clean_dl, + test_other_dl, + gigaspeech_dev_dl, + gigaspeech_test_dl, + commonvoice_dev_dl, + commonvoice_test_dl, + ] for test_set, test_dl in zip(test_sets, test_dl): results_dict = decode_dataset( diff --git a/egs/multi_en/ASR/zipformer/export-onnx-streaming.py b/egs/multi_en/ASR/zipformer/export-onnx-streaming.py index 8cec09869..a2c97ccfa 100755 --- a/egs/multi_en/ASR/zipformer/export-onnx-streaming.py +++ b/egs/multi_en/ASR/zipformer/export-onnx-streaming.py @@ -76,7 +76,7 @@ import torch.nn as nn from decoder import Decoder from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled -from train import add_model_arguments, get_params, get_model +from train import add_model_arguments, get_model, get_params from zipformer import Zipformer2 from icefall.checkpoint import ( @@ -85,7 +85,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool, make_pad_mask +from icefall.utils import make_pad_mask, str2bool def get_parser(): @@ -182,7 +182,10 @@ class OnnxEncoder(nn.Module): """A wrapper for Zipformer and the encoder_proj from the joiner""" def __init__( - self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear + self, + encoder: Zipformer2, + encoder_embed: nn.Module, + encoder_proj: nn.Linear, ): """ Args: @@ -210,7 +213,11 @@ class OnnxEncoder(nn.Module): left_context_len = self.left_context_len cached_embed_left_pad = states[-2] - x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward( + ( + x, + x_lens, + new_cached_embed_left_pad, + ) = self.encoder_embed.streaming_forward( x=x, x_lens=x_lens, cached_left_pad=cached_embed_left_pad, diff --git a/egs/multi_en/ASR/zipformer/export-onnx.py b/egs/multi_en/ASR/zipformer/export-onnx.py index f5b01ce71..4db560127 100755 --- a/egs/multi_en/ASR/zipformer/export-onnx.py +++ b/egs/multi_en/ASR/zipformer/export-onnx.py @@ -74,7 +74,7 @@ import torch.nn as nn from decoder import Decoder from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled -from train import add_model_arguments, get_params, get_model +from train import add_model_arguments, get_model, get_params from zipformer import Zipformer2 from icefall.checkpoint import ( @@ -83,7 +83,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool, make_pad_mask +from icefall.utils import make_pad_mask, str2bool def get_parser(): @@ -180,7 +180,10 @@ class OnnxEncoder(nn.Module): """A wrapper for Zipformer and the encoder_proj from the joiner""" def __init__( - self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear + self, + encoder: Zipformer2, + encoder_embed: nn.Module, + encoder_proj: nn.Linear, ): """ Args: diff --git a/egs/multi_en/ASR/zipformer/export.py b/egs/multi_en/ASR/zipformer/export.py index f9036f443..8f7ba5e8c 100755 --- a/egs/multi_en/ASR/zipformer/export.py +++ b/egs/multi_en/ASR/zipformer/export.py @@ -160,8 +160,9 @@ from typing import List, Tuple import sentencepiece as spm import torch +from scaling_converter import convert_scaled_to_non_scaled from torch import Tensor, nn -from train import add_model_arguments, get_params, get_model +from train import add_model_arguments, get_model, get_params from icefall.checkpoint import ( average_checkpoints, @@ -170,7 +171,6 @@ from icefall.checkpoint import ( load_checkpoint, ) from icefall.utils import make_pad_mask, str2bool -from scaling_converter import convert_scaled_to_non_scaled def get_parser(): @@ -315,7 +315,11 @@ class StreamingEncoderModel(nn.Module): left_context_len = self.left_context_len cached_embed_left_pad = states[-2] - x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward( + ( + x, + x_lens, + new_cached_embed_left_pad, + ) = self.encoder_embed.streaming_forward( x=features, x_lens=feature_lengths, cached_left_pad=cached_embed_left_pad, diff --git a/egs/multi_en/ASR/zipformer/multidataset.py b/egs/multi_en/ASR/zipformer/multidataset.py index 798aa27ba..f3f7b883d 100644 --- a/egs/multi_en/ASR/zipformer/multidataset.py +++ b/egs/multi_en/ASR/zipformer/multidataset.py @@ -71,30 +71,57 @@ class MultiDataset: self.manifest_dir / f"cv-en_cuts_train.jsonl.gz" ) - # People's Speech - sorted_filenames = sorted( - glob.glob( - f"{self.manifest_dir}/peoples_speech_train_split/peoples_speech_cuts_*[yna].*.jsonl.gz" - ) + # LibriHeavy + logging.info("Loading LibriHeavy in lazy mode") + libriheavy_small_cuts = load_manifest_lazy( + self.manifest_dir / "libriheavy_cuts_train_small.jsonl.gz" ) - - logging.info( - f"Loading People's Speech {len(sorted_filenames)} splits in lazy mode" - ) - - peoples_speech_cuts = lhotse.combine( - lhotse.load_manifest_lazy(p) for p in sorted_filenames + libriheavy_medium_cuts = load_manifest_lazy( + self.manifest_dir / "libriheavy_cuts_train_medium.jsonl.gz" ) + libriheavy_cuts = lhotse.combine(libriheavy_small_cuts, libriheavy_medium_cuts) return CutSet.mux( librispeech_cuts, gigaspeech_cuts, commonvoice_cuts, - peoples_speech_cuts, + libriheavy_cuts, weights=[ len(librispeech_cuts), len(gigaspeech_cuts), len(commonvoice_cuts), - len(peoples_speech_cuts), + len(libriheavy_cuts), ], ) + + def test_cuts(self) -> CutSet: + logging.info("About to get multidataset test cuts") + + # GigaSpeech + logging.info("Loading GigaSpeech DEV in lazy mode") + gigaspeech_dev_cuts = load_manifest_lazy( + self.manifest_dir / "cuts_DEV.jsonl.gz" + ) + + logging.info("Loading GigaSpeech TEST in lazy mode") + gigaspeech_test_cuts = load_manifest_lazy( + self.manifest_dir / "cuts_TEST.jsonl.gz" + ) + + # CommonVoice + logging.info("Loading CommonVoice DEV in lazy mode") + commonvoice_dev_cuts = load_manifest_lazy( + self.manifest_dir / "cv-en_cuts_dev.jsonl.gz" + ) + + logging.info("Loading CommonVoice TEST in lazy mode") + commonvoice_test_cuts = load_manifest_lazy( + self.manifest_dir / "cv-en_cuts_test.jsonl.gz" + ) + + return [ + gigaspeech_dev_cuts, + gigaspeech_test_cuts, + commonvoice_dev_cuts, + commonvoice_test_cuts, + ] diff --git a/egs/multi_en/ASR/zipformer/streaming_decode.py b/egs/multi_en/ASR/zipformer/streaming_decode.py index c079ff5a7..30348d7e6 100755 --- a/egs/multi_en/ASR/zipformer/streaming_decode.py +++ b/egs/multi_en/ASR/zipformer/streaming_decode.py @@ -51,7 +51,7 @@ from streaming_beam_search import ( ) from torch import Tensor, nn from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_model +from train import add_model_arguments, get_model, get_params from icefall.checkpoint import ( average_checkpoints, @@ -374,7 +374,11 @@ def streaming_forward( Returns encoder outputs, output lengths, and updated states. """ cached_embed_left_pad = states[-2] - (x, x_lens, new_cached_embed_left_pad) = model.encoder_embed.streaming_forward( + ( + x, + x_lens, + new_cached_embed_left_pad, + ) = model.encoder_embed.streaming_forward( x=features, x_lens=feature_lens, cached_left_pad=cached_embed_left_pad, diff --git a/egs/multi_en/ASR/zipformer/train.py b/egs/multi_en/ASR/zipformer/train.py index a42613a27..a32a53db3 100755 --- a/egs/multi_en/ASR/zipformer/train.py +++ b/egs/multi_en/ASR/zipformer/train.py @@ -66,13 +66,13 @@ import torch import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from multidataset import MultiDataset from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import AsrModel +from multidataset import MultiDataset from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling @@ -344,7 +344,7 @@ def get_parser(): parser.add_argument( "--lr-hours", type=float, - default=5000, + default=70000, help="""Number of hours that affects how rapidly the learning rate decreases. """, ) @@ -1052,7 +1052,9 @@ def train_one_epoch( tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if params.use_fp16: tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, ) if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: @@ -1387,5 +1389,6 @@ def main(): torch.set_num_threads(1) torch.set_num_interop_threads(1) + if __name__ == "__main__": main()