deprecate params.bilingual=0, replace ReazonSpeechAsrDataModule for MultiDatasetAsrDataModule, not tested yet

This commit is contained in:
Kinan Martin 2025-05-14 08:40:15 +09:00
parent b2df5bbb83
commit eb5004880f
6 changed files with 144 additions and 135 deletions

View File

@ -39,7 +39,7 @@ from torch.utils.data import DataLoader
from icefall.utils import str2bool from icefall.utils import str2bool
class ReazonSpeechAsrDataModule: class MultiDatasetAsrDataModule:
""" """
DataModule for k2 ASR experiments. DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader, It assumes there is always one train and valid dataloader,
@ -333,23 +333,23 @@ class ReazonSpeechAsrDataModule:
) )
return test_dl return test_dl
@lru_cache() # @lru_cache()
def train_cuts(self) -> CutSet: # def train_cuts(self) -> CutSet:
logging.info("About to get train cuts") # logging.info("About to get train cuts")
return load_manifest_lazy( # return load_manifest_lazy(
self.args.manifest_dir / "reazonspeech_cuts_train.jsonl.gz" # self.args.manifest_dir / "reazonspeech_cuts_train.jsonl.gz"
) # )
@lru_cache() # @lru_cache()
def valid_cuts(self) -> CutSet: # def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts") # logging.info("About to get dev cuts")
return load_manifest_lazy( # return load_manifest_lazy(
self.args.manifest_dir / "reazonspeech_cuts_dev.jsonl.gz" # self.args.manifest_dir / "reazonspeech_cuts_dev.jsonl.gz"
) # )
@lru_cache() # @lru_cache()
def test_cuts(self) -> List[CutSet]: # def test_cuts(self) -> List[CutSet]:
logging.info("About to get test cuts") # logging.info("About to get test cuts")
return load_manifest_lazy( # return load_manifest_lazy(
self.args.manifest_dir / "reazonspeech_cuts_test.jsonl.gz" # self.args.manifest_dir / "reazonspeech_cuts_test.jsonl.gz"
) # )

View File

@ -68,7 +68,7 @@ import k2
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import ReazonSpeechAsrDataModule from asr_datamodule import MultiDatasetAsrDataModule
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search_nbest, fast_beam_search_nbest,
@ -573,7 +573,7 @@ def save_results(
@torch.no_grad() @torch.no_grad()
def main(): def main():
parser = get_parser() parser = get_parser()
ReazonSpeechAsrDataModule.add_arguments(parser) MultiDatasetAsrDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
@ -748,7 +748,7 @@ def main():
# we need cut ids to display recognition results. # we need cut ids to display recognition results.
args.return_cuts = True args.return_cuts = True
data_module = ReazonSpeechAsrDataModule(args) data_module = MultiDatasetAsrDataModule(args)
multi_dataset = MultiDataset(args) multi_dataset = MultiDataset(args)
def remove_short_utt(c: Cut): def remove_short_utt(c: Cut):

View File

@ -57,7 +57,7 @@ import optim
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import ReazonSpeechAsrDataModule from asr_datamodule import MultiDatasetAsrDataModule
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
from lhotse.cut import Cut from lhotse.cut import Cut
@ -1085,8 +1085,8 @@ def run(rank, world_size, args):
return True return True
reazonspeech_corpus = ReazonSpeechAsrDataModule(args) multidataset_datamodule = MultiDatasetAsrDataModule(args)
train_cuts = reazonspeech_corpus.train_cuts() train_cuts = multidataset_datamodule.train_cuts()
train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.filter(remove_short_and_long_utt)
@ -1097,12 +1097,12 @@ def run(rank, world_size, args):
else: else:
sampler_state_dict = None sampler_state_dict = None
train_dl = reazonspeech_corpus.train_dataloaders( train_dl = multidataset_datamodule.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict train_cuts, sampler_state_dict=sampler_state_dict
) )
valid_cuts = reazonspeech_corpus.valid_cuts() valid_cuts = multidataset_datamodule.valid_cuts()
valid_dl = reazonspeech_corpus.valid_dataloaders(valid_cuts) valid_dl = multidataset_datamodule.valid_dataloaders(valid_cuts)
if params.start_batch <= 0 and not params.print_diagnostics: if params.start_batch <= 0 and not params.print_diagnostics:
scan_pessimistic_batches_for_oom( scan_pessimistic_batches_for_oom(
@ -1242,7 +1242,7 @@ def scan_pessimistic_batches_for_oom(
def main(): def main():
raise RuntimeError("Please don't use this file directly!") raise RuntimeError("Please don't use this file directly!")
parser = get_parser() parser = get_parser()
ReazonSpeechAsrDataModule.add_arguments(parser) MultiDatasetAsrDataModule.add_arguments(parser)
Tokenizer.add_arguments(parser) Tokenizer.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()

View File

@ -13,36 +13,36 @@ class MultiDataset:
Args: Args:
manifest_dir: manifest_dir:
It is expected to contain the following files: It is expected to contain the following files:
- reazonspeech_cuts_train.jsonl.gz - mls_english/
- librispeech_cuts_train-clean-100.jsonl.gz - mls_eng_cuts_train.jsonl.gz
- librispeech_cuts_train-clean-360.jsonl.gz - mls_eng_cuts_dev.jsonl.gz
- librispeech_cuts_train-other-500.jsonl.gz - mls_eng_cuts_test.jsonl.gz
- reazonspeech/
- reazonspeech_cuts_train.jsonl.gz
- reazonspeech_cuts_dev.jsonl.gz
- reazonspeech_cuts_test.jsonl.gz
""" """
self.fbank_dir = Path(args.manifest_dir) self.manifest_dir = Path(args.manifest_dir)
def train_cuts(self) -> CutSet: def train_cuts(self) -> CutSet:
logging.info("About to get multidataset train cuts") logging.info("About to get multidataset train cuts")
logging.info("Loading Reazonspeech in lazy mode") logging.info("Loading Reazonspeech TRAIN set in lazy mode")
reazonspeech_cuts = load_manifest_lazy( reazonspeech_train_cuts = load_manifest_lazy(
self.fbank_dir / "reazonspeech_cuts_train.jsonl.gz" self.manifest_dir / "reazonspeech_cuts_train.jsonl.gz"
) )
logging.info("Loading LibriSpeech in lazy mode") logging.info("Loading MLS English TRAIN set in lazy mode")
train_clean_100_cuts = self.train_clean_100_cuts() mls_eng_train_cuts = load_manifest_lazy(
train_clean_360_cuts = self.train_clean_360_cuts() self.manifest_dir / "mls_eng_cuts_train.jsonl.gz"
train_other_500_cuts = self.train_other_500_cuts() )
return CutSet.mux( return CutSet.mux(
reazonspeech_cuts, reazonspeech_train_cuts,
train_clean_100_cuts, mls_eng_train_cuts,
train_clean_360_cuts,
train_other_500_cuts,
weights=[ weights=[
len(reazonspeech_cuts), len(reazonspeech_train_cuts),
len(train_clean_100_cuts), len(mls_eng_train_cuts),
len(train_clean_360_cuts),
len(train_other_500_cuts),
], ],
) )
@ -51,93 +51,90 @@ class MultiDataset:
logging.info("Loading Reazonspeech DEV set in lazy mode") logging.info("Loading Reazonspeech DEV set in lazy mode")
reazonspeech_dev_cuts = load_manifest_lazy( reazonspeech_dev_cuts = load_manifest_lazy(
self.fbank_dir / "reazonspeech_cuts_dev.jsonl.gz" self.manifest_dir / "reazonspeech_cuts_dev.jsonl.gz"
) )
logging.info("Loading LibriSpeech DEV set in lazy mode") logging.info("Loading MLS English DEV set in lazy mode")
dev_clean_cuts = self.dev_clean_cuts() mls_eng_dev_cuts = load_manifest_lazy(
dev_other_cuts = self.dev_other_cuts() self.manifest_dir / "mls_eng_cuts_dev.jsonl.gz"
)
return CutSet.mux( return CutSet.mux(
reazonspeech_dev_cuts, reazonspeech_dev_cuts,
dev_clean_cuts, mls_eng_dev_cuts,
dev_other_cuts,
weights=[ weights=[
len(reazonspeech_dev_cuts), len(reazonspeech_dev_cuts),
len(dev_clean_cuts), len(mls_eng_dev_cuts),
len(dev_other_cuts),
], ],
) )
def test_cuts(self) -> Dict[str, CutSet]: def test_cuts(self) -> CutSet:
logging.info("About to get multidataset test cuts") logging.info("About to get multidataset test cuts")
logging.info("Loading Reazonspeech set in lazy mode") logging.info("Loading Reazonspeech TEST set in lazy mode")
reazonspeech_test_cuts = load_manifest_lazy( reazonspeech_test_cuts = load_manifest_lazy(
self.fbank_dir / "reazonspeech_cuts_test.jsonl.gz" self.manifest_dir / "reazonspeech_cuts_test.jsonl.gz"
)
reazonspeech_dev_cuts = load_manifest_lazy(
self.fbank_dir / "reazonspeech_cuts_dev.jsonl.gz"
) )
logging.info("Loading LibriSpeech set in lazy mode") logging.info("Loading MLS English TEST set in lazy mode")
test_clean_cuts = self.test_clean_cuts() mls_eng_test_cuts = load_manifest_lazy(
test_other_cuts = self.test_other_cuts() self.manifest_dir / "mls_eng_cuts_test.jsonl.gz"
test_cuts = {
"reazonspeech_test": reazonspeech_test_cuts,
"reazonspeech_dev": reazonspeech_dev_cuts,
"librispeech_test_clean": test_clean_cuts,
"librispeech_test_other": test_other_cuts,
}
return test_cuts
@lru_cache()
def train_clean_100_cuts(self) -> CutSet:
logging.info("About to get train-clean-100 cuts")
return load_manifest_lazy(
self.fbank_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
) )
@lru_cache() return CutSet.mux(
def train_clean_360_cuts(self) -> CutSet: reazonspeech_test_cuts,
logging.info("About to get train-clean-360 cuts") mls_eng_test_cuts,
return load_manifest_lazy( weights=[
self.fbank_dir / "librispeech_cuts_train-clean-360.jsonl.gz" len(reazonspeech_test_cuts),
len(mls_eng_test_cuts),
],
) )
@lru_cache() # @lru_cache()
def train_other_500_cuts(self) -> CutSet: # def train_clean_100_cuts(self) -> CutSet:
logging.info("About to get train-other-500 cuts") # logging.info("About to get train-clean-100 cuts")
return load_manifest_lazy( # return load_manifest_lazy(
self.fbank_dir / "librispeech_cuts_train-other-500.jsonl.gz" # self.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
) # )
@lru_cache() # @lru_cache()
def dev_clean_cuts(self) -> CutSet: # def train_clean_360_cuts(self) -> CutSet:
logging.info("About to get dev-clean cuts") # logging.info("About to get train-clean-360 cuts")
return load_manifest_lazy( # return load_manifest_lazy(
self.fbank_dir / "librispeech_cuts_dev-clean.jsonl.gz" # self.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
) # )
@lru_cache() # @lru_cache()
def dev_other_cuts(self) -> CutSet: # def train_other_500_cuts(self) -> CutSet:
logging.info("About to get dev-other cuts") # logging.info("About to get train-other-500 cuts")
return load_manifest_lazy( # return load_manifest_lazy(
self.fbank_dir / "librispeech_cuts_dev-other.jsonl.gz" # self.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
) # )
@lru_cache() # @lru_cache()
def test_clean_cuts(self) -> CutSet: # def dev_clean_cuts(self) -> CutSet:
logging.info("About to get test-clean cuts") # logging.info("About to get dev-clean cuts")
return load_manifest_lazy( # return load_manifest_lazy(
self.fbank_dir / "librispeech_cuts_test-clean.jsonl.gz" # self.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
) # )
@lru_cache() # @lru_cache()
def test_other_cuts(self) -> CutSet: # def dev_other_cuts(self) -> CutSet:
logging.info("About to get test-other cuts") # logging.info("About to get dev-other cuts")
return load_manifest_lazy( # return load_manifest_lazy(
self.fbank_dir / "librispeech_cuts_test-other.jsonl.gz" # self.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
) # )
# @lru_cache()
# def test_clean_cuts(self) -> CutSet:
# logging.info("About to get test-clean cuts")
# return load_manifest_lazy(
# self.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
# )
# @lru_cache()
# def test_other_cuts(self) -> CutSet:
# logging.info("About to get test-other cuts")
# return load_manifest_lazy(
# self.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
# )

View File

@ -63,7 +63,7 @@ import k2
import numpy as np import numpy as np
import sentencepiece as spm import sentencepiece as spm
import torch import torch
from asr_datamodule import ReazonSpeechAsrDataModule from asr_datamodule import MultiDatasetAsrDataModule
from decode_stream import DecodeStream from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet from lhotse import CutSet
@ -740,7 +740,7 @@ def save_results(
@torch.no_grad() @torch.no_grad()
def main(): def main():
parser = get_parser() parser = get_parser()
ReazonSpeechAsrDataModule.add_arguments(parser) MultiDatasetAsrDataModule.add_arguments(parser)
Tokenizer.add_arguments(parser) Tokenizer.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
@ -887,7 +887,7 @@ def main():
# we need cut ids to display recognition results. # we need cut ids to display recognition results.
args.return_cuts = True args.return_cuts = True
reazonspeech_corpus = ReazonSpeechAsrDataModule(args) multidataset_datamodule = MultiDatasetAsrDataModule(args)
if params.bilingual: if params.bilingual:
multi_dataset = MultiDataset(args) multi_dataset = MultiDataset(args)
@ -904,8 +904,8 @@ def main():
test_sets = test_sets_cuts.keys() test_sets = test_sets_cuts.keys()
test_cuts = [test_sets_cuts[k] for k in test_sets] test_cuts = [test_sets_cuts[k] for k in test_sets]
valid_cuts = reazonspeech_corpus.valid_cuts() valid_cuts = multidataset_datamodule.valid_cuts()
test_cuts = reazonspeech_corpus.test_cuts() test_cuts = multidataset_datamodule.test_cuts()
test_sets = ["valid", "test"] test_sets = ["valid", "test"]
test_cuts = [valid_cuts, test_cuts] test_cuts = [valid_cuts, test_cuts]

View File

@ -66,7 +66,7 @@ import sentencepiece as spm
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import ReazonSpeechAsrDataModule from asr_datamodule import MultiDatasetAsrDataModule
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
from lhotse.cut import Cut from lhotse.cut import Cut
@ -272,7 +272,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--bilingual", "--bilingual",
type=str2bool, type=str2bool,
default=False, default=True,
help="Whether the model is bilingual or not. 1 = bilingual.", help="Whether the model is bilingual or not. 1 = bilingual.",
) )
@ -804,7 +804,8 @@ def compute_loss(
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
if not params.bilingual: if not params.bilingual:
y = tokenizer.encode(texts, out_type=int) assert NotImplementedError("only bilingual training has been implemented")
# y = tokenizer.encode(texts, out_type=int)
else: else:
y = sentencepiece_processor.encode(texts, out_type=int) y = sentencepiece_processor.encode(texts, out_type=int)
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y)
@ -1147,9 +1148,10 @@ def run(rank, world_size, args):
# <blk> is defined in local/prepare_lang_char.py # <blk> is defined in local/prepare_lang_char.py
if not params.bilingual: if not params.bilingual:
tokenizer = Tokenizer.load(args.lang, args.lang_type) assert NotImplementedError("only bilingual training has been implemented")
params.blank_id = tokenizer.piece_to_id("<blk>") # tokenizer = Tokenizer.load(args.lang, args.lang_type)
params.vocab_size = tokenizer.get_piece_size() # params.blank_id = tokenizer.piece_to_id("<blk>")
# params.vocab_size = tokenizer.get_piece_size()
else: else:
sentencepiece_processor = spm.SentencePieceProcessor() sentencepiece_processor = spm.SentencePieceProcessor()
sentencepiece_processor.load(params.bpe_model) sentencepiece_processor.load(params.bpe_model)
@ -1212,12 +1214,13 @@ def run(rank, world_size, args):
if params.inf_check: if params.inf_check:
register_inf_check_hooks(model) register_inf_check_hooks(model)
reazonspeech_corpus = ReazonSpeechAsrDataModule(args) multidataset_datamodule = MultiDatasetAsrDataModule(args)
if params.bilingual: if params.bilingual:
multi_dataset = MultiDataset(args) multi_dataset = MultiDataset(args)
train_cuts = multi_dataset.train_cuts() train_cuts = multi_dataset.train_cuts()
else: else:
train_cuts = reazonspeech_corpus.train_cuts() assert NotImplementedError("only bilingual training has been implemented")
# train_cuts = reazonspeech_corpus.train_cuts()
def remove_short_and_long_utt(c: Cut): def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds # Keep only utterances with duration between 1 second and 20 seconds
@ -1242,6 +1245,7 @@ def run(rank, world_size, args):
# for subsampling # for subsampling
T = ((c.num_samples - 7) // 2 + 1) // 2 T = ((c.num_samples - 7) // 2 + 1) // 2
if not params.bilingual: if not params.bilingual:
assert NotImplementedError("only bilingual training has been implemented")
tokens = tokenizer.encode(c.supervisions[0].text, out_type=str) tokens = tokenizer.encode(c.supervisions[0].text, out_type=str)
else: else:
tokens = sentencepiece_processor.encode( tokens = sentencepiece_processor.encode(
@ -1272,6 +1276,8 @@ def run(rank, world_size, args):
if params.bilingual: if params.bilingual:
train_cuts = train_cuts.map(tokenize_and_encode_text) train_cuts = train_cuts.map(tokenize_and_encode_text)
else:
assert NotImplementedError("only bilingual training has been implemented")
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint # We only load the sampler's state dict when it loads a checkpoint
@ -1280,15 +1286,20 @@ def run(rank, world_size, args):
else: else:
sampler_state_dict = None sampler_state_dict = None
train_dl = reazonspeech_corpus.train_dataloaders( # train_dl = reazonspeech_corpus.train_dataloaders(
# train_cuts, sampler_state_dict=sampler_state_dict
# )
train_dl = multidataset_datamodule.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict train_cuts, sampler_state_dict=sampler_state_dict
) )
if params.bilingual: if params.bilingual:
valid_cuts = reazonspeech_corpus.valid_cuts()
else:
valid_cuts = multi_dataset.dev_cuts() valid_cuts = multi_dataset.dev_cuts()
valid_dl = reazonspeech_corpus.valid_dataloaders(valid_cuts) else:
assert NotImplementedError("only bilingual training has been implemented")
# valid_cuts = multi_dataset.dev_cuts()
valid_dl = multidataset_datamodule.valid_dataloaders(valid_cuts)
if not params.print_diagnostics: if not params.print_diagnostics:
scan_pessimistic_batches_for_oom( scan_pessimistic_batches_for_oom(
@ -1386,7 +1397,8 @@ def display_and_save_batch(
if params.bilingual: if params.bilingual:
y = sentencepiece_processor.encode(supervisions["text"], out_type=int) y = sentencepiece_processor.encode(supervisions["text"], out_type=int)
else: else:
y = tokenizer.encode(supervisions["text"], out_type=int) assert NotImplementedError("only bilingual training has been implemented")
# y = tokenizer.encode(supervisions["text"], out_type=int)
num_tokens = sum(len(i) for i in y) num_tokens = sum(len(i) for i in y)
logging.info(f"num tokens: {num_tokens}") logging.info(f"num tokens: {num_tokens}")
@ -1442,7 +1454,7 @@ def scan_pessimistic_batches_for_oom(
def main(): def main():
parser = get_parser() parser = get_parser()
ReazonSpeechAsrDataModule.add_arguments(parser) MultiDatasetAsrDataModule.add_arguments(parser)
Tokenizer.add_arguments(parser) Tokenizer.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)