diff --git a/egs/multi_ja_en/ASR/local/utils/asr_datamodule.py b/egs/multi_ja_en/ASR/local/utils/asr_datamodule.py index be18e65c1..98d246985 100644 --- a/egs/multi_ja_en/ASR/local/utils/asr_datamodule.py +++ b/egs/multi_ja_en/ASR/local/utils/asr_datamodule.py @@ -39,7 +39,7 @@ from torch.utils.data import DataLoader from icefall.utils import str2bool -class ReazonSpeechAsrDataModule: +class MultiDatasetAsrDataModule: """ DataModule for k2 ASR experiments. It assumes there is always one train and valid dataloader, @@ -333,23 +333,23 @@ class ReazonSpeechAsrDataModule: ) return test_dl - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - return load_manifest_lazy( - self.args.manifest_dir / "reazonspeech_cuts_train.jsonl.gz" - ) + # @lru_cache() + # def train_cuts(self) -> CutSet: + # logging.info("About to get train cuts") + # return load_manifest_lazy( + # self.args.manifest_dir / "reazonspeech_cuts_train.jsonl.gz" + # ) - @lru_cache() - def valid_cuts(self) -> CutSet: - logging.info("About to get dev cuts") - return load_manifest_lazy( - self.args.manifest_dir / "reazonspeech_cuts_dev.jsonl.gz" - ) + # @lru_cache() + # def valid_cuts(self) -> CutSet: + # logging.info("About to get dev cuts") + # return load_manifest_lazy( + # self.args.manifest_dir / "reazonspeech_cuts_dev.jsonl.gz" + # ) - @lru_cache() - def test_cuts(self) -> List[CutSet]: - logging.info("About to get test cuts") - return load_manifest_lazy( - self.args.manifest_dir / "reazonspeech_cuts_test.jsonl.gz" - ) + # @lru_cache() + # def test_cuts(self) -> List[CutSet]: + # logging.info("About to get test cuts") + # return load_manifest_lazy( + # self.args.manifest_dir / "reazonspeech_cuts_test.jsonl.gz" + # ) diff --git a/egs/multi_ja_en/ASR/zipformer/decode.py b/egs/multi_ja_en/ASR/zipformer/decode.py index 26ce3e018..37cf39ddd 100755 --- a/egs/multi_ja_en/ASR/zipformer/decode.py +++ b/egs/multi_ja_en/ASR/zipformer/decode.py @@ -68,7 +68,7 @@ import k2 import sentencepiece as spm import torch import torch.nn as nn -from asr_datamodule import ReazonSpeechAsrDataModule +from asr_datamodule import MultiDatasetAsrDataModule from beam_search import ( beam_search, fast_beam_search_nbest, @@ -573,7 +573,7 @@ def save_results( @torch.no_grad() def main(): parser = get_parser() - ReazonSpeechAsrDataModule.add_arguments(parser) + MultiDatasetAsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -748,7 +748,7 @@ def main(): # we need cut ids to display recognition results. args.return_cuts = True - data_module = ReazonSpeechAsrDataModule(args) + data_module = MultiDatasetAsrDataModule(args) multi_dataset = MultiDataset(args) def remove_short_utt(c: Cut): diff --git a/egs/multi_ja_en/ASR/zipformer/do_not_use_it_directly.py b/egs/multi_ja_en/ASR/zipformer/do_not_use_it_directly.py index 072679cfc..32e6380eb 100755 --- a/egs/multi_ja_en/ASR/zipformer/do_not_use_it_directly.py +++ b/egs/multi_ja_en/ASR/zipformer/do_not_use_it_directly.py @@ -57,7 +57,7 @@ import optim import torch import torch.multiprocessing as mp import torch.nn as nn -from asr_datamodule import ReazonSpeechAsrDataModule +from asr_datamodule import MultiDatasetAsrDataModule from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut @@ -1085,8 +1085,8 @@ def run(rank, world_size, args): return True - reazonspeech_corpus = ReazonSpeechAsrDataModule(args) - train_cuts = reazonspeech_corpus.train_cuts() + multidataset_datamodule = MultiDatasetAsrDataModule(args) + train_cuts = multidataset_datamodule.train_cuts() train_cuts = train_cuts.filter(remove_short_and_long_utt) @@ -1097,12 +1097,12 @@ def run(rank, world_size, args): else: sampler_state_dict = None - train_dl = reazonspeech_corpus.train_dataloaders( + train_dl = multidataset_datamodule.train_dataloaders( train_cuts, sampler_state_dict=sampler_state_dict ) - valid_cuts = reazonspeech_corpus.valid_cuts() - valid_dl = reazonspeech_corpus.valid_dataloaders(valid_cuts) + valid_cuts = multidataset_datamodule.valid_cuts() + valid_dl = multidataset_datamodule.valid_dataloaders(valid_cuts) if params.start_batch <= 0 and not params.print_diagnostics: scan_pessimistic_batches_for_oom( @@ -1242,7 +1242,7 @@ def scan_pessimistic_batches_for_oom( def main(): raise RuntimeError("Please don't use this file directly!") parser = get_parser() - ReazonSpeechAsrDataModule.add_arguments(parser) + MultiDatasetAsrDataModule.add_arguments(parser) Tokenizer.add_arguments(parser) args = parser.parse_args() diff --git a/egs/multi_ja_en/ASR/zipformer/multi_dataset.py b/egs/multi_ja_en/ASR/zipformer/multi_dataset.py index b0cdc1f6a..171dccf5b 100644 --- a/egs/multi_ja_en/ASR/zipformer/multi_dataset.py +++ b/egs/multi_ja_en/ASR/zipformer/multi_dataset.py @@ -13,36 +13,36 @@ class MultiDataset: Args: manifest_dir: It is expected to contain the following files: - - reazonspeech_cuts_train.jsonl.gz - - librispeech_cuts_train-clean-100.jsonl.gz - - librispeech_cuts_train-clean-360.jsonl.gz - - librispeech_cuts_train-other-500.jsonl.gz + - mls_english/ + - mls_eng_cuts_train.jsonl.gz + - mls_eng_cuts_dev.jsonl.gz + - mls_eng_cuts_test.jsonl.gz + - reazonspeech/ + - reazonspeech_cuts_train.jsonl.gz + - reazonspeech_cuts_dev.jsonl.gz + - reazonspeech_cuts_test.jsonl.gz """ - self.fbank_dir = Path(args.manifest_dir) + self.manifest_dir = Path(args.manifest_dir) def train_cuts(self) -> CutSet: logging.info("About to get multidataset train cuts") - logging.info("Loading Reazonspeech in lazy mode") - reazonspeech_cuts = load_manifest_lazy( - self.fbank_dir / "reazonspeech_cuts_train.jsonl.gz" + logging.info("Loading Reazonspeech TRAIN set in lazy mode") + reazonspeech_train_cuts = load_manifest_lazy( + self.manifest_dir / "reazonspeech_cuts_train.jsonl.gz" ) - logging.info("Loading LibriSpeech in lazy mode") - train_clean_100_cuts = self.train_clean_100_cuts() - train_clean_360_cuts = self.train_clean_360_cuts() - train_other_500_cuts = self.train_other_500_cuts() + logging.info("Loading MLS English TRAIN set in lazy mode") + mls_eng_train_cuts = load_manifest_lazy( + self.manifest_dir / "mls_eng_cuts_train.jsonl.gz" + ) return CutSet.mux( - reazonspeech_cuts, - train_clean_100_cuts, - train_clean_360_cuts, - train_other_500_cuts, + reazonspeech_train_cuts, + mls_eng_train_cuts, weights=[ - len(reazonspeech_cuts), - len(train_clean_100_cuts), - len(train_clean_360_cuts), - len(train_other_500_cuts), + len(reazonspeech_train_cuts), + len(mls_eng_train_cuts), ], ) @@ -51,93 +51,90 @@ class MultiDataset: logging.info("Loading Reazonspeech DEV set in lazy mode") reazonspeech_dev_cuts = load_manifest_lazy( - self.fbank_dir / "reazonspeech_cuts_dev.jsonl.gz" + self.manifest_dir / "reazonspeech_cuts_dev.jsonl.gz" ) - logging.info("Loading LibriSpeech DEV set in lazy mode") - dev_clean_cuts = self.dev_clean_cuts() - dev_other_cuts = self.dev_other_cuts() + logging.info("Loading MLS English DEV set in lazy mode") + mls_eng_dev_cuts = load_manifest_lazy( + self.manifest_dir / "mls_eng_cuts_dev.jsonl.gz" + ) return CutSet.mux( reazonspeech_dev_cuts, - dev_clean_cuts, - dev_other_cuts, + mls_eng_dev_cuts, weights=[ len(reazonspeech_dev_cuts), - len(dev_clean_cuts), - len(dev_other_cuts), + len(mls_eng_dev_cuts), ], ) - def test_cuts(self) -> Dict[str, CutSet]: + def test_cuts(self) -> CutSet: logging.info("About to get multidataset test cuts") - logging.info("Loading Reazonspeech set in lazy mode") + logging.info("Loading Reazonspeech TEST set in lazy mode") reazonspeech_test_cuts = load_manifest_lazy( - self.fbank_dir / "reazonspeech_cuts_test.jsonl.gz" - ) - reazonspeech_dev_cuts = load_manifest_lazy( - self.fbank_dir / "reazonspeech_cuts_dev.jsonl.gz" + self.manifest_dir / "reazonspeech_cuts_test.jsonl.gz" ) - logging.info("Loading LibriSpeech set in lazy mode") - test_clean_cuts = self.test_clean_cuts() - test_other_cuts = self.test_other_cuts() - - test_cuts = { - "reazonspeech_test": reazonspeech_test_cuts, - "reazonspeech_dev": reazonspeech_dev_cuts, - "librispeech_test_clean": test_clean_cuts, - "librispeech_test_other": test_other_cuts, - } - - return test_cuts - - @lru_cache() - def train_clean_100_cuts(self) -> CutSet: - logging.info("About to get train-clean-100 cuts") - return load_manifest_lazy( - self.fbank_dir / "librispeech_cuts_train-clean-100.jsonl.gz" + logging.info("Loading MLS English TEST set in lazy mode") + mls_eng_test_cuts = load_manifest_lazy( + self.manifest_dir / "mls_eng_cuts_test.jsonl.gz" ) - @lru_cache() - def train_clean_360_cuts(self) -> CutSet: - logging.info("About to get train-clean-360 cuts") - return load_manifest_lazy( - self.fbank_dir / "librispeech_cuts_train-clean-360.jsonl.gz" + return CutSet.mux( + reazonspeech_test_cuts, + mls_eng_test_cuts, + weights=[ + len(reazonspeech_test_cuts), + len(mls_eng_test_cuts), + ], ) - @lru_cache() - def train_other_500_cuts(self) -> CutSet: - logging.info("About to get train-other-500 cuts") - return load_manifest_lazy( - self.fbank_dir / "librispeech_cuts_train-other-500.jsonl.gz" - ) + # @lru_cache() + # def train_clean_100_cuts(self) -> CutSet: + # logging.info("About to get train-clean-100 cuts") + # return load_manifest_lazy( + # self.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz" + # ) - @lru_cache() - def dev_clean_cuts(self) -> CutSet: - logging.info("About to get dev-clean cuts") - return load_manifest_lazy( - self.fbank_dir / "librispeech_cuts_dev-clean.jsonl.gz" - ) + # @lru_cache() + # def train_clean_360_cuts(self) -> CutSet: + # logging.info("About to get train-clean-360 cuts") + # return load_manifest_lazy( + # self.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz" + # ) - @lru_cache() - def dev_other_cuts(self) -> CutSet: - logging.info("About to get dev-other cuts") - return load_manifest_lazy( - self.fbank_dir / "librispeech_cuts_dev-other.jsonl.gz" - ) + # @lru_cache() + # def train_other_500_cuts(self) -> CutSet: + # logging.info("About to get train-other-500 cuts") + # return load_manifest_lazy( + # self.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz" + # ) - @lru_cache() - def test_clean_cuts(self) -> CutSet: - logging.info("About to get test-clean cuts") - return load_manifest_lazy( - self.fbank_dir / "librispeech_cuts_test-clean.jsonl.gz" - ) + # @lru_cache() + # def dev_clean_cuts(self) -> CutSet: + # logging.info("About to get dev-clean cuts") + # return load_manifest_lazy( + # self.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz" + # ) - @lru_cache() - def test_other_cuts(self) -> CutSet: - logging.info("About to get test-other cuts") - return load_manifest_lazy( - self.fbank_dir / "librispeech_cuts_test-other.jsonl.gz" - ) + # @lru_cache() + # def dev_other_cuts(self) -> CutSet: + # logging.info("About to get dev-other cuts") + # return load_manifest_lazy( + # self.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz" + # ) + + # @lru_cache() + # def test_clean_cuts(self) -> CutSet: + # logging.info("About to get test-clean cuts") + # return load_manifest_lazy( + # self.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" + # ) + + # @lru_cache() + # def test_other_cuts(self) -> CutSet: + # logging.info("About to get test-other cuts") + # return load_manifest_lazy( + # self.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" + # ) diff --git a/egs/multi_ja_en/ASR/zipformer/streaming_decode.py b/egs/multi_ja_en/ASR/zipformer/streaming_decode.py index 935f86de1..e1869d784 100755 --- a/egs/multi_ja_en/ASR/zipformer/streaming_decode.py +++ b/egs/multi_ja_en/ASR/zipformer/streaming_decode.py @@ -63,7 +63,7 @@ import k2 import numpy as np import sentencepiece as spm import torch -from asr_datamodule import ReazonSpeechAsrDataModule +from asr_datamodule import MultiDatasetAsrDataModule from decode_stream import DecodeStream from kaldifeat import Fbank, FbankOptions from lhotse import CutSet @@ -740,7 +740,7 @@ def save_results( @torch.no_grad() def main(): parser = get_parser() - ReazonSpeechAsrDataModule.add_arguments(parser) + MultiDatasetAsrDataModule.add_arguments(parser) Tokenizer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -887,7 +887,7 @@ def main(): # we need cut ids to display recognition results. args.return_cuts = True - reazonspeech_corpus = ReazonSpeechAsrDataModule(args) + multidataset_datamodule = MultiDatasetAsrDataModule(args) if params.bilingual: multi_dataset = MultiDataset(args) @@ -904,8 +904,8 @@ def main(): test_sets = test_sets_cuts.keys() test_cuts = [test_sets_cuts[k] for k in test_sets] - valid_cuts = reazonspeech_corpus.valid_cuts() - test_cuts = reazonspeech_corpus.test_cuts() + valid_cuts = multidataset_datamodule.valid_cuts() + test_cuts = multidataset_datamodule.test_cuts() test_sets = ["valid", "test"] test_cuts = [valid_cuts, test_cuts] diff --git a/egs/multi_ja_en/ASR/zipformer/train.py b/egs/multi_ja_en/ASR/zipformer/train.py index bfb037f50..c3664f7f5 100755 --- a/egs/multi_ja_en/ASR/zipformer/train.py +++ b/egs/multi_ja_en/ASR/zipformer/train.py @@ -66,7 +66,7 @@ import sentencepiece as spm import torch import torch.multiprocessing as mp import torch.nn as nn -from asr_datamodule import ReazonSpeechAsrDataModule +from asr_datamodule import MultiDatasetAsrDataModule from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut @@ -272,7 +272,7 @@ def get_parser(): parser.add_argument( "--bilingual", type=str2bool, - default=False, + default=True, help="Whether the model is bilingual or not. 1 = bilingual.", ) @@ -804,7 +804,8 @@ def compute_loss( texts = batch["supervisions"]["text"] if not params.bilingual: - y = tokenizer.encode(texts, out_type=int) + assert NotImplementedError("only bilingual training has been implemented") + # y = tokenizer.encode(texts, out_type=int) else: y = sentencepiece_processor.encode(texts, out_type=int) y = k2.RaggedTensor(y) @@ -1147,9 +1148,10 @@ def run(rank, world_size, args): # is defined in local/prepare_lang_char.py if not params.bilingual: - tokenizer = Tokenizer.load(args.lang, args.lang_type) - params.blank_id = tokenizer.piece_to_id("") - params.vocab_size = tokenizer.get_piece_size() + assert NotImplementedError("only bilingual training has been implemented") + # tokenizer = Tokenizer.load(args.lang, args.lang_type) + # params.blank_id = tokenizer.piece_to_id("") + # params.vocab_size = tokenizer.get_piece_size() else: sentencepiece_processor = spm.SentencePieceProcessor() sentencepiece_processor.load(params.bpe_model) @@ -1212,12 +1214,13 @@ def run(rank, world_size, args): if params.inf_check: register_inf_check_hooks(model) - reazonspeech_corpus = ReazonSpeechAsrDataModule(args) + multidataset_datamodule = MultiDatasetAsrDataModule(args) if params.bilingual: multi_dataset = MultiDataset(args) train_cuts = multi_dataset.train_cuts() else: - train_cuts = reazonspeech_corpus.train_cuts() + assert NotImplementedError("only bilingual training has been implemented") + # train_cuts = reazonspeech_corpus.train_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds @@ -1242,6 +1245,7 @@ def run(rank, world_size, args): # for subsampling T = ((c.num_samples - 7) // 2 + 1) // 2 if not params.bilingual: + assert NotImplementedError("only bilingual training has been implemented") tokens = tokenizer.encode(c.supervisions[0].text, out_type=str) else: tokens = sentencepiece_processor.encode( @@ -1272,6 +1276,8 @@ def run(rank, world_size, args): if params.bilingual: train_cuts = train_cuts.map(tokenize_and_encode_text) + else: + assert NotImplementedError("only bilingual training has been implemented") if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # We only load the sampler's state dict when it loads a checkpoint @@ -1280,15 +1286,20 @@ def run(rank, world_size, args): else: sampler_state_dict = None - train_dl = reazonspeech_corpus.train_dataloaders( + # train_dl = reazonspeech_corpus.train_dataloaders( + # train_cuts, sampler_state_dict=sampler_state_dict + # ) + train_dl = multidataset_datamodule.train_dataloaders( train_cuts, sampler_state_dict=sampler_state_dict ) if params.bilingual: - valid_cuts = reazonspeech_corpus.valid_cuts() - else: valid_cuts = multi_dataset.dev_cuts() - valid_dl = reazonspeech_corpus.valid_dataloaders(valid_cuts) + else: + assert NotImplementedError("only bilingual training has been implemented") + # valid_cuts = multi_dataset.dev_cuts() + + valid_dl = multidataset_datamodule.valid_dataloaders(valid_cuts) if not params.print_diagnostics: scan_pessimistic_batches_for_oom( @@ -1386,7 +1397,8 @@ def display_and_save_batch( if params.bilingual: y = sentencepiece_processor.encode(supervisions["text"], out_type=int) else: - y = tokenizer.encode(supervisions["text"], out_type=int) + assert NotImplementedError("only bilingual training has been implemented") + # y = tokenizer.encode(supervisions["text"], out_type=int) num_tokens = sum(len(i) for i in y) logging.info(f"num tokens: {num_tokens}") @@ -1442,7 +1454,7 @@ def scan_pessimistic_batches_for_oom( def main(): parser = get_parser() - ReazonSpeechAsrDataModule.add_arguments(parser) + MultiDatasetAsrDataModule.add_arguments(parser) Tokenizer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir)