deprecate params.bilingual=0, replace ReazonSpeechAsrDataModule for MultiDatasetAsrDataModule, not tested yet

This commit is contained in:
Kinan Martin 2025-05-14 08:40:15 +09:00
parent b2df5bbb83
commit eb5004880f
6 changed files with 144 additions and 135 deletions

View File

@ -39,7 +39,7 @@ from torch.utils.data import DataLoader
from icefall.utils import str2bool
class ReazonSpeechAsrDataModule:
class MultiDatasetAsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
@ -333,23 +333,23 @@ class ReazonSpeechAsrDataModule:
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
return load_manifest_lazy(
self.args.manifest_dir / "reazonspeech_cuts_train.jsonl.gz"
)
# @lru_cache()
# def train_cuts(self) -> CutSet:
# logging.info("About to get train cuts")
# return load_manifest_lazy(
# self.args.manifest_dir / "reazonspeech_cuts_train.jsonl.gz"
# )
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
return load_manifest_lazy(
self.args.manifest_dir / "reazonspeech_cuts_dev.jsonl.gz"
)
# @lru_cache()
# def valid_cuts(self) -> CutSet:
# logging.info("About to get dev cuts")
# return load_manifest_lazy(
# self.args.manifest_dir / "reazonspeech_cuts_dev.jsonl.gz"
# )
@lru_cache()
def test_cuts(self) -> List[CutSet]:
logging.info("About to get test cuts")
return load_manifest_lazy(
self.args.manifest_dir / "reazonspeech_cuts_test.jsonl.gz"
)
# @lru_cache()
# def test_cuts(self) -> List[CutSet]:
# logging.info("About to get test cuts")
# return load_manifest_lazy(
# self.args.manifest_dir / "reazonspeech_cuts_test.jsonl.gz"
# )

View File

@ -68,7 +68,7 @@ import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import ReazonSpeechAsrDataModule
from asr_datamodule import MultiDatasetAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
@ -573,7 +573,7 @@ def save_results(
@torch.no_grad()
def main():
parser = get_parser()
ReazonSpeechAsrDataModule.add_arguments(parser)
MultiDatasetAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
@ -748,7 +748,7 @@ def main():
# we need cut ids to display recognition results.
args.return_cuts = True
data_module = ReazonSpeechAsrDataModule(args)
data_module = MultiDatasetAsrDataModule(args)
multi_dataset = MultiDataset(args)
def remove_short_utt(c: Cut):

View File

@ -57,7 +57,7 @@ import optim
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import ReazonSpeechAsrDataModule
from asr_datamodule import MultiDatasetAsrDataModule
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
@ -1085,8 +1085,8 @@ def run(rank, world_size, args):
return True
reazonspeech_corpus = ReazonSpeechAsrDataModule(args)
train_cuts = reazonspeech_corpus.train_cuts()
multidataset_datamodule = MultiDatasetAsrDataModule(args)
train_cuts = multidataset_datamodule.train_cuts()
train_cuts = train_cuts.filter(remove_short_and_long_utt)
@ -1097,12 +1097,12 @@ def run(rank, world_size, args):
else:
sampler_state_dict = None
train_dl = reazonspeech_corpus.train_dataloaders(
train_dl = multidataset_datamodule.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
valid_cuts = reazonspeech_corpus.valid_cuts()
valid_dl = reazonspeech_corpus.valid_dataloaders(valid_cuts)
valid_cuts = multidataset_datamodule.valid_cuts()
valid_dl = multidataset_datamodule.valid_dataloaders(valid_cuts)
if params.start_batch <= 0 and not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
@ -1242,7 +1242,7 @@ def scan_pessimistic_batches_for_oom(
def main():
raise RuntimeError("Please don't use this file directly!")
parser = get_parser()
ReazonSpeechAsrDataModule.add_arguments(parser)
MultiDatasetAsrDataModule.add_arguments(parser)
Tokenizer.add_arguments(parser)
args = parser.parse_args()

View File

@ -13,36 +13,36 @@ class MultiDataset:
Args:
manifest_dir:
It is expected to contain the following files:
- reazonspeech_cuts_train.jsonl.gz
- librispeech_cuts_train-clean-100.jsonl.gz
- librispeech_cuts_train-clean-360.jsonl.gz
- librispeech_cuts_train-other-500.jsonl.gz
- mls_english/
- mls_eng_cuts_train.jsonl.gz
- mls_eng_cuts_dev.jsonl.gz
- mls_eng_cuts_test.jsonl.gz
- reazonspeech/
- reazonspeech_cuts_train.jsonl.gz
- reazonspeech_cuts_dev.jsonl.gz
- reazonspeech_cuts_test.jsonl.gz
"""
self.fbank_dir = Path(args.manifest_dir)
self.manifest_dir = Path(args.manifest_dir)
def train_cuts(self) -> CutSet:
logging.info("About to get multidataset train cuts")
logging.info("Loading Reazonspeech in lazy mode")
reazonspeech_cuts = load_manifest_lazy(
self.fbank_dir / "reazonspeech_cuts_train.jsonl.gz"
logging.info("Loading Reazonspeech TRAIN set in lazy mode")
reazonspeech_train_cuts = load_manifest_lazy(
self.manifest_dir / "reazonspeech_cuts_train.jsonl.gz"
)
logging.info("Loading LibriSpeech in lazy mode")
train_clean_100_cuts = self.train_clean_100_cuts()
train_clean_360_cuts = self.train_clean_360_cuts()
train_other_500_cuts = self.train_other_500_cuts()
logging.info("Loading MLS English TRAIN set in lazy mode")
mls_eng_train_cuts = load_manifest_lazy(
self.manifest_dir / "mls_eng_cuts_train.jsonl.gz"
)
return CutSet.mux(
reazonspeech_cuts,
train_clean_100_cuts,
train_clean_360_cuts,
train_other_500_cuts,
reazonspeech_train_cuts,
mls_eng_train_cuts,
weights=[
len(reazonspeech_cuts),
len(train_clean_100_cuts),
len(train_clean_360_cuts),
len(train_other_500_cuts),
len(reazonspeech_train_cuts),
len(mls_eng_train_cuts),
],
)
@ -51,93 +51,90 @@ class MultiDataset:
logging.info("Loading Reazonspeech DEV set in lazy mode")
reazonspeech_dev_cuts = load_manifest_lazy(
self.fbank_dir / "reazonspeech_cuts_dev.jsonl.gz"
self.manifest_dir / "reazonspeech_cuts_dev.jsonl.gz"
)
logging.info("Loading LibriSpeech DEV set in lazy mode")
dev_clean_cuts = self.dev_clean_cuts()
dev_other_cuts = self.dev_other_cuts()
logging.info("Loading MLS English DEV set in lazy mode")
mls_eng_dev_cuts = load_manifest_lazy(
self.manifest_dir / "mls_eng_cuts_dev.jsonl.gz"
)
return CutSet.mux(
reazonspeech_dev_cuts,
dev_clean_cuts,
dev_other_cuts,
mls_eng_dev_cuts,
weights=[
len(reazonspeech_dev_cuts),
len(dev_clean_cuts),
len(dev_other_cuts),
len(mls_eng_dev_cuts),
],
)
def test_cuts(self) -> Dict[str, CutSet]:
def test_cuts(self) -> CutSet:
logging.info("About to get multidataset test cuts")
logging.info("Loading Reazonspeech set in lazy mode")
logging.info("Loading Reazonspeech TEST set in lazy mode")
reazonspeech_test_cuts = load_manifest_lazy(
self.fbank_dir / "reazonspeech_cuts_test.jsonl.gz"
)
reazonspeech_dev_cuts = load_manifest_lazy(
self.fbank_dir / "reazonspeech_cuts_dev.jsonl.gz"
self.manifest_dir / "reazonspeech_cuts_test.jsonl.gz"
)
logging.info("Loading LibriSpeech set in lazy mode")
test_clean_cuts = self.test_clean_cuts()
test_other_cuts = self.test_other_cuts()
test_cuts = {
"reazonspeech_test": reazonspeech_test_cuts,
"reazonspeech_dev": reazonspeech_dev_cuts,
"librispeech_test_clean": test_clean_cuts,
"librispeech_test_other": test_other_cuts,
}
return test_cuts
@lru_cache()
def train_clean_100_cuts(self) -> CutSet:
logging.info("About to get train-clean-100 cuts")
return load_manifest_lazy(
self.fbank_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
logging.info("Loading MLS English TEST set in lazy mode")
mls_eng_test_cuts = load_manifest_lazy(
self.manifest_dir / "mls_eng_cuts_test.jsonl.gz"
)
@lru_cache()
def train_clean_360_cuts(self) -> CutSet:
logging.info("About to get train-clean-360 cuts")
return load_manifest_lazy(
self.fbank_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
return CutSet.mux(
reazonspeech_test_cuts,
mls_eng_test_cuts,
weights=[
len(reazonspeech_test_cuts),
len(mls_eng_test_cuts),
],
)
@lru_cache()
def train_other_500_cuts(self) -> CutSet:
logging.info("About to get train-other-500 cuts")
return load_manifest_lazy(
self.fbank_dir / "librispeech_cuts_train-other-500.jsonl.gz"
)
# @lru_cache()
# def train_clean_100_cuts(self) -> CutSet:
# logging.info("About to get train-clean-100 cuts")
# return load_manifest_lazy(
# self.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
# )
@lru_cache()
def dev_clean_cuts(self) -> CutSet:
logging.info("About to get dev-clean cuts")
return load_manifest_lazy(
self.fbank_dir / "librispeech_cuts_dev-clean.jsonl.gz"
)
# @lru_cache()
# def train_clean_360_cuts(self) -> CutSet:
# logging.info("About to get train-clean-360 cuts")
# return load_manifest_lazy(
# self.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
# )
@lru_cache()
def dev_other_cuts(self) -> CutSet:
logging.info("About to get dev-other cuts")
return load_manifest_lazy(
self.fbank_dir / "librispeech_cuts_dev-other.jsonl.gz"
)
# @lru_cache()
# def train_other_500_cuts(self) -> CutSet:
# logging.info("About to get train-other-500 cuts")
# return load_manifest_lazy(
# self.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
# )
@lru_cache()
def test_clean_cuts(self) -> CutSet:
logging.info("About to get test-clean cuts")
return load_manifest_lazy(
self.fbank_dir / "librispeech_cuts_test-clean.jsonl.gz"
)
# @lru_cache()
# def dev_clean_cuts(self) -> CutSet:
# logging.info("About to get dev-clean cuts")
# return load_manifest_lazy(
# self.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
# )
@lru_cache()
def test_other_cuts(self) -> CutSet:
logging.info("About to get test-other cuts")
return load_manifest_lazy(
self.fbank_dir / "librispeech_cuts_test-other.jsonl.gz"
)
# @lru_cache()
# def dev_other_cuts(self) -> CutSet:
# logging.info("About to get dev-other cuts")
# return load_manifest_lazy(
# self.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
# )
# @lru_cache()
# def test_clean_cuts(self) -> CutSet:
# logging.info("About to get test-clean cuts")
# return load_manifest_lazy(
# self.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
# )
# @lru_cache()
# def test_other_cuts(self) -> CutSet:
# logging.info("About to get test-other cuts")
# return load_manifest_lazy(
# self.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
# )

View File

@ -63,7 +63,7 @@ import k2
import numpy as np
import sentencepiece as spm
import torch
from asr_datamodule import ReazonSpeechAsrDataModule
from asr_datamodule import MultiDatasetAsrDataModule
from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet
@ -740,7 +740,7 @@ def save_results(
@torch.no_grad()
def main():
parser = get_parser()
ReazonSpeechAsrDataModule.add_arguments(parser)
MultiDatasetAsrDataModule.add_arguments(parser)
Tokenizer.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
@ -887,7 +887,7 @@ def main():
# we need cut ids to display recognition results.
args.return_cuts = True
reazonspeech_corpus = ReazonSpeechAsrDataModule(args)
multidataset_datamodule = MultiDatasetAsrDataModule(args)
if params.bilingual:
multi_dataset = MultiDataset(args)
@ -904,8 +904,8 @@ def main():
test_sets = test_sets_cuts.keys()
test_cuts = [test_sets_cuts[k] for k in test_sets]
valid_cuts = reazonspeech_corpus.valid_cuts()
test_cuts = reazonspeech_corpus.test_cuts()
valid_cuts = multidataset_datamodule.valid_cuts()
test_cuts = multidataset_datamodule.test_cuts()
test_sets = ["valid", "test"]
test_cuts = [valid_cuts, test_cuts]

View File

@ -66,7 +66,7 @@ import sentencepiece as spm
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import ReazonSpeechAsrDataModule
from asr_datamodule import MultiDatasetAsrDataModule
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
@ -272,7 +272,7 @@ def get_parser():
parser.add_argument(
"--bilingual",
type=str2bool,
default=False,
default=True,
help="Whether the model is bilingual or not. 1 = bilingual.",
)
@ -804,7 +804,8 @@ def compute_loss(
texts = batch["supervisions"]["text"]
if not params.bilingual:
y = tokenizer.encode(texts, out_type=int)
assert NotImplementedError("only bilingual training has been implemented")
# y = tokenizer.encode(texts, out_type=int)
else:
y = sentencepiece_processor.encode(texts, out_type=int)
y = k2.RaggedTensor(y)
@ -1147,9 +1148,10 @@ def run(rank, world_size, args):
# <blk> is defined in local/prepare_lang_char.py
if not params.bilingual:
tokenizer = Tokenizer.load(args.lang, args.lang_type)
params.blank_id = tokenizer.piece_to_id("<blk>")
params.vocab_size = tokenizer.get_piece_size()
assert NotImplementedError("only bilingual training has been implemented")
# tokenizer = Tokenizer.load(args.lang, args.lang_type)
# params.blank_id = tokenizer.piece_to_id("<blk>")
# params.vocab_size = tokenizer.get_piece_size()
else:
sentencepiece_processor = spm.SentencePieceProcessor()
sentencepiece_processor.load(params.bpe_model)
@ -1212,12 +1214,13 @@ def run(rank, world_size, args):
if params.inf_check:
register_inf_check_hooks(model)
reazonspeech_corpus = ReazonSpeechAsrDataModule(args)
multidataset_datamodule = MultiDatasetAsrDataModule(args)
if params.bilingual:
multi_dataset = MultiDataset(args)
train_cuts = multi_dataset.train_cuts()
else:
train_cuts = reazonspeech_corpus.train_cuts()
assert NotImplementedError("only bilingual training has been implemented")
# train_cuts = reazonspeech_corpus.train_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
@ -1242,6 +1245,7 @@ def run(rank, world_size, args):
# for subsampling
T = ((c.num_samples - 7) // 2 + 1) // 2
if not params.bilingual:
assert NotImplementedError("only bilingual training has been implemented")
tokens = tokenizer.encode(c.supervisions[0].text, out_type=str)
else:
tokens = sentencepiece_processor.encode(
@ -1272,6 +1276,8 @@ def run(rank, world_size, args):
if params.bilingual:
train_cuts = train_cuts.map(tokenize_and_encode_text)
else:
assert NotImplementedError("only bilingual training has been implemented")
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint
@ -1280,15 +1286,20 @@ def run(rank, world_size, args):
else:
sampler_state_dict = None
train_dl = reazonspeech_corpus.train_dataloaders(
# train_dl = reazonspeech_corpus.train_dataloaders(
# train_cuts, sampler_state_dict=sampler_state_dict
# )
train_dl = multidataset_datamodule.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
if params.bilingual:
valid_cuts = reazonspeech_corpus.valid_cuts()
else:
valid_cuts = multi_dataset.dev_cuts()
valid_dl = reazonspeech_corpus.valid_dataloaders(valid_cuts)
else:
assert NotImplementedError("only bilingual training has been implemented")
# valid_cuts = multi_dataset.dev_cuts()
valid_dl = multidataset_datamodule.valid_dataloaders(valid_cuts)
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
@ -1386,7 +1397,8 @@ def display_and_save_batch(
if params.bilingual:
y = sentencepiece_processor.encode(supervisions["text"], out_type=int)
else:
y = tokenizer.encode(supervisions["text"], out_type=int)
assert NotImplementedError("only bilingual training has been implemented")
# y = tokenizer.encode(supervisions["text"], out_type=int)
num_tokens = sum(len(i) for i in y)
logging.info(f"num tokens: {num_tokens}")
@ -1442,7 +1454,7 @@ def scan_pessimistic_batches_for_oom(
def main():
parser = get_parser()
ReazonSpeechAsrDataModule.add_arguments(parser)
MultiDatasetAsrDataModule.add_arguments(parser)
Tokenizer.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)